This commit is contained in:
lz_db
2025-12-04 19:44:22 +08:00
parent f85f4ef152
commit 8dc0f0dbc3
11 changed files with 1128 additions and 519 deletions

View File

@@ -1,12 +1,9 @@
from .base_sync import BaseSync
from loguru import logger
from typing import List, Dict, Any, Set, Tuple
import json
import asyncio
import utils.helpers as helpers
from datetime import datetime
from sqlalchemy import text, and_, select, delete
from models.orm_models import StrategyPosition
import utils.helpers as helpers
import time
class PositionSyncBatch(BaseSync):
@@ -18,13 +15,15 @@ class PositionSyncBatch(BaseSync):
async def sync_batch(self, accounts: Dict[str, Dict]):
"""批量同步所有账号的持仓数据"""
return
try:
logger.info(f"开始批量同步持仓数据,共 {len(accounts)} 个账号")
start_time = time.time()
# 1. 收集所有账号的持仓数据
all_positions = await self._collect_all_positions(accounts)
all_positions = await self.redis_client._collect_all_positions(accounts)
if not all_positions:
logger.info("无持仓数据需要同步")
@@ -421,91 +420,6 @@ class PositionSyncBatch(BaseSync):
finally:
session.close()
async def _collect_all_positions(self, accounts: Dict[str, Dict]) -> List[Dict]:
"""收集所有账号的持仓数据"""
all_positions = []
try:
# 按交易所分组账号
account_groups = self._group_accounts_by_exchange(accounts)
# 并发收集每个交易所的数据
tasks = []
for exchange_id, account_list in account_groups.items():
task = self._collect_exchange_positions(exchange_id, account_list)
tasks.append(task)
# 等待所有任务完成并合并结果
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, list):
all_positions.extend(result)
except Exception as e:
logger.error(f"收集持仓数据失败: {e}")
return all_positions
def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]:
"""按交易所分组账号"""
groups = {}
for account_id, account_info in accounts.items():
exchange_id = account_info.get('exchange_id')
if exchange_id:
if exchange_id not in groups:
groups[exchange_id] = []
groups[exchange_id].append(account_info)
return groups
async def _collect_exchange_positions(self, exchange_id: str, account_list: List[Dict]) -> List[Dict]:
"""收集某个交易所的持仓数据"""
positions_list = []
try:
tasks = []
for account_info in account_list:
k_id = int(account_info['k_id'])
st_id = account_info.get('st_id', 0)
task = self._get_positions_from_redis(k_id, st_id, exchange_id)
tasks.append(task)
# 并发获取
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, list):
positions_list.extend(result)
except Exception as e:
logger.error(f"收集交易所 {exchange_id} 持仓数据失败: {e}")
return positions_list
async def _get_positions_from_redis(self, k_id: int, st_id: int, exchange_id: str) -> List[Dict]:
"""从Redis获取持仓数据"""
try:
redis_key = f"{exchange_id}:positions:{k_id}"
redis_data = self.redis_client.client.hget(redis_key, 'positions')
if not redis_data:
return []
positions = json.loads(redis_data)
# 添加账号信息
for position in positions:
# print(position['symbol'])
position['k_id'] = k_id
position['st_id'] = st_id
position['exchange_id'] = exchange_id
return positions
except Exception as e:
logger.error(f"获取Redis持仓数据失败: k_id={k_id}, error={e}")
return []
def _convert_position_data(self, data: Dict) -> Dict:
"""转换持仓数据格式"""
try:
@@ -531,7 +445,3 @@ class PositionSyncBatch(BaseSync):
logger.error(f"转换持仓数据异常: {data}, error={e}")
return {}
async def sync(self):
"""兼容旧接口"""
accounts = self.get_accounts_from_redis()
await self.sync_batch(accounts)