from .base_sync import BaseSync from loguru import logger from typing import List, Dict import json from datetime import datetime, timedelta class PositionSync(BaseSync): """持仓数据同步器""" async def sync(self): """同步持仓数据""" try: # 获取所有账号 accounts = self.get_accounts_from_redis() for k_id_str, account_info in accounts.items(): try: k_id = int(k_id_str) st_id = account_info.get('st_id', 0) exchange_id = account_info['exchange_id'] if k_id <= 0 or st_id <= 0: continue # 从Redis获取持仓数据 positions = await self._get_positions_from_redis(k_id, exchange_id) # 同步到数据库 if positions: success = self._sync_positions_to_db(k_id, st_id, positions) if success: logger.debug(f"持仓同步成功: k_id={k_id}, 持仓数={len(positions)}") except Exception as e: logger.error(f"同步账号 {k_id_str} 持仓失败: {e}") continue logger.info("持仓数据同步完成") except Exception as e: logger.error(f"持仓同步失败: {e}") async def _get_positions_from_redis(self, k_id: int, exchange_id: str) -> List[Dict]: """从Redis获取持仓数据""" try: redis_key = f"{exchange_id}:positions:{k_id}" redis_data = self.redis_client.client.hget(redis_key, 'positions') if not redis_data: return [] positions = json.loads(redis_data) # 添加账号信息 for position in positions: position['k_id'] = k_id return positions except Exception as e: logger.error(f"获取Redis持仓数据失败: k_id={k_id}, error={e}") return [] def _sync_positions_to_db(self, k_id: int, st_id: int, positions_data: List[Dict]) -> bool: """同步持仓数据到数据库""" session = self.db_manager.get_session() try: # 使用批量优化方案 from sqlalchemy.dialects.mysql import insert # 准备数据 insert_data = [] keep_keys = set() # 需要保留的(symbol, side)组合 for pos_data in positions_data: try: # 转换数据(这里需要实现转换逻辑) pos_dict = self._convert_position_data(pos_data) if not all([pos_dict.get('symbol'), pos_dict.get('side')]): continue # 重命名qty为sum if 'qty' in pos_dict: pos_dict['sum'] = pos_dict.pop('qty') insert_data.append(pos_dict) keep_keys.add((pos_dict['symbol'], pos_dict['side'])) except Exception as e: logger.error(f"转换持仓数据失败: {pos_data}, error={e}") continue with session.begin(): if not insert_data: # 清空该账号持仓 session.execute( delete(StrategyPosition).where( and_( StrategyPosition.k_id == k_id, StrategyPosition.st_id == st_id ) ) ) return True # 批量插入/更新 stmt = insert(StrategyPosition.__table__).values(insert_data) update_dict = { 'price': stmt.inserted.price, 'sum': stmt.inserted.sum, 'asset_num': stmt.inserted.asset_num, 'asset_profit': stmt.inserted.asset_profit, 'leverage': stmt.inserted.leverage, 'uptime': stmt.inserted.uptime, 'profit_price': stmt.inserted.profit_price, 'stop_price': stmt.inserted.stop_price, 'liquidation_price': stmt.inserted.liquidation_price } stmt = stmt.on_duplicate_key_update(**update_dict) session.execute(stmt) # 删除多余持仓 if keep_keys: existing_positions = session.execute( select(StrategyPosition).where( and_( StrategyPosition.k_id == k_id, StrategyPosition.st_id == st_id ) ) ).scalars().all() to_delete_ids = [] for existing in existing_positions: key = (existing.symbol, existing.side) if key not in keep_keys: to_delete_ids.append(existing.id) if to_delete_ids: session.execute( delete(StrategyPosition).where( StrategyPosition.id.in_(to_delete_ids) ) ) return True except Exception as e: logger.error(f"同步持仓到数据库失败: k_id={k_id}, error={e}") return False finally: session.close() def _convert_position_data(self, data: Dict) -> Dict: """转换持仓数据格式""" # 这里实现具体的转换逻辑 return { 'st_id': int(data.get('st_id', 0)), 'k_id': int(data.get('k_id', 0)), 'asset': 'USDT', 'symbol': data.get('symbol', ''), 'side': data.get('side', ''), 'price': float(data.get('price', 0)) if data.get('price') is not None else None, 'qty': float(data.get('qty', 0)) if data.get('qty') is not None else None, 'asset_num': float(data.get('asset_num', 0)) if data.get('asset_num') is not None else None, 'asset_profit': float(data.get('asset_profit', 0)) if data.get('asset_profit') is not None else None, 'leverage': int(data.get('leverage', 0)) if data.get('leverage') is not None else None, 'uptime': int(data.get('uptime', 0)) if data.get('uptime') is not None else None, 'profit_price': float(data.get('profit_price', 0)) if data.get('profit_price') is not None else None, 'stop_price': float(data.get('stop_price', 0)) if data.get('stop_price') is not None else None, 'liquidation_price': float(data.get('liquidation_price', 0)) if data.get('liquidation_price') is not None else None }