Files
exchange_monitor_sync/sync/position_sync.py
lz_db c8a6cfead1 1
2025-12-02 22:36:52 +08:00

191 lines
8.3 KiB
Python

from .base_sync import BaseSync
from loguru import logger
from typing import List, Dict
import json
import asyncio
from concurrent.futures import ThreadPoolExecutor
class PositionSync(BaseSync):
"""持仓数据同步器(批量版本)"""
def __init__(self):
super().__init__()
self.max_concurrent = 10 # 每个同步器的最大并发数
async def sync_batch(self, accounts: Dict[str, Dict]):
"""批量同步所有账号的持仓数据"""
try:
logger.info(f"开始批量同步持仓数据,共 {len(accounts)} 个账号")
# 按账号分组
account_groups = self._group_accounts_by_exchange(accounts)
# 并发处理每个交易所的账号
tasks = []
for exchange_id, account_list in account_groups.items():
task = self._sync_exchange_accounts(exchange_id, account_list)
tasks.append(task)
# 等待所有任务完成
results = await asyncio.gather(*tasks, return_exceptions=True)
# 统计结果
success_count = sum(1 for r in results if isinstance(r, bool) and r)
logger.info(f"持仓批量同步完成: 成功 {success_count}/{len(results)} 个交易所组")
except Exception as e:
logger.error(f"持仓批量同步失败: {e}")
def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]:
"""按交易所分组账号"""
groups = {}
for account_id, account_info in accounts.items():
exchange_id = account_info.get('exchange_id')
if exchange_id:
if exchange_id not in groups:
groups[exchange_id] = []
groups[exchange_id].append(account_info)
return groups
async def _sync_exchange_accounts(self, exchange_id: str, account_list: List[Dict]):
"""同步某个交易所的所有账号"""
try:
# 收集所有账号的持仓数据
all_positions = []
for account_info in account_list:
k_id = int(account_info['k_id'])
st_id = account_info.get('st_id', 0)
# 从Redis获取持仓数据
positions = await self._get_positions_from_redis(k_id, exchange_id)
if positions:
# 添加账号信息
for position in positions:
position['k_id'] = k_id
position['st_id'] = st_id
all_positions.extend(positions)
if not all_positions:
logger.debug(f"交易所 {exchange_id} 无持仓数据")
return True
# 批量同步到数据库
success = self._sync_positions_batch_to_db(all_positions)
if success:
logger.info(f"交易所 {exchange_id} 持仓同步成功: {len(all_positions)} 条持仓")
return success
except Exception as e:
logger.error(f"同步交易所 {exchange_id} 持仓失败: {e}")
return False
def _sync_positions_batch_to_db(self, all_positions: List[Dict]) -> bool:
"""批量同步持仓数据到数据库(优化版)"""
session = self.db_manager.get_session()
try:
# 按k_id分组
positions_by_account = {}
for position in all_positions:
k_id = position['k_id']
if k_id not in positions_by_account:
positions_by_account[k_id] = []
positions_by_account[k_id].append(position)
success_count = 0
with session.begin():
for k_id, positions in positions_by_account.items():
try:
st_id = positions[0]['st_id'] if positions else 0
# 准备数据
insert_data = []
keep_keys = set()
for pos_data in positions:
try:
pos_dict = self._convert_position_data(pos_data)
if not all([pos_dict.get('symbol'), pos_dict.get('side')]):
continue
# 重命名qty为sum
if 'qty' in pos_dict:
pos_dict['sum'] = pos_dict.pop('qty')
insert_data.append(pos_dict)
keep_keys.add((pos_dict['symbol'], pos_dict['side']))
except Exception as e:
logger.error(f"转换持仓数据失败: {pos_data}, error={e}")
continue
if not insert_data:
continue
# 批量插入/更新
from sqlalchemy.dialects.mysql import insert
stmt = insert(StrategyPosition.__table__).values(insert_data)
update_dict = {
'price': stmt.inserted.price,
'sum': stmt.inserted.sum,
'asset_num': stmt.inserted.asset_num,
'asset_profit': stmt.inserted.asset_profit,
'leverage': stmt.inserted.leverage,
'uptime': stmt.inserted.uptime,
'profit_price': stmt.inserted.profit_price,
'stop_price': stmt.inserted.stop_price,
'liquidation_price': stmt.inserted.liquidation_price
}
stmt = stmt.on_duplicate_key_update(**update_dict)
session.execute(stmt)
# 删除多余持仓
if keep_keys:
existing_positions = session.execute(
select(StrategyPosition).where(
and_(
StrategyPosition.k_id == k_id,
StrategyPosition.st_id == st_id
)
)
).scalars().all()
to_delete_ids = []
for existing in existing_positions:
key = (existing.symbol, existing.side)
if key not in keep_keys:
to_delete_ids.append(existing.id)
if to_delete_ids:
# 分块删除
chunk_size = 100
for i in range(0, len(to_delete_ids), chunk_size):
chunk = to_delete_ids[i:i + chunk_size]
session.execute(
delete(StrategyPosition).where(
StrategyPosition.id.in_(chunk)
)
)
success_count += 1
except Exception as e:
logger.error(f"同步账号 {k_id} 持仓失败: {e}")
continue
logger.info(f"批量同步完成: 成功 {success_count}/{len(positions_by_account)} 个账号")
return success_count > 0
except Exception as e:
logger.error(f"批量同步持仓到数据库失败: {e}")
return False
finally:
session.close()
# 其他方法保持不变...