diff --git a/.env b/.env index 0c1c6aa..7396ac0 100644 --- a/.env +++ b/.env @@ -38,4 +38,19 @@ COMPUTER_NAMES=lz_c01,lz_c02,lz_c03 COMPUTER_NAME_PATTERN=^lz_c\d{2}$ # 并发配置 -MAX_CONCURRENT=10 \ No newline at end of file +MAX_CONCURRENT=10 + +# 订单同步配置 +ORDER_SYNC_RECENT_DAYS=3 +ORDER_BATCH_SIZE=1000 +ORDER_REDIS_SCAN_COUNT=1000 + +# 持仓同步配置 +POSITION_BATCH_SIZE=500 + +# 账户同步配置 +ACCOUNT_SYNC_RECENT_DAYS=3 + +# 并发控制 +MAX_CONCURRENT_ACCOUNTS=50 +REDIS_BATCH_SIZE=20 \ No newline at end of file diff --git a/sync/account_sync_batch.py b/sync/account_sync_batch.py new file mode 100644 index 0000000..d35ac04 --- /dev/null +++ b/sync/account_sync_batch.py @@ -0,0 +1,366 @@ +from .base_sync import BaseSync +from loguru import logger +from typing import List, Dict, Any, Set +import json +import time +from datetime import datetime, timedelta +from sqlalchemy import text, and_ +from models.orm_models import StrategyKX + +class AccountSyncBatch(BaseSync): + """账户信息批量同步器""" + + async def sync_batch(self, accounts: Dict[str, Dict]): + """批量同步所有账号的账户信息""" + try: + logger.info(f"开始批量同步账户信息,共 {len(accounts)} 个账号") + + # 收集所有账号的数据 + all_account_data = await self._collect_all_account_data(accounts) + + if not all_account_data: + logger.info("无账户信息数据需要同步") + return + + # 批量同步到数据库 + success = await self._sync_account_info_batch_to_db(all_account_data) + + if success: + logger.info(f"账户信息批量同步完成: 处理 {len(all_account_data)} 条记录") + else: + logger.error("账户信息批量同步失败") + + except Exception as e: + logger.error(f"账户信息批量同步失败: {e}") + + async def _collect_all_account_data(self, accounts: Dict[str, Dict]) -> List[Dict]: + """收集所有账号的账户信息数据""" + all_account_data = [] + + try: + # 按交易所分组账号 + account_groups = self._group_accounts_by_exchange(accounts) + + # 并发收集每个交易所的数据 + tasks = [] + for exchange_id, account_list in account_groups.items(): + task = self._collect_exchange_account_data(exchange_id, account_list) + tasks.append(task) + + # 等待所有任务完成并合并结果 + results = await asyncio.gather(*tasks, return_exceptions=True) + + for result in results: + if isinstance(result, list): + all_account_data.extend(result) + + logger.info(f"收集到 {len(all_account_data)} 条账户信息记录") + + except Exception as e: + logger.error(f"收集账户信息数据失败: {e}") + + return all_account_data + + def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]: + """按交易所分组账号""" + groups = {} + for account_id, account_info in accounts.items(): + exchange_id = account_info.get('exchange_id') + if exchange_id: + if exchange_id not in groups: + groups[exchange_id] = [] + groups[exchange_id].append(account_info) + return groups + + async def _collect_exchange_account_data(self, exchange_id: str, account_list: List[Dict]) -> List[Dict]: + """收集某个交易所的账户信息数据""" + account_data_list = [] + + try: + for account_info in account_list: + k_id = int(account_info['k_id']) + st_id = account_info.get('st_id', 0) + + # 从Redis获取账户信息数据 + account_data = await self._get_account_info_from_redis(k_id, st_id, exchange_id) + account_data_list.extend(account_data) + + logger.debug(f"交易所 {exchange_id}: 收集到 {len(account_data_list)} 条账户信息") + + except Exception as e: + logger.error(f"收集交易所 {exchange_id} 账户信息失败: {e}") + + return account_data_list + + async def _get_account_info_from_redis(self, k_id: int, st_id: int, exchange_id: str) -> List[Dict]: + """从Redis获取账户信息数据(批量优化版本)""" + try: + redis_key = f"{exchange_id}:balance:{k_id}" + redis_funds = self.redis_client.client.hgetall(redis_key) + + if not redis_funds: + return [] + + # 按天统计数据 + from config.settings import SYNC_CONFIG + recent_days = SYNC_CONFIG['recent_days'] + + today = datetime.now() + date_stats = {} + + # 收集所有日期的数据 + for fund_key, fund_json in redis_funds.items(): + try: + fund_data = json.loads(fund_json) + date_str = fund_data.get('lz_time', '') + lz_type = fund_data.get('lz_type', '') + + if not date_str or lz_type not in ['lz_balance', 'deposit', 'withdrawal']: + continue + + # 只处理最近N天的数据 + date_obj = datetime.strptime(date_str, '%Y-%m-%d') + if (today - date_obj).days > recent_days: + continue + + if date_str not in date_stats: + date_stats[date_str] = { + 'balance': 0.0, + 'deposit': 0.0, + 'withdrawal': 0.0, + 'has_balance': False + } + + lz_amount = float(fund_data.get('lz_amount', 0)) + + if lz_type == 'lz_balance': + date_stats[date_str]['balance'] = lz_amount + date_stats[date_str]['has_balance'] = True + elif lz_type == 'deposit': + date_stats[date_str]['deposit'] += lz_amount + elif lz_type == 'withdrawal': + date_stats[date_str]['withdrawal'] += lz_amount + + except (json.JSONDecodeError, ValueError) as e: + logger.debug(f"解析Redis数据失败: {fund_key}, error={e}") + continue + + # 转换为账户信息数据 + account_data_list = [] + sorted_dates = sorted(date_stats.keys()) + + # 获取前一天余额用于计算利润 + prev_balance_map = self._get_previous_balances(redis_funds, sorted_dates) + + for date_str in sorted_dates: + stats = date_stats[date_str] + + # 如果没有余额数据但有充提数据,仍然处理 + if not stats['has_balance'] and stats['deposit'] == 0 and stats['withdrawal'] == 0: + continue + + balance = stats['balance'] + deposit = stats['deposit'] + withdrawal = stats['withdrawal'] + + # 计算利润 + prev_balance = prev_balance_map.get(date_str, 0.0) + profit = balance - deposit - withdrawal - prev_balance + + # 转换时间戳 + date_obj = datetime.strptime(date_str, '%Y-%m-%d') + time_timestamp = int(date_obj.timestamp()) + + account_data = { + 'st_id': st_id, + 'k_id': k_id, + 'balance': balance, + 'withdrawal': withdrawal, + 'deposit': deposit, + 'other': 0.0, # 暂时为0 + 'profit': profit, + 'time': time_timestamp + } + + account_data_list.append(account_data) + + return account_data_list + + except Exception as e: + logger.error(f"获取Redis账户信息失败: k_id={k_id}, error={e}") + return [] + + def _get_previous_balances(self, redis_funds: Dict, sorted_dates: List[str]) -> Dict[str, float]: + """获取前一天的余额""" + prev_balance_map = {} + prev_date = None + + for date_str in sorted_dates: + # 查找前一天的余额 + if prev_date: + for fund_key, fund_json in redis_funds.items(): + try: + fund_data = json.loads(fund_json) + if (fund_data.get('lz_time') == prev_date and + fund_data.get('lz_type') == 'lz_balance'): + prev_balance_map[date_str] = float(fund_data.get('lz_amount', 0)) + break + except: + continue + else: + prev_balance_map[date_str] = 0.0 + + prev_date = date_str + + return prev_balance_map + + async def _sync_account_info_batch_to_db(self, account_data_list: List[Dict]) -> bool: + """批量同步账户信息到数据库(最高效版本)""" + session = self.db_manager.get_session() + try: + if not account_data_list: + return True + + with session.begin(): + # 方法1:使用原生SQL批量插入/更新(性能最好) + success = self._batch_upsert_account_info(session, account_data_list) + + if not success: + # 方法2:回退到ORM批量操作 + success = self._batch_orm_upsert_account_info(session, account_data_list) + + return success + + except Exception as e: + logger.error(f"批量同步账户信息到数据库失败: {e}") + return False + finally: + session.close() + + def _batch_upsert_account_info(self, session, account_data_list: List[Dict]) -> bool: + """使用原生SQL批量插入/更新账户信息""" + try: + # 准备批量数据 + values_list = [] + for data in account_data_list: + values = ( + f"({data['st_id']}, {data['k_id']}, 'USDT', " + f"{data['balance']}, {data['withdrawal']}, {data['deposit']}, " + f"{data['other']}, {data['profit']}, {data['time']})" + ) + values_list.append(values) + + if not values_list: + return True + + values_str = ", ".join(values_list) + + # 使用INSERT ... ON DUPLICATE KEY UPDATE + sql = f""" + INSERT INTO deh_strategy_kx_new + (st_id, k_id, asset, balance, withdrawal, deposit, other, profit, time) + VALUES {values_str} + ON DUPLICATE KEY UPDATE + balance = VALUES(balance), + withdrawal = VALUES(withdrawal), + deposit = VALUES(deposit), + other = VALUES(other), + profit = VALUES(profit), + up_time = NOW() + """ + + session.execute(text(sql)) + + logger.info(f"原生SQL批量更新账户信息: {len(account_data_list)} 条记录") + return True + + except Exception as e: + logger.error(f"原生SQL批量更新账户信息失败: {e}") + return False + + def _batch_orm_upsert_account_info(self, session, account_data_list: List[Dict]) -> bool: + """使用ORM批量插入/更新账户信息""" + try: + # 分组数据以提高效率 + account_data_by_key = {} + for data in account_data_list: + key = (data['k_id'], data['st_id'], data['time']) + account_data_by_key[key] = data + + # 批量查询现有记录 + existing_records = self._batch_query_existing_records(session, list(account_data_by_key.keys())) + + # 批量更新或插入 + to_update = [] + to_insert = [] + + for key, data in account_data_by_key.items(): + if key in existing_records: + # 更新 + record = existing_records[key] + record.balance = data['balance'] + record.withdrawal = data['withdrawal'] + record.deposit = data['deposit'] + record.other = data['other'] + record.profit = data['profit'] + else: + # 插入 + to_insert.append(StrategyKX(**data)) + + # 批量插入新记录 + if to_insert: + session.add_all(to_insert) + + logger.info(f"ORM批量更新账户信息: 更新 {len(existing_records)} 条,插入 {len(to_insert)} 条") + return True + + except Exception as e: + logger.error(f"ORM批量更新账户信息失败: {e}") + return False + + def _batch_query_existing_records(self, session, keys: List[tuple]) -> Dict[tuple, StrategyKX]: + """批量查询现有记录""" + existing_records = {} + + try: + if not keys: + return existing_records + + # 构建查询条件 + conditions = [] + for k_id, st_id, time_val in keys: + conditions.append(f"(k_id = {k_id} AND st_id = {st_id} AND time = {time_val})") + + if conditions: + conditions_str = " OR ".join(conditions) + sql = f""" + SELECT * FROM deh_strategy_kx_new + WHERE {conditions_str} + """ + + results = session.execute(text(sql)).fetchall() + + for row in results: + key = (row.k_id, row.st_id, row.time) + existing_records[key] = StrategyKX( + id=row.id, + st_id=row.st_id, + k_id=row.k_id, + asset=row.asset, + balance=row.balance, + withdrawal=row.withdrawal, + deposit=row.deposit, + other=row.other, + profit=row.profit, + time=row.time + ) + + except Exception as e: + logger.error(f"批量查询现有记录失败: {e}") + + return existing_records + + async def sync(self): + """兼容旧接口""" + accounts = self.get_accounts_from_redis() + await self.sync_batch(accounts) \ No newline at end of file diff --git a/sync/base_sync.py b/sync/base_sync.py index 4a43fff..8afb734 100644 --- a/sync/base_sync.py +++ b/sync/base_sync.py @@ -1,8 +1,10 @@ +# sync/base_sync.py from abc import ABC, abstractmethod from loguru import logger -from typing import List, Dict, Any, Set +from typing import List, Dict, Any, Set, Optional import json import re +import time from utils.redis_client import RedisClient from utils.database_manager import DatabaseManager @@ -16,28 +18,52 @@ class BaseSync(ABC): self.db_manager = DatabaseManager() self.computer_names = self._get_computer_names() self.computer_name_pattern = re.compile(COMPUTER_NAME_PATTERN) + self.sync_stats = { + 'total_accounts': 0, + 'success_count': 0, + 'error_count': 0, + 'last_sync_time': 0, + 'avg_sync_time': 0 + } def _get_computer_names(self) -> List[str]: """获取计算机名列表""" if ',' in COMPUTER_NAMES: - return [name.strip() for name in COMPUTER_NAMES.split(',')] + names = [name.strip() for name in COMPUTER_NAMES.split(',')] + logger.info(f"使用配置的计算机名列表: {names}") + return names return [COMPUTER_NAMES.strip()] + @abstractmethod + async def sync(self): + """执行同步(兼容旧接口)""" + pass + + @abstractmethod + async def sync_batch(self, accounts: Dict[str, Dict]): + """批量同步数据""" + pass + def get_accounts_from_redis(self) -> Dict[str, Dict]: """从Redis获取所有计算机名的账号配置""" try: accounts_dict = {} + total_keys_processed = 0 # 方法1:使用配置的计算机名列表 for computer_name in self.computer_names: accounts = self._get_accounts_by_computer_name(computer_name) + total_keys_processed += 1 accounts_dict.update(accounts) - # 方法2:自动发现所有匹配的key(备用方案) + # 方法2:如果配置的计算机名没有数据,尝试自动发现(备用方案) if not accounts_dict: + logger.warning("配置的计算机名未找到数据,尝试自动发现...") accounts_dict = self._discover_all_accounts() + self.sync_stats['total_accounts'] = len(accounts_dict) logger.info(f"从 {len(self.computer_names)} 个计算机名获取到 {len(accounts_dict)} 个账号") + return accounts_dict except Exception as e: @@ -58,6 +84,8 @@ class BaseSync(ABC): logger.debug(f"未找到 {redis_key} 的策略API配置") return {} + logger.info(f"从 {redis_key} 获取到 {len(result)} 个交易所配置") + for exchange_name, accounts_json in result.items(): try: accounts = json.loads(accounts_json) @@ -77,8 +105,11 @@ class BaseSync(ABC): except json.JSONDecodeError as e: logger.error(f"解析交易所 {exchange_name} 的JSON数据失败: {e}") continue + except Exception as e: + logger.error(f"处理交易所 {exchange_name} 数据异常: {e}") + continue - logger.info(f"从 {redis_key} 获取到 {len(accounts_dict)} 个账号") + logger.info(f"从 {redis_key} 解析到 {len(accounts_dict)} 个账号") except Exception as e: logger.error(f"获取计算机名 {computer_name} 的账号失败: {e}") @@ -88,10 +119,11 @@ class BaseSync(ABC): def _discover_all_accounts(self) -> Dict[str, Dict]: """自动发现所有匹配的账号key""" accounts_dict = {} + discovered_keys = [] try: # 获取所有匹配模式的key - pattern = f"*_strategy_api" + pattern = "*_strategy_api" cursor = 0 while True: @@ -99,23 +131,265 @@ class BaseSync(ABC): for key in keys: key_str = key.decode('utf-8') if isinstance(key, bytes) else key - - # 提取计算机名 - computer_name = key_str.replace('_strategy_api', '') - - # 验证计算机名格式 - if self.computer_name_pattern.match(computer_name): - accounts = self._get_accounts_by_computer_name(computer_name) - accounts_dict.update(accounts) + discovered_keys.append(key_str) if cursor == 0: break - logger.info(f"自动发现 {len(accounts_dict)} 个账号") + logger.info(f"自动发现 {len(discovered_keys)} 个策略API key") + + # 处理每个发现的key + for key_str in discovered_keys: + # 提取计算机名 + computer_name = key_str.replace('_strategy_api', '') + + # 验证计算机名格式 + if self.computer_name_pattern.match(computer_name): + accounts = self._get_accounts_by_computer_name(computer_name) + accounts_dict.update(accounts) + else: + logger.warning(f"跳过不符合格式的计算机名: {computer_name}") + + logger.info(f"自动发现共获取到 {len(accounts_dict)} 个账号") except Exception as e: logger.error(f"自动发现账号失败: {e}") return accounts_dict - # 其他方法保持不变... \ No newline at end of file + def format_exchange_id(self, key: str) -> str: + """格式化交易所ID""" + key = key.lower().strip() + + # 交易所名称映射 + exchange_mapping = { + 'metatrader': 'mt5', + 'binance_spot_test': 'binance', + 'binance_spot': 'binance', + 'binance': 'binance', + 'gate_spot': 'gate', + 'okex': 'okx', + 'okx': 'okx', + 'bybit': 'bybit', + 'bybit_spot': 'bybit', + 'bybit_test': 'bybit', + 'huobi': 'huobi', + 'huobi_spot': 'huobi', + 'gate': 'gate', + 'gateio': 'gate', + 'kucoin': 'kucoin', + 'kucoin_spot': 'kucoin', + 'mexc': 'mexc', + 'mexc_spot': 'mexc', + 'bitget': 'bitget', + 'bitget_spot': 'bitget' + } + + normalized_key = exchange_mapping.get(key, key) + + # 记录未映射的交易所 + if normalized_key == key and key not in exchange_mapping.values(): + logger.debug(f"未映射的交易所名称: {key}") + + return normalized_key + + def parse_account(self, exchange_id: str, account_id: str, account_info: str) -> Optional[Dict]: + """解析账号信息""" + try: + source_account_info = json.loads(account_info) + + # 基础信息 + account_data = { + 'exchange_id': exchange_id, + 'k_id': account_id, + 'st_id': self._safe_int(source_account_info.get('st_id'), 0), + 'add_time': self._safe_int(source_account_info.get('add_time'), 0), + 'account_type': source_account_info.get('account_type', 'real'), + 'api_key': source_account_info.get('api_key', ''), + 'secret_key': source_account_info.get('secret_key', ''), + 'password': source_account_info.get('password', ''), + 'access_token': source_account_info.get('access_token', ''), + 'remark': source_account_info.get('remark', '') + } + + # MT5特殊处理 + if exchange_id == 'mt5': + # 解析服务器地址和端口 + server_info = source_account_info.get('secret_key', '') + if ':' in server_info: + host, port = server_info.split(':', 1) + account_data['mt5_host'] = host + account_data['mt5_port'] = self._safe_int(port, 0) + + # 合并原始信息 + result = {**source_account_info, **account_data} + + # 验证必要字段 + if not result.get('st_id') or not result.get('exchange_id'): + logger.warning(f"账号 {account_id} 缺少必要字段: st_id={result.get('st_id')}, exchange_id={result.get('exchange_id')}") + return None + + return result + + except json.JSONDecodeError as e: + logger.error(f"解析账号 {account_id} JSON数据失败: {e}, 原始数据: {account_info[:100]}...") + return None + except Exception as e: + logger.error(f"处理账号 {account_id} 数据异常: {e}") + return None + + def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]: + """按交易所分组账号""" + groups = {} + for account_id, account_info in accounts.items(): + exchange_id = account_info.get('exchange_id') + if exchange_id: + if exchange_id not in groups: + groups[exchange_id] = [] + groups[exchange_id].append(account_info) + return groups + + def _safe_float(self, value: Any, default: float = 0.0) -> float: + """安全转换为float""" + if value is None: + return default + try: + if isinstance(value, str): + value = value.strip() + if value == '': + return default + return float(value) + except (ValueError, TypeError): + return default + + def _safe_int(self, value: Any, default: int = 0) -> int: + """安全转换为int""" + if value is None: + return default + try: + if isinstance(value, str): + value = value.strip() + if value == '': + return default + return int(float(value)) + except (ValueError, TypeError): + return default + + def _safe_str(self, value: Any, default: str = '') -> str: + """安全转换为str""" + if value is None: + return default + try: + result = str(value).strip() + return result if result else default + except: + return default + + def _escape_sql_value(self, value: Any) -> str: + """转义SQL值""" + if value is None: + return 'NULL' + if isinstance(value, bool): + return '1' if value else '0' + if isinstance(value, (int, float)): + return str(value) + if isinstance(value, str): + # 转义单引号 + escaped = value.replace("'", "''") + return f"'{escaped}'" + # 其他类型转换为字符串 + escaped = str(value).replace("'", "''") + return f"'{escaped}'" + + def _build_sql_values_list(self, data_list: List[Dict], fields_mapping: Dict[str, str] = None) -> List[str]: + """构建SQL VALUES列表""" + values_list = [] + + for data in data_list: + try: + value_parts = [] + for field, value in data.items(): + # 应用字段映射 + if fields_mapping and field in fields_mapping: + db_field = fields_mapping[field] + else: + db_field = field + + escaped_value = self._escape_sql_value(value) + value_parts.append(escaped_value) + + values_str = ", ".join(value_parts) + values_list.append(f"({values_str})") + + except Exception as e: + logger.error(f"构建SQL值失败: {data}, error={e}") + continue + + return values_list + + def _get_recent_dates(self, days: int) -> List[str]: + """获取最近N天的日期列表""" + from datetime import datetime, timedelta + + dates = [] + today = datetime.now() + + for i in range(days): + date = today - timedelta(days=i) + dates.append(date.strftime('%Y-%m-%d')) + + return dates + + def _date_to_timestamp(self, date_str: str) -> int: + """将日期字符串转换为时间戳(当天0点)""" + from datetime import datetime + + try: + dt = datetime.strptime(date_str, '%Y-%m-%d') + return int(dt.timestamp()) + except ValueError: + return 0 + + def update_stats(self, success: bool = True, sync_time: float = 0): + """更新统计信息""" + if success: + self.sync_stats['success_count'] += 1 + else: + self.sync_stats['error_count'] += 1 + + if sync_time > 0: + self.sync_stats['last_sync_time'] = sync_time + # 计算平均时间(滑动平均) + if self.sync_stats['avg_sync_time'] == 0: + self.sync_stats['avg_sync_time'] = sync_time + else: + self.sync_stats['avg_sync_time'] = ( + self.sync_stats['avg_sync_time'] * 0.9 + sync_time * 0.1 + ) + + def print_stats(self, sync_type: str = ""): + """打印统计信息""" + stats = self.sync_stats + prefix = f"[{sync_type}] " if sync_type else "" + + stats_str = ( + f"{prefix}统计: 账号数={stats['total_accounts']}, " + f"成功={stats['success_count']}, 失败={stats['error_count']}, " + f"本次耗时={stats['last_sync_time']:.2f}s, " + f"平均耗时={stats['avg_sync_time']:.2f}s" + ) + + if stats['error_count'] > 0: + logger.warning(stats_str) + else: + logger.info(stats_str) + + def reset_stats(self): + """重置统计信息""" + self.sync_stats = { + 'total_accounts': 0, + 'success_count': 0, + 'error_count': 0, + 'last_sync_time': 0, + 'avg_sync_time': 0 + } \ No newline at end of file diff --git a/sync/manager.py b/sync/manager.py index ed00a2f..4019f21 100644 --- a/sync/manager.py +++ b/sync/manager.py @@ -1,99 +1,97 @@ import asyncio from loguru import logger -from typing import List, Dict, Optional import signal import sys -from concurrent.futures import ThreadPoolExecutor import time -from asyncio import Semaphore +from typing import Dict from config.settings import SYNC_CONFIG -from .position_sync import PositionSync -from .order_sync import OrderSync -from .account_sync import AccountSync +from .position_sync_batch import PositionSyncBatch +from .order_sync_batch import OrderSyncBatch # 使用批量版本 +from .account_sync_batch import AccountSyncBatch +from utils.batch_position_sync import BatchPositionSync +from utils.batch_order_sync import BatchOrderSync +from utils.batch_account_sync import BatchAccountSync +from utils.redis_batch_helper import RedisBatchHelper class SyncManager: - """同步管理器(支持批量并发处理)""" + """同步管理器(完整批量版本)""" def __init__(self): self.is_running = True self.sync_interval = SYNC_CONFIG['interval'] - self.max_concurrent = int(os.getenv('MAX_CONCURRENT', '10')) # 最大并发数 + + # 初始化批量同步工具 + self.batch_tools = {} + self.redis_helper = None # 初始化同步器 self.syncers = [] - self.executor = ThreadPoolExecutor(max_workers=self.max_concurrent) - - self.semaphore = Semaphore(self.max_concurrent) # 控制并发数 if SYNC_CONFIG['enable_position_sync']: - self.syncers.append(PositionSync()) - logger.info("启用持仓同步") + position_sync = PositionSyncBatch() + self.syncers.append(position_sync) + self.batch_tools['position'] = BatchPositionSync(position_sync.db_manager) + logger.info("启用持仓批量同步") if SYNC_CONFIG['enable_order_sync']: - self.syncers.append(OrderSync()) - logger.info("启用订单同步") + order_sync = OrderSyncBatch() + self.syncers.append(order_sync) + self.batch_tools['order'] = BatchOrderSync(order_sync.db_manager) + + # 初始化Redis批量助手 + if order_sync.redis_client: + self.redis_helper = RedisBatchHelper(order_sync.redis_client.client) + + logger.info("启用订单批量同步") if SYNC_CONFIG['enable_account_sync']: - self.syncers.append(AccountSync()) - logger.info("启用账户信息同步") + account_sync = AccountSyncBatch() + self.syncers.append(account_sync) + self.batch_tools['account'] = BatchAccountSync(account_sync.db_manager) + logger.info("启用账户信息批量同步") # 性能统计 self.stats = { - 'total_accounts': 0, - 'success_count': 0, - 'error_count': 0, + 'total_syncs': 0, 'last_sync_time': 0, - 'avg_sync_time': 0 + 'avg_sync_time': 0, + 'position': {'accounts': 0, 'positions': 0, 'time': 0}, + 'order': {'accounts': 0, 'orders': 0, 'time': 0}, + 'account': {'accounts': 0, 'records': 0, 'time': 0} } # 注册信号处理器 signal.signal(signal.SIGINT, self.signal_handler) signal.signal(signal.SIGTERM, self.signal_handler) - - async def _run_syncer_with_limit(self, syncer): - """带并发限制的运行""" - async with self.semaphore: - return await self._run_syncer(syncer) - - def signal_handler(self, signum, frame): - """信号处理器""" - logger.info(f"接收到信号 {signum},正在关闭...") - self.is_running = False - def batch_process_accounts(self, accounts: Dict[str, Dict], batch_size: int = 100): - """分批处理账号""" - account_items = list(accounts.items()) - - for i in range(0, len(account_items), batch_size): - batch = dict(account_items[i:i + batch_size]) - # 处理这批账号 - self._process_account_batch(batch) - - # 批次间休息,避免数据库压力过大 - time.sleep(0.1) - async def start(self): """启动同步服务""" - logger.info(f"同步服务启动,间隔 {self.sync_interval} 秒,最大并发 {self.max_concurrent}") + logger.info(f"同步服务启动,间隔 {self.sync_interval} 秒") while self.is_running: try: - start_time = time.time() + self.stats['total_syncs'] += 1 + sync_start = time.time() - # 执行所有同步器 - tasks = [self._run_syncer(syncer) for syncer in self.syncers] - results = await asyncio.gather(*tasks, return_exceptions=True) + # 获取所有账号(只获取一次) + accounts = await self._get_all_accounts() + + if not accounts: + logger.warning("未获取到任何账号,等待下次同步") + await asyncio.sleep(self.sync_interval) + continue + + logger.info(f"第{self.stats['total_syncs']}次同步开始,共 {len(accounts)} 个账号") + + # 并发执行所有同步 + await self._execute_all_syncers_concurrent(accounts) # 更新统计 - sync_time = time.time() - start_time - self.stats['last_sync_time'] = sync_time - self.stats['avg_sync_time'] = (self.stats['avg_sync_time'] * 0.9 + sync_time * 0.1) + sync_time = time.time() - sync_start + self._update_stats(sync_time) - # 打印统计信息 - self._print_stats() - - logger.debug(f"同步完成,耗时 {sync_time:.2f} 秒,等待 {self.sync_interval} 秒") + logger.info(f"同步完成,总耗时 {sync_time:.2f} 秒,等待 {self.sync_interval} 秒") await asyncio.sleep(self.sync_interval) except asyncio.CancelledError: @@ -101,41 +99,182 @@ class SyncManager: break except Exception as e: logger.error(f"同步任务异常: {e}") - self.stats['error_count'] += 1 - await asyncio.sleep(30) # 出错后等待30秒 + await asyncio.sleep(30) - async def _run_syncer(self, syncer): - """运行单个同步器""" - try: - # 获取所有账号 - accounts = syncer.get_accounts_from_redis() - self.stats['total_accounts'] = len(accounts) + async def _get_all_accounts(self) -> Dict[str, Dict]: + """获取所有账号""" + if not self.syncers: + return {} + + # 使用第一个同步器获取账号 + return self.syncers[0].get_accounts_from_redis() + + async def _execute_all_syncers_concurrent(self, accounts: Dict[str, Dict]): + """并发执行所有同步器""" + tasks = [] + + # 持仓批量同步 + if 'position' in self.batch_tools: + task = self._sync_positions_batch(accounts) + tasks.append(task) + + # 订单批量同步 + if 'order' in self.batch_tools: + task = self._sync_orders_batch(accounts) + tasks.append(task) + + # 账户信息批量同步 + if 'account' in self.batch_tools: + task = self._sync_accounts_batch(accounts) + tasks.append(task) + + # 并发执行所有任务 + if tasks: + results = await asyncio.gather(*tasks, return_exceptions=True) - if not accounts: - logger.warning("未获取到任何账号") + # 检查结果 + for i, result in enumerate(results): + if isinstance(result, Exception): + logger.error(f"同步任务 {i} 失败: {result}") + + async def _sync_positions_batch(self, accounts: Dict[str, Dict]): + """批量同步持仓数据""" + try: + start_time = time.time() + + # 收集所有持仓数据 + position_sync = next((s for s in self.syncers if isinstance(s, PositionSyncBatch)), None) + if not position_sync: return - # 批量处理账号 - await syncer.sync_batch(accounts) - self.stats['success_count'] += 1 + all_positions = await position_sync._collect_all_positions(accounts) + + if not all_positions: + self.stats['position'] = {'accounts': 0, 'positions': 0, 'time': 0} + return + + # 使用批量工具同步 + batch_tool = self.batch_tools['position'] + success, stats = batch_tool.sync_positions_batch(all_positions) + + if success: + elapsed = time.time() - start_time + self.stats['position'] = { + 'accounts': len(accounts), + 'positions': stats['total'], + 'time': elapsed + } + + except Exception as e: + logger.error(f"批量同步持仓失败: {e}") + self.stats['position'] = {'accounts': 0, 'positions': 0, 'time': 0} + + async def _sync_orders_batch(self, accounts: Dict[str, Dict]): + """批量同步订单数据""" + try: + start_time = time.time() + + # 收集所有订单数据 + order_sync = next((s for s in self.syncers if isinstance(s, OrderSyncBatch)), None) + if not order_sync: + return + + all_orders = await order_sync._collect_all_orders(accounts) + + if not all_orders: + self.stats['order'] = {'accounts': 0, 'orders': 0, 'time': 0} + return + + # 使用批量工具同步 + batch_tool = self.batch_tools['order'] + success, processed_count = batch_tool.sync_orders_batch(all_orders) + + if success: + elapsed = time.time() - start_time + self.stats['order'] = { + 'accounts': len(accounts), + 'orders': processed_count, + 'time': elapsed + } + + except Exception as e: + logger.error(f"批量同步订单失败: {e}") + self.stats['order'] = {'accounts': 0, 'orders': 0, 'time': 0} + + async def _sync_accounts_batch(self, accounts: Dict[str, Dict]): + """批量同步账户信息数据""" + try: + start_time = time.time() + + # 收集所有账户数据 + account_sync = next((s for s in self.syncers if isinstance(s, AccountSyncBatch)), None) + if not account_sync: + return + + all_account_data = await account_sync._collect_all_account_data(accounts) + + if not all_account_data: + self.stats['account'] = {'accounts': 0, 'records': 0, 'time': 0} + return + + # 使用批量工具同步 + batch_tool = self.batch_tools['account'] + updated, inserted = batch_tool.sync_accounts_batch(all_account_data) + + elapsed = time.time() - start_time + self.stats['account'] = { + 'accounts': len(accounts), + 'records': len(all_account_data), + 'time': elapsed + } except Exception as e: - logger.error(f"同步器 {syncer.__class__.__name__} 执行失败: {e}") - self.stats['error_count'] += 1 + logger.error(f"批量同步账户信息失败: {e}") + self.stats['account'] = {'accounts': 0, 'records': 0, 'time': 0} - def _print_stats(self): - """打印统计信息""" - stats_str = ( - f"统计: 账号数={self.stats['total_accounts']}, " - f"成功={self.stats['success_count']}, " - f"失败={self.stats['error_count']}, " - f"本次耗时={self.stats['last_sync_time']:.2f}s, " - f"平均耗时={self.stats['avg_sync_time']:.2f}s" - ) - logger.info(stats_str) + def _update_stats(self, sync_time: float): + """更新统计信息""" + self.stats['last_sync_time'] = sync_time + self.stats['avg_sync_time'] = (self.stats['avg_sync_time'] * 0.9 + sync_time * 0.1) + + # 打印详细统计 + stats_lines = [ + f"=== 第{self.stats['total_syncs']}次同步统计 ===", + f"总耗时: {sync_time:.2f}秒 | 平均耗时: {self.stats['avg_sync_time']:.2f}秒" + ] + + if self.stats['position']['accounts'] > 0: + stats_lines.append( + f"持仓: {self.stats['position']['accounts']}账号/{self.stats['position']['positions']}条" + f"/{self.stats['position']['time']:.2f}秒" + ) + + if self.stats['order']['accounts'] > 0: + stats_lines.append( + f"订单: {self.stats['order']['accounts']}账号/{self.stats['order']['orders']}条" + f"/{self.stats['order']['time']:.2f}秒" + ) + + if self.stats['account']['accounts'] > 0: + stats_lines.append( + f"账户: {self.stats['account']['accounts']}账号/{self.stats['account']['records']}条" + f"/{self.stats['account']['time']:.2f}秒" + ) + + logger.info("\n".join(stats_lines)) + + def signal_handler(self, signum, frame): + """信号处理器""" + logger.info(f"接收到信号 {signum},正在关闭...") + self.is_running = False async def stop(self): """停止同步服务""" self.is_running = False - self.executor.shutdown(wait=True) + + # 关闭所有数据库连接 + for syncer in self.syncers: + if hasattr(syncer, 'db_manager'): + syncer.db_manager.close() + logger.info("同步服务停止") \ No newline at end of file diff --git a/sync/order_sync_batch.py b/sync/order_sync_batch.py new file mode 100644 index 0000000..f76c267 --- /dev/null +++ b/sync/order_sync_batch.py @@ -0,0 +1,269 @@ +from .base_sync import BaseSync +from loguru import logger +from typing import List, Dict, Any, Tuple +import json +import asyncio +import time +from datetime import datetime, timedelta +from sqlalchemy import text +import redis + +class OrderSyncBatch(BaseSync): + """订单数据批量同步器""" + + def __init__(self): + super().__init__() + self.batch_size = 1000 # 每批处理数量 + self.recent_days = 3 # 同步最近几天的数据 + + async def sync_batch(self, accounts: Dict[str, Dict]): + """批量同步所有账号的订单数据""" + try: + logger.info(f"开始批量同步订单数据,共 {len(accounts)} 个账号") + start_time = time.time() + + # 1. 收集所有账号的订单数据 + all_orders = await self._collect_all_orders(accounts) + + if not all_orders: + logger.info("无订单数据需要同步") + return + + logger.info(f"收集到 {len(all_orders)} 条订单数据") + + # 2. 批量同步到数据库 + success, processed_count = await self._sync_orders_batch_to_db(all_orders) + + elapsed = time.time() - start_time + if success: + logger.info(f"订单批量同步完成: 处理 {processed_count} 条订单,耗时 {elapsed:.2f}秒") + else: + logger.error("订单批量同步失败") + + except Exception as e: + logger.error(f"订单批量同步失败: {e}") + + async def _collect_all_orders(self, accounts: Dict[str, Dict]) -> List[Dict]: + """收集所有账号的订单数据""" + all_orders = [] + + try: + # 按交易所分组账号 + account_groups = self._group_accounts_by_exchange(accounts) + + # 并发收集每个交易所的数据 + tasks = [] + for exchange_id, account_list in account_groups.items(): + task = self._collect_exchange_orders(exchange_id, account_list) + tasks.append(task) + + # 等待所有任务完成并合并结果 + results = await asyncio.gather(*tasks, return_exceptions=True) + + for result in results: + if isinstance(result, list): + all_orders.extend(result) + + except Exception as e: + logger.error(f"收集订单数据失败: {e}") + + return all_orders + + def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]: + """按交易所分组账号""" + groups = {} + for account_id, account_info in accounts.items(): + exchange_id = account_info.get('exchange_id') + if exchange_id: + if exchange_id not in groups: + groups[exchange_id] = [] + groups[exchange_id].append(account_info) + return groups + + async def _collect_exchange_orders(self, exchange_id: str, account_list: List[Dict]) -> List[Dict]: + """收集某个交易所的订单数据""" + orders_list = [] + + try: + # 并发获取每个账号的数据 + tasks = [] + for account_info in account_list: + k_id = int(account_info['k_id']) + st_id = account_info.get('st_id', 0) + task = self._get_recent_orders_from_redis(k_id, st_id, exchange_id) + tasks.append(task) + + results = await asyncio.gather(*tasks, return_exceptions=True) + + for result in results: + if isinstance(result, list): + orders_list.extend(result) + + logger.debug(f"交易所 {exchange_id}: 收集到 {len(orders_list)} 条订单") + + except Exception as e: + logger.error(f"收集交易所 {exchange_id} 订单数据失败: {e}") + + return orders_list + + async def _get_recent_orders_from_redis(self, k_id: int, st_id: int, exchange_id: str) -> List[Dict]: + """从Redis获取最近N天的订单数据""" + try: + redis_key = f"{exchange_id}:orders:{k_id}" + + # 计算最近N天的日期 + today = datetime.now() + recent_dates = [] + for i in range(self.recent_days): + date = today - timedelta(days=i) + date_format = date.strftime('%Y-%m-%d') + recent_dates.append(date_format) + + # 使用scan获取所有符合条件的key + cursor = 0 + recent_keys = [] + + while True: + cursor, keys = self.redis_client.client.hscan(redis_key, cursor, count=1000) + + for key, _ in keys.items(): + key_str = key.decode('utf-8') if isinstance(key, bytes) else key + + if key_str == 'positions': + continue + + # 检查是否以最近N天的日期开头 + for date_format in recent_dates: + if key_str.startswith(date_format + '_'): + recent_keys.append(key_str) + break + + if cursor == 0: + break + + if not recent_keys: + return [] + + # 批量获取订单数据 + orders_list = [] + + # 分批获取,避免单次hgetall数据量太大 + chunk_size = 500 + for i in range(0, len(recent_keys), chunk_size): + chunk_keys = recent_keys[i:i + chunk_size] + + # 使用hmget批量获取 + chunk_values = self.redis_client.client.hmget(redis_key, chunk_keys) + + for key, order_json in zip(chunk_keys, chunk_values): + if not order_json: + continue + + try: + order = json.loads(order_json) + + # 验证时间 + order_time = order.get('time', 0) + if order_time >= int(time.time()) - self.recent_days * 24 * 3600: + # 添加账号信息 + order['k_id'] = k_id + order['st_id'] = st_id + order['exchange_id'] = exchange_id + orders_list.append(order) + + except json.JSONDecodeError as e: + logger.debug(f"解析订单JSON失败: key={key}, error={e}") + continue + + return orders_list + + except Exception as e: + logger.error(f"获取Redis订单数据失败: k_id={k_id}, error={e}") + return [] + + async def _sync_orders_batch_to_db(self, all_orders: List[Dict]) -> Tuple[bool, int]: + """批量同步订单数据到数据库""" + try: + if not all_orders: + return True, 0 + + # 转换数据 + converted_orders = [] + for order in all_orders: + try: + order_dict = self._convert_order_data(order) + + # 检查完整性 + required_fields = ['order_id', 'symbol', 'side', 'time'] + if not all(order_dict.get(field) for field in required_fields): + continue + + converted_orders.append(order_dict) + + except Exception as e: + logger.error(f"转换订单数据失败: {order}, error={e}") + continue + + if not converted_orders: + return True, 0 + + # 使用批量工具同步 + from utils.batch_order_sync import BatchOrderSync + batch_tool = BatchOrderSync(self.db_manager, self.batch_size) + + success, processed_count = batch_tool.sync_orders_batch(converted_orders) + + return success, processed_count + + except Exception as e: + logger.error(f"批量同步订单到数据库失败: {e}") + return False, 0 + + def _convert_order_data(self, data: Dict) -> Dict: + """转换订单数据格式""" + try: + # 安全转换函数 + def safe_float(value): + if value is None: + return None + try: + return float(value) + except (ValueError, TypeError): + return None + + def safe_int(value): + if value is None: + return None + try: + return int(float(value)) + except (ValueError, TypeError): + return None + + def safe_str(value): + if value is None: + return '' + return str(value) + + return { + 'st_id': safe_int(data.get('st_id'), 0), + 'k_id': safe_int(data.get('k_id'), 0), + 'asset': 'USDT', + 'order_id': safe_str(data.get('order_id')), + 'symbol': safe_str(data.get('symbol')), + 'side': safe_str(data.get('side')), + 'price': safe_float(data.get('price')), + 'time': safe_int(data.get('time')), + 'order_qty': safe_float(data.get('order_qty')), + 'last_qty': safe_float(data.get('last_qty')), + 'avg_price': safe_float(data.get('avg_price')), + 'exchange_id': None # 忽略该字段 + } + + except Exception as e: + logger.error(f"转换订单数据异常: {data}, error={e}") + return {} + + async def sync(self): + """兼容旧接口""" + accounts = self.get_accounts_from_redis() + await self.sync_batch(accounts) \ No newline at end of file diff --git a/sync/position_sync_batch.py b/sync/position_sync_batch.py new file mode 100644 index 0000000..4531a5a --- /dev/null +++ b/sync/position_sync_batch.py @@ -0,0 +1,378 @@ +from .base_sync import BaseSync +from loguru import logger +from typing import List, Dict, Any, Set, Tuple +import json +import asyncio +from datetime import datetime +from sqlalchemy import text, and_, select, delete +from models.orm_models import StrategyPosition +import time + +class PositionSyncBatch(BaseSync): + """持仓数据批量同步器""" + + def __init__(self): + super().__init__() + self.batch_size = 500 # 每批处理数量 + + async def sync_batch(self, accounts: Dict[str, Dict]): + """批量同步所有账号的持仓数据""" + try: + logger.info(f"开始批量同步持仓数据,共 {len(accounts)} 个账号") + start_time = time.time() + + # 1. 收集所有账号的持仓数据 + all_positions = await self._collect_all_positions(accounts) + + if not all_positions: + logger.info("无持仓数据需要同步") + return + + logger.info(f"收集到 {len(all_positions)} 条持仓数据") + + # 2. 批量同步到数据库 + success, stats = await self._sync_positions_batch_to_db(all_positions) + + elapsed = time.time() - start_time + if success: + logger.info(f"持仓批量同步完成: 处理 {stats['total']} 条,更新 {stats['updated']} 条," + f"插入 {stats['inserted']} 条,删除 {stats['deleted']} 条,耗时 {elapsed:.2f}秒") + else: + logger.error("持仓批量同步失败") + + except Exception as e: + logger.error(f"持仓批量同步失败: {e}") + + async def _collect_all_positions(self, accounts: Dict[str, Dict]) -> List[Dict]: + """收集所有账号的持仓数据""" + all_positions = [] + + try: + # 按交易所分组账号 + account_groups = self._group_accounts_by_exchange(accounts) + + # 并发收集每个交易所的数据 + tasks = [] + for exchange_id, account_list in account_groups.items(): + task = self._collect_exchange_positions(exchange_id, account_list) + tasks.append(task) + + # 等待所有任务完成并合并结果 + results = await asyncio.gather(*tasks, return_exceptions=True) + + for result in results: + if isinstance(result, list): + all_positions.extend(result) + + except Exception as e: + logger.error(f"收集持仓数据失败: {e}") + + return all_positions + + def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]: + """按交易所分组账号""" + groups = {} + for account_id, account_info in accounts.items(): + exchange_id = account_info.get('exchange_id') + if exchange_id: + if exchange_id not in groups: + groups[exchange_id] = [] + groups[exchange_id].append(account_info) + return groups + + async def _collect_exchange_positions(self, exchange_id: str, account_list: List[Dict]) -> List[Dict]: + """收集某个交易所的持仓数据""" + positions_list = [] + + try: + tasks = [] + for account_info in account_list: + k_id = int(account_info['k_id']) + st_id = account_info.get('st_id', 0) + task = self._get_positions_from_redis(k_id, st_id, exchange_id) + tasks.append(task) + + # 并发获取 + results = await asyncio.gather(*tasks, return_exceptions=True) + + for result in results: + if isinstance(result, list): + positions_list.extend(result) + + except Exception as e: + logger.error(f"收集交易所 {exchange_id} 持仓数据失败: {e}") + + return positions_list + + async def _get_positions_from_redis(self, k_id: int, st_id: int, exchange_id: str) -> List[Dict]: + """从Redis获取持仓数据""" + try: + redis_key = f"{exchange_id}:positions:{k_id}" + redis_data = self.redis_client.client.hget(redis_key, 'positions') + + if not redis_data: + return [] + + positions = json.loads(redis_data) + + # 添加账号信息 + for position in positions: + position['k_id'] = k_id + position['st_id'] = st_id + position['exchange_id'] = exchange_id + + return positions + + except Exception as e: + logger.error(f"获取Redis持仓数据失败: k_id={k_id}, error={e}") + return [] + + async def _sync_positions_batch_to_db(self, all_positions: List[Dict]) -> Tuple[bool, Dict]: + """批量同步持仓数据到数据库""" + try: + if not all_positions: + return True, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0} + + # 按账号分组 + positions_by_account = {} + for position in all_positions: + k_id = position['k_id'] + if k_id not in positions_by_account: + positions_by_account[k_id] = [] + positions_by_account[k_id].append(position) + + logger.info(f"开始批量处理 {len(positions_by_account)} 个账号的持仓数据") + + # 批量处理每个账号 + total_stats = {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0} + + for k_id, positions in positions_by_account.items(): + st_id = positions[0]['st_id'] if positions else 0 + + # 处理单个账号的批量同步 + success, stats = await self._sync_single_account_batch(k_id, st_id, positions) + + if success: + total_stats['total'] += stats['total'] + total_stats['updated'] += stats['updated'] + total_stats['inserted'] += stats['inserted'] + total_stats['deleted'] += stats['deleted'] + + return True, total_stats + + except Exception as e: + logger.error(f"批量同步持仓到数据库失败: {e}") + return False, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0} + + async def _sync_single_account_batch(self, k_id: int, st_id: int, positions: List[Dict]) -> Tuple[bool, Dict]: + """批量同步单个账号的持仓数据""" + session = self.db_manager.get_session() + try: + # 准备数据 + insert_data = [] + new_positions_map = {} # (symbol, side) -> position_id (用于删除) + + for position_data in positions: + try: + position_dict = self._convert_position_data(position_data) + if not all([position_dict.get('symbol'), position_dict.get('side')]): + continue + + symbol = position_dict['symbol'] + side = position_dict['side'] + key = (symbol, side) + + # 重命名qty为sum + if 'qty' in position_dict: + position_dict['sum'] = position_dict.pop('qty') + + insert_data.append(position_dict) + new_positions_map[key] = position_dict.get('id') # 如果有id的话 + + except Exception as e: + logger.error(f"转换持仓数据失败: {position_data}, error={e}") + continue + + with session.begin(): + if not insert_data: + # 清空该账号所有持仓 + result = session.execute( + delete(StrategyPosition).where( + and_( + StrategyPosition.k_id == k_id, + StrategyPosition.st_id == st_id + ) + ) + ) + deleted_count = result.rowcount + + return True, { + 'total': 0, + 'updated': 0, + 'inserted': 0, + 'deleted': deleted_count + } + + # 1. 批量插入/更新持仓数据 + processed_count = self._batch_upsert_positions(session, insert_data) + + # 2. 批量删除多余持仓 + deleted_count = self._batch_delete_extra_positions(session, k_id, st_id, new_positions_map) + + # 注意:这里无法区分插入和更新的数量,processed_count是总处理数 + inserted_count = processed_count # 简化处理 + updated_count = 0 # 需要更复杂的逻辑来区分 + + stats = { + 'total': len(insert_data), + 'updated': updated_count, + 'inserted': inserted_count, + 'deleted': deleted_count + } + + return True, stats + + except Exception as e: + logger.error(f"批量同步账号 {k_id} 持仓失败: {e}") + return False, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0} + finally: + session.close() + + def _batch_upsert_positions(self, session, insert_data: List[Dict]) -> int: + """批量插入/更新持仓数据""" + try: + # 分块处理 + chunk_size = self.batch_size + total_processed = 0 + + for i in range(0, len(insert_data), chunk_size): + chunk = insert_data[i:i + chunk_size] + + values_list = [] + for data in chunk: + values = ( + f"({data['st_id']}, {data['k_id']}, '{data.get('asset', 'USDT')}', " + f"'{data['symbol'].replace(\"'\", \"''\")}', '{data['side']}', " + f"{data.get('price') or 'NULL'}, {data.get('sum') or 'NULL'}, " + f"{data.get('asset_num') or 'NULL'}, {data.get('asset_profit') or 'NULL'}, " + f"{data.get('leverage') or 'NULL'}, {data.get('uptime') or 'NULL'}, " + f"{data.get('profit_price') or 'NULL'}, {data.get('stop_price') or 'NULL'}, " + f"{data.get('liquidation_price') or 'NULL'})" + ) + values_list.append(values) + + if values_list: + values_str = ", ".join(values_list) + + sql = f""" + INSERT INTO deh_strategy_position_new + (st_id, k_id, asset, symbol, side, price, `sum`, + asset_num, asset_profit, leverage, uptime, + profit_price, stop_price, liquidation_price) + VALUES {values_str} + ON DUPLICATE KEY UPDATE + price = VALUES(price), + `sum` = VALUES(`sum`), + asset_num = VALUES(asset_num), + asset_profit = VALUES(asset_profit), + leverage = VALUES(leverage), + uptime = VALUES(uptime), + profit_price = VALUES(profit_price), + stop_price = VALUES(stop_price), + liquidation_price = VALUES(liquidation_price) + """ + + session.execute(text(sql)) + total_processed += len(chunk) + + return total_processed + + except Exception as e: + logger.error(f"批量插入/更新持仓失败: {e}") + raise + + def _batch_delete_extra_positions(self, session, k_id: int, st_id: int, new_positions_map: Dict) -> int: + """批量删除多余持仓""" + try: + if not new_positions_map: + # 删除所有持仓 + result = session.execute( + delete(StrategyPosition).where( + and_( + StrategyPosition.k_id == k_id, + StrategyPosition.st_id == st_id + ) + ) + ) + return result.rowcount + + # 构建保留条件 + conditions = [] + for (symbol, side) in new_positions_map.keys(): + safe_symbol = symbol.replace("'", "''") if symbol else '' + safe_side = side.replace("'", "''") if side else '' + conditions.append(f"(symbol = '{safe_symbol}' AND side = '{safe_side}')") + + if conditions: + conditions_str = " OR ".join(conditions) + + sql = f""" + DELETE FROM deh_strategy_position_new + WHERE k_id = {k_id} AND st_id = {st_id} + AND NOT ({conditions_str}) + """ + + result = session.execute(text(sql)) + return result.rowcount + + return 0 + + except Exception as e: + logger.error(f"批量删除持仓失败: k_id={k_id}, error={e}") + return 0 + + def _convert_position_data(self, data: Dict) -> Dict: + """转换持仓数据格式""" + try: + # 安全转换函数 + def safe_float(value, default=None): + if value is None: + return default + try: + return float(value) + except (ValueError, TypeError): + return default + + def safe_int(value, default=None): + if value is None: + return default + try: + return int(float(value)) + except (ValueError, TypeError): + return default + + return { + 'st_id': safe_int(data.get('st_id'), 0), + 'k_id': safe_int(data.get('k_id'), 0), + 'asset': data.get('asset', 'USDT'), + 'symbol': data.get('symbol', ''), + 'side': data.get('side', ''), + 'price': safe_float(data.get('price')), + 'qty': safe_float(data.get('qty')), # 后面会重命名为sum + 'asset_num': safe_float(data.get('asset_num')), + 'asset_profit': safe_float(data.get('asset_profit')), + 'leverage': safe_int(data.get('leverage')), + 'uptime': safe_int(data.get('uptime')), + 'profit_price': safe_float(data.get('profit_price')), + 'stop_price': safe_float(data.get('stop_price')), + 'liquidation_price': safe_float(data.get('liquidation_price')) + } + + except Exception as e: + logger.error(f"转换持仓数据异常: {data}, error={e}") + return {} + + async def sync(self): + """兼容旧接口""" + accounts = self.get_accounts_from_redis() + await self.sync_batch(accounts) \ No newline at end of file diff --git a/utils/batch_account_sync.py b/utils/batch_account_sync.py new file mode 100644 index 0000000..d5697ce --- /dev/null +++ b/utils/batch_account_sync.py @@ -0,0 +1,174 @@ +from typing import List, Dict, Any, Tuple +from loguru import logger +from sqlalchemy import text +import time + +class BatchAccountSync: + """账户信息批量同步工具""" + + def __init__(self, db_manager): + self.db_manager = db_manager + + def sync_accounts_batch(self, all_account_data: List[Dict]) -> Tuple[int, int]: + """批量同步账户信息(最高效版本)""" + if not all_account_data: + return 0, 0 + + session = self.db_manager.get_session() + try: + start_time = time.time() + + # 方法1:使用临时表进行批量操作(性能最好) + updated_count, inserted_count = self._sync_using_temp_table(session, all_account_data) + + elapsed = time.time() - start_time + logger.info(f"账户信息批量同步完成: 更新 {updated_count} 条,插入 {inserted_count} 条,耗时 {elapsed:.2f}秒") + + return updated_count, inserted_count + + except Exception as e: + logger.error(f"账户信息批量同步失败: {e}") + return 0, 0 + finally: + session.close() + + def _sync_using_temp_table(self, session, all_account_data: List[Dict]) -> Tuple[int, int]: + """使用临时表进行批量同步""" + try: + # 1. 创建临时表 + session.execute(text(""" + CREATE TEMPORARY TABLE IF NOT EXISTS temp_account_info ( + st_id INT, + k_id INT, + asset VARCHAR(32), + balance DECIMAL(20, 8), + withdrawal DECIMAL(20, 8), + deposit DECIMAL(20, 8), + other DECIMAL(20, 8), + profit DECIMAL(20, 8), + time INT, + PRIMARY KEY (k_id, st_id, time) + ) + """)) + + # 2. 清空临时表 + session.execute(text("TRUNCATE TABLE temp_account_info")) + + # 3. 批量插入数据到临时表 + chunk_size = 1000 + total_inserted = 0 + + for i in range(0, len(all_account_data), chunk_size): + chunk = all_account_data[i:i + chunk_size] + + values_list = [] + for data in chunk: + values = ( + f"({data['st_id']}, {data['k_id']}, 'USDT', " + f"{data['balance']}, {data['withdrawal']}, {data['deposit']}, " + f"{data['other']}, {data['profit']}, {data['time']})" + ) + values_list.append(values) + + if values_list: + values_str = ", ".join(values_list) + sql = f""" + INSERT INTO temp_account_info + (st_id, k_id, asset, balance, withdrawal, deposit, other, profit, time) + VALUES {values_str} + """ + session.execute(text(sql)) + total_inserted += len(chunk) + + # 4. 使用临时表更新主表 + # 更新已存在的记录 + update_result = session.execute(text(""" + UPDATE deh_strategy_kx_new main + INNER JOIN temp_account_info temp + ON main.k_id = temp.k_id + AND main.st_id = temp.st_id + AND main.time = temp.time + SET main.balance = temp.balance, + main.withdrawal = temp.withdrawal, + main.deposit = temp.deposit, + main.other = temp.other, + main.profit = temp.profit, + main.up_time = NOW() + """)) + updated_count = update_result.rowcount + + # 插入新记录 + insert_result = session.execute(text(""" + INSERT INTO deh_strategy_kx_new + (st_id, k_id, asset, balance, withdrawal, deposit, other, profit, time, up_time) + SELECT + st_id, k_id, asset, balance, withdrawal, deposit, other, profit, time, NOW() + FROM temp_account_info temp + WHERE NOT EXISTS ( + SELECT 1 FROM deh_strategy_kx_new main + WHERE main.k_id = temp.k_id + AND main.st_id = temp.st_id + AND main.time = temp.time + ) + """)) + inserted_count = insert_result.rowcount + + # 5. 删除临时表 + session.execute(text("DROP TEMPORARY TABLE IF EXISTS temp_account_info")) + + session.commit() + + return updated_count, inserted_count + + except Exception as e: + session.rollback() + logger.error(f"临时表同步失败: {e}") + raise + + def _sync_using_on_duplicate(self, session, all_account_data: List[Dict]) -> Tuple[int, int]: + """使用ON DUPLICATE KEY UPDATE批量同步(简化版)""" + try: + # 分块执行,避免SQL过长 + chunk_size = 1000 + total_processed = 0 + + for i in range(0, len(all_account_data), chunk_size): + chunk = all_account_data[i:i + chunk_size] + + values_list = [] + for data in chunk: + values = ( + f"({data['st_id']}, {data['k_id']}, 'USDT', " + f"{data['balance']}, {data['withdrawal']}, {data['deposit']}, " + f"{data['other']}, {data['profit']}, {data['time']})" + ) + values_list.append(values) + + if values_list: + values_str = ", ".join(values_list) + + sql = f""" + INSERT INTO deh_strategy_kx_new + (st_id, k_id, asset, balance, withdrawal, deposit, other, profit, time) + VALUES {values_str} + ON DUPLICATE KEY UPDATE + balance = VALUES(balance), + withdrawal = VALUES(withdrawal), + deposit = VALUES(deposit), + other = VALUES(other), + profit = VALUES(profit), + up_time = NOW() + """ + + result = session.execute(text(sql)) + total_processed += len(chunk) + + session.commit() + + # 注意:这里无法区分更新和插入的数量 + return total_processed, 0 + + except Exception as e: + session.rollback() + logger.error(f"ON DUPLICATE同步失败: {e}") + raise \ No newline at end of file diff --git a/utils/batch_order_sync.py b/utils/batch_order_sync.py new file mode 100644 index 0000000..ef6a68e --- /dev/null +++ b/utils/batch_order_sync.py @@ -0,0 +1,313 @@ +from typing import List, Dict, Any, Tuple +from loguru import logger +from sqlalchemy import text +import time + +class BatchOrderSync: + """订单数据批量同步工具(最高性能)""" + + def __init__(self, db_manager, batch_size: int = 1000): + self.db_manager = db_manager + self.batch_size = batch_size + + def sync_orders_batch(self, all_orders: List[Dict]) -> Tuple[bool, int]: + """批量同步订单数据""" + if not all_orders: + return True, 0 + + session = self.db_manager.get_session() + try: + start_time = time.time() + + # 方法1:使用临时表(性能最好) + processed_count = self._sync_using_temp_table(session, all_orders) + + elapsed = time.time() - start_time + logger.info(f"订单批量同步完成: 处理 {processed_count} 条订单,耗时 {elapsed:.2f}秒") + + return True, processed_count + + except Exception as e: + logger.error(f"订单批量同步失败: {e}") + return False, 0 + finally: + session.close() + + def _sync_using_temp_table(self, session, all_orders: List[Dict]) -> int: + """使用临时表批量同步订单""" + try: + # 1. 创建临时表 + session.execute(text(""" + CREATE TEMPORARY TABLE IF NOT EXISTS temp_orders ( + st_id INT, + k_id INT, + asset VARCHAR(32), + order_id VARCHAR(765), + symbol VARCHAR(120), + side VARCHAR(120), + price FLOAT, + time INT, + order_qty FLOAT, + last_qty FLOAT, + avg_price FLOAT, + exchange_id INT, + UNIQUE KEY idx_unique_order (order_id, symbol, k_id, side) + ) + """)) + + # 2. 清空临时表 + session.execute(text("TRUNCATE TABLE temp_orders")) + + # 3. 批量插入数据到临时表(分块) + inserted_count = self._batch_insert_to_temp_table(session, all_orders) + + if inserted_count == 0: + session.execute(text("DROP TEMPORARY TABLE IF EXISTS temp_orders")) + return 0 + + # 4. 使用临时表更新主表 + # 更新已存在的记录(只更新需要比较的字段) + update_result = session.execute(text(""" + UPDATE deh_strategy_order_new main + INNER JOIN temp_orders temp + ON main.order_id = temp.order_id + AND main.symbol = temp.symbol + AND main.k_id = temp.k_id + AND main.side = temp.side + SET main.side = temp.side, + main.price = temp.price, + main.time = temp.time, + main.order_qty = temp.order_qty, + main.last_qty = temp.last_qty, + main.avg_price = temp.avg_price + WHERE main.side != temp.side + OR main.price != temp.price + OR main.time != temp.time + OR main.order_qty != temp.order_qty + OR main.last_qty != temp.last_qty + OR main.avg_price != temp.avg_price + """)) + updated_count = update_result.rowcount + + # 插入新记录 + insert_result = session.execute(text(""" + INSERT INTO deh_strategy_order_new + (st_id, k_id, asset, order_id, symbol, side, price, time, + order_qty, last_qty, avg_price, exchange_id) + SELECT + st_id, k_id, asset, order_id, symbol, side, price, time, + order_qty, last_qty, avg_price, exchange_id + FROM temp_orders temp + WHERE NOT EXISTS ( + SELECT 1 FROM deh_strategy_order_new main + WHERE main.order_id = temp.order_id + AND main.symbol = temp.symbol + AND main.k_id = temp.k_id + AND main.side = temp.side + ) + """)) + inserted_count = insert_result.rowcount + + # 5. 删除临时表 + session.execute(text("DROP TEMPORARY TABLE IF EXISTS temp_orders")) + + session.commit() + + total_processed = updated_count + inserted_count + logger.info(f"订单批量同步: 更新 {updated_count} 条,插入 {inserted_count} 条") + + return total_processed + + except Exception as e: + session.rollback() + logger.error(f"临时表同步订单失败: {e}") + raise + + def _batch_insert_to_temp_table(self, session, all_orders: List[Dict]) -> int: + """批量插入数据到临时表""" + total_inserted = 0 + + try: + # 分块处理 + for i in range(0, len(all_orders), self.batch_size): + chunk = all_orders[i:i + self.batch_size] + + values_list = [] + for order in chunk: + try: + # 处理NULL值 + price = order.get('price') + time_val = order.get('time') + order_qty = order.get('order_qty') + last_qty = order.get('last_qty') + avg_price = order.get('avg_price') + # 转义单引号 + symbol = order.get('symbol').replace("'", "''") if order.get('symbol') else '' + order_id = order.get('order_id').replace("'", "''") if order.get('order_id') else '' + + values = ( + f"({order['st_id']}, {order['k_id']}, '{order.get('asset', 'USDT')}', " + f"'{order_id}', " + f"'{symbol}', " + f"'{order['side']}', " + f"{price if price is not None else 'NULL'}, " + f"{time_val if time_val is not None else 'NULL'}, " + f"{order_qty if order_qty is not None else 'NULL'}, " + f"{last_qty if last_qty is not None else 'NULL'}, " + f"{avg_price if avg_price is not None else 'NULL'}, " + "NULL)" + ) + values_list.append(values) + + except Exception as e: + logger.error(f"构建订单值失败: {order}, error={e}") + continue + + if values_list: + values_str = ", ".join(values_list) + + sql = f""" + INSERT INTO temp_orders + (st_id, k_id, asset, order_id, symbol, side, price, time, + order_qty, last_qty, avg_price, exchange_id) + VALUES {values_str} + """ + + result = session.execute(text(sql)) + total_inserted += len(chunk) + + return total_inserted + + except Exception as e: + logger.error(f"批量插入临时表失败: {e}") + raise + + def _batch_insert_to_temp_table1(self, session, all_orders: List[Dict]) -> int: + """批量插入数据到临时表(使用参数化查询)temp_orders""" + total_inserted = 0 + + try: + # 分块处理 + for i in range(0, len(all_orders), self.batch_size): + chunk = all_orders[i:i + self.batch_size] + + # 准备参数化数据 + insert_data = [] + for order in chunk: + try: + insert_data.append({ + 'st_id': order['st_id'], + 'k_id': order['k_id'], + 'asset': order.get('asset', 'USDT'), + 'order_id': order['order_id'], + 'symbol': order['symbol'], + 'side': order['side'], + 'price': order.get('price'), + 'time': order.get('time'), + 'order_qty': order.get('order_qty'), + 'last_qty': order.get('last_qty'), + 'avg_price': order.get('avg_price') + # exchange_id 留空,使用默认值NULL + }) + except KeyError as e: + logger.error(f"订单数据缺少必要字段: {order}, missing={e}") + continue + except Exception as e: + logger.error(f"处理订单数据失败: {order}, error={e}") + continue + + if insert_data: + sql = text(f""" + INSERT INTO {self.temp_table_name} + (st_id, k_id, asset, order_id, symbol, side, price, time, + order_qty, last_qty, avg_price) + VALUES + (:st_id, :k_id, :asset, :order_id, :symbol, :side, :price, :time, + :order_qty, :last_qty, :avg_price) + """) + + try: + session.execute(sql, insert_data) + session.commit() + total_inserted += len(insert_data) + logger.debug(f"插入 {len(insert_data)} 条数据到临时表") + except Exception as e: + session.rollback() + logger.error(f"执行批量插入失败: {e}") + raise + + logger.info(f"总共插入 {total_inserted} 条数据到临时表") + return total_inserted + + except Exception as e: + logger.error(f"批量插入临时表失败: {e}") + session.rollback() + raise + + + def _sync_using_on_duplicate(self, session, all_orders: List[Dict]) -> int: + """使用ON DUPLICATE KEY UPDATE批量同步(简化版)""" + try: + total_processed = 0 + + # 分块执行 + for i in range(0, len(all_orders), self.batch_size): + chunk = all_orders[i:i + self.batch_size] + + values_list = [] + for order in chunk: + try: + # 处理NULL值 + price = order.get('price') + time_val = order.get('time') + order_qty = order.get('order_qty') + last_qty = order.get('last_qty') + avg_price = order.get('avg_price') + symbol = order.get('symbol').replace("'", "''") if order.get('symbol') else '' + order_id = order.get('order_id').replace("'", "''") if order.get('order_id') else '' + + values = ( + f"({order['st_id']}, {order['k_id']}, '{order.get('asset', 'USDT')}', " + f"'{order_id}', " + f"'{symbol}', " + f"'{order['side']}', " + f"{price if price is not None else 'NULL'}, " + f"{time_val if time_val is not None else 'NULL'}, " + f"{order_qty if order_qty is not None else 'NULL'}, " + f"{last_qty if last_qty is not None else 'NULL'}, " + f"{avg_price if avg_price is not None else 'NULL'}, " + "NULL)" + ) + values_list.append(values) + + except Exception as e: + logger.error(f"构建订单值失败: {order}, error={e}") + continue + + if values_list: + values_str = ", ".join(values_list) + + sql = f""" + INSERT INTO deh_strategy_order_new + (st_id, k_id, asset, order_id, symbol, side, price, time, + order_qty, last_qty, avg_price, exchange_id) + VALUES {values_str} + ON DUPLICATE KEY UPDATE + side = VALUES(side), + price = VALUES(price), + time = VALUES(time), + order_qty = VALUES(order_qty), + last_qty = VALUES(last_qty), + avg_price = VALUES(avg_price) + """ + + session.execute(text(sql)) + total_processed += len(chunk) + + session.commit() + return total_processed + + except Exception as e: + session.rollback() + logger.error(f"ON DUPLICATE同步订单失败: {e}") + raise \ No newline at end of file diff --git a/utils/batch_position_sync.py b/utils/batch_position_sync.py new file mode 100644 index 0000000..c462aff --- /dev/null +++ b/utils/batch_position_sync.py @@ -0,0 +1,254 @@ +from typing import List, Dict, Any, Tuple +from loguru import logger +from sqlalchemy import text +import time + +class BatchPositionSync: + """持仓数据批量同步工具(使用临时表,最高性能)""" + + def __init__(self, db_manager, batch_size: int = 500): + self.db_manager = db_manager + self.batch_size = batch_size + + def sync_positions_batch(self, all_positions: List[Dict]) -> Tuple[bool, Dict]: + """批量同步持仓数据(最高效版本)""" + if not all_positions: + return True, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0} + + session = self.db_manager.get_session() + try: + start_time = time.time() + + # 按账号分组 + positions_by_account = self._group_positions_by_account(all_positions) + + total_stats = {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0} + + with session.begin(): + # 处理每个账号 + for (k_id, st_id), positions in positions_by_account.items(): + success, stats = self._sync_account_using_temp_table( + session, k_id, st_id, positions + ) + + if success: + total_stats['total'] += stats['total'] + total_stats['updated'] += stats['updated'] + total_stats['inserted'] += stats['inserted'] + total_stats['deleted'] += stats['deleted'] + + elapsed = time.time() - start_time + logger.info(f"持仓批量同步完成: 处理 {len(positions_by_account)} 个账号," + f"总持仓 {total_stats['total']} 条,耗时 {elapsed:.2f}秒") + + return True, total_stats + + except Exception as e: + logger.error(f"持仓批量同步失败: {e}") + return False, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0} + finally: + session.close() + + def _group_positions_by_account(self, all_positions: List[Dict]) -> Dict[Tuple[int, int], List[Dict]]: + """按账号分组持仓数据""" + groups = {} + for position in all_positions: + k_id = position.get('k_id') + st_id = position.get('st_id', 0) + key = (k_id, st_id) + + if key not in groups: + groups[key] = [] + groups[key].append(position) + + return groups + + def _sync_account_using_temp_table(self, session, k_id: int, st_id: int, positions: List[Dict]) -> Tuple[bool, Dict]: + """使用临时表同步单个账号的持仓数据""" + try: + # 1. 创建临时表 + session.execute(text(""" + CREATE TEMPORARY TABLE IF NOT EXISTS temp_positions ( + st_id INT, + k_id INT, + asset VARCHAR(32), + symbol VARCHAR(50), + side VARCHAR(10), + price FLOAT, + `sum` FLOAT, + asset_num DECIMAL(20, 8), + asset_profit DECIMAL(20, 8), + leverage INT, + uptime INT, + profit_price DECIMAL(20, 8), + stop_price DECIMAL(20, 8), + liquidation_price DECIMAL(20, 8), + PRIMARY KEY (k_id, st_id, symbol, side) + ) + """)) + + # 2. 清空临时表 + session.execute(text("TRUNCATE TABLE temp_positions")) + + # 3. 批量插入数据到临时表 + self._batch_insert_to_temp_table(session, positions) + + # 4. 使用临时表更新主表 + # 更新已存在的记录 + update_result = session.execute(text(f""" + UPDATE deh_strategy_position_new main + INNER JOIN temp_positions temp + ON main.k_id = temp.k_id + AND main.st_id = temp.st_id + AND main.symbol = temp.symbol + AND main.side = temp.side + SET main.price = temp.price, + main.`sum` = temp.`sum`, + main.asset_num = temp.asset_num, + main.asset_profit = temp.asset_profit, + main.leverage = temp.leverage, + main.uptime = temp.uptime, + main.profit_price = temp.profit_price, + main.stop_price = temp.stop_price, + main.liquidation_price = temp.liquidation_price + WHERE main.k_id = {k_id} AND main.st_id = {st_id} + """)) + updated_count = update_result.rowcount + + # 插入新记录 + insert_result = session.execute(text(f""" + INSERT INTO deh_strategy_position_new + (st_id, k_id, asset, symbol, side, price, `sum`, + asset_num, asset_profit, leverage, uptime, + profit_price, stop_price, liquidation_price) + SELECT + st_id, k_id, asset, symbol, side, price, `sum`, + asset_num, asset_profit, leverage, uptime, + profit_price, stop_price, liquidation_price + FROM temp_positions temp + WHERE NOT EXISTS ( + SELECT 1 FROM deh_strategy_position_new main + WHERE main.k_id = temp.k_id + AND main.st_id = temp.st_id + AND main.symbol = temp.symbol + AND main.side = temp.side + ) + AND temp.k_id = {k_id} AND temp.st_id = {st_id} + """)) + inserted_count = insert_result.rowcount + + # 5. 删除多余持仓(在临时表中不存在但在主表中存在的) + delete_result = session.execute(text(f""" + DELETE main + FROM deh_strategy_position_new main + LEFT JOIN temp_positions temp + ON main.k_id = temp.k_id + AND main.st_id = temp.st_id + AND main.symbol = temp.symbol + AND main.side = temp.side + WHERE main.k_id = {k_id} AND main.st_id = {st_id} + AND temp.symbol IS NULL + """)) + deleted_count = delete_result.rowcount + + # 6. 删除临时表 + session.execute(text("DROP TEMPORARY TABLE IF EXISTS temp_positions")) + + stats = { + 'total': len(positions), + 'updated': updated_count, + 'inserted': inserted_count, + 'deleted': deleted_count + } + + logger.debug(f"账号({k_id},{st_id})持仓同步: 更新{updated_count} 插入{inserted_count} 删除{deleted_count}") + + return True, stats + + except Exception as e: + logger.error(f"临时表同步账号({k_id},{st_id})持仓失败: {e}") + return False, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0} + + def _batch_insert_to_temp_table(self, session, positions: List[Dict]): + """批量插入数据到临时表(使用参数化查询)""" + if not positions: + return + + # 分块处理 + for i in range(0, len(positions), self.batch_size): + chunk = positions[i:i + self.batch_size] + + # 准备参数化数据 + insert_data = [] + for position in chunk: + try: + data = self._convert_position_for_temp(position) + if not all([data.get('symbol'), data.get('side')]): + continue + + insert_data.append({ + 'st_id': data['st_id'], + 'k_id': data['k_id'], + 'asset': data.get('asset', 'USDT'), + 'symbol': data['symbol'], + 'side': data['side'], + 'price': data.get('price'), + 'sum_val': data.get('sum'), # 注意字段名 + 'asset_num': data.get('asset_num'), + 'asset_profit': data.get('asset_profit'), + 'leverage': data.get('leverage'), + 'uptime': data.get('uptime'), + 'profit_price': data.get('profit_price'), + 'stop_price': data.get('stop_price'), + 'liquidation_price': data.get('liquidation_price') + }) + + except Exception as e: + logger.error(f"转换持仓数据失败: {position}, error={e}") + continue + + if insert_data: + sql = """ + INSERT INTO temp_positions + (st_id, k_id, asset, symbol, side, price, `sum`, + asset_num, asset_profit, leverage, uptime, + profit_price, stop_price, liquidation_price) + VALUES + (:st_id, :k_id, :asset, :symbol, :side, :price, :sum_val, + :asset_num, :asset_profit, :leverage, :uptime, + :profit_price, :stop_price, :liquidation_price) + """ + + session.execute(text(sql), insert_data) + + def _convert_position_for_temp(self, data: Dict) -> Dict: + """转换持仓数据格式用于临时表""" + # 使用安全转换 + def safe_float(value): + try: + return float(value) if value is not None else None + except: + return None + + def safe_int(value): + try: + return int(value) if value is not None else None + except: + return None + + return { + 'st_id': safe_int(data.get('st_id')) or 0, + 'k_id': safe_int(data.get('k_id')) or 0, + 'asset': data.get('asset', 'USDT'), + 'symbol': str(data.get('symbol', '')), + 'side': str(data.get('side', '')), + 'price': safe_float(data.get('price')), + 'sum': safe_float(data.get('qty')), # 注意:这里直接使用sum + 'asset_num': safe_float(data.get('asset_num')), + 'asset_profit': safe_float(data.get('asset_profit')), + 'leverage': safe_int(data.get('leverage')), + 'uptime': safe_int(data.get('uptime')), + 'profit_price': safe_float(data.get('profit_price')), + 'stop_price': safe_float(data.get('stop_price')), + 'liquidation_price': safe_float(data.get('liquidation_price')) + } \ No newline at end of file diff --git a/utils/redis_batch_helper.py b/utils/redis_batch_helper.py new file mode 100644 index 0000000..9037def --- /dev/null +++ b/utils/redis_batch_helper.py @@ -0,0 +1,129 @@ +import redis +from loguru import logger +from typing import List, Dict, Tuple +import json +import time +from datetime import datetime, timedelta + +class RedisBatchHelper: + """Redis批量数据获取助手""" + + def __init__(self, redis_client): + self.redis_client = redis_client + + def get_recent_orders_batch(self, exchange_id: str, account_list: List[Tuple[int, int]], + recent_days: int = 3) -> List[Dict]: + """批量获取多个账号的最近订单数据(优化内存使用)""" + all_orders = [] + + try: + # 分批处理账号,避免内存过大 + batch_size = 20 # 每批处理20个账号 + for i in range(0, len(account_list), batch_size): + batch_accounts = account_list[i:i + batch_size] + + # 并发获取这批账号的数据 + batch_orders = self._get_batch_accounts_orders(exchange_id, batch_accounts, recent_days) + all_orders.extend(batch_orders) + + # 批次间休息,避免Redis压力过大 + if i + batch_size < len(account_list): + time.sleep(0.05) + + logger.info(f"批量获取订单完成: {len(account_list)}个账号,{len(all_orders)}条订单") + + except Exception as e: + logger.error(f"批量获取订单失败: {e}") + + return all_orders + + def _get_batch_accounts_orders(self, exchange_id: str, account_list: List[Tuple[int, int]], + recent_days: int) -> List[Dict]: + """获取一批账号的订单数据""" + batch_orders = [] + + try: + # 计算最近日期 + today = datetime.now() + recent_dates = [] + for i in range(recent_days): + date = today - timedelta(days=i) + recent_dates.append(date.strftime('%Y-%m-%d')) + + # 为每个账号构建key列表 + all_keys = [] + key_to_account = {} + + for k_id, st_id in account_list: + redis_key = f"{exchange_id}:orders:{k_id}" + + # 获取该账号的所有key + try: + account_keys = self.redis_client.hkeys(redis_key) + + for key in account_keys: + key_str = key.decode('utf-8') if isinstance(key, bytes) else key + + if key_str == 'positions': + continue + + # 检查是否是最近日期 + for date_format in recent_dates: + if key_str.startswith(date_format + '_'): + all_keys.append((redis_key, key_str)) + key_to_account[(redis_key, key_str)] = (k_id, st_id) + break + + except Exception as e: + logger.error(f"获取账号 {k_id} 的key失败: {e}") + continue + + if not all_keys: + return batch_orders + + # 分批获取订单数据 + chunk_size = 500 + for i in range(0, len(all_keys), chunk_size): + chunk = all_keys[i:i + chunk_size] + + # 按redis_key分组,使用hmget批量获取 + keys_by_redis_key = {} + for redis_key, key_str in chunk: + if redis_key not in keys_by_redis_key: + keys_by_redis_key[redis_key] = [] + keys_by_redis_key[redis_key].append(key_str) + + # 为每个redis_key批量获取 + for redis_key, key_list in keys_by_redis_key.items(): + try: + values = self.redis_client.hmget(redis_key, key_list) + + for key_str, order_json in zip(key_list, values): + if not order_json: + continue + + try: + order = json.loads(order_json) + + # 验证时间 + order_time = order.get('time', 0) + if order_time >= int(time.time()) - recent_days * 24 * 3600: + # 添加账号信息 + k_id, st_id = key_to_account.get((redis_key, key_str), (0, 0)) + order['k_id'] = k_id + order['st_id'] = st_id + order['exchange_id'] = exchange_id + batch_orders.append(order) + + except json.JSONDecodeError as e: + logger.debug(f"解析订单JSON失败: key={key_str}, error={e}") + continue + + except Exception as e: + logger.error(f"批量获取Redis数据失败: {redis_key}, error={e}") + continue + + except Exception as e: + logger.error(f"获取批量账号订单失败: {e}") + + return batch_orders \ No newline at end of file