191 lines
8.3 KiB
Python
191 lines
8.3 KiB
Python
from .base_sync import BaseSync
|
|
from loguru import logger
|
|
from typing import List, Dict
|
|
import json
|
|
import asyncio
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
class PositionSync(BaseSync):
|
|
"""持仓数据同步器(批量版本)"""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.max_concurrent = 10 # 每个同步器的最大并发数
|
|
|
|
async def sync_batch(self, accounts: Dict[str, Dict]):
|
|
"""批量同步所有账号的持仓数据"""
|
|
try:
|
|
logger.info(f"开始批量同步持仓数据,共 {len(accounts)} 个账号")
|
|
|
|
# 按账号分组
|
|
account_groups = self._group_accounts_by_exchange(accounts)
|
|
|
|
# 并发处理每个交易所的账号
|
|
tasks = []
|
|
for exchange_id, account_list in account_groups.items():
|
|
task = self._sync_exchange_accounts(exchange_id, account_list)
|
|
tasks.append(task)
|
|
|
|
# 等待所有任务完成
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
# 统计结果
|
|
success_count = sum(1 for r in results if isinstance(r, bool) and r)
|
|
logger.info(f"持仓批量同步完成: 成功 {success_count}/{len(results)} 个交易所组")
|
|
|
|
except Exception as e:
|
|
logger.error(f"持仓批量同步失败: {e}")
|
|
|
|
def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]:
|
|
"""按交易所分组账号"""
|
|
groups = {}
|
|
for account_id, account_info in accounts.items():
|
|
exchange_id = account_info.get('exchange_id')
|
|
if exchange_id:
|
|
if exchange_id not in groups:
|
|
groups[exchange_id] = []
|
|
groups[exchange_id].append(account_info)
|
|
return groups
|
|
|
|
async def _sync_exchange_accounts(self, exchange_id: str, account_list: List[Dict]):
|
|
"""同步某个交易所的所有账号"""
|
|
try:
|
|
# 收集所有账号的持仓数据
|
|
all_positions = []
|
|
|
|
for account_info in account_list:
|
|
k_id = int(account_info['k_id'])
|
|
st_id = account_info.get('st_id', 0)
|
|
|
|
# 从Redis获取持仓数据
|
|
positions = await self._get_positions_from_redis(k_id, exchange_id)
|
|
|
|
if positions:
|
|
# 添加账号信息
|
|
for position in positions:
|
|
position['k_id'] = k_id
|
|
position['st_id'] = st_id
|
|
all_positions.extend(positions)
|
|
|
|
if not all_positions:
|
|
logger.debug(f"交易所 {exchange_id} 无持仓数据")
|
|
return True
|
|
|
|
# 批量同步到数据库
|
|
success = self._sync_positions_batch_to_db(all_positions)
|
|
if success:
|
|
logger.info(f"交易所 {exchange_id} 持仓同步成功: {len(all_positions)} 条持仓")
|
|
|
|
return success
|
|
|
|
except Exception as e:
|
|
logger.error(f"同步交易所 {exchange_id} 持仓失败: {e}")
|
|
return False
|
|
|
|
def _sync_positions_batch_to_db(self, all_positions: List[Dict]) -> bool:
|
|
"""批量同步持仓数据到数据库(优化版)"""
|
|
session = self.db_manager.get_session()
|
|
try:
|
|
# 按k_id分组
|
|
positions_by_account = {}
|
|
for position in all_positions:
|
|
k_id = position['k_id']
|
|
if k_id not in positions_by_account:
|
|
positions_by_account[k_id] = []
|
|
positions_by_account[k_id].append(position)
|
|
|
|
success_count = 0
|
|
|
|
with session.begin():
|
|
for k_id, positions in positions_by_account.items():
|
|
try:
|
|
st_id = positions[0]['st_id'] if positions else 0
|
|
|
|
# 准备数据
|
|
insert_data = []
|
|
keep_keys = set()
|
|
|
|
for pos_data in positions:
|
|
try:
|
|
pos_dict = self._convert_position_data(pos_data)
|
|
if not all([pos_dict.get('symbol'), pos_dict.get('side')]):
|
|
continue
|
|
|
|
# 重命名qty为sum
|
|
if 'qty' in pos_dict:
|
|
pos_dict['sum'] = pos_dict.pop('qty')
|
|
|
|
insert_data.append(pos_dict)
|
|
keep_keys.add((pos_dict['symbol'], pos_dict['side']))
|
|
|
|
except Exception as e:
|
|
logger.error(f"转换持仓数据失败: {pos_data}, error={e}")
|
|
continue
|
|
|
|
if not insert_data:
|
|
continue
|
|
|
|
# 批量插入/更新
|
|
from sqlalchemy.dialects.mysql import insert
|
|
|
|
stmt = insert(StrategyPosition.__table__).values(insert_data)
|
|
|
|
update_dict = {
|
|
'price': stmt.inserted.price,
|
|
'sum': stmt.inserted.sum,
|
|
'asset_num': stmt.inserted.asset_num,
|
|
'asset_profit': stmt.inserted.asset_profit,
|
|
'leverage': stmt.inserted.leverage,
|
|
'uptime': stmt.inserted.uptime,
|
|
'profit_price': stmt.inserted.profit_price,
|
|
'stop_price': stmt.inserted.stop_price,
|
|
'liquidation_price': stmt.inserted.liquidation_price
|
|
}
|
|
|
|
stmt = stmt.on_duplicate_key_update(**update_dict)
|
|
session.execute(stmt)
|
|
|
|
# 删除多余持仓
|
|
if keep_keys:
|
|
existing_positions = session.execute(
|
|
select(StrategyPosition).where(
|
|
and_(
|
|
StrategyPosition.k_id == k_id,
|
|
StrategyPosition.st_id == st_id
|
|
)
|
|
)
|
|
).scalars().all()
|
|
|
|
to_delete_ids = []
|
|
for existing in existing_positions:
|
|
key = (existing.symbol, existing.side)
|
|
if key not in keep_keys:
|
|
to_delete_ids.append(existing.id)
|
|
|
|
if to_delete_ids:
|
|
# 分块删除
|
|
chunk_size = 100
|
|
for i in range(0, len(to_delete_ids), chunk_size):
|
|
chunk = to_delete_ids[i:i + chunk_size]
|
|
session.execute(
|
|
delete(StrategyPosition).where(
|
|
StrategyPosition.id.in_(chunk)
|
|
)
|
|
)
|
|
|
|
success_count += 1
|
|
|
|
except Exception as e:
|
|
logger.error(f"同步账号 {k_id} 持仓失败: {e}")
|
|
continue
|
|
|
|
logger.info(f"批量同步完成: 成功 {success_count}/{len(positions_by_account)} 个账号")
|
|
return success_count > 0
|
|
|
|
except Exception as e:
|
|
logger.error(f"批量同步持仓到数据库失败: {e}")
|
|
return False
|
|
finally:
|
|
session.close()
|
|
|
|
# 其他方法保持不变... |