# sync/base_sync.py from abc import ABC, abstractmethod from loguru import logger from typing import List, Dict, Any, Set, Optional import json import re import time from utils.redis_client import RedisClient from utils.database_manager import DatabaseManager from config.settings import COMPUTER_NAMES, COMPUTER_NAME_PATTERN class BaseSync(ABC): """同步基类""" def __init__(self): self.redis_client = RedisClient() self.db_manager = DatabaseManager() self.computer_names = self._get_computer_names() self.computer_name_pattern = re.compile(COMPUTER_NAME_PATTERN) self.sync_stats = { 'total_accounts': 0, 'success_count': 0, 'error_count': 0, 'last_sync_time': 0, 'avg_sync_time': 0 } def _get_computer_names(self) -> List[str]: """获取计算机名列表""" if ',' in COMPUTER_NAMES: names = [name.strip() for name in COMPUTER_NAMES.split(',')] logger.info(f"使用配置的计算机名列表: {names}") return names return [COMPUTER_NAMES.strip()] @abstractmethod async def sync(self): """执行同步(兼容旧接口)""" pass @abstractmethod async def sync_batch(self, accounts: Dict[str, Dict]): """批量同步数据""" pass def get_accounts_from_redis(self) -> Dict[str, Dict]: """从Redis获取所有计算机名的账号配置""" try: accounts_dict = {} total_keys_processed = 0 # 方法1:使用配置的计算机名列表 for computer_name in self.computer_names: accounts = self._get_accounts_by_computer_name(computer_name) total_keys_processed += 1 accounts_dict.update(accounts) # 方法2:如果配置的计算机名没有数据,尝试自动发现(备用方案) if not accounts_dict: logger.warning("配置的计算机名未找到数据,尝试自动发现...") accounts_dict = self._discover_all_accounts() self.sync_stats['total_accounts'] = len(accounts_dict) logger.info(f"从 {len(self.computer_names)} 个计算机名获取到 {len(accounts_dict)} 个账号") return accounts_dict except Exception as e: logger.error(f"获取账户信息失败: {e}") return {} def _get_accounts_by_computer_name(self, computer_name: str) -> Dict[str, Dict]: """获取指定计算机名的账号""" accounts_dict = {} try: # 构建key redis_key = f"{computer_name}_strategy_api" # 从Redis获取数据 result = self.redis_client.client.hgetall(redis_key) if not result: logger.debug(f"未找到 {redis_key} 的策略API配置") return {} logger.info(f"从 {redis_key} 获取到 {len(result)} 个交易所配置") for exchange_name, accounts_json in result.items(): try: accounts = json.loads(accounts_json) if not accounts: continue # 格式化交易所ID exchange_id = self.format_exchange_id(exchange_name) for account_id, account_info in accounts.items(): parsed_account = self.parse_account(exchange_id, account_id, account_info) if parsed_account: # 添加计算机名标记 parsed_account['computer_name'] = computer_name accounts_dict[account_id] = parsed_account except json.JSONDecodeError as e: logger.error(f"解析交易所 {exchange_name} 的JSON数据失败: {e}") continue except Exception as e: logger.error(f"处理交易所 {exchange_name} 数据异常: {e}") continue logger.info(f"从 {redis_key} 解析到 {len(accounts_dict)} 个账号") except Exception as e: logger.error(f"获取计算机名 {computer_name} 的账号失败: {e}") return accounts_dict def _discover_all_accounts(self) -> Dict[str, Dict]: """自动发现所有匹配的账号key""" accounts_dict = {} discovered_keys = [] try: # 获取所有匹配模式的key pattern = "*_strategy_api" cursor = 0 while True: cursor, keys = self.redis_client.client.scan(cursor, match=pattern, count=100) for key in keys: key_str = key.decode('utf-8') if isinstance(key, bytes) else key discovered_keys.append(key_str) if cursor == 0: break logger.info(f"自动发现 {len(discovered_keys)} 个策略API key") # 处理每个发现的key for key_str in discovered_keys: # 提取计算机名 computer_name = key_str.replace('_strategy_api', '') # 验证计算机名格式 if self.computer_name_pattern.match(computer_name): accounts = self._get_accounts_by_computer_name(computer_name) accounts_dict.update(accounts) else: logger.warning(f"跳过不符合格式的计算机名: {computer_name}") logger.info(f"自动发现共获取到 {len(accounts_dict)} 个账号") except Exception as e: logger.error(f"自动发现账号失败: {e}") return accounts_dict def format_exchange_id(self, key: str) -> str: """格式化交易所ID""" key = key.lower().strip() # 交易所名称映射 exchange_mapping = { 'metatrader': 'mt5', 'binance_spot_test': 'binance', 'binance_spot': 'binance', 'binance': 'binance', 'gate_spot': 'gate', 'okex': 'okx', 'okx': 'okx', 'bybit': 'bybit', 'bybit_spot': 'bybit', 'bybit_test': 'bybit', 'huobi': 'huobi', 'huobi_spot': 'huobi', 'gate': 'gate', 'gateio': 'gate', 'kucoin': 'kucoin', 'kucoin_spot': 'kucoin', 'mexc': 'mexc', 'mexc_spot': 'mexc', 'bitget': 'bitget', 'bitget_spot': 'bitget' } normalized_key = exchange_mapping.get(key, key) # 记录未映射的交易所 if normalized_key == key and key not in exchange_mapping.values(): logger.debug(f"未映射的交易所名称: {key}") return normalized_key def parse_account(self, exchange_id: str, account_id: str, account_info: str) -> Optional[Dict]: """解析账号信息""" try: source_account_info = json.loads(account_info) # 基础信息 account_data = { 'exchange_id': exchange_id, 'k_id': account_id, 'st_id': self._safe_int(source_account_info.get('st_id'), 0), 'add_time': self._safe_int(source_account_info.get('add_time'), 0), 'account_type': source_account_info.get('account_type', 'real'), 'api_key': source_account_info.get('api_key', ''), 'secret_key': source_account_info.get('secret_key', ''), 'password': source_account_info.get('password', ''), 'access_token': source_account_info.get('access_token', ''), 'remark': source_account_info.get('remark', '') } # MT5特殊处理 if exchange_id == 'mt5': # 解析服务器地址和端口 server_info = source_account_info.get('secret_key', '') if ':' in server_info: host, port = server_info.split(':', 1) account_data['mt5_host'] = host account_data['mt5_port'] = self._safe_int(port, 0) # 合并原始信息 result = {**source_account_info, **account_data} # 验证必要字段 if not result.get('st_id') or not result.get('exchange_id'): logger.warning(f"账号 {account_id} 缺少必要字段: st_id={result.get('st_id')}, exchange_id={result.get('exchange_id')}") return None return result except json.JSONDecodeError as e: logger.error(f"解析账号 {account_id} JSON数据失败: {e}, 原始数据: {account_info[:100]}...") return None except Exception as e: logger.error(f"处理账号 {account_id} 数据异常: {e}") return None def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]: """按交易所分组账号""" groups = {} for account_id, account_info in accounts.items(): exchange_id = account_info.get('exchange_id') if exchange_id: if exchange_id not in groups: groups[exchange_id] = [] groups[exchange_id].append(account_info) return groups def _safe_float(self, value: Any, default: float = 0.0) -> float: """安全转换为float""" if value is None: return default try: if isinstance(value, str): value = value.strip() if value == '': return default return float(value) except (ValueError, TypeError): return default def _safe_int(self, value: Any, default: int = 0) -> int: """安全转换为int""" if value is None: return default try: if isinstance(value, str): value = value.strip() if value == '': return default return int(float(value)) except (ValueError, TypeError): return default def _safe_str(self, value: Any, default: str = '') -> str: """安全转换为str""" if value is None: return default try: result = str(value).strip() return result if result else default except: return default def _escape_sql_value(self, value: Any) -> str: """转义SQL值""" if value is None: return 'NULL' if isinstance(value, bool): return '1' if value else '0' if isinstance(value, (int, float)): return str(value) if isinstance(value, str): # 转义单引号 escaped = value.replace("'", "''") return f"'{escaped}'" # 其他类型转换为字符串 escaped = str(value).replace("'", "''") return f"'{escaped}'" def _build_sql_values_list(self, data_list: List[Dict], fields_mapping: Dict[str, str] = None) -> List[str]: """构建SQL VALUES列表""" values_list = [] for data in data_list: try: value_parts = [] for field, value in data.items(): # 应用字段映射 if fields_mapping and field in fields_mapping: db_field = fields_mapping[field] else: db_field = field escaped_value = self._escape_sql_value(value) value_parts.append(escaped_value) values_str = ", ".join(value_parts) values_list.append(f"({values_str})") except Exception as e: logger.error(f"构建SQL值失败: {data}, error={e}") continue return values_list def _get_recent_dates(self, days: int) -> List[str]: """获取最近N天的日期列表""" from datetime import datetime, timedelta dates = [] today = datetime.now() for i in range(days): date = today - timedelta(days=i) dates.append(date.strftime('%Y-%m-%d')) return dates def _date_to_timestamp(self, date_str: str) -> int: """将日期字符串转换为时间戳(当天0点)""" from datetime import datetime try: dt = datetime.strptime(date_str, '%Y-%m-%d') return int(dt.timestamp()) except ValueError: return 0 def update_stats(self, success: bool = True, sync_time: float = 0): """更新统计信息""" if success: self.sync_stats['success_count'] += 1 else: self.sync_stats['error_count'] += 1 if sync_time > 0: self.sync_stats['last_sync_time'] = sync_time # 计算平均时间(滑动平均) if self.sync_stats['avg_sync_time'] == 0: self.sync_stats['avg_sync_time'] = sync_time else: self.sync_stats['avg_sync_time'] = ( self.sync_stats['avg_sync_time'] * 0.9 + sync_time * 0.1 ) def print_stats(self, sync_type: str = ""): """打印统计信息""" stats = self.sync_stats prefix = f"[{sync_type}] " if sync_type else "" stats_str = ( f"{prefix}统计: 账号数={stats['total_accounts']}, " f"成功={stats['success_count']}, 失败={stats['error_count']}, " f"本次耗时={stats['last_sync_time']:.2f}s, " f"平均耗时={stats['avg_sync_time']:.2f}s" ) if stats['error_count'] > 0: logger.warning(stats_str) else: logger.info(stats_str) def reset_stats(self): """重置统计信息""" self.sync_stats = { 'total_accounts': 0, 'success_count': 0, 'error_count': 0, 'last_sync_time': 0, 'avg_sync_time': 0 }