Files
exchange_monitor_sync/sync/base_sync.py
lz_db 803d40b88e 1
2025-12-03 14:40:14 +08:00

395 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# sync/base_sync.py
from abc import ABC, abstractmethod
from loguru import logger
from typing import List, Dict, Any, Set, Optional
import json
import re
import time
from utils.redis_client import RedisClient
from utils.database_manager import DatabaseManager
from config.settings import COMPUTER_NAMES, COMPUTER_NAME_PATTERN
class BaseSync(ABC):
"""同步基类"""
def __init__(self):
self.redis_client = RedisClient()
self.db_manager = DatabaseManager()
self.computer_names = self._get_computer_names()
self.computer_name_pattern = re.compile(COMPUTER_NAME_PATTERN)
self.sync_stats = {
'total_accounts': 0,
'success_count': 0,
'error_count': 0,
'last_sync_time': 0,
'avg_sync_time': 0
}
def _get_computer_names(self) -> List[str]:
"""获取计算机名列表"""
if ',' in COMPUTER_NAMES:
names = [name.strip() for name in COMPUTER_NAMES.split(',')]
logger.info(f"使用配置的计算机名列表: {names}")
return names
return [COMPUTER_NAMES.strip()]
@abstractmethod
async def sync(self):
"""执行同步(兼容旧接口)"""
pass
@abstractmethod
async def sync_batch(self, accounts: Dict[str, Dict]):
"""批量同步数据"""
pass
def get_accounts_from_redis(self) -> Dict[str, Dict]:
"""从Redis获取所有计算机名的账号配置"""
try:
accounts_dict = {}
total_keys_processed = 0
# 方法1使用配置的计算机名列表
for computer_name in self.computer_names:
accounts = self._get_accounts_by_computer_name(computer_name)
total_keys_processed += 1
accounts_dict.update(accounts)
# 方法2如果配置的计算机名没有数据尝试自动发现备用方案
if not accounts_dict:
logger.warning("配置的计算机名未找到数据,尝试自动发现...")
accounts_dict = self._discover_all_accounts()
self.sync_stats['total_accounts'] = len(accounts_dict)
logger.info(f"{len(self.computer_names)} 个计算机名获取到 {len(accounts_dict)} 个账号")
return accounts_dict
except Exception as e:
logger.error(f"获取账户信息失败: {e}")
return {}
def _get_accounts_by_computer_name(self, computer_name: str) -> Dict[str, Dict]:
"""获取指定计算机名的账号"""
accounts_dict = {}
try:
# 构建key
redis_key = f"{computer_name}_strategy_api"
# 从Redis获取数据
result = self.redis_client.client.hgetall(redis_key)
if not result:
logger.debug(f"未找到 {redis_key} 的策略API配置")
return {}
logger.info(f"{redis_key} 获取到 {len(result)} 个交易所配置")
for exchange_name, accounts_json in result.items():
try:
accounts = json.loads(accounts_json)
if not accounts:
continue
# 格式化交易所ID
exchange_id = self.format_exchange_id(exchange_name)
for account_id, account_info in accounts.items():
parsed_account = self.parse_account(exchange_id, account_id, account_info)
if parsed_account:
# 添加计算机名标记
parsed_account['computer_name'] = computer_name
accounts_dict[account_id] = parsed_account
except json.JSONDecodeError as e:
logger.error(f"解析交易所 {exchange_name} 的JSON数据失败: {e}")
continue
except Exception as e:
logger.error(f"处理交易所 {exchange_name} 数据异常: {e}")
continue
logger.info(f"{redis_key} 解析到 {len(accounts_dict)} 个账号")
except Exception as e:
logger.error(f"获取计算机名 {computer_name} 的账号失败: {e}")
return accounts_dict
def _discover_all_accounts(self) -> Dict[str, Dict]:
"""自动发现所有匹配的账号key"""
accounts_dict = {}
discovered_keys = []
try:
# 获取所有匹配模式的key
pattern = "*_strategy_api"
cursor = 0
while True:
cursor, keys = self.redis_client.client.scan(cursor, match=pattern, count=100)
for key in keys:
key_str = key.decode('utf-8') if isinstance(key, bytes) else key
discovered_keys.append(key_str)
if cursor == 0:
break
logger.info(f"自动发现 {len(discovered_keys)} 个策略API key")
# 处理每个发现的key
for key_str in discovered_keys:
# 提取计算机名
computer_name = key_str.replace('_strategy_api', '')
# 验证计算机名格式
if self.computer_name_pattern.match(computer_name):
accounts = self._get_accounts_by_computer_name(computer_name)
accounts_dict.update(accounts)
else:
logger.warning(f"跳过不符合格式的计算机名: {computer_name}")
logger.info(f"自动发现共获取到 {len(accounts_dict)} 个账号")
except Exception as e:
logger.error(f"自动发现账号失败: {e}")
return accounts_dict
def format_exchange_id(self, key: str) -> str:
"""格式化交易所ID"""
key = key.lower().strip()
# 交易所名称映射
exchange_mapping = {
'metatrader': 'mt5',
'binance_spot_test': 'binance',
'binance_spot': 'binance',
'binance': 'binance',
'gate_spot': 'gate',
'okex': 'okx',
'okx': 'okx',
'bybit': 'bybit',
'bybit_spot': 'bybit',
'bybit_test': 'bybit',
'huobi': 'huobi',
'huobi_spot': 'huobi',
'gate': 'gate',
'gateio': 'gate',
'kucoin': 'kucoin',
'kucoin_spot': 'kucoin',
'mexc': 'mexc',
'mexc_spot': 'mexc',
'bitget': 'bitget',
'bitget_spot': 'bitget'
}
normalized_key = exchange_mapping.get(key, key)
# 记录未映射的交易所
if normalized_key == key and key not in exchange_mapping.values():
logger.debug(f"未映射的交易所名称: {key}")
return normalized_key
def parse_account(self, exchange_id: str, account_id: str, account_info: str) -> Optional[Dict]:
"""解析账号信息"""
try:
source_account_info = json.loads(account_info)
# 基础信息
account_data = {
'exchange_id': exchange_id,
'k_id': account_id,
'st_id': self._safe_int(source_account_info.get('st_id'), 0),
'add_time': self._safe_int(source_account_info.get('add_time'), 0),
'account_type': source_account_info.get('account_type', 'real'),
'api_key': source_account_info.get('api_key', ''),
'secret_key': source_account_info.get('secret_key', ''),
'password': source_account_info.get('password', ''),
'access_token': source_account_info.get('access_token', ''),
'remark': source_account_info.get('remark', '')
}
# MT5特殊处理
if exchange_id == 'mt5':
# 解析服务器地址和端口
server_info = source_account_info.get('secret_key', '')
if ':' in server_info:
host, port = server_info.split(':', 1)
account_data['mt5_host'] = host
account_data['mt5_port'] = self._safe_int(port, 0)
# 合并原始信息
result = {**source_account_info, **account_data}
# 验证必要字段
if not result.get('st_id') or not result.get('exchange_id'):
logger.warning(f"账号 {account_id} 缺少必要字段: st_id={result.get('st_id')}, exchange_id={result.get('exchange_id')}")
return None
return result
except json.JSONDecodeError as e:
logger.error(f"解析账号 {account_id} JSON数据失败: {e}, 原始数据: {account_info[:100]}...")
return None
except Exception as e:
logger.error(f"处理账号 {account_id} 数据异常: {e}")
return None
def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]:
"""按交易所分组账号"""
groups = {}
for account_id, account_info in accounts.items():
exchange_id = account_info.get('exchange_id')
if exchange_id:
if exchange_id not in groups:
groups[exchange_id] = []
groups[exchange_id].append(account_info)
return groups
def _safe_float(self, value: Any, default: float = 0.0) -> float:
"""安全转换为float"""
if value is None:
return default
try:
if isinstance(value, str):
value = value.strip()
if value == '':
return default
return float(value)
except (ValueError, TypeError):
return default
def _safe_int(self, value: Any, default: int = 0) -> int:
"""安全转换为int"""
if value is None:
return default
try:
if isinstance(value, str):
value = value.strip()
if value == '':
return default
return int(float(value))
except (ValueError, TypeError):
return default
def _safe_str(self, value: Any, default: str = '') -> str:
"""安全转换为str"""
if value is None:
return default
try:
result = str(value).strip()
return result if result else default
except:
return default
def _escape_sql_value(self, value: Any) -> str:
"""转义SQL值"""
if value is None:
return 'NULL'
if isinstance(value, bool):
return '1' if value else '0'
if isinstance(value, (int, float)):
return str(value)
if isinstance(value, str):
# 转义单引号
escaped = value.replace("'", "''")
return f"'{escaped}'"
# 其他类型转换为字符串
escaped = str(value).replace("'", "''")
return f"'{escaped}'"
def _build_sql_values_list(self, data_list: List[Dict], fields_mapping: Dict[str, str] = None) -> List[str]:
"""构建SQL VALUES列表"""
values_list = []
for data in data_list:
try:
value_parts = []
for field, value in data.items():
# 应用字段映射
if fields_mapping and field in fields_mapping:
db_field = fields_mapping[field]
else:
db_field = field
escaped_value = self._escape_sql_value(value)
value_parts.append(escaped_value)
values_str = ", ".join(value_parts)
values_list.append(f"({values_str})")
except Exception as e:
logger.error(f"构建SQL值失败: {data}, error={e}")
continue
return values_list
def _get_recent_dates(self, days: int) -> List[str]:
"""获取最近N天的日期列表"""
from datetime import datetime, timedelta
dates = []
today = datetime.now()
for i in range(days):
date = today - timedelta(days=i)
dates.append(date.strftime('%Y-%m-%d'))
return dates
def _date_to_timestamp(self, date_str: str) -> int:
"""将日期字符串转换为时间戳当天0点"""
from datetime import datetime
try:
dt = datetime.strptime(date_str, '%Y-%m-%d')
return int(dt.timestamp())
except ValueError:
return 0
def update_stats(self, success: bool = True, sync_time: float = 0):
"""更新统计信息"""
if success:
self.sync_stats['success_count'] += 1
else:
self.sync_stats['error_count'] += 1
if sync_time > 0:
self.sync_stats['last_sync_time'] = sync_time
# 计算平均时间(滑动平均)
if self.sync_stats['avg_sync_time'] == 0:
self.sync_stats['avg_sync_time'] = sync_time
else:
self.sync_stats['avg_sync_time'] = (
self.sync_stats['avg_sync_time'] * 0.9 + sync_time * 0.1
)
def print_stats(self, sync_type: str = ""):
"""打印统计信息"""
stats = self.sync_stats
prefix = f"[{sync_type}] " if sync_type else ""
stats_str = (
f"{prefix}统计: 账号数={stats['total_accounts']}, "
f"成功={stats['success_count']}, 失败={stats['error_count']}, "
f"本次耗时={stats['last_sync_time']:.2f}s, "
f"平均耗时={stats['avg_sync_time']:.2f}s"
)
if stats['error_count'] > 0:
logger.warning(stats_str)
else:
logger.info(stats_str)
def reset_stats(self):
"""重置统计信息"""
self.sync_stats = {
'total_accounts': 0,
'success_count': 0,
'error_count': 0,
'last_sync_time': 0,
'avg_sync_time': 0
}