This commit is contained in:
lz_db
2025-12-04 22:03:02 +08:00
parent a7152f58ea
commit 878901826c
8 changed files with 604 additions and 191 deletions

View File

@@ -1,10 +1,7 @@
from .base_sync import BaseSync
from loguru import logger
from typing import List, Dict, Any, Set
import json
import time
from datetime import datetime, timedelta
from sqlalchemy import text, and_
from typing import List, Dict
from sqlalchemy import text
from models.orm_models import StrategyKX
class AccountSyncBatch(BaseSync):
@@ -185,8 +182,3 @@ class AccountSyncBatch(BaseSync):
logger.error(f"批量查询现有记录失败: {e}")
return existing_records
async def sync(self):
"""兼容旧接口"""
accounts = self.get_accounts_from_redis()
await self.sync_batch(accounts)

View File

@@ -1,12 +1,9 @@
from .base_sync import BaseSync
from loguru import logger
from typing import List, Dict, Any, Tuple
import json
import asyncio
import time
from datetime import datetime, timedelta
from sqlalchemy import text
import redis
class OrderSyncBatch(BaseSync):
"""订单数据批量同步器"""
@@ -20,11 +17,11 @@ class OrderSyncBatch(BaseSync):
"""批量同步所有账号的订单数据"""
try:
logger.info(f"开始批量同步订单数据,共 {len(accounts)} 个账号")
return
start_time = time.time()
# 1. 收集所有账号的订单数据
all_orders = await self._collect_all_orders(accounts)
all_orders = await self.redis_client._collect_all_orders(accounts)
if not all_orders:
logger.info("无订单数据需要同步")
@@ -33,8 +30,12 @@ class OrderSyncBatch(BaseSync):
logger.info(f"收集到 {len(all_orders)} 条订单数据")
# 2. 批量同步到数据库
# 使用基本版本
success, processed_count = await self._sync_orders_batch_to_db(all_orders)
# 或者使用增强版本
# success, results = await self._sync_orders_batch_to_db_enhanced(all_orders)
elapsed = time.time() - start_time
if success:
logger.info(f"订单批量同步完成: 处理 {processed_count} 条订单,耗时 {elapsed:.2f}")
@@ -43,189 +44,273 @@ class OrderSyncBatch(BaseSync):
except Exception as e:
logger.error(f"订单批量同步失败: {e}")
async def _collect_all_orders(self, accounts: Dict[str, Dict]) -> List[Dict]:
"""收集所有账号的订单数据"""
all_orders = []
try:
# 按交易所分组账号
account_groups = self._group_accounts_by_exchange(accounts)
# 并发收集每个交易所的数据
tasks = []
for exchange_id, account_list in account_groups.items():
task = self._collect_exchange_orders(exchange_id, account_list)
tasks.append(task)
# 等待所有任务完成并合并结果
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, list):
all_orders.extend(result)
except Exception as e:
logger.error(f"收集订单数据失败: {e}")
return all_orders
def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]:
"""按交易所分组账号"""
groups = {}
for account_id, account_info in accounts.items():
exchange_id = account_info.get('exchange_id')
if exchange_id:
if exchange_id not in groups:
groups[exchange_id] = []
groups[exchange_id].append(account_info)
return groups
async def _collect_exchange_orders(self, exchange_id: str, account_list: List[Dict]) -> List[Dict]:
"""收集某个交易所的订单数据"""
orders_list = []
try:
# 并发获取每个账号的数据
tasks = []
for account_info in account_list:
k_id = int(account_info['k_id'])
st_id = account_info.get('st_id', 0)
task = self._get_recent_orders_from_redis(k_id, st_id, exchange_id)
tasks.append(task)
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, list):
orders_list.extend(result)
logger.debug(f"交易所 {exchange_id}: 收集到 {len(orders_list)} 条订单")
except Exception as e:
logger.error(f"收集交易所 {exchange_id} 订单数据失败: {e}")
return orders_list
async def _get_recent_orders_from_redis(self, k_id: int, st_id: int, exchange_id: str) -> List[Dict]:
"""从Redis获取最近N天的订单数据"""
try:
redis_key = f"{exchange_id}:orders:{k_id}"
# 计算最近N天的日期
today = datetime.now()
recent_dates = []
for i in range(self.recent_days):
date = today - timedelta(days=i)
date_format = date.strftime('%Y-%m-%d')
recent_dates.append(date_format)
# 使用scan获取所有符合条件的key
cursor = 0
recent_keys = []
while True:
cursor, keys = self.redis_client.client.hscan(redis_key, cursor, count=1000)
for key, _ in keys.items():
key_str = key.decode('utf-8') if isinstance(key, bytes) else key
if key_str == 'positions':
continue
# 检查是否以最近N天的日期开头
for date_format in recent_dates:
if key_str.startswith(date_format + '_'):
recent_keys.append(key_str)
break
if cursor == 0:
break
if not recent_keys:
return []
# 批量获取订单数据
orders_list = []
# 分批获取避免单次hgetall数据量太大
chunk_size = 500
for i in range(0, len(recent_keys), chunk_size):
chunk_keys = recent_keys[i:i + chunk_size]
# 使用hmget批量获取
chunk_values = self.redis_client.client.hmget(redis_key, chunk_keys)
for key, order_json in zip(chunk_keys, chunk_values):
if not order_json:
continue
try:
order = json.loads(order_json)
# 验证时间
order_time = order.get('time', 0)
if order_time >= int(time.time()) - self.recent_days * 24 * 3600:
# 添加账号信息
order['k_id'] = k_id
order['st_id'] = st_id
order['exchange_id'] = exchange_id
orders_list.append(order)
except json.JSONDecodeError as e:
logger.debug(f"解析订单JSON失败: key={key}, error={e}")
continue
return orders_list
except Exception as e:
logger.error(f"获取Redis订单数据失败: k_id={k_id}, error={e}")
return []
async def _sync_orders_batch_to_db(self, all_orders: List[Dict]) -> Tuple[bool, int]:
"""批量同步订单数据到数据库"""
"""批量同步订单数据到数据库
Args:
all_orders: 订单数据列表
Returns:
Tuple[bool, int]: (是否成功, 处理的订单数量)
"""
if not all_orders:
return True, 0
session = self.db_manager.get_session()
processed_count = 0
errors = []
try:
if not all_orders:
return True, 0
# 转换数据
converted_orders = []
for order in all_orders:
# 按批次处理
for i in range(0, len(all_orders), self.batch_size):
batch_orders = all_orders[i:i + self.batch_size]
try:
order_dict = self._convert_order_data(order)
session.begin()
# 检查完整性
required_fields = ['order_id', 'symbol', 'side', 'time']
if not all(order_dict.get(field) for field in required_fields):
# 转换数据并准备批量插入
converted_orders = []
for raw_order in batch_orders:
try:
converted = self._convert_order_data(raw_order)
# 检查必要字段
if not all([
converted.get('order_id'),
converted.get('symbol'),
converted.get('k_id'),
converted.get('side')
]):
logger.warning(f"订单缺少必要字段: {raw_order}")
continue
converted_orders.append(converted)
except Exception as e:
logger.error(f"转换订单数据失败: {raw_order}, error={e}")
continue
if not converted_orders:
session.commit()
continue
converted_orders.append(order_dict)
# 批量插入或更新
upsert_sql = text("""
INSERT INTO deh_strategy_order_new
(st_id, k_id, asset, order_id, symbol, side, price, time,
order_qty, last_qty, avg_price, exchange_id)
VALUES
(:st_id, :k_id, :asset, :order_id, :symbol, :side, :price, :time,
:order_qty, :last_qty, :avg_price, :exchange_id)
ON DUPLICATE KEY UPDATE
price = VALUES(price),
time = VALUES(time),
order_qty = VALUES(order_qty),
last_qty = VALUES(last_qty),
avg_price = VALUES(avg_price)
""")
result = session.execute(upsert_sql, converted_orders)
processed_count += len(converted_orders)
# 计算统计信息
batch_size = len(converted_orders)
total_affected = result.rowcount
updated_count = max(0, total_affected - batch_size)
inserted_count = batch_size - updated_count
logger.debug(f"订单批次 {i//self.batch_size + 1}: "
f"处理 {batch_size} 条, "
f"插入 {inserted_count} 条, "
f"更新 {updated_count}")
session.commit()
except Exception as e:
logger.error(f"转换订单数据失败: {order}, error={e}")
continue
session.rollback()
error_msg = f"订单批次 {i//self.batch_size + 1} 处理失败: {str(e)}"
logger.error(error_msg, exc_info=True)
errors.append(error_msg)
# 继续处理下一个批次
if not converted_orders:
return True, 0
# 使用批量工具同步
from utils.batch_order_sync import BatchOrderSync
batch_tool = BatchOrderSync(self.db_manager, self.batch_size)
success, processed_count = batch_tool.sync_orders_batch(converted_orders)
if errors:
logger.error(f"订单同步完成但有错误: {len(errors)} 个错误")
for error in errors[:5]: # 只打印前5个错误
logger.error(f"错误详情: {error}")
if len(errors) > 5:
logger.error(f"...还有 {len(errors) - 5} 个错误")
success = len(errors) == 0
return success, processed_count
except Exception as e:
logger.error(f"批量同步订单到数据库失败: {e}")
return False, 0
logger.error(f"订单批量同步失败: {e}", exc_info=True)
return False, processed_count
finally:
session.close()
async def _sync_orders_batch_to_db_enhanced(self, all_orders: List[Dict]) -> Tuple[bool, Dict]:
"""增强版:批量同步订单数据到数据库(带详细统计)
Args:
all_orders: 订单数据列表
Returns:
Tuple[bool, Dict]: (是否成功, 统计结果)
"""
if not all_orders:
return True, {'total': 0, 'processed': 0, 'inserted': 0, 'updated': 0, 'errors': []}
session = self.db_manager.get_session()
results = {
'total': len(all_orders),
'processed': 0,
'inserted': 0,
'updated': 0,
'errors': [],
'invalid_orders': 0
}
try:
logger.info(f"开始同步 {results['total']} 条订单数据,批次大小: {self.batch_size}")
# 按批次处理
total_batches = (len(all_orders) + self.batch_size - 1) // self.batch_size
for batch_idx in range(total_batches):
start_idx = batch_idx * self.batch_size
end_idx = start_idx + self.batch_size
batch_orders = all_orders[start_idx:end_idx]
logger.debug(f"处理批次 {batch_idx + 1}/{total_batches}: "
f"订单 {start_idx + 1}-{min(end_idx, len(all_orders))}")
try:
session.begin()
# 转换数据
converted_orders = []
batch_invalid = 0
for raw_order in batch_orders:
try:
converted = self._convert_order_data(raw_order)
# 验证必要字段
required_fields = ['order_id', 'symbol', 'k_id', 'side']
missing_fields = [field for field in required_fields if not converted.get(field)]
if missing_fields:
logger.warning(f"订单缺少必要字段 {missing_fields}: {raw_order}")
batch_invalid += 1
continue
# 验证字段长度(防止数据库错误)
order_id = converted.get('order_id', '')
if len(order_id) > 765: # 根据表结构限制
converted['order_id'] = order_id[:765]
logger.warning(f"order_id过长已截断: {order_id}")
symbol = converted.get('symbol', '')
if len(symbol) > 120:
converted['symbol'] = symbol[:120]
side = converted.get('side', '')
if len(side) > 120:
converted['side'] = side[:120]
converted_orders.append(converted)
except Exception as e:
logger.error(f"处理订单失败: {raw_order}, error={e}")
batch_invalid += 1
continue
results['invalid_orders'] += batch_invalid
if not converted_orders:
session.commit()
continue
# 批量插入或更新
upsert_sql = text("""
INSERT INTO deh_strategy_order_new
(st_id, k_id, asset, order_id, symbol, side, price, time,
order_qty, last_qty, avg_price, exchange_id)
VALUES
(:st_id, :k_id, :asset, :order_id, :symbol, :side, :price, :time,
:order_qty, :last_qty, :avg_price, :exchange_id)
ON DUPLICATE KEY UPDATE
price = VALUES(price),
time = VALUES(time),
order_qty = VALUES(order_qty),
last_qty = VALUES(last_qty),
avg_price = VALUES(avg_price),
updated_at = CURRENT_TIMESTAMP
""")
result = session.execute(upsert_sql, converted_orders)
# 统计本批次结果
batch_size = len(converted_orders)
total_affected = result.rowcount
batch_updated = max(0, total_affected - batch_size)
batch_inserted = batch_size - batch_updated
# 累加到总结果
results['processed'] += batch_size
results['inserted'] += batch_inserted
results['updated'] += batch_updated
logger.info(f"批次 {batch_idx + 1} 完成: "
f"有效 {batch_size} 条, "
f"无效 {batch_invalid} 条, "
f"插入 {batch_inserted} 条, "
f"更新 {batch_updated}")
session.commit()
except Exception as e:
session.rollback()
error_msg = f"批次 {batch_idx + 1} 处理失败: {str(e)}"
logger.error(error_msg, exc_info=True)
results['errors'].append(error_msg)
# 继续处理下一个批次
# 最终统计
success_rate = results['processed'] / results['total'] * 100 if results['total'] > 0 else 0
logger.info(f"订单同步完成: "
f"总数={results['total']}, "
f"处理={results['processed']}({success_rate:.1f}%), "
f"插入={results['inserted']}, "
f"更新={results['updated']}, "
f"无效={results['invalid_orders']}, "
f"错误={len(results['errors'])}")
success = len(results['errors']) == 0
return success, results
except Exception as e:
logger.error(f"订单批量同步失败: {e}", exc_info=True)
results['errors'].append(f"同步过程失败: {str(e)}")
return False, results
finally:
session.close()
def _convert_order_data(self, data: Dict) -> Dict:
"""转换订单数据格式"""
try:
# 安全转换函数
def safe_float(value):
if value is None:
if value is None or value == '':
return None
try:
return float(value)
@@ -233,7 +318,7 @@ class OrderSyncBatch(BaseSync):
return None
def safe_int(value):
if value is None:
if value is None or value == '':
return None
try:
return int(float(value))
@@ -246,9 +331,9 @@ class OrderSyncBatch(BaseSync):
return str(value)
return {
'st_id': safe_int(data.get('st_id'), 0),
'k_id': safe_int(data.get('k_id'), 0),
'asset': 'USDT',
'st_id': safe_int(data.get('st_id')) or 0,
'k_id': safe_int(data.get('k_id')) or 0,
'asset': safe_str(data.get('asset')) or 'USDT',
'order_id': safe_str(data.get('order_id')),
'symbol': safe_str(data.get('symbol')),
'side': safe_str(data.get('side')),
@@ -263,8 +348,4 @@ class OrderSyncBatch(BaseSync):
except Exception as e:
logger.error(f"转换订单数据异常: {data}, error={e}")
return {}
async def sync(self):
"""兼容旧接口"""
accounts = self.get_accounts_from_redis()
await self.sync_batch(accounts)