Files
exchange_monitor_sync/sync/base_sync.py

395 lines
15 KiB
Python
Raw Normal View History

2025-12-03 14:40:14 +08:00
# sync/base_sync.py
2025-12-02 22:05:54 +08:00
from abc import ABC, abstractmethod
from loguru import logger
2025-12-03 14:40:14 +08:00
from typing import List, Dict, Any, Set, Optional
2025-12-02 22:05:54 +08:00
import json
2025-12-02 22:36:52 +08:00
import re
2025-12-03 14:40:14 +08:00
import time
2025-12-02 22:05:54 +08:00
from utils.redis_client import RedisClient
from utils.database_manager import DatabaseManager
2025-12-02 22:36:52 +08:00
from config.settings import COMPUTER_NAMES, COMPUTER_NAME_PATTERN
2025-12-02 22:05:54 +08:00
class BaseSync(ABC):
"""同步基类"""
def __init__(self):
self.redis_client = RedisClient()
self.db_manager = DatabaseManager()
2025-12-02 22:36:52 +08:00
self.computer_names = self._get_computer_names()
self.computer_name_pattern = re.compile(COMPUTER_NAME_PATTERN)
2025-12-03 14:40:14 +08:00
self.sync_stats = {
'total_accounts': 0,
'success_count': 0,
'error_count': 0,
'last_sync_time': 0,
'avg_sync_time': 0
}
2025-12-02 22:05:54 +08:00
2025-12-02 22:36:52 +08:00
def _get_computer_names(self) -> List[str]:
"""获取计算机名列表"""
if ',' in COMPUTER_NAMES:
2025-12-03 14:40:14 +08:00
names = [name.strip() for name in COMPUTER_NAMES.split(',')]
logger.info(f"使用配置的计算机名列表: {names}")
return names
2025-12-02 22:36:52 +08:00
return [COMPUTER_NAMES.strip()]
2025-12-02 22:05:54 +08:00
2025-12-03 14:40:14 +08:00
@abstractmethod
async def sync(self):
"""执行同步(兼容旧接口)"""
pass
@abstractmethod
async def sync_batch(self, accounts: Dict[str, Dict]):
"""批量同步数据"""
pass
2025-12-02 22:05:54 +08:00
def get_accounts_from_redis(self) -> Dict[str, Dict]:
2025-12-02 22:36:52 +08:00
"""从Redis获取所有计算机名的账号配置"""
2025-12-02 22:05:54 +08:00
try:
2025-12-02 22:36:52 +08:00
accounts_dict = {}
2025-12-03 14:40:14 +08:00
total_keys_processed = 0
2025-12-02 22:36:52 +08:00
# 方法1使用配置的计算机名列表
for computer_name in self.computer_names:
accounts = self._get_accounts_by_computer_name(computer_name)
2025-12-03 14:40:14 +08:00
total_keys_processed += 1
2025-12-02 22:36:52 +08:00
accounts_dict.update(accounts)
2025-12-03 14:40:14 +08:00
# 方法2如果配置的计算机名没有数据尝试自动发现备用方案
2025-12-02 22:36:52 +08:00
if not accounts_dict:
2025-12-03 14:40:14 +08:00
logger.warning("配置的计算机名未找到数据,尝试自动发现...")
2025-12-02 22:36:52 +08:00
accounts_dict = self._discover_all_accounts()
2025-12-03 14:40:14 +08:00
self.sync_stats['total_accounts'] = len(accounts_dict)
2025-12-02 22:36:52 +08:00
logger.info(f"{len(self.computer_names)} 个计算机名获取到 {len(accounts_dict)} 个账号")
2025-12-03 14:40:14 +08:00
2025-12-02 22:36:52 +08:00
return accounts_dict
except Exception as e:
logger.error(f"获取账户信息失败: {e}")
return {}
def _get_accounts_by_computer_name(self, computer_name: str) -> Dict[str, Dict]:
"""获取指定计算机名的账号"""
accounts_dict = {}
try:
# 构建key
redis_key = f"{computer_name}_strategy_api"
2025-12-02 22:05:54 +08:00
# 从Redis获取数据
2025-12-02 22:36:52 +08:00
result = self.redis_client.client.hgetall(redis_key)
2025-12-02 22:05:54 +08:00
if not result:
2025-12-02 22:36:52 +08:00
logger.debug(f"未找到 {redis_key} 的策略API配置")
2025-12-02 22:05:54 +08:00
return {}
2025-12-03 14:40:14 +08:00
logger.info(f"{redis_key} 获取到 {len(result)} 个交易所配置")
2025-12-02 22:05:54 +08:00
for exchange_name, accounts_json in result.items():
try:
accounts = json.loads(accounts_json)
if not accounts:
continue
# 格式化交易所ID
exchange_id = self.format_exchange_id(exchange_name)
for account_id, account_info in accounts.items():
parsed_account = self.parse_account(exchange_id, account_id, account_info)
if parsed_account:
2025-12-02 22:36:52 +08:00
# 添加计算机名标记
parsed_account['computer_name'] = computer_name
2025-12-02 22:05:54 +08:00
accounts_dict[account_id] = parsed_account
except json.JSONDecodeError as e:
logger.error(f"解析交易所 {exchange_name} 的JSON数据失败: {e}")
continue
2025-12-03 14:40:14 +08:00
except Exception as e:
logger.error(f"处理交易所 {exchange_name} 数据异常: {e}")
continue
2025-12-02 22:05:54 +08:00
2025-12-03 14:40:14 +08:00
logger.info(f"{redis_key} 解析到 {len(accounts_dict)} 个账号")
2025-12-02 22:05:54 +08:00
except Exception as e:
2025-12-02 22:36:52 +08:00
logger.error(f"获取计算机名 {computer_name} 的账号失败: {e}")
2025-12-02 22:05:54 +08:00
2025-12-02 22:36:52 +08:00
return accounts_dict
2025-12-02 22:05:54 +08:00
2025-12-02 22:36:52 +08:00
def _discover_all_accounts(self) -> Dict[str, Dict]:
"""自动发现所有匹配的账号key"""
accounts_dict = {}
2025-12-03 14:40:14 +08:00
discovered_keys = []
2025-12-02 22:36:52 +08:00
2025-12-02 22:05:54 +08:00
try:
2025-12-02 22:36:52 +08:00
# 获取所有匹配模式的key
2025-12-03 14:40:14 +08:00
pattern = "*_strategy_api"
2025-12-02 22:36:52 +08:00
cursor = 0
2025-12-02 22:05:54 +08:00
2025-12-02 22:36:52 +08:00
while True:
cursor, keys = self.redis_client.client.scan(cursor, match=pattern, count=100)
for key in keys:
key_str = key.decode('utf-8') if isinstance(key, bytes) else key
2025-12-03 14:40:14 +08:00
discovered_keys.append(key_str)
2025-12-02 22:36:52 +08:00
if cursor == 0:
break
2025-12-03 14:40:14 +08:00
logger.info(f"自动发现 {len(discovered_keys)} 个策略API key")
# 处理每个发现的key
for key_str in discovered_keys:
# 提取计算机名
computer_name = key_str.replace('_strategy_api', '')
# 验证计算机名格式
if self.computer_name_pattern.match(computer_name):
accounts = self._get_accounts_by_computer_name(computer_name)
accounts_dict.update(accounts)
else:
logger.warning(f"跳过不符合格式的计算机名: {computer_name}")
logger.info(f"自动发现共获取到 {len(accounts_dict)} 个账号")
2025-12-02 22:36:52 +08:00
except Exception as e:
logger.error(f"自动发现账号失败: {e}")
return accounts_dict
2025-12-03 14:40:14 +08:00
def format_exchange_id(self, key: str) -> str:
"""格式化交易所ID"""
key = key.lower().strip()
# 交易所名称映射
exchange_mapping = {
'metatrader': 'mt5',
'binance_spot_test': 'binance',
'binance_spot': 'binance',
'binance': 'binance',
'gate_spot': 'gate',
'okex': 'okx',
'okx': 'okx',
'bybit': 'bybit',
'bybit_spot': 'bybit',
'bybit_test': 'bybit',
'huobi': 'huobi',
'huobi_spot': 'huobi',
'gate': 'gate',
'gateio': 'gate',
'kucoin': 'kucoin',
'kucoin_spot': 'kucoin',
'mexc': 'mexc',
'mexc_spot': 'mexc',
'bitget': 'bitget',
'bitget_spot': 'bitget'
}
normalized_key = exchange_mapping.get(key, key)
# 记录未映射的交易所
if normalized_key == key and key not in exchange_mapping.values():
logger.debug(f"未映射的交易所名称: {key}")
return normalized_key
def parse_account(self, exchange_id: str, account_id: str, account_info: str) -> Optional[Dict]:
"""解析账号信息"""
try:
source_account_info = json.loads(account_info)
# 基础信息
account_data = {
'exchange_id': exchange_id,
'k_id': account_id,
'st_id': self._safe_int(source_account_info.get('st_id'), 0),
'add_time': self._safe_int(source_account_info.get('add_time'), 0),
'account_type': source_account_info.get('account_type', 'real'),
'api_key': source_account_info.get('api_key', ''),
'secret_key': source_account_info.get('secret_key', ''),
'password': source_account_info.get('password', ''),
'access_token': source_account_info.get('access_token', ''),
'remark': source_account_info.get('remark', '')
}
# MT5特殊处理
if exchange_id == 'mt5':
# 解析服务器地址和端口
server_info = source_account_info.get('secret_key', '')
if ':' in server_info:
host, port = server_info.split(':', 1)
account_data['mt5_host'] = host
account_data['mt5_port'] = self._safe_int(port, 0)
# 合并原始信息
result = {**source_account_info, **account_data}
# 验证必要字段
if not result.get('st_id') or not result.get('exchange_id'):
logger.warning(f"账号 {account_id} 缺少必要字段: st_id={result.get('st_id')}, exchange_id={result.get('exchange_id')}")
return None
return result
except json.JSONDecodeError as e:
logger.error(f"解析账号 {account_id} JSON数据失败: {e}, 原始数据: {account_info[:100]}...")
return None
except Exception as e:
logger.error(f"处理账号 {account_id} 数据异常: {e}")
return None
def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]:
"""按交易所分组账号"""
groups = {}
for account_id, account_info in accounts.items():
exchange_id = account_info.get('exchange_id')
if exchange_id:
if exchange_id not in groups:
groups[exchange_id] = []
groups[exchange_id].append(account_info)
return groups
def _safe_float(self, value: Any, default: float = 0.0) -> float:
"""安全转换为float"""
if value is None:
return default
try:
if isinstance(value, str):
value = value.strip()
if value == '':
return default
return float(value)
except (ValueError, TypeError):
return default
def _safe_int(self, value: Any, default: int = 0) -> int:
"""安全转换为int"""
if value is None:
return default
try:
if isinstance(value, str):
value = value.strip()
if value == '':
return default
return int(float(value))
except (ValueError, TypeError):
return default
def _safe_str(self, value: Any, default: str = '') -> str:
"""安全转换为str"""
if value is None:
return default
try:
result = str(value).strip()
return result if result else default
except:
return default
def _escape_sql_value(self, value: Any) -> str:
"""转义SQL值"""
if value is None:
return 'NULL'
if isinstance(value, bool):
return '1' if value else '0'
if isinstance(value, (int, float)):
return str(value)
if isinstance(value, str):
# 转义单引号
escaped = value.replace("'", "''")
return f"'{escaped}'"
# 其他类型转换为字符串
escaped = str(value).replace("'", "''")
return f"'{escaped}'"
def _build_sql_values_list(self, data_list: List[Dict], fields_mapping: Dict[str, str] = None) -> List[str]:
"""构建SQL VALUES列表"""
values_list = []
for data in data_list:
try:
value_parts = []
for field, value in data.items():
# 应用字段映射
if fields_mapping and field in fields_mapping:
db_field = fields_mapping[field]
else:
db_field = field
escaped_value = self._escape_sql_value(value)
value_parts.append(escaped_value)
values_str = ", ".join(value_parts)
values_list.append(f"({values_str})")
except Exception as e:
logger.error(f"构建SQL值失败: {data}, error={e}")
continue
return values_list
def _get_recent_dates(self, days: int) -> List[str]:
"""获取最近N天的日期列表"""
from datetime import datetime, timedelta
dates = []
today = datetime.now()
for i in range(days):
date = today - timedelta(days=i)
dates.append(date.strftime('%Y-%m-%d'))
return dates
def _date_to_timestamp(self, date_str: str) -> int:
"""将日期字符串转换为时间戳当天0点"""
from datetime import datetime
try:
dt = datetime.strptime(date_str, '%Y-%m-%d')
return int(dt.timestamp())
except ValueError:
return 0
def update_stats(self, success: bool = True, sync_time: float = 0):
"""更新统计信息"""
if success:
self.sync_stats['success_count'] += 1
else:
self.sync_stats['error_count'] += 1
if sync_time > 0:
self.sync_stats['last_sync_time'] = sync_time
# 计算平均时间(滑动平均)
if self.sync_stats['avg_sync_time'] == 0:
self.sync_stats['avg_sync_time'] = sync_time
else:
self.sync_stats['avg_sync_time'] = (
self.sync_stats['avg_sync_time'] * 0.9 + sync_time * 0.1
)
def print_stats(self, sync_type: str = ""):
"""打印统计信息"""
stats = self.sync_stats
prefix = f"[{sync_type}] " if sync_type else ""
stats_str = (
f"{prefix}统计: 账号数={stats['total_accounts']}, "
f"成功={stats['success_count']}, 失败={stats['error_count']}, "
f"本次耗时={stats['last_sync_time']:.2f}s, "
f"平均耗时={stats['avg_sync_time']:.2f}s"
)
if stats['error_count'] > 0:
logger.warning(stats_str)
else:
logger.info(stats_str)
def reset_stats(self):
"""重置统计信息"""
self.sync_stats = {
'total_accounts': 0,
'success_count': 0,
'error_count': 0,
'last_sync_time': 0,
'avg_sync_time': 0
}