1
This commit is contained in:
372
sync/manager.py
372
sync/manager.py
@@ -3,26 +3,30 @@ from loguru import logger
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
from typing import Dict
|
||||
import re
|
||||
|
||||
from utils.redis_client import RedisClient
|
||||
from config.settings import SYNC_CONFIG
|
||||
from .position_sync_batch import PositionSyncBatch
|
||||
from .order_sync_batch import OrderSyncBatch # 使用批量版本
|
||||
from .account_sync_batch import AccountSyncBatch
|
||||
from utils.batch_position_sync import BatchPositionSync
|
||||
from utils.batch_order_sync import BatchOrderSync
|
||||
from utils.batch_account_sync import BatchAccountSync
|
||||
from .position_sync import PositionSyncBatch
|
||||
from .order_sync import OrderSyncBatch # 使用批量版本
|
||||
from .account_sync import AccountSyncBatch
|
||||
from utils.redis_batch_helper import RedisBatchHelper
|
||||
from config.settings import COMPUTER_NAMES, COMPUTER_NAME_PATTERN
|
||||
from typing import List, Dict, Any, Set, Optional
|
||||
|
||||
class SyncManager:
|
||||
"""同步管理器(完整批量版本)"""
|
||||
|
||||
def __init__(self):
|
||||
self.is_running = True
|
||||
self.redis_client = RedisClient()
|
||||
self.sync_interval = SYNC_CONFIG['interval']
|
||||
self.computer_names = self._get_computer_names()
|
||||
self.computer_name_pattern = re.compile(COMPUTER_NAME_PATTERN)
|
||||
|
||||
# 初始化批量同步工具
|
||||
self.batch_tools = {}
|
||||
self.redis_helper = None
|
||||
|
||||
# 初始化同步器
|
||||
@@ -31,24 +35,16 @@ class SyncManager:
|
||||
if SYNC_CONFIG['enable_position_sync']:
|
||||
position_sync = PositionSyncBatch()
|
||||
self.syncers.append(position_sync)
|
||||
self.batch_tools['position'] = BatchPositionSync(position_sync.db_manager)
|
||||
logger.info("启用持仓批量同步")
|
||||
|
||||
if SYNC_CONFIG['enable_order_sync']:
|
||||
order_sync = OrderSyncBatch()
|
||||
self.syncers.append(order_sync)
|
||||
self.batch_tools['order'] = BatchOrderSync(order_sync.db_manager)
|
||||
|
||||
# 初始化Redis批量助手
|
||||
if order_sync.redis_client:
|
||||
self.redis_helper = RedisBatchHelper(order_sync.redis_client.client)
|
||||
|
||||
logger.info("启用订单批量同步")
|
||||
|
||||
if SYNC_CONFIG['enable_account_sync']:
|
||||
account_sync = AccountSyncBatch()
|
||||
self.syncers.append(account_sync)
|
||||
self.batch_tools['account'] = BatchAccountSync(account_sync.db_manager)
|
||||
logger.info("启用账户信息批量同步")
|
||||
|
||||
# 性能统计
|
||||
@@ -71,21 +67,25 @@ class SyncManager:
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
self.stats['total_syncs'] += 1
|
||||
sync_start = time.time()
|
||||
|
||||
# 获取所有账号(只获取一次)
|
||||
accounts = await self._get_all_accounts()
|
||||
accounts = await self.get_accounts_from_redis()
|
||||
|
||||
if not accounts:
|
||||
logger.warning("未获取到任何账号,等待下次同步")
|
||||
await asyncio.sleep(self.sync_interval)
|
||||
continue
|
||||
|
||||
|
||||
self.stats['total_syncs'] += 1
|
||||
sync_start = time.time()
|
||||
|
||||
logger.info(f"第{self.stats['total_syncs']}次同步开始,共 {len(accounts)} 个账号")
|
||||
|
||||
# 并发执行所有同步
|
||||
await self._execute_all_syncers_concurrent(accounts)
|
||||
# 执行所有同步器
|
||||
tasks = [syncer.sync(accounts) for syncer in self.syncers]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
|
||||
# 更新统计
|
||||
sync_time = time.time() - sync_start
|
||||
@@ -101,137 +101,255 @@ class SyncManager:
|
||||
logger.error(f"同步任务异常: {e}")
|
||||
await asyncio.sleep(30)
|
||||
|
||||
async def _get_all_accounts(self) -> Dict[str, Dict]:
|
||||
"""获取所有账号"""
|
||||
if not self.syncers:
|
||||
def get_accounts_from_redis(self) -> Dict[str, Dict]:
|
||||
"""从Redis获取所有计算机名的账号配置"""
|
||||
try:
|
||||
accounts_dict = {}
|
||||
total_keys_processed = 0
|
||||
|
||||
# 方法1:使用配置的计算机名列表
|
||||
for computer_name in self.computer_names:
|
||||
accounts = self._get_accounts_by_computer_name(computer_name)
|
||||
total_keys_processed += 1
|
||||
accounts_dict.update(accounts)
|
||||
|
||||
# 方法2:如果配置的计算机名没有数据,尝试自动发现(备用方案)
|
||||
if not accounts_dict:
|
||||
logger.warning("配置的计算机名未找到数据,尝试自动发现...")
|
||||
accounts_dict = self._discover_all_accounts()
|
||||
|
||||
self.sync_stats['total_accounts'] = len(accounts_dict)
|
||||
logger.info(f"从 {len(self.computer_names)} 个计算机名获取到 {len(accounts_dict)} 个账号")
|
||||
|
||||
return accounts_dict
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取账户信息失败: {e}")
|
||||
return {}
|
||||
|
||||
def _get_computer_names(self) -> List[str]:
|
||||
"""获取计算机名列表"""
|
||||
if ',' in COMPUTER_NAMES:
|
||||
names = [name.strip() for name in COMPUTER_NAMES.split(',')]
|
||||
logger.info(f"使用配置的计算机名列表: {names}")
|
||||
return names
|
||||
return [COMPUTER_NAMES.strip()]
|
||||
|
||||
def _get_accounts_by_computer_name(self, computer_name: str) -> Dict[str, Dict]:
|
||||
"""获取指定计算机名的账号"""
|
||||
accounts_dict = {}
|
||||
|
||||
# 使用第一个同步器获取账号
|
||||
return self.syncers[0].get_accounts_from_redis()
|
||||
|
||||
async def _execute_all_syncers_concurrent(self, accounts: Dict[str, Dict]):
|
||||
"""并发执行所有同步器"""
|
||||
tasks = []
|
||||
|
||||
# 持仓批量同步
|
||||
if 'position' in self.batch_tools:
|
||||
task = self._sync_positions_batch(accounts)
|
||||
tasks.append(task)
|
||||
|
||||
# 订单批量同步
|
||||
if 'order' in self.batch_tools:
|
||||
task = self._sync_orders_batch(accounts)
|
||||
tasks.append(task)
|
||||
|
||||
# 账户信息批量同步
|
||||
if 'account' in self.batch_tools:
|
||||
task = self._sync_accounts_batch(accounts)
|
||||
tasks.append(task)
|
||||
|
||||
# 并发执行所有任务
|
||||
if tasks:
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 检查结果
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"同步任务 {i} 失败: {result}")
|
||||
|
||||
async def _sync_positions_batch(self, accounts: Dict[str, Dict]):
|
||||
"""批量同步持仓数据"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
# 构建key
|
||||
redis_key = f"{computer_name}_strategy_api"
|
||||
|
||||
# 收集所有持仓数据
|
||||
position_sync = next((s for s in self.syncers if isinstance(s, PositionSyncBatch)), None)
|
||||
if not position_sync:
|
||||
return
|
||||
# 从Redis获取数据
|
||||
result = self.redis_client.client.hgetall(redis_key)
|
||||
if not result:
|
||||
logger.debug(f"未找到 {redis_key} 的策略API配置")
|
||||
return {}
|
||||
|
||||
all_positions = await position_sync._collect_all_positions(accounts)
|
||||
logger.info(f"从 {redis_key} 获取到 {len(result)} 个交易所配置")
|
||||
|
||||
if not all_positions:
|
||||
self.stats['position'] = {'accounts': 0, 'positions': 0, 'time': 0}
|
||||
return
|
||||
for exchange_name, accounts_json in result.items():
|
||||
try:
|
||||
accounts = json.loads(accounts_json)
|
||||
if not accounts:
|
||||
continue
|
||||
|
||||
# 格式化交易所ID
|
||||
exchange_id = self.format_exchange_id(exchange_name)
|
||||
|
||||
for account_id, account_info in accounts.items():
|
||||
parsed_account = self.parse_account(exchange_id, account_id, account_info)
|
||||
if parsed_account:
|
||||
# 添加计算机名标记
|
||||
parsed_account['computer_name'] = computer_name
|
||||
accounts_dict[account_id] = parsed_account
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"解析交易所 {exchange_name} 的JSON数据失败: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"处理交易所 {exchange_name} 数据异常: {e}")
|
||||
continue
|
||||
|
||||
# 使用批量工具同步
|
||||
batch_tool = self.batch_tools['position']
|
||||
success, stats = batch_tool.sync_positions_batch(all_positions)
|
||||
logger.info(f"从 {redis_key} 解析到 {len(accounts_dict)} 个账号")
|
||||
|
||||
if success:
|
||||
elapsed = time.time() - start_time
|
||||
self.stats['position'] = {
|
||||
'accounts': len(accounts),
|
||||
'positions': stats['total'],
|
||||
'time': elapsed
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量同步持仓失败: {e}")
|
||||
self.stats['position'] = {'accounts': 0, 'positions': 0, 'time': 0}
|
||||
|
||||
async def _sync_orders_batch(self, accounts: Dict[str, Dict]):
|
||||
"""批量同步订单数据"""
|
||||
logger.error(f"获取计算机名 {computer_name} 的账号失败: {e}")
|
||||
|
||||
return accounts_dict
|
||||
|
||||
def _discover_all_accounts(self) -> Dict[str, Dict]:
|
||||
"""自动发现所有匹配的账号key"""
|
||||
accounts_dict = {}
|
||||
discovered_keys = []
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
# 获取所有匹配模式的key
|
||||
pattern = "*_strategy_api"
|
||||
cursor = 0
|
||||
|
||||
# 收集所有订单数据
|
||||
order_sync = next((s for s in self.syncers if isinstance(s, OrderSyncBatch)), None)
|
||||
if not order_sync:
|
||||
return
|
||||
|
||||
all_orders = await order_sync._collect_all_orders(accounts)
|
||||
|
||||
if not all_orders:
|
||||
self.stats['order'] = {'accounts': 0, 'orders': 0, 'time': 0}
|
||||
return
|
||||
|
||||
# 使用批量工具同步
|
||||
batch_tool = self.batch_tools['order']
|
||||
success, processed_count = batch_tool.sync_orders_batch(all_orders)
|
||||
|
||||
if success:
|
||||
elapsed = time.time() - start_time
|
||||
self.stats['order'] = {
|
||||
'accounts': len(accounts),
|
||||
'orders': processed_count,
|
||||
'time': elapsed
|
||||
}
|
||||
while True:
|
||||
cursor, keys = self.redis_client.client.scan(cursor, match=pattern, count=100)
|
||||
|
||||
for key in keys:
|
||||
key_str = key.decode('utf-8') if isinstance(key, bytes) else key
|
||||
discovered_keys.append(key_str)
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
logger.info(f"自动发现 {len(discovered_keys)} 个策略API key")
|
||||
|
||||
# 处理每个发现的key
|
||||
for key_str in discovered_keys:
|
||||
# 提取计算机名
|
||||
computer_name = key_str.replace('_strategy_api', '')
|
||||
|
||||
# 验证计算机名格式
|
||||
if self.computer_name_pattern.match(computer_name):
|
||||
accounts = self._get_accounts_by_computer_name(computer_name)
|
||||
accounts_dict.update(accounts)
|
||||
else:
|
||||
logger.warning(f"跳过不符合格式的计算机名: {computer_name}")
|
||||
|
||||
logger.info(f"自动发现共获取到 {len(accounts_dict)} 个账号")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量同步订单失败: {e}")
|
||||
self.stats['order'] = {'accounts': 0, 'orders': 0, 'time': 0}
|
||||
logger.error(f"自动发现账号失败: {e}")
|
||||
|
||||
return accounts_dict
|
||||
|
||||
async def _sync_accounts_batch(self, accounts: Dict[str, Dict]):
|
||||
"""批量同步账户信息数据"""
|
||||
def _discover_all_accounts(self) -> Dict[str, Dict]:
|
||||
"""自动发现所有匹配的账号key"""
|
||||
accounts_dict = {}
|
||||
discovered_keys = []
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
# 获取所有匹配模式的key
|
||||
pattern = "*_strategy_api"
|
||||
cursor = 0
|
||||
|
||||
# 收集所有账户数据
|
||||
account_sync = next((s for s in self.syncers if isinstance(s, AccountSyncBatch)), None)
|
||||
if not account_sync:
|
||||
return
|
||||
while True:
|
||||
cursor, keys = self.redis_client.client.scan(cursor, match=pattern, count=100)
|
||||
|
||||
for key in keys:
|
||||
key_str = key.decode('utf-8') if isinstance(key, bytes) else key
|
||||
discovered_keys.append(key_str)
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
all_account_data = await account_sync._collect_all_account_data(accounts)
|
||||
logger.info(f"自动发现 {len(discovered_keys)} 个策略API key")
|
||||
|
||||
if not all_account_data:
|
||||
self.stats['account'] = {'accounts': 0, 'records': 0, 'time': 0}
|
||||
return
|
||||
# 处理每个发现的key
|
||||
for key_str in discovered_keys:
|
||||
# 提取计算机名
|
||||
computer_name = key_str.replace('_strategy_api', '')
|
||||
|
||||
# 验证计算机名格式
|
||||
if self.computer_name_pattern.match(computer_name):
|
||||
accounts = self._get_accounts_by_computer_name(computer_name)
|
||||
accounts_dict.update(accounts)
|
||||
else:
|
||||
logger.warning(f"跳过不符合格式的计算机名: {computer_name}")
|
||||
|
||||
# 使用批量工具同步
|
||||
batch_tool = self.batch_tools['account']
|
||||
updated, inserted = batch_tool.sync_accounts_batch(all_account_data)
|
||||
logger.info(f"自动发现共获取到 {len(accounts_dict)} 个账号")
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
self.stats['account'] = {
|
||||
'accounts': len(accounts),
|
||||
'records': len(all_account_data),
|
||||
'time': elapsed
|
||||
except Exception as e:
|
||||
logger.error(f"自动发现账号失败: {e}")
|
||||
|
||||
return accounts_dict
|
||||
|
||||
def format_exchange_id(self, key: str) -> str:
|
||||
"""格式化交易所ID"""
|
||||
key = key.lower().strip()
|
||||
|
||||
# 交易所名称映射
|
||||
exchange_mapping = {
|
||||
'metatrader': 'mt5',
|
||||
'binance_spot_test': 'binance',
|
||||
'binance_spot': 'binance',
|
||||
'binance': 'binance',
|
||||
'gate_spot': 'gate',
|
||||
'okex': 'okx',
|
||||
'okx': 'okx',
|
||||
'bybit': 'bybit',
|
||||
'bybit_spot': 'bybit',
|
||||
'bybit_test': 'bybit',
|
||||
'huobi': 'huobi',
|
||||
'huobi_spot': 'huobi',
|
||||
'gate': 'gate',
|
||||
'gateio': 'gate',
|
||||
'kucoin': 'kucoin',
|
||||
'kucoin_spot': 'kucoin',
|
||||
'mexc': 'mexc',
|
||||
'mexc_spot': 'mexc',
|
||||
'bitget': 'bitget',
|
||||
'bitget_spot': 'bitget'
|
||||
}
|
||||
|
||||
normalized_key = exchange_mapping.get(key, key)
|
||||
|
||||
# 记录未映射的交易所
|
||||
if normalized_key == key and key not in exchange_mapping.values():
|
||||
logger.debug(f"未映射的交易所名称: {key}")
|
||||
|
||||
return normalized_key
|
||||
|
||||
def parse_account(self, exchange_id: str, account_id: str, account_info: str) -> Optional[Dict]:
|
||||
"""解析账号信息"""
|
||||
try:
|
||||
source_account_info = json.loads(account_info)
|
||||
|
||||
# 基础信息
|
||||
account_data = {
|
||||
'exchange_id': exchange_id,
|
||||
'k_id': account_id,
|
||||
'st_id': self._safe_int(source_account_info.get('st_id'), 0),
|
||||
'add_time': self._safe_int(source_account_info.get('add_time'), 0),
|
||||
'account_type': source_account_info.get('account_type', 'real'),
|
||||
'api_key': source_account_info.get('api_key', ''),
|
||||
'secret_key': source_account_info.get('secret_key', ''),
|
||||
'password': source_account_info.get('password', ''),
|
||||
'access_token': source_account_info.get('access_token', ''),
|
||||
'remark': source_account_info.get('remark', '')
|
||||
}
|
||||
|
||||
# 合并原始信息
|
||||
result = {**source_account_info, **account_data}
|
||||
|
||||
# 验证必要字段
|
||||
if not result.get('st_id') or not result.get('exchange_id'):
|
||||
logger.warning(f"账号 {account_id} 缺少必要字段: st_id={result.get('st_id')}, exchange_id={result.get('exchange_id')}")
|
||||
return None
|
||||
|
||||
return result
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"解析账号 {account_id} JSON数据失败: {e}, 原始数据: {account_info[:100]}...")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"批量同步账户信息失败: {e}")
|
||||
self.stats['account'] = {'accounts': 0, 'records': 0, 'time': 0}
|
||||
logger.error(f"处理账号 {account_id} 数据异常: {e}")
|
||||
return None
|
||||
|
||||
def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]:
|
||||
"""按交易所分组账号"""
|
||||
groups = {}
|
||||
for account_id, account_info in accounts.items():
|
||||
exchange_id = account_info.get('exchange_id')
|
||||
if exchange_id:
|
||||
if exchange_id not in groups:
|
||||
groups[exchange_id] = []
|
||||
groups[exchange_id].append(account_info)
|
||||
return groups
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def _update_stats(self, sync_time: float):
|
||||
"""更新统计信息"""
|
||||
self.stats['last_sync_time'] = sync_time
|
||||
|
||||
Reference in New Issue
Block a user