Files
exchange_monitor_sync/sync/position_sync.py

174 lines
7.3 KiB
Python
Raw Normal View History

2025-12-02 22:05:54 +08:00
from .base_sync import BaseSync
from loguru import logger
from typing import List, Dict
import json
from datetime import datetime, timedelta
class PositionSync(BaseSync):
"""持仓数据同步器"""
async def sync(self):
"""同步持仓数据"""
try:
# 获取所有账号
accounts = self.get_accounts_from_redis()
for k_id_str, account_info in accounts.items():
try:
k_id = int(k_id_str)
st_id = account_info.get('st_id', 0)
exchange_id = account_info['exchange_id']
if k_id <= 0 or st_id <= 0:
continue
# 从Redis获取持仓数据
positions = await self._get_positions_from_redis(k_id, exchange_id)
# 同步到数据库
if positions:
success = self._sync_positions_to_db(k_id, st_id, positions)
if success:
logger.debug(f"持仓同步成功: k_id={k_id}, 持仓数={len(positions)}")
except Exception as e:
logger.error(f"同步账号 {k_id_str} 持仓失败: {e}")
continue
logger.info("持仓数据同步完成")
except Exception as e:
logger.error(f"持仓同步失败: {e}")
async def _get_positions_from_redis(self, k_id: int, exchange_id: str) -> List[Dict]:
"""从Redis获取持仓数据"""
try:
redis_key = f"{exchange_id}:positions:{k_id}"
redis_data = self.redis_client.client.hget(redis_key, 'positions')
if not redis_data:
return []
positions = json.loads(redis_data)
# 添加账号信息
for position in positions:
position['k_id'] = k_id
return positions
except Exception as e:
logger.error(f"获取Redis持仓数据失败: k_id={k_id}, error={e}")
return []
def _sync_positions_to_db(self, k_id: int, st_id: int, positions_data: List[Dict]) -> bool:
"""同步持仓数据到数据库"""
session = self.db_manager.get_session()
try:
# 使用批量优化方案
from sqlalchemy.dialects.mysql import insert
# 准备数据
insert_data = []
keep_keys = set() # 需要保留的(symbol, side)组合
for pos_data in positions_data:
try:
# 转换数据(这里需要实现转换逻辑)
pos_dict = self._convert_position_data(pos_data)
if not all([pos_dict.get('symbol'), pos_dict.get('side')]):
continue
# 重命名qty为sum
if 'qty' in pos_dict:
pos_dict['sum'] = pos_dict.pop('qty')
insert_data.append(pos_dict)
keep_keys.add((pos_dict['symbol'], pos_dict['side']))
except Exception as e:
logger.error(f"转换持仓数据失败: {pos_data}, error={e}")
continue
with session.begin():
if not insert_data:
# 清空该账号持仓
session.execute(
delete(StrategyPosition).where(
and_(
StrategyPosition.k_id == k_id,
StrategyPosition.st_id == st_id
)
)
)
return True
# 批量插入/更新
stmt = insert(StrategyPosition.__table__).values(insert_data)
update_dict = {
'price': stmt.inserted.price,
'sum': stmt.inserted.sum,
'asset_num': stmt.inserted.asset_num,
'asset_profit': stmt.inserted.asset_profit,
'leverage': stmt.inserted.leverage,
'uptime': stmt.inserted.uptime,
'profit_price': stmt.inserted.profit_price,
'stop_price': stmt.inserted.stop_price,
'liquidation_price': stmt.inserted.liquidation_price
}
stmt = stmt.on_duplicate_key_update(**update_dict)
session.execute(stmt)
# 删除多余持仓
if keep_keys:
existing_positions = session.execute(
select(StrategyPosition).where(
and_(
StrategyPosition.k_id == k_id,
StrategyPosition.st_id == st_id
)
)
).scalars().all()
to_delete_ids = []
for existing in existing_positions:
key = (existing.symbol, existing.side)
if key not in keep_keys:
to_delete_ids.append(existing.id)
if to_delete_ids:
session.execute(
delete(StrategyPosition).where(
StrategyPosition.id.in_(to_delete_ids)
)
)
return True
except Exception as e:
logger.error(f"同步持仓到数据库失败: k_id={k_id}, error={e}")
return False
finally:
session.close()
def _convert_position_data(self, data: Dict) -> Dict:
"""转换持仓数据格式"""
# 这里实现具体的转换逻辑
return {
'st_id': int(data.get('st_id', 0)),
'k_id': int(data.get('k_id', 0)),
'asset': 'USDT',
'symbol': data.get('symbol', ''),
'side': data.get('side', ''),
'price': float(data.get('price', 0)) if data.get('price') is not None else None,
'qty': float(data.get('qty', 0)) if data.get('qty') is not None else None,
'asset_num': float(data.get('asset_num', 0)) if data.get('asset_num') is not None else None,
'asset_profit': float(data.get('asset_profit', 0)) if data.get('asset_profit') is not None else None,
'leverage': int(data.get('leverage', 0)) if data.get('leverage') is not None else None,
'uptime': int(data.get('uptime', 0)) if data.get('uptime') is not None else None,
'profit_price': float(data.get('profit_price', 0)) if data.get('profit_price') is not None else None,
'stop_price': float(data.get('stop_price', 0)) if data.get('stop_price') is not None else None,
'liquidation_price': float(data.get('liquidation_price', 0)) if data.get('liquidation_price') is not None else None
}