This commit is contained in:
lz_db
2025-12-02 22:36:52 +08:00
parent 7fa249a767
commit c8a6cfead1
6 changed files with 466 additions and 197 deletions

10
.env
View File

@@ -31,5 +31,11 @@ LOG_LEVEL=INFO
LOG_ROTATION=10 MB
LOG_RETENTION=7 days
# 计算机名(用于过滤账号)
COMPUTER_NAME=lz_c01
# 计算机名配置(支持多个)
COMPUTER_NAMES=lz_c01,lz_c02,lz_c03
# 或者使用模式匹配
COMPUTER_NAME_PATTERN=^lz_c\d{2}$
# 并发配置
MAX_CONCURRENT=10

View File

@@ -31,5 +31,7 @@ LOG_CONFIG = {
'format': '{time:YYYY-MM-DD HH:mm:ss} | {level} | {name}:{function}:{line} - {message}'
}
# 计算机名配置(用于过滤账号
COMPUTER_NAME = os.getenv('COMPUTER_NAME', 'lz_c01')
# 计算机名配置(支持多个,用逗号分隔
COMPUTER_NAMES = os.getenv('COMPUTER_NAMES', 'lz_c01')
# 或者使用模式匹配
COMPUTER_NAME_PATTERN = os.getenv('COMPUTER_NAME_PATTERN', r'^lz_c\d{2}$')

View File

@@ -1,10 +1,12 @@
from abc import ABC, abstractmethod
from loguru import logger
from typing import List, Dict, Any
from typing import List, Dict, Any, Set
import json
import re
from utils.redis_client import RedisClient
from utils.database_manager import DatabaseManager
from config.settings import COMPUTER_NAMES, COMPUTER_NAME_PATTERN
class BaseSync(ABC):
"""同步基类"""
@@ -12,27 +14,50 @@ class BaseSync(ABC):
def __init__(self):
self.redis_client = RedisClient()
self.db_manager = DatabaseManager()
self.computer_name = None # 从配置读取
self.computer_names = self._get_computer_names()
self.computer_name_pattern = re.compile(COMPUTER_NAME_PATTERN)
@abstractmethod
async def sync(self):
"""执行同步"""
pass
def _get_computer_names(self) -> List[str]:
"""获取计算机名列表"""
if ',' in COMPUTER_NAMES:
return [name.strip() for name in COMPUTER_NAMES.split(',')]
return [COMPUTER_NAMES.strip()]
def get_accounts_from_redis(self) -> Dict[str, Dict]:
"""从Redis获取账号配置"""
"""从Redis获取所有计算机名的账号配置"""
try:
if self.computer_name is None:
from config.settings import COMPUTER_NAME
self.computer_name = COMPUTER_NAME
accounts_dict = {}
# 方法1使用配置的计算机名列表
for computer_name in self.computer_names:
accounts = self._get_accounts_by_computer_name(computer_name)
accounts_dict.update(accounts)
# 方法2自动发现所有匹配的key备用方案
if not accounts_dict:
accounts_dict = self._discover_all_accounts()
logger.info(f"{len(self.computer_names)} 个计算机名获取到 {len(accounts_dict)} 个账号")
return accounts_dict
except Exception as e:
logger.error(f"获取账户信息失败: {e}")
return {}
def _get_accounts_by_computer_name(self, computer_name: str) -> Dict[str, Dict]:
"""获取指定计算机名的账号"""
accounts_dict = {}
try:
# 构建key
redis_key = f"{computer_name}_strategy_api"
# 从Redis获取数据
result = self.redis_client.client.hgetall(f"{self.computer_name}_strategy_api")
result = self.redis_client.client.hgetall(redis_key)
if not result:
logger.warning(f"未找到 {self.computer_name} 的策略API配置")
logger.debug(f"未找到 {redis_key} 的策略API配置")
return {}
accounts_dict = {}
for exchange_name, accounts_json in result.items():
try:
accounts = json.loads(accounts_json)
@@ -45,46 +70,52 @@ class BaseSync(ABC):
for account_id, account_info in accounts.items():
parsed_account = self.parse_account(exchange_id, account_id, account_info)
if parsed_account:
# 添加计算机名标记
parsed_account['computer_name'] = computer_name
accounts_dict[account_id] = parsed_account
except json.JSONDecodeError as e:
logger.error(f"解析交易所 {exchange_name} 的JSON数据失败: {e}")
continue
return accounts_dict
logger.info(f"{redis_key} 获取到 {len(accounts_dict)} 个账号")
except Exception as e:
logger.error(f"获取账户信息失败: {e}")
return {}
logger.error(f"获取计算机名 {computer_name} 的账号失败: {e}")
def format_exchange_id(self, key: str) -> str:
"""格式化交易所ID"""
key = key.lower().strip()
return accounts_dict
# 交易所名称映射
exchange_mapping = {
'metatrader': 'mt5',
'binance_spot_test': 'binance',
'binance_spot': 'binance',
'binance': 'binance',
'gate_spot': 'gate',
'okex': 'okx'
}
def _discover_all_accounts(self) -> Dict[str, Dict]:
"""自动发现所有匹配的账号key"""
accounts_dict = {}
return exchange_mapping.get(key, key)
def parse_account(self, exchange_id: str, account_id: str, account_info: str) -> Dict:
"""解析账号信息"""
try:
source_account_info = json.loads(account_info)
account_data = {
'exchange_id': exchange_id,
'k_id': account_id,
'st_id': int(source_account_info.get('st_id', 0)),
'add_time': int(source_account_info.get('add_time', 0))
}
return {**source_account_info, **account_data}
# 获取所有匹配模式的key
pattern = f"*_strategy_api"
cursor = 0
except json.JSONDecodeError as e:
logger.error(f"解析账号 {account_id} 数据失败: {e}")
return {}
while True:
cursor, keys = self.redis_client.client.scan(cursor, match=pattern, count=100)
for key in keys:
key_str = key.decode('utf-8') if isinstance(key, bytes) else key
# 提取计算机名
computer_name = key_str.replace('_strategy_api', '')
# 验证计算机名格式
if self.computer_name_pattern.match(computer_name):
accounts = self._get_accounts_by_computer_name(computer_name)
accounts_dict.update(accounts)
if cursor == 0:
break
logger.info(f"自动发现 {len(accounts_dict)} 个账号")
except Exception as e:
logger.error(f"自动发现账号失败: {e}")
return accounts_dict
# 其他方法保持不变...

