from .base_sync import BaseSync from loguru import logger from typing import List, Dict, Any, Set, Tuple import json import asyncio import utils.helpers as helpers from datetime import datetime from sqlalchemy import text, and_, select, delete from models.orm_models import StrategyPosition import time class PositionSyncBatch(BaseSync): """持仓数据批量同步器""" def __init__(self): super().__init__() self.batch_size = 500 # 每批处理数量 async def sync_batch(self, accounts: Dict[str, Dict]): """批量同步所有账号的持仓数据""" try: logger.info(f"开始批量同步持仓数据,共 {len(accounts)} 个账号") start_time = time.time() # 1. 收集所有账号的持仓数据 all_positions = await self._collect_all_positions(accounts) if not all_positions: logger.info("无持仓数据需要同步") return logger.info(f"收集到 {len(all_positions)} 条持仓数据") # 2. 批量同步到数据库 success, stats = await self._sync_positions_batch_to_db_optimized_v3(all_positions) elapsed = time.time() - start_time if success: logger.info(f"持仓批量同步完成: 处理 {stats['total']} 条,受影响 {stats['affected']} 条," f"删除 {stats['deleted']} 条,耗时 {elapsed:.2f}秒") else: logger.error("持仓批量同步失败") except Exception as e: logger.error(f"持仓批量同步失败: {e}")# 获取完整的错误信息 import traceback error_details = { 'error_type': type(e).__name__, 'error_message': str(e), 'traceback': traceback.format_exc() } logger.error("完整堆栈跟踪:\n{traceback}", traceback=error_details['traceback']) async def _sync_positions_batch_to_db_optimized(self, all_positions: List[Dict]) -> Tuple[bool, Dict]: """ 批量同步持仓数据(不使用临时表) Args: all_positions: 所有持仓数据列表,每个持仓包含k_id(账号ID)等字段 Returns: Tuple[bool, Dict]: (是否成功, 结果统计) """ if not all_positions: return True, {'total': 0, 'affected': 0, 'deleted': 0, 'errors': []} session = self.db_manager.get_session() results = { 'total': 0, 'affected': 0, 'deleted': 0, 'errors': [] } # 按账号分组 positions_by_account = {} for position in all_positions: # print(position['symbol']) k_id = position['k_id'] if k_id not in positions_by_account: positions_by_account[k_id] = [] positions_by_account[k_id].append(position) logger.info(f"开始批量处理 {len(positions_by_account)} 个账号的持仓数据") try: # 按分组处理(10个账号一组) account_ids = list(positions_by_account.keys()) for group_idx in range(0, len(account_ids), 10): group_account_ids = account_ids[group_idx:group_idx + 10] logger.info(f"处理第 {group_idx//10 + 1} 组账号: {group_account_ids}") # 收集本组所有持仓数据 group_positions = [] for k_id in group_account_ids: group_positions.extend(positions_by_account[k_id]) if not group_positions: continue # 处理持仓数据 processed_positions = [] account_position_keys = {} # 记录每个账号的持仓标识 for raw_position in group_positions: try: k_id = raw_position['k_id'] processed = self._convert_position_data(raw_position) # 检查必要字段 if not all([processed.get('symbol'), processed.get('side')]): continue # 确保st_id存在 if 'st_id' not in processed: processed['st_id'] = raw_position.get('st_id', 0) # 确保k_id存在 if 'k_id' not in processed: processed['k_id'] = k_id # 重命名qty为sum(如果存在) if 'qty' in processed: processed['sum'] = processed.pop('qty') processed_positions.append(processed) # 记录持仓唯一标识 if k_id not in account_position_keys: account_position_keys[k_id] = set() position_key = f"{processed['st_id']}&{processed['symbol']}&{processed['side']}" # print(position_key) account_position_keys[k_id].add(position_key) except Exception as e: logger.error(f"处理持仓数据失败: {raw_position}, error={e}") continue # 批量插入或更新 if processed_positions: try: # 使用ON DUPLICATE KEY UPDATE批量处理 upsert_sql = text(""" INSERT INTO deh_strategy_position_new (st_id, k_id, asset, symbol, side, price, `sum`, asset_num, asset_profit, leverage, uptime, profit_price, stop_price, liquidation_price) VALUES (:st_id, :k_id, :asset, :symbol, :side, :price, :sum, :asset_num, :asset_profit, :leverage, :uptime, :profit_price, :stop_price, :liquidation_price) ON DUPLICATE KEY UPDATE price = VALUES(price), `sum` = VALUES(`sum`), asset_num = VALUES(asset_num), asset_profit = VALUES(asset_profit), leverage = VALUES(leverage), uptime = VALUES(uptime), profit_price = VALUES(profit_price), stop_price = VALUES(stop_price), liquidation_price = VALUES(liquidation_price) """) result = session.execute(upsert_sql, processed_positions) # 正确计算插入和更新的数量 total_affected = result.rowcount # 受影响的总行数 batch_size = len(processed_positions) # 本次尝试插入的数量 # 累加到总结果 results['total'] += batch_size results['affected'] += total_affected logger.debug(f"第 {group_idx//10 + 1} 组: " f"处理 {batch_size} 条, " f"受影响 {total_affected} 条") except Exception as e: logger.error(f"批量插入/更新失败: {e}", exc_info=True) session.rollback() results['errors'].append(f"批量插入/更新失败: {str(e)}") # 继续处理下一组 continue # 删除本组每个账号中已不存在的持仓 for k_id in group_account_ids: try: if k_id not in account_position_keys or not account_position_keys[k_id]: # 如果该账号没有任何持仓,删除所有 delete_sql = text(""" DELETE FROM deh_strategy_position_new WHERE k_id = :k_id """) result = session.execute(delete_sql, {'k_id': k_id}) deleted_count = result.rowcount results['deleted'] += deleted_count if deleted_count > 0: logger.debug(f"账号 {k_id}: 删除所有旧持仓,共 {deleted_count} 条") else: # 构建当前持仓的条件 current_keys = account_position_keys[k_id] # 使用多个OR条件来处理IN子句的限制 conditions = [] params = {'k_id': k_id} for idx, key in enumerate(current_keys): parts = key.split('&') if len(parts) >= 3: # 确保有st_id, symbol, side三部分 st_id_val = parts[0] symbol_val = parts[1] side_val = parts[2] conditions.append(f"(st_id = :st_id_{idx} AND symbol = :symbol_{idx} AND side = :side_{idx})") params[f'st_id_{idx}'] = int(st_id_val) if st_id_val.isdigit() else st_id_val params[f'symbol_{idx}'] = symbol_val params[f'side_{idx}'] = side_val if conditions: conditions_str = " OR ".join(conditions) # 删除不在当前持仓列表中的记录 delete_sql = text(f""" DELETE FROM deh_strategy_position_new WHERE k_id = :k_id AND NOT ({conditions_str}) """) result = session.execute(delete_sql, params) deleted_count = result.rowcount results['deleted'] += deleted_count if deleted_count > 0: logger.debug(f"账号 {k_id}: 删除 {deleted_count} 条过期持仓") except Exception as e: logger.error(f"删除账号 {k_id} 旧持仓失败: {e}") # 记录错误但继续处理其他账号 results['errors'].append(f"删除账号 {k_id} 旧持仓失败: {str(e)}") # 每组结束后提交 try: session.commit() logger.debug(f"第 {group_idx//10 + 1} 组处理完成并提交") except Exception as e: session.rollback() logger.error(f"第 {group_idx//10 + 1} 组提交失败: {e}") results['errors'].append(f"第 {group_idx//10 + 1} 组提交失败: {str(e)}") logger.info(f"批量同步完成: " f"总数={results['total']}, " f"受影响={results['affected']}, " f"删除={results['deleted']}, " f"错误数={len(results['errors'])}") success = len(results['errors']) == 0 return success, results except Exception as e: session.rollback() logger.error(f"批量同步过程中发生错误: {e}", exc_info=True) results['errors'].append(f"同步过程错误: {str(e)}") return False, results finally: session.close() async def _sync_positions_batch_to_db_optimized_v3(self, all_positions: List[Dict]) -> Tuple[bool, Dict]: """ 最优化的批量同步(兼容所有MySQL版本) 使用策略: 1. 一次性UPSERT所有持仓数据 2. 使用UNION ALL构造虚拟表进行JOIN删除 Args: all_positions: 所有持仓数据列表 Returns: Tuple[bool, Dict]: (是否成功, 结果统计) """ if not all_positions: return True, {'total': 0, 'affected': 0, 'deleted': 0, 'errors': []} session = self.db_manager.get_session() results = { 'total': 0, 'affected': 0, 'deleted': 0, 'errors': [] } try: session.begin() # 准备数据 processed_positions = [] current_position_records = set() # 使用set去重,避免重复 for raw_position in all_positions: try: processed = self._convert_position_data(raw_position) if not all([processed.get('symbol'), processed.get('side')]): continue if 'qty' in processed: processed['sum'] = processed.pop('qty') k_id = processed.get('k_id', raw_position['k_id']) st_id = processed.get('st_id', raw_position.get('st_id', 0)) symbol = processed.get('symbol') side = processed.get('side') processed_positions.append(processed) # 去重记录当前持仓 record_key = (k_id, st_id, symbol, side) current_position_records.add(record_key) except Exception as e: logger.error(f"处理持仓数据失败: {raw_position}, error={e}") continue if not processed_positions: session.commit() return True, results results['total'] = len(processed_positions) logger.info(f"准备同步 {results['total']} 条持仓数据,去重后 {len(current_position_records)} 条唯一记录") # 批量UPSERT upsert_sql = text(""" INSERT INTO deh_strategy_position_new (st_id, k_id, asset, symbol, side, price, `sum`, asset_num, asset_profit, leverage, uptime, profit_price, stop_price, liquidation_price) VALUES (:st_id, :k_id, :asset, :symbol, :side, :price, :sum, :asset_num, :asset_profit, :leverage, :uptime, :profit_price, :stop_price, :liquidation_price) ON DUPLICATE KEY UPDATE price = VALUES(price), `sum` = VALUES(`sum`), asset_num = VALUES(asset_num), asset_profit = VALUES(asset_profit), leverage = VALUES(leverage), uptime = VALUES(uptime), profit_price = VALUES(profit_price), stop_price = VALUES(stop_price), liquidation_price = VALUES(liquidation_price) """) result = session.execute(upsert_sql, processed_positions) total_affected = result.rowcount results['affected'] =total_affected logger.info(f"UPSERT完成: 总数 {results['total']} 条, 受影响 {results['affected']} 条") # 批量删除(使用UNION ALL构造虚拟表) if current_position_records: # 构建UNION ALL查询 union_parts = [] for record in current_position_records: k_id, st_id, symbol, side = record # 转义单引号 symbol_escaped = symbol.replace("'", "''") side_escaped = side.replace("'", "''") union_parts.append(f"SELECT {k_id} as k_id, {st_id} as st_id, '{symbol_escaped}' as symbol, '{side_escaped}' as side") if union_parts: union_sql = " UNION ALL ".join(union_parts) # 或者使用LEFT JOIN方式 delete_sql_join = text(f""" DELETE p FROM deh_strategy_position_new p LEFT JOIN ( {union_sql} ) AS current_pos ON p.k_id = current_pos.k_id AND p.st_id = current_pos.st_id AND p.symbol = current_pos.symbol AND p.side = current_pos.side WHERE current_pos.k_id IS NULL """) result = session.execute(delete_sql_join) deleted_count = result.rowcount results['deleted'] = deleted_count logger.info(f"删除 {deleted_count} 条过期持仓") session.commit() logger.info(f"批量同步V3完成: 总数={results['total']}, " f"受影响={results['affected']}, " f"删除={results['deleted']}") return True, results except Exception as e: session.rollback() logger.error(f"批量同步V3失败: {e}", exc_info=True) results['errors'].append(f"同步失败: {str(e)}") return False, results finally: session.close() async def _collect_all_positions(self, accounts: Dict[str, Dict]) -> List[Dict]: """收集所有账号的持仓数据""" all_positions = [] try: # 按交易所分组账号 account_groups = self._group_accounts_by_exchange(accounts) # 并发收集每个交易所的数据 tasks = [] for exchange_id, account_list in account_groups.items(): task = self._collect_exchange_positions(exchange_id, account_list) tasks.append(task) # 等待所有任务完成并合并结果 results = await asyncio.gather(*tasks, return_exceptions=True) for result in results: if isinstance(result, list): all_positions.extend(result) except Exception as e: logger.error(f"收集持仓数据失败: {e}") return all_positions def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]: """按交易所分组账号""" groups = {} for account_id, account_info in accounts.items(): exchange_id = account_info.get('exchange_id') if exchange_id: if exchange_id not in groups: groups[exchange_id] = [] groups[exchange_id].append(account_info) return groups async def _collect_exchange_positions(self, exchange_id: str, account_list: List[Dict]) -> List[Dict]: """收集某个交易所的持仓数据""" positions_list = [] try: tasks = [] for account_info in account_list: k_id = int(account_info['k_id']) st_id = account_info.get('st_id', 0) task = self._get_positions_from_redis(k_id, st_id, exchange_id) tasks.append(task) # 并发获取 results = await asyncio.gather(*tasks, return_exceptions=True) for result in results: if isinstance(result, list): positions_list.extend(result) except Exception as e: logger.error(f"收集交易所 {exchange_id} 持仓数据失败: {e}") return positions_list async def _get_positions_from_redis(self, k_id: int, st_id: int, exchange_id: str) -> List[Dict]: """从Redis获取持仓数据""" try: redis_key = f"{exchange_id}:positions:{k_id}" redis_data = self.redis_client.client.hget(redis_key, 'positions') if not redis_data: return [] positions = json.loads(redis_data) # 添加账号信息 for position in positions: # print(position['symbol']) position['k_id'] = k_id position['st_id'] = st_id position['exchange_id'] = exchange_id return positions except Exception as e: logger.error(f"获取Redis持仓数据失败: k_id={k_id}, error={e}") return [] def _convert_position_data(self, data: Dict) -> Dict: """转换持仓数据格式""" try: return { 'st_id': helpers.safe_int(data.get('st_id'), 0), 'k_id': helpers.safe_int(data.get('k_id'), 0), 'asset': data.get('asset', 'USDT'), 'symbol': data.get('symbol', ''), 'side': data.get('side', ''), 'price': helpers.safe_float(data.get('price')), 'qty': helpers.safe_float(data.get('qty')), # 后面会重命名为sum 'asset_num': helpers.safe_float(data.get('asset_num')), 'asset_profit': helpers.safe_float(data.get('asset_profit')), 'leverage': helpers.safe_int(data.get('leverage')), 'uptime': helpers.safe_int(data.get('uptime')), 'profit_price': helpers.safe_float(data.get('profit_price')), 'stop_price': helpers.safe_float(data.get('stop_price')), 'liquidation_price': helpers.safe_float(data.get('liquidation_price')) } except Exception as e: logger.error(f"转换持仓数据异常: {data}, error={e}") return {} async def sync(self): """兼容旧接口""" accounts = self.get_accounts_from_redis() await self.sync_batch(accounts)