diff --git a/sync/account_sync.py b/sync/account_sync.py index 75b83ae..d35ac04 100644 --- a/sync/account_sync.py +++ b/sync/account_sync.py @@ -1,48 +1,99 @@ from .base_sync import BaseSync from loguru import logger -from typing import List, Dict +from typing import List, Dict, Any, Set import json import time from datetime import datetime, timedelta +from sqlalchemy import text, and_ +from models.orm_models import StrategyKX -class AccountSync(BaseSync): - """账户信息同步器""" +class AccountSyncBatch(BaseSync): + """账户信息批量同步器""" - async def sync(self): - """同步账户信息数据""" + async def sync_batch(self, accounts: Dict[str, Dict]): + """批量同步所有账号的账户信息""" try: - # 获取所有账号 - accounts = self.get_accounts_from_redis() + logger.info(f"开始批量同步账户信息,共 {len(accounts)} 个账号") - for k_id_str, account_info in accounts.items(): - try: - k_id = int(k_id_str) - st_id = account_info.get('st_id', 0) - exchange_id = account_info['exchange_id'] - - if k_id <= 0 or st_id <= 0: - continue - - # 从Redis获取账户信息数据 - account_data = await self._get_account_info_from_redis(k_id, st_id, exchange_id) - - # 同步到数据库 - if account_data: - success = self._sync_account_info_to_db(account_data) - if success: - logger.debug(f"账户信息同步成功: k_id={k_id}") - - except Exception as e: - logger.error(f"同步账号 {k_id_str} 账户信息失败: {e}") - continue + # 收集所有账号的数据 + all_account_data = await self._collect_all_account_data(accounts) - logger.info("账户信息同步完成") + if not all_account_data: + logger.info("无账户信息数据需要同步") + return + + # 批量同步到数据库 + success = await self._sync_account_info_batch_to_db(all_account_data) + + if success: + logger.info(f"账户信息批量同步完成: 处理 {len(all_account_data)} 条记录") + else: + logger.error("账户信息批量同步失败") except Exception as e: - logger.error(f"账户信息同步失败: {e}") + logger.error(f"账户信息批量同步失败: {e}") + + async def _collect_all_account_data(self, accounts: Dict[str, Dict]) -> List[Dict]: + """收集所有账号的账户信息数据""" + all_account_data = [] + + try: + # 按交易所分组账号 + account_groups = self._group_accounts_by_exchange(accounts) + + # 并发收集每个交易所的数据 + tasks = [] + for exchange_id, account_list in account_groups.items(): + task = self._collect_exchange_account_data(exchange_id, account_list) + tasks.append(task) + + # 等待所有任务完成并合并结果 + results = await asyncio.gather(*tasks, return_exceptions=True) + + for result in results: + if isinstance(result, list): + all_account_data.extend(result) + + logger.info(f"收集到 {len(all_account_data)} 条账户信息记录") + + except Exception as e: + logger.error(f"收集账户信息数据失败: {e}") + + return all_account_data + + def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]: + """按交易所分组账号""" + groups = {} + for account_id, account_info in accounts.items(): + exchange_id = account_info.get('exchange_id') + if exchange_id: + if exchange_id not in groups: + groups[exchange_id] = [] + groups[exchange_id].append(account_info) + return groups + + async def _collect_exchange_account_data(self, exchange_id: str, account_list: List[Dict]) -> List[Dict]: + """收集某个交易所的账户信息数据""" + account_data_list = [] + + try: + for account_info in account_list: + k_id = int(account_info['k_id']) + st_id = account_info.get('st_id', 0) + + # 从Redis获取账户信息数据 + account_data = await self._get_account_info_from_redis(k_id, st_id, exchange_id) + account_data_list.extend(account_data) + + logger.debug(f"交易所 {exchange_id}: 收集到 {len(account_data_list)} 条账户信息") + + except Exception as e: + logger.error(f"收集交易所 {exchange_id} 账户信息失败: {e}") + + return account_data_list async def _get_account_info_from_redis(self, k_id: int, st_id: int, exchange_id: str) -> List[Dict]: - """从Redis获取账户信息数据""" + """从Redis获取账户信息数据(批量优化版本)""" try: redis_key = f"{exchange_id}:balance:{k_id}" redis_funds = self.redis_client.client.hgetall(redis_key) @@ -97,7 +148,9 @@ class AccountSync(BaseSync): # 转换为账户信息数据 account_data_list = [] sorted_dates = sorted(date_stats.keys()) - prev_balance = 0.0 + + # 获取前一天余额用于计算利润 + prev_balance_map = self._get_previous_balances(redis_funds, sorted_dates) for date_str in sorted_dates: stats = date_stats[date_str] @@ -111,6 +164,7 @@ class AccountSync(BaseSync): withdrawal = stats['withdrawal'] # 计算利润 + prev_balance = prev_balance_map.get(date_str, 0.0) profit = balance - deposit - withdrawal - prev_balance # 转换时间戳 @@ -129,10 +183,6 @@ class AccountSync(BaseSync): } account_data_list.append(account_data) - - # 更新前一天的余额 - if stats['has_balance']: - prev_balance = balance return account_data_list @@ -140,44 +190,177 @@ class AccountSync(BaseSync): logger.error(f"获取Redis账户信息失败: k_id={k_id}, error={e}") return [] - def _sync_account_info_to_db(self, account_data_list: List[Dict]) -> bool: - """同步账户信息到数据库""" + def _get_previous_balances(self, redis_funds: Dict, sorted_dates: List[str]) -> Dict[str, float]: + """获取前一天的余额""" + prev_balance_map = {} + prev_date = None + + for date_str in sorted_dates: + # 查找前一天的余额 + if prev_date: + for fund_key, fund_json in redis_funds.items(): + try: + fund_data = json.loads(fund_json) + if (fund_data.get('lz_time') == prev_date and + fund_data.get('lz_type') == 'lz_balance'): + prev_balance_map[date_str] = float(fund_data.get('lz_amount', 0)) + break + except: + continue + else: + prev_balance_map[date_str] = 0.0 + + prev_date = date_str + + return prev_balance_map + + async def _sync_account_info_batch_to_db(self, account_data_list: List[Dict]) -> bool: + """批量同步账户信息到数据库(最高效版本)""" session = self.db_manager.get_session() try: - with session.begin(): - for account_data in account_data_list: - try: - # 查询是否已存在 - existing = session.execute( - select(StrategyKX).where( - and_( - StrategyKX.k_id == account_data['k_id'], - StrategyKX.st_id == account_data['st_id'], - StrategyKX.time == account_data['time'] - ) - ) - ).scalar_one_or_none() - - if existing: - # 更新 - existing.balance = account_data['balance'] - existing.withdrawal = account_data['withdrawal'] - existing.deposit = account_data['deposit'] - existing.other = account_data['other'] - existing.profit = account_data['profit'] - else: - # 插入 - new_account = StrategyKX(**account_data) - session.add(new_account) - - except Exception as e: - logger.error(f"处理账户数据失败: {account_data}, error={e}") - continue + if not account_data_list: + return True + with session.begin(): + # 方法1:使用原生SQL批量插入/更新(性能最好) + success = self._batch_upsert_account_info(session, account_data_list) + + if not success: + # 方法2:回退到ORM批量操作 + success = self._batch_orm_upsert_account_info(session, account_data_list) + + return success + + except Exception as e: + logger.error(f"批量同步账户信息到数据库失败: {e}") + return False + finally: + session.close() + + def _batch_upsert_account_info(self, session, account_data_list: List[Dict]) -> bool: + """使用原生SQL批量插入/更新账户信息""" + try: + # 准备批量数据 + values_list = [] + for data in account_data_list: + values = ( + f"({data['st_id']}, {data['k_id']}, 'USDT', " + f"{data['balance']}, {data['withdrawal']}, {data['deposit']}, " + f"{data['other']}, {data['profit']}, {data['time']})" + ) + values_list.append(values) + + if not values_list: + return True + + values_str = ", ".join(values_list) + + # 使用INSERT ... ON DUPLICATE KEY UPDATE + sql = f""" + INSERT INTO deh_strategy_kx_new + (st_id, k_id, asset, balance, withdrawal, deposit, other, profit, time) + VALUES {values_str} + ON DUPLICATE KEY UPDATE + balance = VALUES(balance), + withdrawal = VALUES(withdrawal), + deposit = VALUES(deposit), + other = VALUES(other), + profit = VALUES(profit), + up_time = NOW() + """ + + session.execute(text(sql)) + + logger.info(f"原生SQL批量更新账户信息: {len(account_data_list)} 条记录") return True except Exception as e: - logger.error(f"同步账户信息到数据库失败: error={e}") + logger.error(f"原生SQL批量更新账户信息失败: {e}") return False - finally: - session.close() \ No newline at end of file + + def _batch_orm_upsert_account_info(self, session, account_data_list: List[Dict]) -> bool: + """使用ORM批量插入/更新账户信息""" + try: + # 分组数据以提高效率 + account_data_by_key = {} + for data in account_data_list: + key = (data['k_id'], data['st_id'], data['time']) + account_data_by_key[key] = data + + # 批量查询现有记录 + existing_records = self._batch_query_existing_records(session, list(account_data_by_key.keys())) + + # 批量更新或插入 + to_update = [] + to_insert = [] + + for key, data in account_data_by_key.items(): + if key in existing_records: + # 更新 + record = existing_records[key] + record.balance = data['balance'] + record.withdrawal = data['withdrawal'] + record.deposit = data['deposit'] + record.other = data['other'] + record.profit = data['profit'] + else: + # 插入 + to_insert.append(StrategyKX(**data)) + + # 批量插入新记录 + if to_insert: + session.add_all(to_insert) + + logger.info(f"ORM批量更新账户信息: 更新 {len(existing_records)} 条,插入 {len(to_insert)} 条") + return True + + except Exception as e: + logger.error(f"ORM批量更新账户信息失败: {e}") + return False + + def _batch_query_existing_records(self, session, keys: List[tuple]) -> Dict[tuple, StrategyKX]: + """批量查询现有记录""" + existing_records = {} + + try: + if not keys: + return existing_records + + # 构建查询条件 + conditions = [] + for k_id, st_id, time_val in keys: + conditions.append(f"(k_id = {k_id} AND st_id = {st_id} AND time = {time_val})") + + if conditions: + conditions_str = " OR ".join(conditions) + sql = f""" + SELECT * FROM deh_strategy_kx_new + WHERE {conditions_str} + """ + + results = session.execute(text(sql)).fetchall() + + for row in results: + key = (row.k_id, row.st_id, row.time) + existing_records[key] = StrategyKX( + id=row.id, + st_id=row.st_id, + k_id=row.k_id, + asset=row.asset, + balance=row.balance, + withdrawal=row.withdrawal, + deposit=row.deposit, + other=row.other, + profit=row.profit, + time=row.time + ) + + except Exception as e: + logger.error(f"批量查询现有记录失败: {e}") + + return existing_records + + async def sync(self): + """兼容旧接口""" + accounts = self.get_accounts_from_redis() + await self.sync_batch(accounts) \ No newline at end of file diff --git a/sync/account_sync_batch.py b/sync/account_sync_batch.py deleted file mode 100644 index d35ac04..0000000 --- a/sync/account_sync_batch.py +++ /dev/null @@ -1,366 +0,0 @@ -from .base_sync import BaseSync -from loguru import logger -from typing import List, Dict, Any, Set -import json -import time -from datetime import datetime, timedelta -from sqlalchemy import text, and_ -from models.orm_models import StrategyKX - -class AccountSyncBatch(BaseSync): - """账户信息批量同步器""" - - async def sync_batch(self, accounts: Dict[str, Dict]): - """批量同步所有账号的账户信息""" - try: - logger.info(f"开始批量同步账户信息,共 {len(accounts)} 个账号") - - # 收集所有账号的数据 - all_account_data = await self._collect_all_account_data(accounts) - - if not all_account_data: - logger.info("无账户信息数据需要同步") - return - - # 批量同步到数据库 - success = await self._sync_account_info_batch_to_db(all_account_data) - - if success: - logger.info(f"账户信息批量同步完成: 处理 {len(all_account_data)} 条记录") - else: - logger.error("账户信息批量同步失败") - - except Exception as e: - logger.error(f"账户信息批量同步失败: {e}") - - async def _collect_all_account_data(self, accounts: Dict[str, Dict]) -> List[Dict]: - """收集所有账号的账户信息数据""" - all_account_data = [] - - try: - # 按交易所分组账号 - account_groups = self._group_accounts_by_exchange(accounts) - - # 并发收集每个交易所的数据 - tasks = [] - for exchange_id, account_list in account_groups.items(): - task = self._collect_exchange_account_data(exchange_id, account_list) - tasks.append(task) - - # 等待所有任务完成并合并结果 - results = await asyncio.gather(*tasks, return_exceptions=True) - - for result in results: - if isinstance(result, list): - all_account_data.extend(result) - - logger.info(f"收集到 {len(all_account_data)} 条账户信息记录") - - except Exception as e: - logger.error(f"收集账户信息数据失败: {e}") - - return all_account_data - - def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]: - """按交易所分组账号""" - groups = {} - for account_id, account_info in accounts.items(): - exchange_id = account_info.get('exchange_id') - if exchange_id: - if exchange_id not in groups: - groups[exchange_id] = [] - groups[exchange_id].append(account_info) - return groups - - async def _collect_exchange_account_data(self, exchange_id: str, account_list: List[Dict]) -> List[Dict]: - """收集某个交易所的账户信息数据""" - account_data_list = [] - - try: - for account_info in account_list: - k_id = int(account_info['k_id']) - st_id = account_info.get('st_id', 0) - - # 从Redis获取账户信息数据 - account_data = await self._get_account_info_from_redis(k_id, st_id, exchange_id) - account_data_list.extend(account_data) - - logger.debug(f"交易所 {exchange_id}: 收集到 {len(account_data_list)} 条账户信息") - - except Exception as e: - logger.error(f"收集交易所 {exchange_id} 账户信息失败: {e}") - - return account_data_list - - async def _get_account_info_from_redis(self, k_id: int, st_id: int, exchange_id: str) -> List[Dict]: - """从Redis获取账户信息数据(批量优化版本)""" - try: - redis_key = f"{exchange_id}:balance:{k_id}" - redis_funds = self.redis_client.client.hgetall(redis_key) - - if not redis_funds: - return [] - - # 按天统计数据 - from config.settings import SYNC_CONFIG - recent_days = SYNC_CONFIG['recent_days'] - - today = datetime.now() - date_stats = {} - - # 收集所有日期的数据 - for fund_key, fund_json in redis_funds.items(): - try: - fund_data = json.loads(fund_json) - date_str = fund_data.get('lz_time', '') - lz_type = fund_data.get('lz_type', '') - - if not date_str or lz_type not in ['lz_balance', 'deposit', 'withdrawal']: - continue - - # 只处理最近N天的数据 - date_obj = datetime.strptime(date_str, '%Y-%m-%d') - if (today - date_obj).days > recent_days: - continue - - if date_str not in date_stats: - date_stats[date_str] = { - 'balance': 0.0, - 'deposit': 0.0, - 'withdrawal': 0.0, - 'has_balance': False - } - - lz_amount = float(fund_data.get('lz_amount', 0)) - - if lz_type == 'lz_balance': - date_stats[date_str]['balance'] = lz_amount - date_stats[date_str]['has_balance'] = True - elif lz_type == 'deposit': - date_stats[date_str]['deposit'] += lz_amount - elif lz_type == 'withdrawal': - date_stats[date_str]['withdrawal'] += lz_amount - - except (json.JSONDecodeError, ValueError) as e: - logger.debug(f"解析Redis数据失败: {fund_key}, error={e}") - continue - - # 转换为账户信息数据 - account_data_list = [] - sorted_dates = sorted(date_stats.keys()) - - # 获取前一天余额用于计算利润 - prev_balance_map = self._get_previous_balances(redis_funds, sorted_dates) - - for date_str in sorted_dates: - stats = date_stats[date_str] - - # 如果没有余额数据但有充提数据,仍然处理 - if not stats['has_balance'] and stats['deposit'] == 0 and stats['withdrawal'] == 0: - continue - - balance = stats['balance'] - deposit = stats['deposit'] - withdrawal = stats['withdrawal'] - - # 计算利润 - prev_balance = prev_balance_map.get(date_str, 0.0) - profit = balance - deposit - withdrawal - prev_balance - - # 转换时间戳 - date_obj = datetime.strptime(date_str, '%Y-%m-%d') - time_timestamp = int(date_obj.timestamp()) - - account_data = { - 'st_id': st_id, - 'k_id': k_id, - 'balance': balance, - 'withdrawal': withdrawal, - 'deposit': deposit, - 'other': 0.0, # 暂时为0 - 'profit': profit, - 'time': time_timestamp - } - - account_data_list.append(account_data) - - return account_data_list - - except Exception as e: - logger.error(f"获取Redis账户信息失败: k_id={k_id}, error={e}") - return [] - - def _get_previous_balances(self, redis_funds: Dict, sorted_dates: List[str]) -> Dict[str, float]: - """获取前一天的余额""" - prev_balance_map = {} - prev_date = None - - for date_str in sorted_dates: - # 查找前一天的余额 - if prev_date: - for fund_key, fund_json in redis_funds.items(): - try: - fund_data = json.loads(fund_json) - if (fund_data.get('lz_time') == prev_date and - fund_data.get('lz_type') == 'lz_balance'): - prev_balance_map[date_str] = float(fund_data.get('lz_amount', 0)) - break - except: - continue - else: - prev_balance_map[date_str] = 0.0 - - prev_date = date_str - - return prev_balance_map - - async def _sync_account_info_batch_to_db(self, account_data_list: List[Dict]) -> bool: - """批量同步账户信息到数据库(最高效版本)""" - session = self.db_manager.get_session() - try: - if not account_data_list: - return True - - with session.begin(): - # 方法1:使用原生SQL批量插入/更新(性能最好) - success = self._batch_upsert_account_info(session, account_data_list) - - if not success: - # 方法2:回退到ORM批量操作 - success = self._batch_orm_upsert_account_info(session, account_data_list) - - return success - - except Exception as e: - logger.error(f"批量同步账户信息到数据库失败: {e}") - return False - finally: - session.close() - - def _batch_upsert_account_info(self, session, account_data_list: List[Dict]) -> bool: - """使用原生SQL批量插入/更新账户信息""" - try: - # 准备批量数据 - values_list = [] - for data in account_data_list: - values = ( - f"({data['st_id']}, {data['k_id']}, 'USDT', " - f"{data['balance']}, {data['withdrawal']}, {data['deposit']}, " - f"{data['other']}, {data['profit']}, {data['time']})" - ) - values_list.append(values) - - if not values_list: - return True - - values_str = ", ".join(values_list) - - # 使用INSERT ... ON DUPLICATE KEY UPDATE - sql = f""" - INSERT INTO deh_strategy_kx_new - (st_id, k_id, asset, balance, withdrawal, deposit, other, profit, time) - VALUES {values_str} - ON DUPLICATE KEY UPDATE - balance = VALUES(balance), - withdrawal = VALUES(withdrawal), - deposit = VALUES(deposit), - other = VALUES(other), - profit = VALUES(profit), - up_time = NOW() - """ - - session.execute(text(sql)) - - logger.info(f"原生SQL批量更新账户信息: {len(account_data_list)} 条记录") - return True - - except Exception as e: - logger.error(f"原生SQL批量更新账户信息失败: {e}") - return False - - def _batch_orm_upsert_account_info(self, session, account_data_list: List[Dict]) -> bool: - """使用ORM批量插入/更新账户信息""" - try: - # 分组数据以提高效率 - account_data_by_key = {} - for data in account_data_list: - key = (data['k_id'], data['st_id'], data['time']) - account_data_by_key[key] = data - - # 批量查询现有记录 - existing_records = self._batch_query_existing_records(session, list(account_data_by_key.keys())) - - # 批量更新或插入 - to_update = [] - to_insert = [] - - for key, data in account_data_by_key.items(): - if key in existing_records: - # 更新 - record = existing_records[key] - record.balance = data['balance'] - record.withdrawal = data['withdrawal'] - record.deposit = data['deposit'] - record.other = data['other'] - record.profit = data['profit'] - else: - # 插入 - to_insert.append(StrategyKX(**data)) - - # 批量插入新记录 - if to_insert: - session.add_all(to_insert) - - logger.info(f"ORM批量更新账户信息: 更新 {len(existing_records)} 条,插入 {len(to_insert)} 条") - return True - - except Exception as e: - logger.error(f"ORM批量更新账户信息失败: {e}") - return False - - def _batch_query_existing_records(self, session, keys: List[tuple]) -> Dict[tuple, StrategyKX]: - """批量查询现有记录""" - existing_records = {} - - try: - if not keys: - return existing_records - - # 构建查询条件 - conditions = [] - for k_id, st_id, time_val in keys: - conditions.append(f"(k_id = {k_id} AND st_id = {st_id} AND time = {time_val})") - - if conditions: - conditions_str = " OR ".join(conditions) - sql = f""" - SELECT * FROM deh_strategy_kx_new - WHERE {conditions_str} - """ - - results = session.execute(text(sql)).fetchall() - - for row in results: - key = (row.k_id, row.st_id, row.time) - existing_records[key] = StrategyKX( - id=row.id, - st_id=row.st_id, - k_id=row.k_id, - asset=row.asset, - balance=row.balance, - withdrawal=row.withdrawal, - deposit=row.deposit, - other=row.other, - profit=row.profit, - time=row.time - ) - - except Exception as e: - logger.error(f"批量查询现有记录失败: {e}") - - return existing_records - - async def sync(self): - """兼容旧接口""" - accounts = self.get_accounts_from_redis() - await self.sync_batch(accounts) \ No newline at end of file diff --git a/sync/base_sync.py b/sync/base_sync.py index 8afb734..265ebe5 100644 --- a/sync/base_sync.py +++ b/sync/base_sync.py @@ -44,211 +44,6 @@ class BaseSync(ABC): """批量同步数据""" pass - def get_accounts_from_redis(self) -> Dict[str, Dict]: - """从Redis获取所有计算机名的账号配置""" - try: - accounts_dict = {} - total_keys_processed = 0 - - # 方法1:使用配置的计算机名列表 - for computer_name in self.computer_names: - accounts = self._get_accounts_by_computer_name(computer_name) - total_keys_processed += 1 - accounts_dict.update(accounts) - - # 方法2:如果配置的计算机名没有数据,尝试自动发现(备用方案) - if not accounts_dict: - logger.warning("配置的计算机名未找到数据,尝试自动发现...") - accounts_dict = self._discover_all_accounts() - - self.sync_stats['total_accounts'] = len(accounts_dict) - logger.info(f"从 {len(self.computer_names)} 个计算机名获取到 {len(accounts_dict)} 个账号") - - return accounts_dict - - except Exception as e: - logger.error(f"获取账户信息失败: {e}") - return {} - - def _get_accounts_by_computer_name(self, computer_name: str) -> Dict[str, Dict]: - """获取指定计算机名的账号""" - accounts_dict = {} - - try: - # 构建key - redis_key = f"{computer_name}_strategy_api" - - # 从Redis获取数据 - result = self.redis_client.client.hgetall(redis_key) - if not result: - logger.debug(f"未找到 {redis_key} 的策略API配置") - return {} - - logger.info(f"从 {redis_key} 获取到 {len(result)} 个交易所配置") - - for exchange_name, accounts_json in result.items(): - try: - accounts = json.loads(accounts_json) - if not accounts: - continue - - # 格式化交易所ID - exchange_id = self.format_exchange_id(exchange_name) - - for account_id, account_info in accounts.items(): - parsed_account = self.parse_account(exchange_id, account_id, account_info) - if parsed_account: - # 添加计算机名标记 - parsed_account['computer_name'] = computer_name - accounts_dict[account_id] = parsed_account - - except json.JSONDecodeError as e: - logger.error(f"解析交易所 {exchange_name} 的JSON数据失败: {e}") - continue - except Exception as e: - logger.error(f"处理交易所 {exchange_name} 数据异常: {e}") - continue - - logger.info(f"从 {redis_key} 解析到 {len(accounts_dict)} 个账号") - - except Exception as e: - logger.error(f"获取计算机名 {computer_name} 的账号失败: {e}") - - return accounts_dict - - def _discover_all_accounts(self) -> Dict[str, Dict]: - """自动发现所有匹配的账号key""" - accounts_dict = {} - discovered_keys = [] - - try: - # 获取所有匹配模式的key - pattern = "*_strategy_api" - cursor = 0 - - while True: - cursor, keys = self.redis_client.client.scan(cursor, match=pattern, count=100) - - for key in keys: - key_str = key.decode('utf-8') if isinstance(key, bytes) else key - discovered_keys.append(key_str) - - if cursor == 0: - break - - logger.info(f"自动发现 {len(discovered_keys)} 个策略API key") - - # 处理每个发现的key - for key_str in discovered_keys: - # 提取计算机名 - computer_name = key_str.replace('_strategy_api', '') - - # 验证计算机名格式 - if self.computer_name_pattern.match(computer_name): - accounts = self._get_accounts_by_computer_name(computer_name) - accounts_dict.update(accounts) - else: - logger.warning(f"跳过不符合格式的计算机名: {computer_name}") - - logger.info(f"自动发现共获取到 {len(accounts_dict)} 个账号") - - except Exception as e: - logger.error(f"自动发现账号失败: {e}") - - return accounts_dict - - def format_exchange_id(self, key: str) -> str: - """格式化交易所ID""" - key = key.lower().strip() - - # 交易所名称映射 - exchange_mapping = { - 'metatrader': 'mt5', - 'binance_spot_test': 'binance', - 'binance_spot': 'binance', - 'binance': 'binance', - 'gate_spot': 'gate', - 'okex': 'okx', - 'okx': 'okx', - 'bybit': 'bybit', - 'bybit_spot': 'bybit', - 'bybit_test': 'bybit', - 'huobi': 'huobi', - 'huobi_spot': 'huobi', - 'gate': 'gate', - 'gateio': 'gate', - 'kucoin': 'kucoin', - 'kucoin_spot': 'kucoin', - 'mexc': 'mexc', - 'mexc_spot': 'mexc', - 'bitget': 'bitget', - 'bitget_spot': 'bitget' - } - - normalized_key = exchange_mapping.get(key, key) - - # 记录未映射的交易所 - if normalized_key == key and key not in exchange_mapping.values(): - logger.debug(f"未映射的交易所名称: {key}") - - return normalized_key - - def parse_account(self, exchange_id: str, account_id: str, account_info: str) -> Optional[Dict]: - """解析账号信息""" - try: - source_account_info = json.loads(account_info) - - # 基础信息 - account_data = { - 'exchange_id': exchange_id, - 'k_id': account_id, - 'st_id': self._safe_int(source_account_info.get('st_id'), 0), - 'add_time': self._safe_int(source_account_info.get('add_time'), 0), - 'account_type': source_account_info.get('account_type', 'real'), - 'api_key': source_account_info.get('api_key', ''), - 'secret_key': source_account_info.get('secret_key', ''), - 'password': source_account_info.get('password', ''), - 'access_token': source_account_info.get('access_token', ''), - 'remark': source_account_info.get('remark', '') - } - - # MT5特殊处理 - if exchange_id == 'mt5': - # 解析服务器地址和端口 - server_info = source_account_info.get('secret_key', '') - if ':' in server_info: - host, port = server_info.split(':', 1) - account_data['mt5_host'] = host - account_data['mt5_port'] = self._safe_int(port, 0) - - # 合并原始信息 - result = {**source_account_info, **account_data} - - # 验证必要字段 - if not result.get('st_id') or not result.get('exchange_id'): - logger.warning(f"账号 {account_id} 缺少必要字段: st_id={result.get('st_id')}, exchange_id={result.get('exchange_id')}") - return None - - return result - - except json.JSONDecodeError as e: - logger.error(f"解析账号 {account_id} JSON数据失败: {e}, 原始数据: {account_info[:100]}...") - return None - except Exception as e: - logger.error(f"处理账号 {account_id} 数据异常: {e}") - return None - - def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]: - """按交易所分组账号""" - groups = {} - for account_id, account_info in accounts.items(): - exchange_id = account_info.get('exchange_id') - if exchange_id: - if exchange_id not in groups: - groups[exchange_id] = [] - groups[exchange_id].append(account_info) - return groups - def _safe_float(self, value: Any, default: float = 0.0) -> float: """安全转换为float""" if value is None: diff --git a/sync/manager.py b/sync/manager.py index 4019f21..1d5ae58 100644 --- a/sync/manager.py +++ b/sync/manager.py @@ -3,26 +3,30 @@ from loguru import logger import signal import sys import time +import json from typing import Dict +import re +from utils.redis_client import RedisClient from config.settings import SYNC_CONFIG -from .position_sync_batch import PositionSyncBatch -from .order_sync_batch import OrderSyncBatch # 使用批量版本 -from .account_sync_batch import AccountSyncBatch -from utils.batch_position_sync import BatchPositionSync -from utils.batch_order_sync import BatchOrderSync -from utils.batch_account_sync import BatchAccountSync +from .position_sync import PositionSyncBatch +from .order_sync import OrderSyncBatch # 使用批量版本 +from .account_sync import AccountSyncBatch from utils.redis_batch_helper import RedisBatchHelper +from config.settings import COMPUTER_NAMES, COMPUTER_NAME_PATTERN +from typing import List, Dict, Any, Set, Optional class SyncManager: """同步管理器(完整批量版本)""" def __init__(self): self.is_running = True + self.redis_client = RedisClient() self.sync_interval = SYNC_CONFIG['interval'] + self.computer_names = self._get_computer_names() + self.computer_name_pattern = re.compile(COMPUTER_NAME_PATTERN) # 初始化批量同步工具 - self.batch_tools = {} self.redis_helper = None # 初始化同步器 @@ -31,24 +35,16 @@ class SyncManager: if SYNC_CONFIG['enable_position_sync']: position_sync = PositionSyncBatch() self.syncers.append(position_sync) - self.batch_tools['position'] = BatchPositionSync(position_sync.db_manager) logger.info("启用持仓批量同步") if SYNC_CONFIG['enable_order_sync']: order_sync = OrderSyncBatch() self.syncers.append(order_sync) - self.batch_tools['order'] = BatchOrderSync(order_sync.db_manager) - - # 初始化Redis批量助手 - if order_sync.redis_client: - self.redis_helper = RedisBatchHelper(order_sync.redis_client.client) - logger.info("启用订单批量同步") if SYNC_CONFIG['enable_account_sync']: account_sync = AccountSyncBatch() self.syncers.append(account_sync) - self.batch_tools['account'] = BatchAccountSync(account_sync.db_manager) logger.info("启用账户信息批量同步") # 性能统计 @@ -71,21 +67,25 @@ class SyncManager: while self.is_running: try: - self.stats['total_syncs'] += 1 - sync_start = time.time() # 获取所有账号(只获取一次) - accounts = await self._get_all_accounts() + accounts = await self.get_accounts_from_redis() if not accounts: logger.warning("未获取到任何账号,等待下次同步") await asyncio.sleep(self.sync_interval) continue + + self.stats['total_syncs'] += 1 + sync_start = time.time() + logger.info(f"第{self.stats['total_syncs']}次同步开始,共 {len(accounts)} 个账号") - # 并发执行所有同步 - await self._execute_all_syncers_concurrent(accounts) + # 执行所有同步器 + tasks = [syncer.sync(accounts) for syncer in self.syncers] + await asyncio.gather(*tasks, return_exceptions=True) + # 更新统计 sync_time = time.time() - sync_start @@ -101,137 +101,255 @@ class SyncManager: logger.error(f"同步任务异常: {e}") await asyncio.sleep(30) - async def _get_all_accounts(self) -> Dict[str, Dict]: - """获取所有账号""" - if not self.syncers: + def get_accounts_from_redis(self) -> Dict[str, Dict]: + """从Redis获取所有计算机名的账号配置""" + try: + accounts_dict = {} + total_keys_processed = 0 + + # 方法1:使用配置的计算机名列表 + for computer_name in self.computer_names: + accounts = self._get_accounts_by_computer_name(computer_name) + total_keys_processed += 1 + accounts_dict.update(accounts) + + # 方法2:如果配置的计算机名没有数据,尝试自动发现(备用方案) + if not accounts_dict: + logger.warning("配置的计算机名未找到数据,尝试自动发现...") + accounts_dict = self._discover_all_accounts() + + self.sync_stats['total_accounts'] = len(accounts_dict) + logger.info(f"从 {len(self.computer_names)} 个计算机名获取到 {len(accounts_dict)} 个账号") + + return accounts_dict + + except Exception as e: + logger.error(f"获取账户信息失败: {e}") return {} + + def _get_computer_names(self) -> List[str]: + """获取计算机名列表""" + if ',' in COMPUTER_NAMES: + names = [name.strip() for name in COMPUTER_NAMES.split(',')] + logger.info(f"使用配置的计算机名列表: {names}") + return names + return [COMPUTER_NAMES.strip()] + + def _get_accounts_by_computer_name(self, computer_name: str) -> Dict[str, Dict]: + """获取指定计算机名的账号""" + accounts_dict = {} - # 使用第一个同步器获取账号 - return self.syncers[0].get_accounts_from_redis() - - async def _execute_all_syncers_concurrent(self, accounts: Dict[str, Dict]): - """并发执行所有同步器""" - tasks = [] - - # 持仓批量同步 - if 'position' in self.batch_tools: - task = self._sync_positions_batch(accounts) - tasks.append(task) - - # 订单批量同步 - if 'order' in self.batch_tools: - task = self._sync_orders_batch(accounts) - tasks.append(task) - - # 账户信息批量同步 - if 'account' in self.batch_tools: - task = self._sync_accounts_batch(accounts) - tasks.append(task) - - # 并发执行所有任务 - if tasks: - results = await asyncio.gather(*tasks, return_exceptions=True) - - # 检查结果 - for i, result in enumerate(results): - if isinstance(result, Exception): - logger.error(f"同步任务 {i} 失败: {result}") - - async def _sync_positions_batch(self, accounts: Dict[str, Dict]): - """批量同步持仓数据""" try: - start_time = time.time() + # 构建key + redis_key = f"{computer_name}_strategy_api" - # 收集所有持仓数据 - position_sync = next((s for s in self.syncers if isinstance(s, PositionSyncBatch)), None) - if not position_sync: - return + # 从Redis获取数据 + result = self.redis_client.client.hgetall(redis_key) + if not result: + logger.debug(f"未找到 {redis_key} 的策略API配置") + return {} - all_positions = await position_sync._collect_all_positions(accounts) + logger.info(f"从 {redis_key} 获取到 {len(result)} 个交易所配置") - if not all_positions: - self.stats['position'] = {'accounts': 0, 'positions': 0, 'time': 0} - return + for exchange_name, accounts_json in result.items(): + try: + accounts = json.loads(accounts_json) + if not accounts: + continue + + # 格式化交易所ID + exchange_id = self.format_exchange_id(exchange_name) + + for account_id, account_info in accounts.items(): + parsed_account = self.parse_account(exchange_id, account_id, account_info) + if parsed_account: + # 添加计算机名标记 + parsed_account['computer_name'] = computer_name + accounts_dict[account_id] = parsed_account + + except json.JSONDecodeError as e: + logger.error(f"解析交易所 {exchange_name} 的JSON数据失败: {e}") + continue + except Exception as e: + logger.error(f"处理交易所 {exchange_name} 数据异常: {e}") + continue - # 使用批量工具同步 - batch_tool = self.batch_tools['position'] - success, stats = batch_tool.sync_positions_batch(all_positions) + logger.info(f"从 {redis_key} 解析到 {len(accounts_dict)} 个账号") - if success: - elapsed = time.time() - start_time - self.stats['position'] = { - 'accounts': len(accounts), - 'positions': stats['total'], - 'time': elapsed - } - except Exception as e: - logger.error(f"批量同步持仓失败: {e}") - self.stats['position'] = {'accounts': 0, 'positions': 0, 'time': 0} - - async def _sync_orders_batch(self, accounts: Dict[str, Dict]): - """批量同步订单数据""" + logger.error(f"获取计算机名 {computer_name} 的账号失败: {e}") + + return accounts_dict + + def _discover_all_accounts(self) -> Dict[str, Dict]: + """自动发现所有匹配的账号key""" + accounts_dict = {} + discovered_keys = [] + try: - start_time = time.time() + # 获取所有匹配模式的key + pattern = "*_strategy_api" + cursor = 0 - # 收集所有订单数据 - order_sync = next((s for s in self.syncers if isinstance(s, OrderSyncBatch)), None) - if not order_sync: - return - - all_orders = await order_sync._collect_all_orders(accounts) - - if not all_orders: - self.stats['order'] = {'accounts': 0, 'orders': 0, 'time': 0} - return - - # 使用批量工具同步 - batch_tool = self.batch_tools['order'] - success, processed_count = batch_tool.sync_orders_batch(all_orders) - - if success: - elapsed = time.time() - start_time - self.stats['order'] = { - 'accounts': len(accounts), - 'orders': processed_count, - 'time': elapsed - } + while True: + cursor, keys = self.redis_client.client.scan(cursor, match=pattern, count=100) + for key in keys: + key_str = key.decode('utf-8') if isinstance(key, bytes) else key + discovered_keys.append(key_str) + + if cursor == 0: + break + + logger.info(f"自动发现 {len(discovered_keys)} 个策略API key") + + # 处理每个发现的key + for key_str in discovered_keys: + # 提取计算机名 + computer_name = key_str.replace('_strategy_api', '') + + # 验证计算机名格式 + if self.computer_name_pattern.match(computer_name): + accounts = self._get_accounts_by_computer_name(computer_name) + accounts_dict.update(accounts) + else: + logger.warning(f"跳过不符合格式的计算机名: {computer_name}") + + logger.info(f"自动发现共获取到 {len(accounts_dict)} 个账号") + except Exception as e: - logger.error(f"批量同步订单失败: {e}") - self.stats['order'] = {'accounts': 0, 'orders': 0, 'time': 0} + logger.error(f"自动发现账号失败: {e}") + + return accounts_dict - async def _sync_accounts_batch(self, accounts: Dict[str, Dict]): - """批量同步账户信息数据""" + def _discover_all_accounts(self) -> Dict[str, Dict]: + """自动发现所有匹配的账号key""" + accounts_dict = {} + discovered_keys = [] + try: - start_time = time.time() + # 获取所有匹配模式的key + pattern = "*_strategy_api" + cursor = 0 - # 收集所有账户数据 - account_sync = next((s for s in self.syncers if isinstance(s, AccountSyncBatch)), None) - if not account_sync: - return + while True: + cursor, keys = self.redis_client.client.scan(cursor, match=pattern, count=100) + + for key in keys: + key_str = key.decode('utf-8') if isinstance(key, bytes) else key + discovered_keys.append(key_str) + + if cursor == 0: + break - all_account_data = await account_sync._collect_all_account_data(accounts) + logger.info(f"自动发现 {len(discovered_keys)} 个策略API key") - if not all_account_data: - self.stats['account'] = {'accounts': 0, 'records': 0, 'time': 0} - return + # 处理每个发现的key + for key_str in discovered_keys: + # 提取计算机名 + computer_name = key_str.replace('_strategy_api', '') + + # 验证计算机名格式 + if self.computer_name_pattern.match(computer_name): + accounts = self._get_accounts_by_computer_name(computer_name) + accounts_dict.update(accounts) + else: + logger.warning(f"跳过不符合格式的计算机名: {computer_name}") - # 使用批量工具同步 - batch_tool = self.batch_tools['account'] - updated, inserted = batch_tool.sync_accounts_batch(all_account_data) + logger.info(f"自动发现共获取到 {len(accounts_dict)} 个账号") - elapsed = time.time() - start_time - self.stats['account'] = { - 'accounts': len(accounts), - 'records': len(all_account_data), - 'time': elapsed + except Exception as e: + logger.error(f"自动发现账号失败: {e}") + + return accounts_dict + + def format_exchange_id(self, key: str) -> str: + """格式化交易所ID""" + key = key.lower().strip() + + # 交易所名称映射 + exchange_mapping = { + 'metatrader': 'mt5', + 'binance_spot_test': 'binance', + 'binance_spot': 'binance', + 'binance': 'binance', + 'gate_spot': 'gate', + 'okex': 'okx', + 'okx': 'okx', + 'bybit': 'bybit', + 'bybit_spot': 'bybit', + 'bybit_test': 'bybit', + 'huobi': 'huobi', + 'huobi_spot': 'huobi', + 'gate': 'gate', + 'gateio': 'gate', + 'kucoin': 'kucoin', + 'kucoin_spot': 'kucoin', + 'mexc': 'mexc', + 'mexc_spot': 'mexc', + 'bitget': 'bitget', + 'bitget_spot': 'bitget' + } + + normalized_key = exchange_mapping.get(key, key) + + # 记录未映射的交易所 + if normalized_key == key and key not in exchange_mapping.values(): + logger.debug(f"未映射的交易所名称: {key}") + + return normalized_key + + def parse_account(self, exchange_id: str, account_id: str, account_info: str) -> Optional[Dict]: + """解析账号信息""" + try: + source_account_info = json.loads(account_info) + + # 基础信息 + account_data = { + 'exchange_id': exchange_id, + 'k_id': account_id, + 'st_id': self._safe_int(source_account_info.get('st_id'), 0), + 'add_time': self._safe_int(source_account_info.get('add_time'), 0), + 'account_type': source_account_info.get('account_type', 'real'), + 'api_key': source_account_info.get('api_key', ''), + 'secret_key': source_account_info.get('secret_key', ''), + 'password': source_account_info.get('password', ''), + 'access_token': source_account_info.get('access_token', ''), + 'remark': source_account_info.get('remark', '') } + # 合并原始信息 + result = {**source_account_info, **account_data} + + # 验证必要字段 + if not result.get('st_id') or not result.get('exchange_id'): + logger.warning(f"账号 {account_id} 缺少必要字段: st_id={result.get('st_id')}, exchange_id={result.get('exchange_id')}") + return None + + return result + + except json.JSONDecodeError as e: + logger.error(f"解析账号 {account_id} JSON数据失败: {e}, 原始数据: {account_info[:100]}...") + return None except Exception as e: - logger.error(f"批量同步账户信息失败: {e}") - self.stats['account'] = {'accounts': 0, 'records': 0, 'time': 0} + logger.error(f"处理账号 {account_id} 数据异常: {e}") + return None + def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]: + """按交易所分组账号""" + groups = {} + for account_id, account_info in accounts.items(): + exchange_id = account_info.get('exchange_id') + if exchange_id: + if exchange_id not in groups: + groups[exchange_id] = [] + groups[exchange_id].append(account_info) + return groups + + + + + def _update_stats(self, sync_time: float): """更新统计信息""" self.stats['last_sync_time'] = sync_time diff --git a/sync/order_sync.py b/sync/order_sync.py index d21b87a..f76c267 100644 --- a/sync/order_sync.py +++ b/sync/order_sync.py @@ -1,88 +1,179 @@ from .base_sync import BaseSync from loguru import logger -from typing import List, Dict +from typing import List, Dict, Any, Tuple import json +import asyncio import time from datetime import datetime, timedelta +from sqlalchemy import text +import redis -class OrderSync(BaseSync): - """订单数据同步器""" +class OrderSyncBatch(BaseSync): + """订单数据批量同步器""" - async def sync(self): - """同步订单数据""" + def __init__(self): + super().__init__() + self.batch_size = 1000 # 每批处理数量 + self.recent_days = 3 # 同步最近几天的数据 + + async def sync_batch(self, accounts: Dict[str, Dict]): + """批量同步所有账号的订单数据""" try: - # 获取所有账号 - accounts = self.get_accounts_from_redis() + logger.info(f"开始批量同步订单数据,共 {len(accounts)} 个账号") + start_time = time.time() - for k_id_str, account_info in accounts.items(): - try: - k_id = int(k_id_str) - st_id = account_info.get('st_id', 0) - exchange_id = account_info['exchange_id'] - - if k_id <= 0 or st_id <= 0: - continue - - # 从Redis获取最近N天的订单数据 - orders = await self._get_recent_orders_from_redis(k_id, exchange_id) - - # 同步到数据库 - if orders: - success = self._sync_orders_to_db(k_id, st_id, orders) - if success: - logger.debug(f"订单同步成功: k_id={k_id}, 订单数={len(orders)}") - - except Exception as e: - logger.error(f"同步账号 {k_id_str} 订单失败: {e}") - continue + # 1. 收集所有账号的订单数据 + all_orders = await self._collect_all_orders(accounts) - logger.info("订单数据同步完成") + if not all_orders: + logger.info("无订单数据需要同步") + return + + logger.info(f"收集到 {len(all_orders)} 条订单数据") + + # 2. 批量同步到数据库 + success, processed_count = await self._sync_orders_batch_to_db(all_orders) + + elapsed = time.time() - start_time + if success: + logger.info(f"订单批量同步完成: 处理 {processed_count} 条订单,耗时 {elapsed:.2f}秒") + else: + logger.error("订单批量同步失败") except Exception as e: - logger.error(f"订单同步失败: {e}") + logger.error(f"订单批量同步失败: {e}") - async def _get_recent_orders_from_redis(self, k_id: int, exchange_id: str) -> List[Dict]: + async def _collect_all_orders(self, accounts: Dict[str, Dict]) -> List[Dict]: + """收集所有账号的订单数据""" + all_orders = [] + + try: + # 按交易所分组账号 + account_groups = self._group_accounts_by_exchange(accounts) + + # 并发收集每个交易所的数据 + tasks = [] + for exchange_id, account_list in account_groups.items(): + task = self._collect_exchange_orders(exchange_id, account_list) + tasks.append(task) + + # 等待所有任务完成并合并结果 + results = await asyncio.gather(*tasks, return_exceptions=True) + + for result in results: + if isinstance(result, list): + all_orders.extend(result) + + except Exception as e: + logger.error(f"收集订单数据失败: {e}") + + return all_orders + + def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]: + """按交易所分组账号""" + groups = {} + for account_id, account_info in accounts.items(): + exchange_id = account_info.get('exchange_id') + if exchange_id: + if exchange_id not in groups: + groups[exchange_id] = [] + groups[exchange_id].append(account_info) + return groups + + async def _collect_exchange_orders(self, exchange_id: str, account_list: List[Dict]) -> List[Dict]: + """收集某个交易所的订单数据""" + orders_list = [] + + try: + # 并发获取每个账号的数据 + tasks = [] + for account_info in account_list: + k_id = int(account_info['k_id']) + st_id = account_info.get('st_id', 0) + task = self._get_recent_orders_from_redis(k_id, st_id, exchange_id) + tasks.append(task) + + results = await asyncio.gather(*tasks, return_exceptions=True) + + for result in results: + if isinstance(result, list): + orders_list.extend(result) + + logger.debug(f"交易所 {exchange_id}: 收集到 {len(orders_list)} 条订单") + + except Exception as e: + logger.error(f"收集交易所 {exchange_id} 订单数据失败: {e}") + + return orders_list + + async def _get_recent_orders_from_redis(self, k_id: int, st_id: int, exchange_id: str) -> List[Dict]: """从Redis获取最近N天的订单数据""" try: redis_key = f"{exchange_id}:orders:{k_id}" # 计算最近N天的日期 - from config.settings import SYNC_CONFIG - recent_days = SYNC_CONFIG['recent_days'] - today = datetime.now() recent_dates = [] - for i in range(recent_days): + for i in range(self.recent_days): date = today - timedelta(days=i) date_format = date.strftime('%Y-%m-%d') recent_dates.append(date_format) - # 获取所有key - all_keys = self.redis_client.client.hkeys(redis_key) + # 使用scan获取所有符合条件的key + cursor = 0 + recent_keys = [] - orders_list = [] - for key in all_keys: - key_str = key.decode('utf-8') if isinstance(key, bytes) else key + while True: + cursor, keys = self.redis_client.client.hscan(redis_key, cursor, count=1000) - if key_str == 'positions': - continue - - # 检查是否以最近N天的日期开头 - for date_format in recent_dates: - if key_str.startswith(date_format + '_'): - try: - order_json = self.redis_client.client.hget(redis_key, key_str) - if order_json: - order = json.loads(order_json) - - # 验证时间 - order_time = order.get('time', 0) - if order_time >= int(time.time()) - recent_days * 24 * 3600: - orders_list.append(order) - - break - except: + for key, _ in keys.items(): + key_str = key.decode('utf-8') if isinstance(key, bytes) else key + + if key_str == 'positions': + continue + + # 检查是否以最近N天的日期开头 + for date_format in recent_dates: + if key_str.startswith(date_format + '_'): + recent_keys.append(key_str) break + + if cursor == 0: + break + + if not recent_keys: + return [] + + # 批量获取订单数据 + orders_list = [] + + # 分批获取,避免单次hgetall数据量太大 + chunk_size = 500 + for i in range(0, len(recent_keys), chunk_size): + chunk_keys = recent_keys[i:i + chunk_size] + + # 使用hmget批量获取 + chunk_values = self.redis_client.client.hmget(redis_key, chunk_keys) + + for key, order_json in zip(chunk_keys, chunk_values): + if not order_json: + continue + + try: + order = json.loads(order_json) + + # 验证时间 + order_time = order.get('time', 0) + if order_time >= int(time.time()) - self.recent_days * 24 * 3600: + # 添加账号信息 + order['k_id'] = k_id + order['st_id'] = st_id + order['exchange_id'] = exchange_id + orders_list.append(order) + + except json.JSONDecodeError as e: + logger.debug(f"解析订单JSON失败: key={key}, error={e}") + continue return orders_list @@ -90,77 +181,89 @@ class OrderSync(BaseSync): logger.error(f"获取Redis订单数据失败: k_id={k_id}, error={e}") return [] - def _sync_orders_to_db(self, k_id: int, st_id: int, orders_data: List[Dict]) -> bool: - """同步订单数据到数据库""" - session = self.db_manager.get_session() + async def _sync_orders_batch_to_db(self, all_orders: List[Dict]) -> Tuple[bool, int]: + """批量同步订单数据到数据库""" try: - # 准备批量数据 - insert_data = [] - for order_data in orders_data: + if not all_orders: + return True, 0 + + # 转换数据 + converted_orders = [] + for order in all_orders: try: - order_dict = self._convert_order_data(order_data) + order_dict = self._convert_order_data(order) # 检查完整性 required_fields = ['order_id', 'symbol', 'side', 'time'] if not all(order_dict.get(field) for field in required_fields): continue - insert_data.append(order_dict) + converted_orders.append(order_dict) except Exception as e: - logger.error(f"转换订单数据失败: {order_data}, error={e}") + logger.error(f"转换订单数据失败: {order}, error={e}") continue - if not insert_data: - return True + if not converted_orders: + return True, 0 - with session.begin(): - # 使用参数化批量插入 - sql = """ - INSERT INTO deh_strategy_order_new - (st_id, k_id, asset, order_id, symbol, side, price, time, - order_qty, last_qty, avg_price, exchange_id) - VALUES - (:st_id, :k_id, :asset, :order_id, :symbol, :side, :price, :time, - :order_qty, :last_qty, :avg_price, :exchange_id) - ON DUPLICATE KEY UPDATE - side = VALUES(side), - price = VALUES(price), - time = VALUES(time), - order_qty = VALUES(order_qty), - last_qty = VALUES(last_qty), - avg_price = VALUES(avg_price) - """ - - # 分块执行 - from config.settings import SYNC_CONFIG - chunk_size = SYNC_CONFIG['chunk_size'] - - for i in range(0, len(insert_data), chunk_size): - chunk = insert_data[i:i + chunk_size] - session.execute(text(sql), chunk) + # 使用批量工具同步 + from utils.batch_order_sync import BatchOrderSync + batch_tool = BatchOrderSync(self.db_manager, self.batch_size) - return True + success, processed_count = batch_tool.sync_orders_batch(converted_orders) + + return success, processed_count except Exception as e: - logger.error(f"同步订单到数据库失败: k_id={k_id}, error={e}") - return False - finally: - session.close() + logger.error(f"批量同步订单到数据库失败: {e}") + return False, 0 def _convert_order_data(self, data: Dict) -> Dict: """转换订单数据格式""" - return { - 'st_id': int(data.get('st_id', 0)), - 'k_id': int(data.get('k_id', 0)), - 'asset': 'USDT', - 'order_id': str(data.get('order_id', '')), - 'symbol': data.get('symbol', ''), - 'side': data.get('side', ''), - 'price': float(data.get('price', 0)) if data.get('price') is not None else None, - 'time': int(data.get('time', 0)) if data.get('time') is not None else None, - 'order_qty': float(data.get('order_qty', 0)) if data.get('order_qty') is not None else None, - 'last_qty': float(data.get('last_qty', 0)) if data.get('last_qty') is not None else None, - 'avg_price': float(data.get('avg_price', 0)) if data.get('avg_price') is not None else None, - 'exchange_id': None # 忽略该字段 - } \ No newline at end of file + try: + # 安全转换函数 + def safe_float(value): + if value is None: + return None + try: + return float(value) + except (ValueError, TypeError): + return None + + def safe_int(value): + if value is None: + return None + try: + return int(float(value)) + except (ValueError, TypeError): + return None + + def safe_str(value): + if value is None: + return '' + return str(value) + + return { + 'st_id': safe_int(data.get('st_id'), 0), + 'k_id': safe_int(data.get('k_id'), 0), + 'asset': 'USDT', + 'order_id': safe_str(data.get('order_id')), + 'symbol': safe_str(data.get('symbol')), + 'side': safe_str(data.get('side')), + 'price': safe_float(data.get('price')), + 'time': safe_int(data.get('time')), + 'order_qty': safe_float(data.get('order_qty')), + 'last_qty': safe_float(data.get('last_qty')), + 'avg_price': safe_float(data.get('avg_price')), + 'exchange_id': None # 忽略该字段 + } + + except Exception as e: + logger.error(f"转换订单数据异常: {data}, error={e}") + return {} + + async def sync(self): + """兼容旧接口""" + accounts = self.get_accounts_from_redis() + await self.sync_batch(accounts) \ No newline at end of file diff --git a/sync/order_sync_batch.py b/sync/order_sync_batch.py deleted file mode 100644 index f76c267..0000000 --- a/sync/order_sync_batch.py +++ /dev/null @@ -1,269 +0,0 @@ -from .base_sync import BaseSync -from loguru import logger -from typing import List, Dict, Any, Tuple -import json -import asyncio -import time -from datetime import datetime, timedelta -from sqlalchemy import text -import redis - -class OrderSyncBatch(BaseSync): - """订单数据批量同步器""" - - def __init__(self): - super().__init__() - self.batch_size = 1000 # 每批处理数量 - self.recent_days = 3 # 同步最近几天的数据 - - async def sync_batch(self, accounts: Dict[str, Dict]): - """批量同步所有账号的订单数据""" - try: - logger.info(f"开始批量同步订单数据,共 {len(accounts)} 个账号") - start_time = time.time() - - # 1. 收集所有账号的订单数据 - all_orders = await self._collect_all_orders(accounts) - - if not all_orders: - logger.info("无订单数据需要同步") - return - - logger.info(f"收集到 {len(all_orders)} 条订单数据") - - # 2. 批量同步到数据库 - success, processed_count = await self._sync_orders_batch_to_db(all_orders) - - elapsed = time.time() - start_time - if success: - logger.info(f"订单批量同步完成: 处理 {processed_count} 条订单,耗时 {elapsed:.2f}秒") - else: - logger.error("订单批量同步失败") - - except Exception as e: - logger.error(f"订单批量同步失败: {e}") - - async def _collect_all_orders(self, accounts: Dict[str, Dict]) -> List[Dict]: - """收集所有账号的订单数据""" - all_orders = [] - - try: - # 按交易所分组账号 - account_groups = self._group_accounts_by_exchange(accounts) - - # 并发收集每个交易所的数据 - tasks = [] - for exchange_id, account_list in account_groups.items(): - task = self._collect_exchange_orders(exchange_id, account_list) - tasks.append(task) - - # 等待所有任务完成并合并结果 - results = await asyncio.gather(*tasks, return_exceptions=True) - - for result in results: - if isinstance(result, list): - all_orders.extend(result) - - except Exception as e: - logger.error(f"收集订单数据失败: {e}") - - return all_orders - - def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]: - """按交易所分组账号""" - groups = {} - for account_id, account_info in accounts.items(): - exchange_id = account_info.get('exchange_id') - if exchange_id: - if exchange_id not in groups: - groups[exchange_id] = [] - groups[exchange_id].append(account_info) - return groups - - async def _collect_exchange_orders(self, exchange_id: str, account_list: List[Dict]) -> List[Dict]: - """收集某个交易所的订单数据""" - orders_list = [] - - try: - # 并发获取每个账号的数据 - tasks = [] - for account_info in account_list: - k_id = int(account_info['k_id']) - st_id = account_info.get('st_id', 0) - task = self._get_recent_orders_from_redis(k_id, st_id, exchange_id) - tasks.append(task) - - results = await asyncio.gather(*tasks, return_exceptions=True) - - for result in results: - if isinstance(result, list): - orders_list.extend(result) - - logger.debug(f"交易所 {exchange_id}: 收集到 {len(orders_list)} 条订单") - - except Exception as e: - logger.error(f"收集交易所 {exchange_id} 订单数据失败: {e}") - - return orders_list - - async def _get_recent_orders_from_redis(self, k_id: int, st_id: int, exchange_id: str) -> List[Dict]: - """从Redis获取最近N天的订单数据""" - try: - redis_key = f"{exchange_id}:orders:{k_id}" - - # 计算最近N天的日期 - today = datetime.now() - recent_dates = [] - for i in range(self.recent_days): - date = today - timedelta(days=i) - date_format = date.strftime('%Y-%m-%d') - recent_dates.append(date_format) - - # 使用scan获取所有符合条件的key - cursor = 0 - recent_keys = [] - - while True: - cursor, keys = self.redis_client.client.hscan(redis_key, cursor, count=1000) - - for key, _ in keys.items(): - key_str = key.decode('utf-8') if isinstance(key, bytes) else key - - if key_str == 'positions': - continue - - # 检查是否以最近N天的日期开头 - for date_format in recent_dates: - if key_str.startswith(date_format + '_'): - recent_keys.append(key_str) - break - - if cursor == 0: - break - - if not recent_keys: - return [] - - # 批量获取订单数据 - orders_list = [] - - # 分批获取,避免单次hgetall数据量太大 - chunk_size = 500 - for i in range(0, len(recent_keys), chunk_size): - chunk_keys = recent_keys[i:i + chunk_size] - - # 使用hmget批量获取 - chunk_values = self.redis_client.client.hmget(redis_key, chunk_keys) - - for key, order_json in zip(chunk_keys, chunk_values): - if not order_json: - continue - - try: - order = json.loads(order_json) - - # 验证时间 - order_time = order.get('time', 0) - if order_time >= int(time.time()) - self.recent_days * 24 * 3600: - # 添加账号信息 - order['k_id'] = k_id - order['st_id'] = st_id - order['exchange_id'] = exchange_id - orders_list.append(order) - - except json.JSONDecodeError as e: - logger.debug(f"解析订单JSON失败: key={key}, error={e}") - continue - - return orders_list - - except Exception as e: - logger.error(f"获取Redis订单数据失败: k_id={k_id}, error={e}") - return [] - - async def _sync_orders_batch_to_db(self, all_orders: List[Dict]) -> Tuple[bool, int]: - """批量同步订单数据到数据库""" - try: - if not all_orders: - return True, 0 - - # 转换数据 - converted_orders = [] - for order in all_orders: - try: - order_dict = self._convert_order_data(order) - - # 检查完整性 - required_fields = ['order_id', 'symbol', 'side', 'time'] - if not all(order_dict.get(field) for field in required_fields): - continue - - converted_orders.append(order_dict) - - except Exception as e: - logger.error(f"转换订单数据失败: {order}, error={e}") - continue - - if not converted_orders: - return True, 0 - - # 使用批量工具同步 - from utils.batch_order_sync import BatchOrderSync - batch_tool = BatchOrderSync(self.db_manager, self.batch_size) - - success, processed_count = batch_tool.sync_orders_batch(converted_orders) - - return success, processed_count - - except Exception as e: - logger.error(f"批量同步订单到数据库失败: {e}") - return False, 0 - - def _convert_order_data(self, data: Dict) -> Dict: - """转换订单数据格式""" - try: - # 安全转换函数 - def safe_float(value): - if value is None: - return None - try: - return float(value) - except (ValueError, TypeError): - return None - - def safe_int(value): - if value is None: - return None - try: - return int(float(value)) - except (ValueError, TypeError): - return None - - def safe_str(value): - if value is None: - return '' - return str(value) - - return { - 'st_id': safe_int(data.get('st_id'), 0), - 'k_id': safe_int(data.get('k_id'), 0), - 'asset': 'USDT', - 'order_id': safe_str(data.get('order_id')), - 'symbol': safe_str(data.get('symbol')), - 'side': safe_str(data.get('side')), - 'price': safe_float(data.get('price')), - 'time': safe_int(data.get('time')), - 'order_qty': safe_float(data.get('order_qty')), - 'last_qty': safe_float(data.get('last_qty')), - 'avg_price': safe_float(data.get('avg_price')), - 'exchange_id': None # 忽略该字段 - } - - except Exception as e: - logger.error(f"转换订单数据异常: {data}, error={e}") - return {} - - async def sync(self): - """兼容旧接口""" - accounts = self.get_accounts_from_redis() - await self.sync_batch(accounts) \ No newline at end of file diff --git a/sync/position_sync.py b/sync/position_sync.py index 7d1f956..9e249a1 100644 --- a/sync/position_sync.py +++ b/sync/position_sync.py @@ -1,41 +1,74 @@ from .base_sync import BaseSync from loguru import logger -from typing import List, Dict +from typing import List, Dict, Any, Set, Tuple import json import asyncio -from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from sqlalchemy import text, and_, select, delete +from models.orm_models import StrategyPosition +import time -class PositionSync(BaseSync): - """持仓数据同步器(批量版本)""" +class PositionSyncBatch(BaseSync): + """持仓数据批量同步器""" def __init__(self): super().__init__() - self.max_concurrent = 10 # 每个同步器的最大并发数 + self.batch_size = 500 # 每批处理数量 async def sync_batch(self, accounts: Dict[str, Dict]): """批量同步所有账号的持仓数据""" try: logger.info(f"开始批量同步持仓数据,共 {len(accounts)} 个账号") + start_time = time.time() - # 按账号分组 - account_groups = self._group_accounts_by_exchange(accounts) + # 1. 收集所有账号的持仓数据 + all_positions = await self._collect_all_positions(accounts) - # 并发处理每个交易所的账号 - tasks = [] - for exchange_id, account_list in account_groups.items(): - task = self._sync_exchange_accounts(exchange_id, account_list) - tasks.append(task) + if not all_positions: + logger.info("无持仓数据需要同步") + return - # 等待所有任务完成 - results = await asyncio.gather(*tasks, return_exceptions=True) + logger.info(f"收集到 {len(all_positions)} 条持仓数据") - # 统计结果 - success_count = sum(1 for r in results if isinstance(r, bool) and r) - logger.info(f"持仓批量同步完成: 成功 {success_count}/{len(results)} 个交易所组") + # 2. 批量同步到数据库 + success, stats = await self._sync_positions_batch_to_db(all_positions) + + elapsed = time.time() - start_time + if success: + logger.info(f"持仓批量同步完成: 处理 {stats['total']} 条,更新 {stats['updated']} 条," + f"插入 {stats['inserted']} 条,删除 {stats['deleted']} 条,耗时 {elapsed:.2f}秒") + else: + logger.error("持仓批量同步失败") except Exception as e: logger.error(f"持仓批量同步失败: {e}") + async def _collect_all_positions(self, accounts: Dict[str, Dict]) -> List[Dict]: + """收集所有账号的持仓数据""" + all_positions = [] + + try: + # 按交易所分组账号 + account_groups = self._group_accounts_by_exchange(accounts) + + # 并发收集每个交易所的数据 + tasks = [] + for exchange_id, account_list in account_groups.items(): + task = self._collect_exchange_positions(exchange_id, account_list) + tasks.append(task) + + # 等待所有任务完成并合并结果 + results = await asyncio.gather(*tasks, return_exceptions=True) + + for result in results: + if isinstance(result, list): + all_positions.extend(result) + + except Exception as e: + logger.error(f"收集持仓数据失败: {e}") + + return all_positions + def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]: """按交易所分组账号""" groups = {} @@ -47,46 +80,60 @@ class PositionSync(BaseSync): groups[exchange_id].append(account_info) return groups - async def _sync_exchange_accounts(self, exchange_id: str, account_list: List[Dict]): - """同步某个交易所的所有账号""" + async def _collect_exchange_positions(self, exchange_id: str, account_list: List[Dict]) -> List[Dict]: + """收集某个交易所的持仓数据""" + positions_list = [] + try: - # 收集所有账号的持仓数据 - all_positions = [] - + tasks = [] for account_info in account_list: k_id = int(account_info['k_id']) st_id = account_info.get('st_id', 0) - - # 从Redis获取持仓数据 - positions = await self._get_positions_from_redis(k_id, exchange_id) - - if positions: - # 添加账号信息 - for position in positions: - position['k_id'] = k_id - position['st_id'] = st_id - all_positions.extend(positions) + task = self._get_positions_from_redis(k_id, st_id, exchange_id) + tasks.append(task) - if not all_positions: - logger.debug(f"交易所 {exchange_id} 无持仓数据") - return True + # 并发获取 + results = await asyncio.gather(*tasks, return_exceptions=True) - # 批量同步到数据库 - success = self._sync_positions_batch_to_db(all_positions) - if success: - logger.info(f"交易所 {exchange_id} 持仓同步成功: {len(all_positions)} 条持仓") - - return success + for result in results: + if isinstance(result, list): + positions_list.extend(result) except Exception as e: - logger.error(f"同步交易所 {exchange_id} 持仓失败: {e}") - return False + logger.error(f"收集交易所 {exchange_id} 持仓数据失败: {e}") + + return positions_list - def _sync_positions_batch_to_db(self, all_positions: List[Dict]) -> bool: - """批量同步持仓数据到数据库(优化版)""" - session = self.db_manager.get_session() + async def _get_positions_from_redis(self, k_id: int, st_id: int, exchange_id: str) -> List[Dict]: + """从Redis获取持仓数据""" try: - # 按k_id分组 + redis_key = f"{exchange_id}:positions:{k_id}" + redis_data = self.redis_client.client.hget(redis_key, 'positions') + + if not redis_data: + return [] + + positions = json.loads(redis_data) + + # 添加账号信息 + for position in positions: + position['k_id'] = k_id + position['st_id'] = st_id + position['exchange_id'] = exchange_id + + return positions + + except Exception as e: + logger.error(f"获取Redis持仓数据失败: k_id={k_id}, error={e}") + return [] + + async def _sync_positions_batch_to_db(self, all_positions: List[Dict]) -> Tuple[bool, Dict]: + """批量同步持仓数据到数据库""" + try: + if not all_positions: + return True, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0} + + # 按账号分组 positions_by_account = {} for position in all_positions: k_id = position['k_id'] @@ -94,98 +141,239 @@ class PositionSync(BaseSync): positions_by_account[k_id] = [] positions_by_account[k_id].append(position) - success_count = 0 + logger.info(f"开始批量处理 {len(positions_by_account)} 个账号的持仓数据") - with session.begin(): - for k_id, positions in positions_by_account.items(): - try: - st_id = positions[0]['st_id'] if positions else 0 - - # 准备数据 - insert_data = [] - keep_keys = set() - - for pos_data in positions: - try: - pos_dict = self._convert_position_data(pos_data) - if not all([pos_dict.get('symbol'), pos_dict.get('side')]): - continue - - # 重命名qty为sum - if 'qty' in pos_dict: - pos_dict['sum'] = pos_dict.pop('qty') - - insert_data.append(pos_dict) - keep_keys.add((pos_dict['symbol'], pos_dict['side'])) - - except Exception as e: - logger.error(f"转换持仓数据失败: {pos_data}, error={e}") - continue - - if not insert_data: - continue - - # 批量插入/更新 - from sqlalchemy.dialects.mysql import insert - - stmt = insert(StrategyPosition.__table__).values(insert_data) - - update_dict = { - 'price': stmt.inserted.price, - 'sum': stmt.inserted.sum, - 'asset_num': stmt.inserted.asset_num, - 'asset_profit': stmt.inserted.asset_profit, - 'leverage': stmt.inserted.leverage, - 'uptime': stmt.inserted.uptime, - 'profit_price': stmt.inserted.profit_price, - 'stop_price': stmt.inserted.stop_price, - 'liquidation_price': stmt.inserted.liquidation_price - } - - stmt = stmt.on_duplicate_key_update(**update_dict) - session.execute(stmt) - - # 删除多余持仓 - if keep_keys: - existing_positions = session.execute( - select(StrategyPosition).where( - and_( - StrategyPosition.k_id == k_id, - StrategyPosition.st_id == st_id - ) - ) - ).scalars().all() - - to_delete_ids = [] - for existing in existing_positions: - key = (existing.symbol, existing.side) - if key not in keep_keys: - to_delete_ids.append(existing.id) - - if to_delete_ids: - # 分块删除 - chunk_size = 100 - for i in range(0, len(to_delete_ids), chunk_size): - chunk = to_delete_ids[i:i + chunk_size] - session.execute( - delete(StrategyPosition).where( - StrategyPosition.id.in_(chunk) - ) - ) - - success_count += 1 - - except Exception as e: - logger.error(f"同步账号 {k_id} 持仓失败: {e}") - continue + # 批量处理每个账号 + total_stats = {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0} - logger.info(f"批量同步完成: 成功 {success_count}/{len(positions_by_account)} 个账号") - return success_count > 0 + for k_id, positions in positions_by_account.items(): + st_id = positions[0]['st_id'] if positions else 0 + + # 处理单个账号的批量同步 + success, stats = await self._sync_single_account_batch(k_id, st_id, positions) + + if success: + total_stats['total'] += stats['total'] + total_stats['updated'] += stats['updated'] + total_stats['inserted'] += stats['inserted'] + total_stats['deleted'] += stats['deleted'] + + return True, total_stats except Exception as e: logger.error(f"批量同步持仓到数据库失败: {e}") - return False + return False, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0} + + async def _sync_single_account_batch(self, k_id: int, st_id: int, positions: List[Dict]) -> Tuple[bool, Dict]: + """批量同步单个账号的持仓数据""" + session = self.db_manager.get_session() + try: + # 准备数据 + insert_data = [] + new_positions_map = {} # (symbol, side) -> position_id (用于删除) + + for position_data in positions: + try: + position_dict = self._convert_position_data(position_data) + if not all([position_dict.get('symbol'), position_dict.get('side')]): + continue + + symbol = position_dict['symbol'] + side = position_dict['side'] + key = (symbol, side) + + # 重命名qty为sum + if 'qty' in position_dict: + position_dict['sum'] = position_dict.pop('qty') + + insert_data.append(position_dict) + new_positions_map[key] = position_dict.get('id') # 如果有id的话 + + except Exception as e: + logger.error(f"转换持仓数据失败: {position_data}, error={e}") + continue + + with session.begin(): + if not insert_data: + # 清空该账号所有持仓 + result = session.execute( + delete(StrategyPosition).where( + and_( + StrategyPosition.k_id == k_id, + StrategyPosition.st_id == st_id + ) + ) + ) + deleted_count = result.rowcount + + return True, { + 'total': 0, + 'updated': 0, + 'inserted': 0, + 'deleted': deleted_count + } + + # 1. 批量插入/更新持仓数据 + processed_count = self._batch_upsert_positions(session, insert_data) + + # 2. 批量删除多余持仓 + deleted_count = self._batch_delete_extra_positions(session, k_id, st_id, new_positions_map) + + # 注意:这里无法区分插入和更新的数量,processed_count是总处理数 + inserted_count = processed_count # 简化处理 + updated_count = 0 # 需要更复杂的逻辑来区分 + + stats = { + 'total': len(insert_data), + 'updated': updated_count, + 'inserted': inserted_count, + 'deleted': deleted_count + } + + return True, stats + + except Exception as e: + logger.error(f"批量同步账号 {k_id} 持仓失败: {e}") + return False, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0} finally: session.close() - # 其他方法保持不变... \ No newline at end of file + def _batch_upsert_positions(self, session, insert_data: List[Dict]) -> int: + """批量插入/更新持仓数据""" + try: + # 分块处理 + chunk_size = self.batch_size + total_processed = 0 + + for i in range(0, len(insert_data), chunk_size): + chunk = insert_data[i:i + chunk_size] + + values_list = [] + for data in chunk: + symbol = data.get('symbol').replace("'", "''") if data.get('symbol') else '' + values = ( + f"({data['st_id']}, {data['k_id']}, '{data.get('asset', 'USDT')}', " + f"'{symbol}', " + f"{data.get('price') or 'NULL'}, {data.get('sum') or 'NULL'}, " + f"{data.get('asset_num') or 'NULL'}, {data.get('asset_profit') or 'NULL'}, " + f"{data.get('leverage') or 'NULL'}, {data.get('uptime') or 'NULL'}, " + f"{data.get('profit_price') or 'NULL'}, {data.get('stop_price') or 'NULL'}, " + f"{data.get('liquidation_price') or 'NULL'})" + ) + values_list.append(values) + + if values_list: + values_str = ", ".join(values_list) + + sql = f""" + INSERT INTO deh_strategy_position_new + (st_id, k_id, asset, symbol, side, price, `sum`, + asset_num, asset_profit, leverage, uptime, + profit_price, stop_price, liquidation_price) + VALUES {values_str} + ON DUPLICATE KEY UPDATE + price = VALUES(price), + `sum` = VALUES(`sum`), + asset_num = VALUES(asset_num), + asset_profit = VALUES(asset_profit), + leverage = VALUES(leverage), + uptime = VALUES(uptime), + profit_price = VALUES(profit_price), + stop_price = VALUES(stop_price), + liquidation_price = VALUES(liquidation_price) + """ + + session.execute(text(sql)) + total_processed += len(chunk) + + return total_processed + + except Exception as e: + logger.error(f"批量插入/更新持仓失败: {e}") + raise + + def _batch_delete_extra_positions(self, session, k_id: int, st_id: int, new_positions_map: Dict) -> int: + """批量删除多余持仓""" + try: + if not new_positions_map: + # 删除所有持仓 + result = session.execute( + delete(StrategyPosition).where( + and_( + StrategyPosition.k_id == k_id, + StrategyPosition.st_id == st_id + ) + ) + ) + return result.rowcount + + # 构建保留条件 + conditions = [] + for (symbol, side) in new_positions_map.keys(): + safe_symbol = symbol.replace("'", "''") if symbol else '' + safe_side = side.replace("'", "''") if side else '' + conditions.append(f"(symbol = '{safe_symbol}' AND side = '{safe_side}')") + + if conditions: + conditions_str = " OR ".join(conditions) + + sql = f""" + DELETE FROM deh_strategy_position_new + WHERE k_id = {k_id} AND st_id = {st_id} + AND NOT ({conditions_str}) + """ + + result = session.execute(text(sql)) + return result.rowcount + + return 0 + + except Exception as e: + logger.error(f"批量删除持仓失败: k_id={k_id}, error={e}") + return 0 + + def _convert_position_data(self, data: Dict) -> Dict: + """转换持仓数据格式""" + try: + # 安全转换函数 + def safe_float(value, default=None): + if value is None: + return default + try: + return float(value) + except (ValueError, TypeError): + return default + + def safe_int(value, default=None): + if value is None: + return default + try: + return int(float(value)) + except (ValueError, TypeError): + return default + + return { + 'st_id': safe_int(data.get('st_id'), 0), + 'k_id': safe_int(data.get('k_id'), 0), + 'asset': data.get('asset', 'USDT'), + 'symbol': data.get('symbol', ''), + 'side': data.get('side', ''), + 'price': safe_float(data.get('price')), + 'qty': safe_float(data.get('qty')), # 后面会重命名为sum + 'asset_num': safe_float(data.get('asset_num')), + 'asset_profit': safe_float(data.get('asset_profit')), + 'leverage': safe_int(data.get('leverage')), + 'uptime': safe_int(data.get('uptime')), + 'profit_price': safe_float(data.get('profit_price')), + 'stop_price': safe_float(data.get('stop_price')), + 'liquidation_price': safe_float(data.get('liquidation_price')) + } + + except Exception as e: + logger.error(f"转换持仓数据异常: {data}, error={e}") + return {} + + async def sync(self): + """兼容旧接口""" + accounts = self.get_accounts_from_redis() + await self.sync_batch(accounts) \ No newline at end of file diff --git a/sync/position_sync_batch.py b/sync/position_sync_batch.py deleted file mode 100644 index 9e249a1..0000000 --- a/sync/position_sync_batch.py +++ /dev/null @@ -1,379 +0,0 @@ -from .base_sync import BaseSync -from loguru import logger -from typing import List, Dict, Any, Set, Tuple -import json -import asyncio -from datetime import datetime -from sqlalchemy import text, and_, select, delete -from models.orm_models import StrategyPosition -import time - -class PositionSyncBatch(BaseSync): - """持仓数据批量同步器""" - - def __init__(self): - super().__init__() - self.batch_size = 500 # 每批处理数量 - - async def sync_batch(self, accounts: Dict[str, Dict]): - """批量同步所有账号的持仓数据""" - try: - logger.info(f"开始批量同步持仓数据,共 {len(accounts)} 个账号") - start_time = time.time() - - # 1. 收集所有账号的持仓数据 - all_positions = await self._collect_all_positions(accounts) - - if not all_positions: - logger.info("无持仓数据需要同步") - return - - logger.info(f"收集到 {len(all_positions)} 条持仓数据") - - # 2. 批量同步到数据库 - success, stats = await self._sync_positions_batch_to_db(all_positions) - - elapsed = time.time() - start_time - if success: - logger.info(f"持仓批量同步完成: 处理 {stats['total']} 条,更新 {stats['updated']} 条," - f"插入 {stats['inserted']} 条,删除 {stats['deleted']} 条,耗时 {elapsed:.2f}秒") - else: - logger.error("持仓批量同步失败") - - except Exception as e: - logger.error(f"持仓批量同步失败: {e}") - - async def _collect_all_positions(self, accounts: Dict[str, Dict]) -> List[Dict]: - """收集所有账号的持仓数据""" - all_positions = [] - - try: - # 按交易所分组账号 - account_groups = self._group_accounts_by_exchange(accounts) - - # 并发收集每个交易所的数据 - tasks = [] - for exchange_id, account_list in account_groups.items(): - task = self._collect_exchange_positions(exchange_id, account_list) - tasks.append(task) - - # 等待所有任务完成并合并结果 - results = await asyncio.gather(*tasks, return_exceptions=True) - - for result in results: - if isinstance(result, list): - all_positions.extend(result) - - except Exception as e: - logger.error(f"收集持仓数据失败: {e}") - - return all_positions - - def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]: - """按交易所分组账号""" - groups = {} - for account_id, account_info in accounts.items(): - exchange_id = account_info.get('exchange_id') - if exchange_id: - if exchange_id not in groups: - groups[exchange_id] = [] - groups[exchange_id].append(account_info) - return groups - - async def _collect_exchange_positions(self, exchange_id: str, account_list: List[Dict]) -> List[Dict]: - """收集某个交易所的持仓数据""" - positions_list = [] - - try: - tasks = [] - for account_info in account_list: - k_id = int(account_info['k_id']) - st_id = account_info.get('st_id', 0) - task = self._get_positions_from_redis(k_id, st_id, exchange_id) - tasks.append(task) - - # 并发获取 - results = await asyncio.gather(*tasks, return_exceptions=True) - - for result in results: - if isinstance(result, list): - positions_list.extend(result) - - except Exception as e: - logger.error(f"收集交易所 {exchange_id} 持仓数据失败: {e}") - - return positions_list - - async def _get_positions_from_redis(self, k_id: int, st_id: int, exchange_id: str) -> List[Dict]: - """从Redis获取持仓数据""" - try: - redis_key = f"{exchange_id}:positions:{k_id}" - redis_data = self.redis_client.client.hget(redis_key, 'positions') - - if not redis_data: - return [] - - positions = json.loads(redis_data) - - # 添加账号信息 - for position in positions: - position['k_id'] = k_id - position['st_id'] = st_id - position['exchange_id'] = exchange_id - - return positions - - except Exception as e: - logger.error(f"获取Redis持仓数据失败: k_id={k_id}, error={e}") - return [] - - async def _sync_positions_batch_to_db(self, all_positions: List[Dict]) -> Tuple[bool, Dict]: - """批量同步持仓数据到数据库""" - try: - if not all_positions: - return True, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0} - - # 按账号分组 - positions_by_account = {} - for position in all_positions: - k_id = position['k_id'] - if k_id not in positions_by_account: - positions_by_account[k_id] = [] - positions_by_account[k_id].append(position) - - logger.info(f"开始批量处理 {len(positions_by_account)} 个账号的持仓数据") - - # 批量处理每个账号 - total_stats = {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0} - - for k_id, positions in positions_by_account.items(): - st_id = positions[0]['st_id'] if positions else 0 - - # 处理单个账号的批量同步 - success, stats = await self._sync_single_account_batch(k_id, st_id, positions) - - if success: - total_stats['total'] += stats['total'] - total_stats['updated'] += stats['updated'] - total_stats['inserted'] += stats['inserted'] - total_stats['deleted'] += stats['deleted'] - - return True, total_stats - - except Exception as e: - logger.error(f"批量同步持仓到数据库失败: {e}") - return False, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0} - - async def _sync_single_account_batch(self, k_id: int, st_id: int, positions: List[Dict]) -> Tuple[bool, Dict]: - """批量同步单个账号的持仓数据""" - session = self.db_manager.get_session() - try: - # 准备数据 - insert_data = [] - new_positions_map = {} # (symbol, side) -> position_id (用于删除) - - for position_data in positions: - try: - position_dict = self._convert_position_data(position_data) - if not all([position_dict.get('symbol'), position_dict.get('side')]): - continue - - symbol = position_dict['symbol'] - side = position_dict['side'] - key = (symbol, side) - - # 重命名qty为sum - if 'qty' in position_dict: - position_dict['sum'] = position_dict.pop('qty') - - insert_data.append(position_dict) - new_positions_map[key] = position_dict.get('id') # 如果有id的话 - - except Exception as e: - logger.error(f"转换持仓数据失败: {position_data}, error={e}") - continue - - with session.begin(): - if not insert_data: - # 清空该账号所有持仓 - result = session.execute( - delete(StrategyPosition).where( - and_( - StrategyPosition.k_id == k_id, - StrategyPosition.st_id == st_id - ) - ) - ) - deleted_count = result.rowcount - - return True, { - 'total': 0, - 'updated': 0, - 'inserted': 0, - 'deleted': deleted_count - } - - # 1. 批量插入/更新持仓数据 - processed_count = self._batch_upsert_positions(session, insert_data) - - # 2. 批量删除多余持仓 - deleted_count = self._batch_delete_extra_positions(session, k_id, st_id, new_positions_map) - - # 注意:这里无法区分插入和更新的数量,processed_count是总处理数 - inserted_count = processed_count # 简化处理 - updated_count = 0 # 需要更复杂的逻辑来区分 - - stats = { - 'total': len(insert_data), - 'updated': updated_count, - 'inserted': inserted_count, - 'deleted': deleted_count - } - - return True, stats - - except Exception as e: - logger.error(f"批量同步账号 {k_id} 持仓失败: {e}") - return False, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0} - finally: - session.close() - - def _batch_upsert_positions(self, session, insert_data: List[Dict]) -> int: - """批量插入/更新持仓数据""" - try: - # 分块处理 - chunk_size = self.batch_size - total_processed = 0 - - for i in range(0, len(insert_data), chunk_size): - chunk = insert_data[i:i + chunk_size] - - values_list = [] - for data in chunk: - symbol = data.get('symbol').replace("'", "''") if data.get('symbol') else '' - values = ( - f"({data['st_id']}, {data['k_id']}, '{data.get('asset', 'USDT')}', " - f"'{symbol}', " - f"{data.get('price') or 'NULL'}, {data.get('sum') or 'NULL'}, " - f"{data.get('asset_num') or 'NULL'}, {data.get('asset_profit') or 'NULL'}, " - f"{data.get('leverage') or 'NULL'}, {data.get('uptime') or 'NULL'}, " - f"{data.get('profit_price') or 'NULL'}, {data.get('stop_price') or 'NULL'}, " - f"{data.get('liquidation_price') or 'NULL'})" - ) - values_list.append(values) - - if values_list: - values_str = ", ".join(values_list) - - sql = f""" - INSERT INTO deh_strategy_position_new - (st_id, k_id, asset, symbol, side, price, `sum`, - asset_num, asset_profit, leverage, uptime, - profit_price, stop_price, liquidation_price) - VALUES {values_str} - ON DUPLICATE KEY UPDATE - price = VALUES(price), - `sum` = VALUES(`sum`), - asset_num = VALUES(asset_num), - asset_profit = VALUES(asset_profit), - leverage = VALUES(leverage), - uptime = VALUES(uptime), - profit_price = VALUES(profit_price), - stop_price = VALUES(stop_price), - liquidation_price = VALUES(liquidation_price) - """ - - session.execute(text(sql)) - total_processed += len(chunk) - - return total_processed - - except Exception as e: - logger.error(f"批量插入/更新持仓失败: {e}") - raise - - def _batch_delete_extra_positions(self, session, k_id: int, st_id: int, new_positions_map: Dict) -> int: - """批量删除多余持仓""" - try: - if not new_positions_map: - # 删除所有持仓 - result = session.execute( - delete(StrategyPosition).where( - and_( - StrategyPosition.k_id == k_id, - StrategyPosition.st_id == st_id - ) - ) - ) - return result.rowcount - - # 构建保留条件 - conditions = [] - for (symbol, side) in new_positions_map.keys(): - safe_symbol = symbol.replace("'", "''") if symbol else '' - safe_side = side.replace("'", "''") if side else '' - conditions.append(f"(symbol = '{safe_symbol}' AND side = '{safe_side}')") - - if conditions: - conditions_str = " OR ".join(conditions) - - sql = f""" - DELETE FROM deh_strategy_position_new - WHERE k_id = {k_id} AND st_id = {st_id} - AND NOT ({conditions_str}) - """ - - result = session.execute(text(sql)) - return result.rowcount - - return 0 - - except Exception as e: - logger.error(f"批量删除持仓失败: k_id={k_id}, error={e}") - return 0 - - def _convert_position_data(self, data: Dict) -> Dict: - """转换持仓数据格式""" - try: - # 安全转换函数 - def safe_float(value, default=None): - if value is None: - return default - try: - return float(value) - except (ValueError, TypeError): - return default - - def safe_int(value, default=None): - if value is None: - return default - try: - return int(float(value)) - except (ValueError, TypeError): - return default - - return { - 'st_id': safe_int(data.get('st_id'), 0), - 'k_id': safe_int(data.get('k_id'), 0), - 'asset': data.get('asset', 'USDT'), - 'symbol': data.get('symbol', ''), - 'side': data.get('side', ''), - 'price': safe_float(data.get('price')), - 'qty': safe_float(data.get('qty')), # 后面会重命名为sum - 'asset_num': safe_float(data.get('asset_num')), - 'asset_profit': safe_float(data.get('asset_profit')), - 'leverage': safe_int(data.get('leverage')), - 'uptime': safe_int(data.get('uptime')), - 'profit_price': safe_float(data.get('profit_price')), - 'stop_price': safe_float(data.get('stop_price')), - 'liquidation_price': safe_float(data.get('liquidation_price')) - } - - except Exception as e: - logger.error(f"转换持仓数据异常: {data}, error={e}") - return {} - - async def sync(self): - """兼容旧接口""" - accounts = self.get_accounts_from_redis() - await self.sync_batch(accounts) \ No newline at end of file diff --git a/utils/batch_account_sync.py b/utils/batch_account_sync.py deleted file mode 100644 index d5697ce..0000000 --- a/utils/batch_account_sync.py +++ /dev/null @@ -1,174 +0,0 @@ -from typing import List, Dict, Any, Tuple -from loguru import logger -from sqlalchemy import text -import time - -class BatchAccountSync: - """账户信息批量同步工具""" - - def __init__(self, db_manager): - self.db_manager = db_manager - - def sync_accounts_batch(self, all_account_data: List[Dict]) -> Tuple[int, int]: - """批量同步账户信息(最高效版本)""" - if not all_account_data: - return 0, 0 - - session = self.db_manager.get_session() - try: - start_time = time.time() - - # 方法1:使用临时表进行批量操作(性能最好) - updated_count, inserted_count = self._sync_using_temp_table(session, all_account_data) - - elapsed = time.time() - start_time - logger.info(f"账户信息批量同步完成: 更新 {updated_count} 条,插入 {inserted_count} 条,耗时 {elapsed:.2f}秒") - - return updated_count, inserted_count - - except Exception as e: - logger.error(f"账户信息批量同步失败: {e}") - return 0, 0 - finally: - session.close() - - def _sync_using_temp_table(self, session, all_account_data: List[Dict]) -> Tuple[int, int]: - """使用临时表进行批量同步""" - try: - # 1. 创建临时表 - session.execute(text(""" - CREATE TEMPORARY TABLE IF NOT EXISTS temp_account_info ( - st_id INT, - k_id INT, - asset VARCHAR(32), - balance DECIMAL(20, 8), - withdrawal DECIMAL(20, 8), - deposit DECIMAL(20, 8), - other DECIMAL(20, 8), - profit DECIMAL(20, 8), - time INT, - PRIMARY KEY (k_id, st_id, time) - ) - """)) - - # 2. 清空临时表 - session.execute(text("TRUNCATE TABLE temp_account_info")) - - # 3. 批量插入数据到临时表 - chunk_size = 1000 - total_inserted = 0 - - for i in range(0, len(all_account_data), chunk_size): - chunk = all_account_data[i:i + chunk_size] - - values_list = [] - for data in chunk: - values = ( - f"({data['st_id']}, {data['k_id']}, 'USDT', " - f"{data['balance']}, {data['withdrawal']}, {data['deposit']}, " - f"{data['other']}, {data['profit']}, {data['time']})" - ) - values_list.append(values) - - if values_list: - values_str = ", ".join(values_list) - sql = f""" - INSERT INTO temp_account_info - (st_id, k_id, asset, balance, withdrawal, deposit, other, profit, time) - VALUES {values_str} - """ - session.execute(text(sql)) - total_inserted += len(chunk) - - # 4. 使用临时表更新主表 - # 更新已存在的记录 - update_result = session.execute(text(""" - UPDATE deh_strategy_kx_new main - INNER JOIN temp_account_info temp - ON main.k_id = temp.k_id - AND main.st_id = temp.st_id - AND main.time = temp.time - SET main.balance = temp.balance, - main.withdrawal = temp.withdrawal, - main.deposit = temp.deposit, - main.other = temp.other, - main.profit = temp.profit, - main.up_time = NOW() - """)) - updated_count = update_result.rowcount - - # 插入新记录 - insert_result = session.execute(text(""" - INSERT INTO deh_strategy_kx_new - (st_id, k_id, asset, balance, withdrawal, deposit, other, profit, time, up_time) - SELECT - st_id, k_id, asset, balance, withdrawal, deposit, other, profit, time, NOW() - FROM temp_account_info temp - WHERE NOT EXISTS ( - SELECT 1 FROM deh_strategy_kx_new main - WHERE main.k_id = temp.k_id - AND main.st_id = temp.st_id - AND main.time = temp.time - ) - """)) - inserted_count = insert_result.rowcount - - # 5. 删除临时表 - session.execute(text("DROP TEMPORARY TABLE IF EXISTS temp_account_info")) - - session.commit() - - return updated_count, inserted_count - - except Exception as e: - session.rollback() - logger.error(f"临时表同步失败: {e}") - raise - - def _sync_using_on_duplicate(self, session, all_account_data: List[Dict]) -> Tuple[int, int]: - """使用ON DUPLICATE KEY UPDATE批量同步(简化版)""" - try: - # 分块执行,避免SQL过长 - chunk_size = 1000 - total_processed = 0 - - for i in range(0, len(all_account_data), chunk_size): - chunk = all_account_data[i:i + chunk_size] - - values_list = [] - for data in chunk: - values = ( - f"({data['st_id']}, {data['k_id']}, 'USDT', " - f"{data['balance']}, {data['withdrawal']}, {data['deposit']}, " - f"{data['other']}, {data['profit']}, {data['time']})" - ) - values_list.append(values) - - if values_list: - values_str = ", ".join(values_list) - - sql = f""" - INSERT INTO deh_strategy_kx_new - (st_id, k_id, asset, balance, withdrawal, deposit, other, profit, time) - VALUES {values_str} - ON DUPLICATE KEY UPDATE - balance = VALUES(balance), - withdrawal = VALUES(withdrawal), - deposit = VALUES(deposit), - other = VALUES(other), - profit = VALUES(profit), - up_time = NOW() - """ - - result = session.execute(text(sql)) - total_processed += len(chunk) - - session.commit() - - # 注意:这里无法区分更新和插入的数量 - return total_processed, 0 - - except Exception as e: - session.rollback() - logger.error(f"ON DUPLICATE同步失败: {e}") - raise \ No newline at end of file diff --git a/utils/batch_operations.py b/utils/batch_operations.py deleted file mode 100644 index 618b70a..0000000 --- a/utils/batch_operations.py +++ /dev/null @@ -1,138 +0,0 @@ -from typing import List, Dict, Any, Tuple -from loguru import logger -from sqlalchemy import text -from .database_manager import DatabaseManager - -class BatchOperations: - """批量数据库操作工具""" - - def __init__(self): - self.db_manager = DatabaseManager() - - def batch_insert_update_positions(self, positions_data: List[Dict]) -> Tuple[int, int]: - """批量插入/更新持仓数据""" - session = self.db_manager.get_session() - try: - if not positions_data: - return 0, 0 - - # 按账号分组 - positions_by_account = {} - for position in positions_data: - k_id = position.get('k_id') - if k_id not in positions_by_account: - positions_by_account[k_id] = [] - positions_by_account[k_id].append(position) - - total_processed = 0 - total_deleted = 0 - - with session.begin(): - for k_id, positions in positions_by_account.items(): - processed, deleted = self._process_account_positions(session, k_id, positions) - total_processed += processed - total_deleted += deleted - - logger.info(f"批量处理持仓完成: 处理 {total_processed} 条,删除 {total_deleted} 条") - return total_processed, total_deleted - - except Exception as e: - logger.error(f"批量处理持仓失败: {e}") - return 0, 0 - finally: - session.close() - - def _process_account_positions(self, session, k_id: int, positions: List[Dict]) -> Tuple[int, int]: - """处理单个账号的持仓数据""" - try: - st_id = positions[0].get('st_id', 0) if positions else 0 - - # 准备数据 - insert_data = [] - keep_keys = set() - - for pos_data in positions: - # 转换数据 - pos_dict = self._convert_position_data(pos_data) - if not all([pos_dict.get('symbol'), pos_dict.get('side')]): - continue - - # 重命名qty为sum - if 'qty' in pos_dict: - pos_dict['sum'] = pos_dict.pop('qty') - - insert_data.append(pos_dict) - keep_keys.add((pos_dict['symbol'], pos_dict['side'])) - - if not insert_data: - # 清空该账号持仓 - result = session.execute( - text("DELETE FROM deh_strategy_position_new WHERE k_id = :k_id AND st_id = :st_id"), - {'k_id': k_id, 'st_id': st_id} - ) - return 0, result.rowcount - - # 批量插入/更新 - sql = """ - INSERT INTO deh_strategy_position_new - (st_id, k_id, asset, symbol, side, price, `sum`, - asset_num, asset_profit, leverage, uptime, - profit_price, stop_price, liquidation_price) - VALUES - (:st_id, :k_id, :asset, :symbol, :side, :price, :sum, - :asset_num, :asset_profit, :leverage, :uptime, - :profit_price, :stop_price, :liquidation_price) - ON DUPLICATE KEY UPDATE - price = VALUES(price), - `sum` = VALUES(`sum`), - asset_num = VALUES(asset_num), - asset_profit = VALUES(asset_profit), - leverage = VALUES(leverage), - uptime = VALUES(uptime), - profit_price = VALUES(profit_price), - stop_price = VALUES(stop_price), - liquidation_price = VALUES(liquidation_price) - """ - - # 分块执行 - chunk_size = 500 - processed_count = 0 - - for i in range(0, len(insert_data), chunk_size): - chunk = insert_data[i:i + chunk_size] - session.execute(text(sql), chunk) - processed_count += len(chunk) - - # 删除多余持仓 - deleted_count = 0 - if keep_keys: - # 构建删除条件 - conditions = [] - for symbol, side in keep_keys: - safe_symbol = symbol.replace("'", "''") if symbol else '' - safe_side = side.replace("'", "''") if side else '' - conditions.append(f"(symbol = '{safe_symbol}' AND side = '{safe_side}')") - - if conditions: - conditions_str = " OR ".join(conditions) - delete_sql = f""" - DELETE FROM deh_strategy_position_new - WHERE k_id = {k_id} AND st_id = {st_id} - AND NOT ({conditions_str}) - """ - - result = session.execute(text(delete_sql)) - deleted_count = result.rowcount - - return processed_count, deleted_count - - except Exception as e: - logger.error(f"处理账号 {k_id} 持仓失败: {e}") - return 0, 0 - - def _convert_position_data(self, data: Dict) -> Dict: - """转换持仓数据格式""" - # 转换逻辑... - pass - - # 类似的批量方法 for orders and account info... \ No newline at end of file diff --git a/utils/batch_order_sync.py b/utils/batch_order_sync.py deleted file mode 100644 index ef6a68e..0000000 --- a/utils/batch_order_sync.py +++ /dev/null @@ -1,313 +0,0 @@ -from typing import List, Dict, Any, Tuple -from loguru import logger -from sqlalchemy import text -import time - -class BatchOrderSync: - """订单数据批量同步工具(最高性能)""" - - def __init__(self, db_manager, batch_size: int = 1000): - self.db_manager = db_manager - self.batch_size = batch_size - - def sync_orders_batch(self, all_orders: List[Dict]) -> Tuple[bool, int]: - """批量同步订单数据""" - if not all_orders: - return True, 0 - - session = self.db_manager.get_session() - try: - start_time = time.time() - - # 方法1:使用临时表(性能最好) - processed_count = self._sync_using_temp_table(session, all_orders) - - elapsed = time.time() - start_time - logger.info(f"订单批量同步完成: 处理 {processed_count} 条订单,耗时 {elapsed:.2f}秒") - - return True, processed_count - - except Exception as e: - logger.error(f"订单批量同步失败: {e}") - return False, 0 - finally: - session.close() - - def _sync_using_temp_table(self, session, all_orders: List[Dict]) -> int: - """使用临时表批量同步订单""" - try: - # 1. 创建临时表 - session.execute(text(""" - CREATE TEMPORARY TABLE IF NOT EXISTS temp_orders ( - st_id INT, - k_id INT, - asset VARCHAR(32), - order_id VARCHAR(765), - symbol VARCHAR(120), - side VARCHAR(120), - price FLOAT, - time INT, - order_qty FLOAT, - last_qty FLOAT, - avg_price FLOAT, - exchange_id INT, - UNIQUE KEY idx_unique_order (order_id, symbol, k_id, side) - ) - """)) - - # 2. 清空临时表 - session.execute(text("TRUNCATE TABLE temp_orders")) - - # 3. 批量插入数据到临时表(分块) - inserted_count = self._batch_insert_to_temp_table(session, all_orders) - - if inserted_count == 0: - session.execute(text("DROP TEMPORARY TABLE IF EXISTS temp_orders")) - return 0 - - # 4. 使用临时表更新主表 - # 更新已存在的记录(只更新需要比较的字段) - update_result = session.execute(text(""" - UPDATE deh_strategy_order_new main - INNER JOIN temp_orders temp - ON main.order_id = temp.order_id - AND main.symbol = temp.symbol - AND main.k_id = temp.k_id - AND main.side = temp.side - SET main.side = temp.side, - main.price = temp.price, - main.time = temp.time, - main.order_qty = temp.order_qty, - main.last_qty = temp.last_qty, - main.avg_price = temp.avg_price - WHERE main.side != temp.side - OR main.price != temp.price - OR main.time != temp.time - OR main.order_qty != temp.order_qty - OR main.last_qty != temp.last_qty - OR main.avg_price != temp.avg_price - """)) - updated_count = update_result.rowcount - - # 插入新记录 - insert_result = session.execute(text(""" - INSERT INTO deh_strategy_order_new - (st_id, k_id, asset, order_id, symbol, side, price, time, - order_qty, last_qty, avg_price, exchange_id) - SELECT - st_id, k_id, asset, order_id, symbol, side, price, time, - order_qty, last_qty, avg_price, exchange_id - FROM temp_orders temp - WHERE NOT EXISTS ( - SELECT 1 FROM deh_strategy_order_new main - WHERE main.order_id = temp.order_id - AND main.symbol = temp.symbol - AND main.k_id = temp.k_id - AND main.side = temp.side - ) - """)) - inserted_count = insert_result.rowcount - - # 5. 删除临时表 - session.execute(text("DROP TEMPORARY TABLE IF EXISTS temp_orders")) - - session.commit() - - total_processed = updated_count + inserted_count - logger.info(f"订单批量同步: 更新 {updated_count} 条,插入 {inserted_count} 条") - - return total_processed - - except Exception as e: - session.rollback() - logger.error(f"临时表同步订单失败: {e}") - raise - - def _batch_insert_to_temp_table(self, session, all_orders: List[Dict]) -> int: - """批量插入数据到临时表""" - total_inserted = 0 - - try: - # 分块处理 - for i in range(0, len(all_orders), self.batch_size): - chunk = all_orders[i:i + self.batch_size] - - values_list = [] - for order in chunk: - try: - # 处理NULL值 - price = order.get('price') - time_val = order.get('time') - order_qty = order.get('order_qty') - last_qty = order.get('last_qty') - avg_price = order.get('avg_price') - # 转义单引号 - symbol = order.get('symbol').replace("'", "''") if order.get('symbol') else '' - order_id = order.get('order_id').replace("'", "''") if order.get('order_id') else '' - - values = ( - f"({order['st_id']}, {order['k_id']}, '{order.get('asset', 'USDT')}', " - f"'{order_id}', " - f"'{symbol}', " - f"'{order['side']}', " - f"{price if price is not None else 'NULL'}, " - f"{time_val if time_val is not None else 'NULL'}, " - f"{order_qty if order_qty is not None else 'NULL'}, " - f"{last_qty if last_qty is not None else 'NULL'}, " - f"{avg_price if avg_price is not None else 'NULL'}, " - "NULL)" - ) - values_list.append(values) - - except Exception as e: - logger.error(f"构建订单值失败: {order}, error={e}") - continue - - if values_list: - values_str = ", ".join(values_list) - - sql = f""" - INSERT INTO temp_orders - (st_id, k_id, asset, order_id, symbol, side, price, time, - order_qty, last_qty, avg_price, exchange_id) - VALUES {values_str} - """ - - result = session.execute(text(sql)) - total_inserted += len(chunk) - - return total_inserted - - except Exception as e: - logger.error(f"批量插入临时表失败: {e}") - raise - - def _batch_insert_to_temp_table1(self, session, all_orders: List[Dict]) -> int: - """批量插入数据到临时表(使用参数化查询)temp_orders""" - total_inserted = 0 - - try: - # 分块处理 - for i in range(0, len(all_orders), self.batch_size): - chunk = all_orders[i:i + self.batch_size] - - # 准备参数化数据 - insert_data = [] - for order in chunk: - try: - insert_data.append({ - 'st_id': order['st_id'], - 'k_id': order['k_id'], - 'asset': order.get('asset', 'USDT'), - 'order_id': order['order_id'], - 'symbol': order['symbol'], - 'side': order['side'], - 'price': order.get('price'), - 'time': order.get('time'), - 'order_qty': order.get('order_qty'), - 'last_qty': order.get('last_qty'), - 'avg_price': order.get('avg_price') - # exchange_id 留空,使用默认值NULL - }) - except KeyError as e: - logger.error(f"订单数据缺少必要字段: {order}, missing={e}") - continue - except Exception as e: - logger.error(f"处理订单数据失败: {order}, error={e}") - continue - - if insert_data: - sql = text(f""" - INSERT INTO {self.temp_table_name} - (st_id, k_id, asset, order_id, symbol, side, price, time, - order_qty, last_qty, avg_price) - VALUES - (:st_id, :k_id, :asset, :order_id, :symbol, :side, :price, :time, - :order_qty, :last_qty, :avg_price) - """) - - try: - session.execute(sql, insert_data) - session.commit() - total_inserted += len(insert_data) - logger.debug(f"插入 {len(insert_data)} 条数据到临时表") - except Exception as e: - session.rollback() - logger.error(f"执行批量插入失败: {e}") - raise - - logger.info(f"总共插入 {total_inserted} 条数据到临时表") - return total_inserted - - except Exception as e: - logger.error(f"批量插入临时表失败: {e}") - session.rollback() - raise - - - def _sync_using_on_duplicate(self, session, all_orders: List[Dict]) -> int: - """使用ON DUPLICATE KEY UPDATE批量同步(简化版)""" - try: - total_processed = 0 - - # 分块执行 - for i in range(0, len(all_orders), self.batch_size): - chunk = all_orders[i:i + self.batch_size] - - values_list = [] - for order in chunk: - try: - # 处理NULL值 - price = order.get('price') - time_val = order.get('time') - order_qty = order.get('order_qty') - last_qty = order.get('last_qty') - avg_price = order.get('avg_price') - symbol = order.get('symbol').replace("'", "''") if order.get('symbol') else '' - order_id = order.get('order_id').replace("'", "''") if order.get('order_id') else '' - - values = ( - f"({order['st_id']}, {order['k_id']}, '{order.get('asset', 'USDT')}', " - f"'{order_id}', " - f"'{symbol}', " - f"'{order['side']}', " - f"{price if price is not None else 'NULL'}, " - f"{time_val if time_val is not None else 'NULL'}, " - f"{order_qty if order_qty is not None else 'NULL'}, " - f"{last_qty if last_qty is not None else 'NULL'}, " - f"{avg_price if avg_price is not None else 'NULL'}, " - "NULL)" - ) - values_list.append(values) - - except Exception as e: - logger.error(f"构建订单值失败: {order}, error={e}") - continue - - if values_list: - values_str = ", ".join(values_list) - - sql = f""" - INSERT INTO deh_strategy_order_new - (st_id, k_id, asset, order_id, symbol, side, price, time, - order_qty, last_qty, avg_price, exchange_id) - VALUES {values_str} - ON DUPLICATE KEY UPDATE - side = VALUES(side), - price = VALUES(price), - time = VALUES(time), - order_qty = VALUES(order_qty), - last_qty = VALUES(last_qty), - avg_price = VALUES(avg_price) - """ - - session.execute(text(sql)) - total_processed += len(chunk) - - session.commit() - return total_processed - - except Exception as e: - session.rollback() - logger.error(f"ON DUPLICATE同步订单失败: {e}") - raise \ No newline at end of file diff --git a/utils/batch_position_sync.py b/utils/batch_position_sync.py deleted file mode 100644 index c462aff..0000000 --- a/utils/batch_position_sync.py +++ /dev/null @@ -1,254 +0,0 @@ -from typing import List, Dict, Any, Tuple -from loguru import logger -from sqlalchemy import text -import time - -class BatchPositionSync: - """持仓数据批量同步工具(使用临时表,最高性能)""" - - def __init__(self, db_manager, batch_size: int = 500): - self.db_manager = db_manager - self.batch_size = batch_size - - def sync_positions_batch(self, all_positions: List[Dict]) -> Tuple[bool, Dict]: - """批量同步持仓数据(最高效版本)""" - if not all_positions: - return True, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0} - - session = self.db_manager.get_session() - try: - start_time = time.time() - - # 按账号分组 - positions_by_account = self._group_positions_by_account(all_positions) - - total_stats = {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0} - - with session.begin(): - # 处理每个账号 - for (k_id, st_id), positions in positions_by_account.items(): - success, stats = self._sync_account_using_temp_table( - session, k_id, st_id, positions - ) - - if success: - total_stats['total'] += stats['total'] - total_stats['updated'] += stats['updated'] - total_stats['inserted'] += stats['inserted'] - total_stats['deleted'] += stats['deleted'] - - elapsed = time.time() - start_time - logger.info(f"持仓批量同步完成: 处理 {len(positions_by_account)} 个账号," - f"总持仓 {total_stats['total']} 条,耗时 {elapsed:.2f}秒") - - return True, total_stats - - except Exception as e: - logger.error(f"持仓批量同步失败: {e}") - return False, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0} - finally: - session.close() - - def _group_positions_by_account(self, all_positions: List[Dict]) -> Dict[Tuple[int, int], List[Dict]]: - """按账号分组持仓数据""" - groups = {} - for position in all_positions: - k_id = position.get('k_id') - st_id = position.get('st_id', 0) - key = (k_id, st_id) - - if key not in groups: - groups[key] = [] - groups[key].append(position) - - return groups - - def _sync_account_using_temp_table(self, session, k_id: int, st_id: int, positions: List[Dict]) -> Tuple[bool, Dict]: - """使用临时表同步单个账号的持仓数据""" - try: - # 1. 创建临时表 - session.execute(text(""" - CREATE TEMPORARY TABLE IF NOT EXISTS temp_positions ( - st_id INT, - k_id INT, - asset VARCHAR(32), - symbol VARCHAR(50), - side VARCHAR(10), - price FLOAT, - `sum` FLOAT, - asset_num DECIMAL(20, 8), - asset_profit DECIMAL(20, 8), - leverage INT, - uptime INT, - profit_price DECIMAL(20, 8), - stop_price DECIMAL(20, 8), - liquidation_price DECIMAL(20, 8), - PRIMARY KEY (k_id, st_id, symbol, side) - ) - """)) - - # 2. 清空临时表 - session.execute(text("TRUNCATE TABLE temp_positions")) - - # 3. 批量插入数据到临时表 - self._batch_insert_to_temp_table(session, positions) - - # 4. 使用临时表更新主表 - # 更新已存在的记录 - update_result = session.execute(text(f""" - UPDATE deh_strategy_position_new main - INNER JOIN temp_positions temp - ON main.k_id = temp.k_id - AND main.st_id = temp.st_id - AND main.symbol = temp.symbol - AND main.side = temp.side - SET main.price = temp.price, - main.`sum` = temp.`sum`, - main.asset_num = temp.asset_num, - main.asset_profit = temp.asset_profit, - main.leverage = temp.leverage, - main.uptime = temp.uptime, - main.profit_price = temp.profit_price, - main.stop_price = temp.stop_price, - main.liquidation_price = temp.liquidation_price - WHERE main.k_id = {k_id} AND main.st_id = {st_id} - """)) - updated_count = update_result.rowcount - - # 插入新记录 - insert_result = session.execute(text(f""" - INSERT INTO deh_strategy_position_new - (st_id, k_id, asset, symbol, side, price, `sum`, - asset_num, asset_profit, leverage, uptime, - profit_price, stop_price, liquidation_price) - SELECT - st_id, k_id, asset, symbol, side, price, `sum`, - asset_num, asset_profit, leverage, uptime, - profit_price, stop_price, liquidation_price - FROM temp_positions temp - WHERE NOT EXISTS ( - SELECT 1 FROM deh_strategy_position_new main - WHERE main.k_id = temp.k_id - AND main.st_id = temp.st_id - AND main.symbol = temp.symbol - AND main.side = temp.side - ) - AND temp.k_id = {k_id} AND temp.st_id = {st_id} - """)) - inserted_count = insert_result.rowcount - - # 5. 删除多余持仓(在临时表中不存在但在主表中存在的) - delete_result = session.execute(text(f""" - DELETE main - FROM deh_strategy_position_new main - LEFT JOIN temp_positions temp - ON main.k_id = temp.k_id - AND main.st_id = temp.st_id - AND main.symbol = temp.symbol - AND main.side = temp.side - WHERE main.k_id = {k_id} AND main.st_id = {st_id} - AND temp.symbol IS NULL - """)) - deleted_count = delete_result.rowcount - - # 6. 删除临时表 - session.execute(text("DROP TEMPORARY TABLE IF EXISTS temp_positions")) - - stats = { - 'total': len(positions), - 'updated': updated_count, - 'inserted': inserted_count, - 'deleted': deleted_count - } - - logger.debug(f"账号({k_id},{st_id})持仓同步: 更新{updated_count} 插入{inserted_count} 删除{deleted_count}") - - return True, stats - - except Exception as e: - logger.error(f"临时表同步账号({k_id},{st_id})持仓失败: {e}") - return False, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0} - - def _batch_insert_to_temp_table(self, session, positions: List[Dict]): - """批量插入数据到临时表(使用参数化查询)""" - if not positions: - return - - # 分块处理 - for i in range(0, len(positions), self.batch_size): - chunk = positions[i:i + self.batch_size] - - # 准备参数化数据 - insert_data = [] - for position in chunk: - try: - data = self._convert_position_for_temp(position) - if not all([data.get('symbol'), data.get('side')]): - continue - - insert_data.append({ - 'st_id': data['st_id'], - 'k_id': data['k_id'], - 'asset': data.get('asset', 'USDT'), - 'symbol': data['symbol'], - 'side': data['side'], - 'price': data.get('price'), - 'sum_val': data.get('sum'), # 注意字段名 - 'asset_num': data.get('asset_num'), - 'asset_profit': data.get('asset_profit'), - 'leverage': data.get('leverage'), - 'uptime': data.get('uptime'), - 'profit_price': data.get('profit_price'), - 'stop_price': data.get('stop_price'), - 'liquidation_price': data.get('liquidation_price') - }) - - except Exception as e: - logger.error(f"转换持仓数据失败: {position}, error={e}") - continue - - if insert_data: - sql = """ - INSERT INTO temp_positions - (st_id, k_id, asset, symbol, side, price, `sum`, - asset_num, asset_profit, leverage, uptime, - profit_price, stop_price, liquidation_price) - VALUES - (:st_id, :k_id, :asset, :symbol, :side, :price, :sum_val, - :asset_num, :asset_profit, :leverage, :uptime, - :profit_price, :stop_price, :liquidation_price) - """ - - session.execute(text(sql), insert_data) - - def _convert_position_for_temp(self, data: Dict) -> Dict: - """转换持仓数据格式用于临时表""" - # 使用安全转换 - def safe_float(value): - try: - return float(value) if value is not None else None - except: - return None - - def safe_int(value): - try: - return int(value) if value is not None else None - except: - return None - - return { - 'st_id': safe_int(data.get('st_id')) or 0, - 'k_id': safe_int(data.get('k_id')) or 0, - 'asset': data.get('asset', 'USDT'), - 'symbol': str(data.get('symbol', '')), - 'side': str(data.get('side', '')), - 'price': safe_float(data.get('price')), - 'sum': safe_float(data.get('qty')), # 注意:这里直接使用sum - 'asset_num': safe_float(data.get('asset_num')), - 'asset_profit': safe_float(data.get('asset_profit')), - 'leverage': safe_int(data.get('leverage')), - 'uptime': safe_int(data.get('uptime')), - 'profit_price': safe_float(data.get('profit_price')), - 'stop_price': safe_float(data.get('stop_price')), - 'liquidation_price': safe_float(data.get('liquidation_price')) - } \ No newline at end of file