View File

@@ -1,8 +1,11 @@
import asyncio
from loguru import logger
from typing import List, Dict
from typing import List, Dict, Optional
import signal
import sys
from concurrent.futures import ThreadPoolExecutor
import time
from asyncio import Semaphore
from config.settings import SYNC_CONFIG
from .position_sync import PositionSync
@@ -10,14 +13,18 @@ from .order_sync import OrderSync
from .account_sync import AccountSync
class SyncManager:
"""同步管理器"""
"""同步管理器(支持批量并发处理)"""
def __init__(self):
self.is_running = True
self.sync_interval = SYNC_CONFIG['interval']
self.max_concurrent = int(os.getenv('MAX_CONCURRENT', '10')) # 最大并发数
# 初始化同步器
self.syncers = []
self.executor = ThreadPoolExecutor(max_workers=self.max_concurrent)
self.semaphore = Semaphore(self.max_concurrent) # 控制并发数
if SYNC_CONFIG['enable_position_sync']:
self.syncers.append(PositionSync())
@@ -31,26 +38,62 @@ class SyncManager:
self.syncers.append(AccountSync())
logger.info("启用账户信息同步")
# 性能统计
self.stats = {
'total_accounts': 0,
'success_count': 0,
'error_count': 0,
'last_sync_time': 0,
'avg_sync_time': 0
}
# 注册信号处理器
signal.signal(signal.SIGINT, self.signal_handler)
signal.signal(signal.SIGTERM, self.signal_handler)
async def _run_syncer_with_limit(self, syncer):
"""带并发限制的运行"""
async with self.semaphore:
return await self._run_syncer(syncer)
def signal_handler(self, signum, frame):
"""信号处理器"""
logger.info(f"接收到信号 {signum},正在关闭...")
self.is_running = False
def batch_process_accounts(self, accounts: Dict[str, Dict], batch_size: int = 100):
"""分批处理账号"""
account_items = list(accounts.items())
for i in range(0, len(account_items), batch_size):
batch = dict(account_items[i:i + batch_size])
# 处理这批账号
self._process_account_batch(batch)
# 批次间休息,避免数据库压力过大
time.sleep(0.1)
async def start(self):
"""启动同步服务"""
logger.info(f"同步服务启动,间隔 {self.sync_interval}")
logger.info(f"同步服务启动,间隔 {self.sync_interval},最大并发 {self.max_concurrent}")
while self.is_running:
try:
# 执行所有同步器
tasks = [syncer.sync() for syncer in self.syncers]
await asyncio.gather(*tasks, return_exceptions=True)
start_time = time.time()
logger.debug(f"同步完成,等待 {self.sync_interval}")
# 执行所有同步器
tasks = [self._run_syncer(syncer) for syncer in self.syncers]
results = await asyncio.gather(*tasks, return_exceptions=True)
# 更新统计
sync_time = time.time() - start_time
self.stats['last_sync_time'] = sync_time
self.stats['avg_sync_time'] = (self.stats['avg_sync_time'] * 0.9 + sync_time * 0.1)
# 打印统计信息
self._print_stats()
logger.debug(f"同步完成,耗时 {sync_time:.2f} 秒,等待 {self.sync_interval}")
await asyncio.sleep(self.sync_interval)
except asyncio.CancelledError:
@@ -58,9 +101,41 @@ class SyncManager:
break
except Exception as e:
logger.error(f"同步任务异常: {e}")
self.stats['error_count'] += 1
await asyncio.sleep(30) # 出错后等待30秒
async def _run_syncer(self, syncer):
"""运行单个同步器"""
try:
# 获取所有账号
accounts = syncer.get_accounts_from_redis()
self.stats['total_accounts'] = len(accounts)
if not accounts:
logger.warning("未获取到任何账号")
return
# 批量处理账号
await syncer.sync_batch(accounts)
self.stats['success_count'] += 1
except Exception as e:
logger.error(f"同步器 {syncer.__class__.__name__} 执行失败: {e}")
self.stats['error_count'] += 1
def _print_stats(self):
"""打印统计信息"""
stats_str = (
f"统计: 账号数={self.stats['total_accounts']}, "
f"成功={self.stats['success_count']}, "
f"失败={self.stats['error_count']}, "
f"本次耗时={self.stats['last_sync_time']:.2f}s, "
f"平均耗时={self.stats['avg_sync_time']:.2f}s"
)
logger.info(stats_str)
async def stop(self):
"""停止同步服务"""
self.is_running = False
self.executor.shutdown(wait=True)
logger.info("同步服务停止")

View File

@@ -2,173 +2,190 @@ from .base_sync import BaseSync
from loguru import logger
from typing import List, Dict
import json
from datetime import datetime, timedelta
import asyncio
from concurrent.futures import ThreadPoolExecutor
class PositionSync(BaseSync):
"""持仓数据同步器"""
"""持仓数据同步器(批量版本)"""
async def sync(self):
"""同步持仓数据"""
def __init__(self):
super().__init__()
self.max_concurrent = 10 # 每个同步器的最大并发数
async def sync_batch(self, accounts: Dict[str, Dict]):
"""批量同步所有账号的持仓数据"""
try:
# 获取所有账号
accounts = self.get_accounts_from_redis()
logger.info(f"开始批量同步持仓数据,共 {len(accounts)} 个账号")
for k_id_str, account_info in accounts.items():
try:
k_id = int(k_id_str)
st_id = account_info.get('st_id', 0)
exchange_id = account_info['exchange_id']
# 按账号分组
account_groups = self._group_accounts_by_exchange(accounts)
if k_id <= 0 or st_id <= 0:
continue
# 并发处理每个交易所的账号
tasks = []
for exchange_id, account_list in account_groups.items():
task = self._sync_exchange_accounts(exchange_id, account_list)
tasks.append(task)
# 从Redis获取持仓数据
positions = await self._get_positions_from_redis(k_id, exchange_id)
# 等待所有任务完成
results = await asyncio.gather(*tasks, return_exceptions=True)
# 同步到数据库
if positions:
success = self._sync_positions_to_db(k_id, st_id, positions)
if success:
logger.debug(f"持仓同步成功: k_id={k_id}, 持仓数={len(positions)}")
except Exception as e:
logger.error(f"同步账号 {k_id_str} 持仓失败: {e}")
continue
logger.info("持仓数据同步完成")
# 统计结果
success_count = sum(1 for r in results if isinstance(r, bool) and r)
logger.info(f"持仓批量同步完成: 成功 {success_count}/{len(results)} 个交易所组")
except Exception as e:
logger.error(f"持仓同步失败: {e}")
logger.error(f"持仓批量同步失败: {e}")
async def _get_positions_from_redis(self, k_id: int, exchange_id: str) -> List[Dict]:
"""从Redis获取持仓数据"""
def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]:
"""按交易所分组账号"""
groups = {}
for account_id, account_info in accounts.items():
exchange_id = account_info.get('exchange_id')
if exchange_id:
if exchange_id not in groups:
groups[exchange_id] = []
groups[exchange_id].append(account_info)
return groups
async def _sync_exchange_accounts(self, exchange_id: str, account_list: List[Dict]):
"""同步某个交易所的所有账号"""
try:
redis_key = f"{exchange_id}:positions:{k_id}"
redis_data = self.redis_client.client.hget(redis_key, 'positions')
# 收集所有账号的持仓数据
all_positions = []
if not redis_data:
return []
for account_info in account_list:
k_id = int(account_info['k_id'])
st_id = account_info.get('st_id', 0)
positions = json.loads(redis_data)
# 从Redis获取持仓数据
positions = await self._get_positions_from_redis(k_id, exchange_id)
# 添加账号信息
for position in positions:
position['k_id'] = k_id
if positions:
# 添加账号信息
for position in positions:
position['k_id'] = k_id
position['st_id'] = st_id
all_positions.extend(positions)
return positions
if not all_positions:
logger.debug(f"交易所 {exchange_id} 无持仓数据")
return True
# 批量同步到数据库
success = self._sync_positions_batch_to_db(all_positions)
if success:
logger.info(f"交易所 {exchange_id} 持仓同步成功: {len(all_positions)} 条持仓")
return success
except Exception as e:
logger.error(f"获取Redis持仓数据失败: k_id={k_id}, error={e}")
return []
logger.error(f"同步交易所 {exchange_id} 持仓失败: {e}")
return False
def _sync_positions_to_db(self, k_id: int, st_id: int, positions_data: List[Dict]) -> bool:
"""同步持仓数据到数据库"""
def _sync_positions_batch_to_db(self, all_positions: List[Dict]) -> bool:
"""批量同步持仓数据到数据库(优化版)"""
session = self.db_manager.get_session()
try:
# 使用批量优化方案
from sqlalchemy.dialects.mysql import insert
# 按k_id分组
positions_by_account = {}
for position in all_positions:
k_id = position['k_id']
if k_id not in positions_by_account:
positions_by_account[k_id] = []
positions_by_account[k_id].append(position)
# 准备数据
insert_data = []
keep_keys = set() # 需要保留的(symbol, side)组合
for pos_data in positions_data:
try:
# 转换数据(这里需要实现转换逻辑)
pos_dict = self._convert_position_data(pos_data)
if not all([pos_dict.get('symbol'), pos_dict.get('side')]):
continue
# 重命名qty为sum
if 'qty' in pos_dict:
pos_dict['sum'] = pos_dict.pop('qty')
insert_data.append(pos_dict)
keep_keys.add((pos_dict['symbol'], pos_dict['side']))
except Exception as e:
logger.error(f"转换持仓数据失败: {pos_data}, error={e}")
continue
success_count = 0
with session.begin():
if not insert_data:
# 清空该账号持仓
session.execute(
delete(StrategyPosition).where(
and_(
StrategyPosition.k_id == k_id,
StrategyPosition.st_id == st_id
)
)
)
return True
for k_id, positions in positions_by_account.items():
try:
st_id = positions[0]['st_id'] if positions else 0
# 批量插入/更新
stmt = insert(StrategyPosition.__table__).values(insert_data)
# 准备数据
insert_data = []
keep_keys = set()
update_dict = {
'price': stmt.inserted.price,
'sum': stmt.inserted.sum,
'asset_num': stmt.inserted.asset_num,
'asset_profit': stmt.inserted.asset_profit,
'leverage': stmt.inserted.leverage,
'uptime': stmt.inserted.uptime,
'profit_price': stmt.inserted.profit_price,
'stop_price': stmt.inserted.stop_price,
'liquidation_price': stmt.inserted.liquidation_price
}
for pos_data in positions:
try:
pos_dict = self._convert_position_data(pos_data)
if not all([pos_dict.get('symbol'), pos_dict.get('side')]):
continue
stmt = stmt.on_duplicate_key_update(**update_dict)
session.execute(stmt)
# 重命名qty为sum
if 'qty' in pos_dict:
pos_dict['sum'] = pos_dict.pop('qty')
# 删除多余持仓
if keep_keys:
existing_positions = session.execute(
select(StrategyPosition).where(
and_(
StrategyPosition.k_id == k_id,
StrategyPosition.st_id == st_id
)
)
).scalars().all()
insert_data.append(pos_dict)
keep_keys.add((pos_dict['symbol'], pos_dict['side']))
to_delete_ids = []
for existing in existing_positions:
key = (existing.symbol, existing.side)
if key not in keep_keys:
to_delete_ids.append(existing.id)
except Exception as e:
logger.error(f"转换持仓数据失败: {pos_data}, error={e}")
continue
if to_delete_ids:
session.execute(
delete(StrategyPosition).where(
StrategyPosition.id.in_(to_delete_ids)
)
)
if not insert_data:
continue
return True
# 批量插入/更新
from sqlalchemy.dialects.mysql import insert
stmt = insert(StrategyPosition.__table__).values(insert_data)
update_dict = {
'price': stmt.inserted.price,
'sum': stmt.inserted.sum,
'asset_num': stmt.inserted.asset_num,
'asset_profit': stmt.inserted.asset_profit,
'leverage': stmt.inserted.leverage,
'uptime': stmt.inserted.uptime,
'profit_price': stmt.inserted.profit_price,
'stop_price': stmt.inserted.stop_price,
'liquidation_price': stmt.inserted.liquidation_price
}
stmt = stmt.on_duplicate_key_update(**update_dict)
session.execute(stmt)
# 删除多余持仓
if keep_keys:
existing_positions = session.execute(
select(StrategyPosition).where(
and_(
StrategyPosition.k_id == k_id,
StrategyPosition.st_id == st_id
)
)
).scalars().all()
to_delete_ids = []
for existing in existing_positions:
key = (existing.symbol, existing.side)
if key not in keep_keys:
to_delete_ids.append(existing.id)
if to_delete_ids:
# 分块删除
chunk_size = 100
for i in range(0, len(to_delete_ids), chunk_size):
chunk = to_delete_ids[i:i + chunk_size]
session.execute(
delete(StrategyPosition).where(
StrategyPosition.id.in_(chunk)
)
)
success_count += 1
except Exception as e:
logger.error(f"同步账号 {k_id} 持仓失败: {e}")
continue
logger.info(f"批量同步完成: 成功 {success_count}/{len(positions_by_account)} 个账号")
return success_count > 0
except Exception as e:
logger.error(f"同步持仓到数据库失败: k_id={k_id}, error={e}")
logger.error(f"批量同步持仓到数据库失败: {e}")
return False
finally:
session.close()
def _convert_position_data(self, data: Dict) -> Dict:
"""转换持仓数据格式"""
# 这里实现具体的转换逻辑
return {
'st_id': int(data.get('st_id', 0)),
'k_id': int(data.get('k_id', 0)),
'asset': 'USDT',
'symbol': data.get('symbol', ''),
'side': data.get('side', ''),
'price': float(data.get('price', 0)) if data.get('price') is not None else None,
'qty': float(data.get('qty', 0)) if data.get('qty') is not None else None,
'asset_num': float(data.get('asset_num', 0)) if data.get('asset_num') is not None else None,
'asset_profit': float(data.get('asset_profit', 0)) if data.get('asset_profit') is not None else None,
'leverage': int(data.get('leverage', 0)) if data.get('leverage') is not None else None,
'uptime': int(data.get('uptime', 0)) if data.get('uptime') is not None else None,
'profit_price': float(data.get('profit_price', 0)) if data.get('profit_price') is not None else None,
'stop_price': float(data.get('stop_price', 0)) if data.get('stop_price') is not None else None,
'liquidation_price': float(data.get('liquidation_price', 0)) if data.get('liquidation_price') is not None else None
}
# 其他方法保持不变...

138
utils/batch_operations.py Normal file
View File

@@ -0,0 +1,138 @@
from typing import List, Dict, Any, Tuple
from loguru import logger
from sqlalchemy import text
from .database_manager import DatabaseManager
class BatchOperations:
"""批量数据库操作工具"""
def __init__(self):
self.db_manager = DatabaseManager()
def batch_insert_update_positions(self, positions_data: List[Dict]) -> Tuple[int, int]:
"""批量插入/更新持仓数据"""
session = self.db_manager.get_session()
try:
if not positions_data:
return 0, 0
# 按账号分组
positions_by_account = {}
for position in positions_data:
k_id = position.get('k_id')
if k_id not in positions_by_account:
positions_by_account[k_id] = []
positions_by_account[k_id].append(position)
total_processed = 0
total_deleted = 0
with session.begin():
for k_id, positions in positions_by_account.items():
processed, deleted = self._process_account_positions(session, k_id, positions)
total_processed += processed
total_deleted += deleted
logger.info(f"批量处理持仓完成: 处理 {total_processed} 条,删除 {total_deleted}")
return total_processed, total_deleted
except Exception as e:
logger.error(f"批量处理持仓失败: {e}")
return 0, 0
finally:
session.close()
def _process_account_positions(self, session, k_id: int, positions: List[Dict]) -> Tuple[int, int]:
"""处理单个账号的持仓数据"""
try:
st_id = positions[0].get('st_id', 0) if positions else 0
# 准备数据
insert_data = []
keep_keys = set()
for pos_data in positions:
# 转换数据
pos_dict = self._convert_position_data(pos_data)
if not all([pos_dict.get('symbol'), pos_dict.get('side')]):
continue
# 重命名qty为sum
if 'qty' in pos_dict:
pos_dict['sum'] = pos_dict.pop('qty')
insert_data.append(pos_dict)
keep_keys.add((pos_dict['symbol'], pos_dict['side']))
if not insert_data:
# 清空该账号持仓
result = session.execute(
text("DELETE FROM deh_strategy_position_new WHERE k_id = :k_id AND st_id = :st_id"),
{'k_id': k_id, 'st_id': st_id}
)
return 0, result.rowcount
# 批量插入/更新
sql = """
INSERT INTO deh_strategy_position_new
(st_id, k_id, asset, symbol, side, price, `sum`,
asset_num, asset_profit, leverage, uptime,
profit_price, stop_price, liquidation_price)
VALUES
(:st_id, :k_id, :asset, :symbol, :side, :price, :sum,
:asset_num, :asset_profit, :leverage, :uptime,
:profit_price, :stop_price, :liquidation_price)
ON DUPLICATE KEY UPDATE
price = VALUES(price),
`sum` = VALUES(`sum`),
asset_num = VALUES(asset_num),
asset_profit = VALUES(asset_profit),
leverage = VALUES(leverage),
uptime = VALUES(uptime),
profit_price = VALUES(profit_price),
stop_price = VALUES(stop_price),
liquidation_price = VALUES(liquidation_price)
"""
# 分块执行
chunk_size = 500
processed_count = 0
for i in range(0, len(insert_data), chunk_size):
chunk = insert_data[i:i + chunk_size]
session.execute(text(sql), chunk)
processed_count += len(chunk)
# 删除多余持仓
deleted_count = 0
if keep_keys:
# 构建删除条件
conditions = []
for symbol, side in keep_keys:
safe_symbol = symbol.replace("'", "''") if symbol else ''
safe_side = side.replace("'", "''") if side else ''
conditions.append(f"(symbol = '{safe_symbol}' AND side = '{safe_side}')")
if conditions:
conditions_str = " OR ".join(conditions)
delete_sql = f"""
DELETE FROM deh_strategy_position_new
WHERE k_id = {k_id} AND st_id = {st_id}
AND NOT ({conditions_str})
"""
result = session.execute(text(delete_sql))
deleted_count = result.rowcount
return processed_count, deleted_count
except Exception as e:
logger.error(f"处理账号 {k_id} 持仓失败: {e}")
return 0, 0
def _convert_position_data(self, data: Dict) -> Dict:
"""转换持仓数据格式"""
# 转换逻辑...
pass
# 类似的批量方法 for orders and account info...