1
This commit is contained in:
10
.env
10
.env
@@ -31,5 +31,11 @@ LOG_LEVEL=INFO
|
||||
LOG_ROTATION=10 MB
|
||||
LOG_RETENTION=7 days
|
||||
|
||||
# 计算机名(用于过滤账号)
|
||||
COMPUTER_NAME=lz_c01
|
||||
|
||||
# 计算机名配置(支持多个)
|
||||
COMPUTER_NAMES=lz_c01,lz_c02,lz_c03
|
||||
# 或者使用模式匹配
|
||||
COMPUTER_NAME_PATTERN=^lz_c\d{2}$
|
||||
|
||||
# 并发配置
|
||||
MAX_CONCURRENT=10
|
||||
@@ -31,5 +31,7 @@ LOG_CONFIG = {
|
||||
'format': '{time:YYYY-MM-DD HH:mm:ss} | {level} | {name}:{function}:{line} - {message}'
|
||||
}
|
||||
|
||||
# 计算机名配置(用于过滤账号)
|
||||
COMPUTER_NAME = os.getenv('COMPUTER_NAME', 'lz_c01')
|
||||
# 计算机名配置(支持多个,用逗号分隔)
|
||||
COMPUTER_NAMES = os.getenv('COMPUTER_NAMES', 'lz_c01')
|
||||
# 或者使用模式匹配
|
||||
COMPUTER_NAME_PATTERN = os.getenv('COMPUTER_NAME_PATTERN', r'^lz_c\d{2}$')
|
||||
@@ -1,10 +1,12 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from loguru import logger
|
||||
from typing import List, Dict, Any
|
||||
from typing import List, Dict, Any, Set
|
||||
import json
|
||||
import re
|
||||
|
||||
from utils.redis_client import RedisClient
|
||||
from utils.database_manager import DatabaseManager
|
||||
from config.settings import COMPUTER_NAMES, COMPUTER_NAME_PATTERN
|
||||
|
||||
class BaseSync(ABC):
|
||||
"""同步基类"""
|
||||
@@ -12,27 +14,50 @@ class BaseSync(ABC):
|
||||
def __init__(self):
|
||||
self.redis_client = RedisClient()
|
||||
self.db_manager = DatabaseManager()
|
||||
self.computer_name = None # 从配置读取
|
||||
self.computer_names = self._get_computer_names()
|
||||
self.computer_name_pattern = re.compile(COMPUTER_NAME_PATTERN)
|
||||
|
||||
@abstractmethod
|
||||
async def sync(self):
|
||||
"""执行同步"""
|
||||
pass
|
||||
def _get_computer_names(self) -> List[str]:
|
||||
"""获取计算机名列表"""
|
||||
if ',' in COMPUTER_NAMES:
|
||||
return [name.strip() for name in COMPUTER_NAMES.split(',')]
|
||||
return [COMPUTER_NAMES.strip()]
|
||||
|
||||
def get_accounts_from_redis(self) -> Dict[str, Dict]:
|
||||
"""从Redis获取账号配置"""
|
||||
"""从Redis获取所有计算机名的账号配置"""
|
||||
try:
|
||||
if self.computer_name is None:
|
||||
from config.settings import COMPUTER_NAME
|
||||
self.computer_name = COMPUTER_NAME
|
||||
accounts_dict = {}
|
||||
|
||||
# 方法1:使用配置的计算机名列表
|
||||
for computer_name in self.computer_names:
|
||||
accounts = self._get_accounts_by_computer_name(computer_name)
|
||||
accounts_dict.update(accounts)
|
||||
|
||||
# 方法2:自动发现所有匹配的key(备用方案)
|
||||
if not accounts_dict:
|
||||
accounts_dict = self._discover_all_accounts()
|
||||
|
||||
logger.info(f"从 {len(self.computer_names)} 个计算机名获取到 {len(accounts_dict)} 个账号")
|
||||
return accounts_dict
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取账户信息失败: {e}")
|
||||
return {}
|
||||
|
||||
def _get_accounts_by_computer_name(self, computer_name: str) -> Dict[str, Dict]:
|
||||
"""获取指定计算机名的账号"""
|
||||
accounts_dict = {}
|
||||
|
||||
try:
|
||||
# 构建key
|
||||
redis_key = f"{computer_name}_strategy_api"
|
||||
|
||||
# 从Redis获取数据
|
||||
result = self.redis_client.client.hgetall(f"{self.computer_name}_strategy_api")
|
||||
result = self.redis_client.client.hgetall(redis_key)
|
||||
if not result:
|
||||
logger.warning(f"未找到 {self.computer_name} 的策略API配置")
|
||||
logger.debug(f"未找到 {redis_key} 的策略API配置")
|
||||
return {}
|
||||
|
||||
accounts_dict = {}
|
||||
for exchange_name, accounts_json in result.items():
|
||||
try:
|
||||
accounts = json.loads(accounts_json)
|
||||
@@ -45,46 +70,52 @@ class BaseSync(ABC):
|
||||
for account_id, account_info in accounts.items():
|
||||
parsed_account = self.parse_account(exchange_id, account_id, account_info)
|
||||
if parsed_account:
|
||||
# 添加计算机名标记
|
||||
parsed_account['computer_name'] = computer_name
|
||||
accounts_dict[account_id] = parsed_account
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"解析交易所 {exchange_name} 的JSON数据失败: {e}")
|
||||
continue
|
||||
|
||||
return accounts_dict
|
||||
logger.info(f"从 {redis_key} 获取到 {len(accounts_dict)} 个账号")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取账户信息失败: {e}")
|
||||
return {}
|
||||
logger.error(f"获取计算机名 {computer_name} 的账号失败: {e}")
|
||||
|
||||
def format_exchange_id(self, key: str) -> str:
|
||||
"""格式化交易所ID"""
|
||||
key = key.lower().strip()
|
||||
return accounts_dict
|
||||
|
||||
# 交易所名称映射
|
||||
exchange_mapping = {
|
||||
'metatrader': 'mt5',
|
||||
'binance_spot_test': 'binance',
|
||||
'binance_spot': 'binance',
|
||||
'binance': 'binance',
|
||||
'gate_spot': 'gate',
|
||||
'okex': 'okx'
|
||||
}
|
||||
def _discover_all_accounts(self) -> Dict[str, Dict]:
|
||||
"""自动发现所有匹配的账号key"""
|
||||
accounts_dict = {}
|
||||
|
||||
return exchange_mapping.get(key, key)
|
||||
|
||||
def parse_account(self, exchange_id: str, account_id: str, account_info: str) -> Dict:
|
||||
"""解析账号信息"""
|
||||
try:
|
||||
source_account_info = json.loads(account_info)
|
||||
account_data = {
|
||||
'exchange_id': exchange_id,
|
||||
'k_id': account_id,
|
||||
'st_id': int(source_account_info.get('st_id', 0)),
|
||||
'add_time': int(source_account_info.get('add_time', 0))
|
||||
}
|
||||
return {**source_account_info, **account_data}
|
||||
# 获取所有匹配模式的key
|
||||
pattern = f"*_strategy_api"
|
||||
cursor = 0
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"解析账号 {account_id} 数据失败: {e}")
|
||||
return {}
|
||||
while True:
|
||||
cursor, keys = self.redis_client.client.scan(cursor, match=pattern, count=100)
|
||||
|
||||
for key in keys:
|
||||
key_str = key.decode('utf-8') if isinstance(key, bytes) else key
|
||||
|
||||
# 提取计算机名
|
||||
computer_name = key_str.replace('_strategy_api', '')
|
||||
|
||||
# 验证计算机名格式
|
||||
if self.computer_name_pattern.match(computer_name):
|
||||
accounts = self._get_accounts_by_computer_name(computer_name)
|
||||
accounts_dict.update(accounts)
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
logger.info(f"自动发现 {len(accounts_dict)} 个账号")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"自动发现账号失败: {e}")
|
||||
|
||||
return accounts_dict
|
||||
|
||||
# 其他方法保持不变...
|
||||
@@ -1,8 +1,11 @@
|
||||
import asyncio
|
||||
from loguru import logger
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Optional
|
||||
import signal
|
||||
import sys
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import time
|
||||
from asyncio import Semaphore
|
||||
|
||||
from config.settings import SYNC_CONFIG
|
||||
from .position_sync import PositionSync
|
||||
@@ -10,14 +13,18 @@ from .order_sync import OrderSync
|
||||
from .account_sync import AccountSync
|
||||
|
||||
class SyncManager:
|
||||
"""同步管理器"""
|
||||
"""同步管理器(支持批量并发处理)"""
|
||||
|
||||
def __init__(self):
|
||||
self.is_running = True
|
||||
self.sync_interval = SYNC_CONFIG['interval']
|
||||
self.max_concurrent = int(os.getenv('MAX_CONCURRENT', '10')) # 最大并发数
|
||||
|
||||
# 初始化同步器
|
||||
self.syncers = []
|
||||
self.executor = ThreadPoolExecutor(max_workers=self.max_concurrent)
|
||||
|
||||
self.semaphore = Semaphore(self.max_concurrent) # 控制并发数
|
||||
|
||||
if SYNC_CONFIG['enable_position_sync']:
|
||||
self.syncers.append(PositionSync())
|
||||
@@ -31,26 +38,62 @@ class SyncManager:
|
||||
self.syncers.append(AccountSync())
|
||||
logger.info("启用账户信息同步")
|
||||
|
||||
# 性能统计
|
||||
self.stats = {
|
||||
'total_accounts': 0,
|
||||
'success_count': 0,
|
||||
'error_count': 0,
|
||||
'last_sync_time': 0,
|
||||
'avg_sync_time': 0
|
||||
}
|
||||
|
||||
# 注册信号处理器
|
||||
signal.signal(signal.SIGINT, self.signal_handler)
|
||||
signal.signal(signal.SIGTERM, self.signal_handler)
|
||||
|
||||
async def _run_syncer_with_limit(self, syncer):
|
||||
"""带并发限制的运行"""
|
||||
async with self.semaphore:
|
||||
return await self._run_syncer(syncer)
|
||||
|
||||
def signal_handler(self, signum, frame):
|
||||
"""信号处理器"""
|
||||
logger.info(f"接收到信号 {signum},正在关闭...")
|
||||
self.is_running = False
|
||||
|
||||
def batch_process_accounts(self, accounts: Dict[str, Dict], batch_size: int = 100):
|
||||
"""分批处理账号"""
|
||||
account_items = list(accounts.items())
|
||||
|
||||
for i in range(0, len(account_items), batch_size):
|
||||
batch = dict(account_items[i:i + batch_size])
|
||||
# 处理这批账号
|
||||
self._process_account_batch(batch)
|
||||
|
||||
# 批次间休息,避免数据库压力过大
|
||||
time.sleep(0.1)
|
||||
|
||||
async def start(self):
|
||||
"""启动同步服务"""
|
||||
logger.info(f"同步服务启动,间隔 {self.sync_interval} 秒")
|
||||
logger.info(f"同步服务启动,间隔 {self.sync_interval} 秒,最大并发 {self.max_concurrent}")
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
# 执行所有同步器
|
||||
tasks = [syncer.sync() for syncer in self.syncers]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
start_time = time.time()
|
||||
|
||||
logger.debug(f"同步完成,等待 {self.sync_interval} 秒")
|
||||
# 执行所有同步器
|
||||
tasks = [self._run_syncer(syncer) for syncer in self.syncers]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 更新统计
|
||||
sync_time = time.time() - start_time
|
||||
self.stats['last_sync_time'] = sync_time
|
||||
self.stats['avg_sync_time'] = (self.stats['avg_sync_time'] * 0.9 + sync_time * 0.1)
|
||||
|
||||
# 打印统计信息
|
||||
self._print_stats()
|
||||
|
||||
logger.debug(f"同步完成,耗时 {sync_time:.2f} 秒,等待 {self.sync_interval} 秒")
|
||||
await asyncio.sleep(self.sync_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
@@ -58,9 +101,41 @@ class SyncManager:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"同步任务异常: {e}")
|
||||
self.stats['error_count'] += 1
|
||||
await asyncio.sleep(30) # 出错后等待30秒
|
||||
|
||||
async def _run_syncer(self, syncer):
|
||||
"""运行单个同步器"""
|
||||
try:
|
||||
# 获取所有账号
|
||||
accounts = syncer.get_accounts_from_redis()
|
||||
self.stats['total_accounts'] = len(accounts)
|
||||
|
||||
if not accounts:
|
||||
logger.warning("未获取到任何账号")
|
||||
return
|
||||
|
||||
# 批量处理账号
|
||||
await syncer.sync_batch(accounts)
|
||||
self.stats['success_count'] += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"同步器 {syncer.__class__.__name__} 执行失败: {e}")
|
||||
self.stats['error_count'] += 1
|
||||
|
||||
def _print_stats(self):
|
||||
"""打印统计信息"""
|
||||
stats_str = (
|
||||
f"统计: 账号数={self.stats['total_accounts']}, "
|
||||
f"成功={self.stats['success_count']}, "
|
||||
f"失败={self.stats['error_count']}, "
|
||||
f"本次耗时={self.stats['last_sync_time']:.2f}s, "
|
||||
f"平均耗时={self.stats['avg_sync_time']:.2f}s"
|
||||
)
|
||||
logger.info(stats_str)
|
||||
|
||||
async def stop(self):
|
||||
"""停止同步服务"""
|
||||
self.is_running = False
|
||||
self.executor.shutdown(wait=True)
|
||||
logger.info("同步服务停止")
|
||||
@@ -2,173 +2,190 @@ from .base_sync import BaseSync
|
||||
from loguru import logger
|
||||
from typing import List, Dict
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
class PositionSync(BaseSync):
|
||||
"""持仓数据同步器"""
|
||||
"""持仓数据同步器(批量版本)"""
|
||||
|
||||
async def sync(self):
|
||||
"""同步持仓数据"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max_concurrent = 10 # 每个同步器的最大并发数
|
||||
|
||||
async def sync_batch(self, accounts: Dict[str, Dict]):
|
||||
"""批量同步所有账号的持仓数据"""
|
||||
try:
|
||||
# 获取所有账号
|
||||
accounts = self.get_accounts_from_redis()
|
||||
logger.info(f"开始批量同步持仓数据,共 {len(accounts)} 个账号")
|
||||
|
||||
for k_id_str, account_info in accounts.items():
|
||||
try:
|
||||
k_id = int(k_id_str)
|
||||
st_id = account_info.get('st_id', 0)
|
||||
exchange_id = account_info['exchange_id']
|
||||
# 按账号分组
|
||||
account_groups = self._group_accounts_by_exchange(accounts)
|
||||
|
||||
if k_id <= 0 or st_id <= 0:
|
||||
continue
|
||||
# 并发处理每个交易所的账号
|
||||
tasks = []
|
||||
for exchange_id, account_list in account_groups.items():
|
||||
task = self._sync_exchange_accounts(exchange_id, account_list)
|
||||
tasks.append(task)
|
||||
|
||||
# 从Redis获取持仓数据
|
||||
positions = await self._get_positions_from_redis(k_id, exchange_id)
|
||||
# 等待所有任务完成
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 同步到数据库
|
||||
if positions:
|
||||
success = self._sync_positions_to_db(k_id, st_id, positions)
|
||||
if success:
|
||||
logger.debug(f"持仓同步成功: k_id={k_id}, 持仓数={len(positions)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"同步账号 {k_id_str} 持仓失败: {e}")
|
||||
continue
|
||||
|
||||
logger.info("持仓数据同步完成")
|
||||
# 统计结果
|
||||
success_count = sum(1 for r in results if isinstance(r, bool) and r)
|
||||
logger.info(f"持仓批量同步完成: 成功 {success_count}/{len(results)} 个交易所组")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"持仓同步失败: {e}")
|
||||
logger.error(f"持仓批量同步失败: {e}")
|
||||
|
||||
async def _get_positions_from_redis(self, k_id: int, exchange_id: str) -> List[Dict]:
|
||||
"""从Redis获取持仓数据"""
|
||||
def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]:
|
||||
"""按交易所分组账号"""
|
||||
groups = {}
|
||||
for account_id, account_info in accounts.items():
|
||||
exchange_id = account_info.get('exchange_id')
|
||||
if exchange_id:
|
||||
if exchange_id not in groups:
|
||||
groups[exchange_id] = []
|
||||
groups[exchange_id].append(account_info)
|
||||
return groups
|
||||
|
||||
async def _sync_exchange_accounts(self, exchange_id: str, account_list: List[Dict]):
|
||||
"""同步某个交易所的所有账号"""
|
||||
try:
|
||||
redis_key = f"{exchange_id}:positions:{k_id}"
|
||||
redis_data = self.redis_client.client.hget(redis_key, 'positions')
|
||||
# 收集所有账号的持仓数据
|
||||
all_positions = []
|
||||
|
||||
if not redis_data:
|
||||
return []
|
||||
for account_info in account_list:
|
||||
k_id = int(account_info['k_id'])
|
||||
st_id = account_info.get('st_id', 0)
|
||||
|
||||
positions = json.loads(redis_data)
|
||||
# 从Redis获取持仓数据
|
||||
positions = await self._get_positions_from_redis(k_id, exchange_id)
|
||||
|
||||
# 添加账号信息
|
||||
for position in positions:
|
||||
position['k_id'] = k_id
|
||||
if positions:
|
||||
# 添加账号信息
|
||||
for position in positions:
|
||||
position['k_id'] = k_id
|
||||
position['st_id'] = st_id
|
||||
all_positions.extend(positions)
|
||||
|
||||
return positions
|
||||
if not all_positions:
|
||||
logger.debug(f"交易所 {exchange_id} 无持仓数据")
|
||||
return True
|
||||
|
||||
# 批量同步到数据库
|
||||
success = self._sync_positions_batch_to_db(all_positions)
|
||||
if success:
|
||||
logger.info(f"交易所 {exchange_id} 持仓同步成功: {len(all_positions)} 条持仓")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取Redis持仓数据失败: k_id={k_id}, error={e}")
|
||||
return []
|
||||
logger.error(f"同步交易所 {exchange_id} 持仓失败: {e}")
|
||||
return False
|
||||
|
||||
def _sync_positions_to_db(self, k_id: int, st_id: int, positions_data: List[Dict]) -> bool:
|
||||
"""同步持仓数据到数据库"""
|
||||
def _sync_positions_batch_to_db(self, all_positions: List[Dict]) -> bool:
|
||||
"""批量同步持仓数据到数据库(优化版)"""
|
||||
session = self.db_manager.get_session()
|
||||
try:
|
||||
# 使用批量优化方案
|
||||
from sqlalchemy.dialects.mysql import insert
|
||||
# 按k_id分组
|
||||
positions_by_account = {}
|
||||
for position in all_positions:
|
||||
k_id = position['k_id']
|
||||
if k_id not in positions_by_account:
|
||||
positions_by_account[k_id] = []
|
||||
positions_by_account[k_id].append(position)
|
||||
|
||||
# 准备数据
|
||||
insert_data = []
|
||||
keep_keys = set() # 需要保留的(symbol, side)组合
|
||||
|
||||
for pos_data in positions_data:
|
||||
try:
|
||||
# 转换数据(这里需要实现转换逻辑)
|
||||
pos_dict = self._convert_position_data(pos_data)
|
||||
if not all([pos_dict.get('symbol'), pos_dict.get('side')]):
|
||||
continue
|
||||
|
||||
# 重命名qty为sum
|
||||
if 'qty' in pos_dict:
|
||||
pos_dict['sum'] = pos_dict.pop('qty')
|
||||
|
||||
insert_data.append(pos_dict)
|
||||
keep_keys.add((pos_dict['symbol'], pos_dict['side']))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"转换持仓数据失败: {pos_data}, error={e}")
|
||||
continue
|
||||
success_count = 0
|
||||
|
||||
with session.begin():
|
||||
if not insert_data:
|
||||
# 清空该账号持仓
|
||||
session.execute(
|
||||
delete(StrategyPosition).where(
|
||||
and_(
|
||||
StrategyPosition.k_id == k_id,
|
||||
StrategyPosition.st_id == st_id
|
||||
)
|
||||
)
|
||||
)
|
||||
return True
|
||||
for k_id, positions in positions_by_account.items():
|
||||
try:
|
||||
st_id = positions[0]['st_id'] if positions else 0
|
||||
|
||||
# 批量插入/更新
|
||||
stmt = insert(StrategyPosition.__table__).values(insert_data)
|
||||
# 准备数据
|
||||
insert_data = []
|
||||
keep_keys = set()
|
||||
|
||||
update_dict = {
|
||||
'price': stmt.inserted.price,
|
||||
'sum': stmt.inserted.sum,
|
||||
'asset_num': stmt.inserted.asset_num,
|
||||
'asset_profit': stmt.inserted.asset_profit,
|
||||
'leverage': stmt.inserted.leverage,
|
||||
'uptime': stmt.inserted.uptime,
|
||||
'profit_price': stmt.inserted.profit_price,
|
||||
'stop_price': stmt.inserted.stop_price,
|
||||
'liquidation_price': stmt.inserted.liquidation_price
|
||||
}
|
||||
for pos_data in positions:
|
||||
try:
|
||||
pos_dict = self._convert_position_data(pos_data)
|
||||
if not all([pos_dict.get('symbol'), pos_dict.get('side')]):
|
||||
continue
|
||||
|
||||
stmt = stmt.on_duplicate_key_update(**update_dict)
|
||||
session.execute(stmt)
|
||||
# 重命名qty为sum
|
||||
if 'qty' in pos_dict:
|
||||
pos_dict['sum'] = pos_dict.pop('qty')
|
||||
|
||||
# 删除多余持仓
|
||||
if keep_keys:
|
||||
existing_positions = session.execute(
|
||||
select(StrategyPosition).where(
|
||||
and_(
|
||||
StrategyPosition.k_id == k_id,
|
||||
StrategyPosition.st_id == st_id
|
||||
)
|
||||
)
|
||||
).scalars().all()
|
||||
insert_data.append(pos_dict)
|
||||
keep_keys.add((pos_dict['symbol'], pos_dict['side']))
|
||||
|
||||
to_delete_ids = []
|
||||
for existing in existing_positions:
|
||||
key = (existing.symbol, existing.side)
|
||||
if key not in keep_keys:
|
||||
to_delete_ids.append(existing.id)
|
||||
except Exception as e:
|
||||
logger.error(f"转换持仓数据失败: {pos_data}, error={e}")
|
||||
continue
|
||||
|
||||
if to_delete_ids:
|
||||
session.execute(
|
||||
delete(StrategyPosition).where(
|
||||
StrategyPosition.id.in_(to_delete_ids)
|
||||
)
|
||||
)
|
||||
if not insert_data:
|
||||
continue
|
||||
|
||||
return True
|
||||
# 批量插入/更新
|
||||
from sqlalchemy.dialects.mysql import insert
|
||||
|
||||
stmt = insert(StrategyPosition.__table__).values(insert_data)
|
||||
|
||||
update_dict = {
|
||||
'price': stmt.inserted.price,
|
||||
'sum': stmt.inserted.sum,
|
||||
'asset_num': stmt.inserted.asset_num,
|
||||
'asset_profit': stmt.inserted.asset_profit,
|
||||
'leverage': stmt.inserted.leverage,
|
||||
'uptime': stmt.inserted.uptime,
|
||||
'profit_price': stmt.inserted.profit_price,
|
||||
'stop_price': stmt.inserted.stop_price,
|
||||
'liquidation_price': stmt.inserted.liquidation_price
|
||||
}
|
||||
|
||||
stmt = stmt.on_duplicate_key_update(**update_dict)
|
||||
session.execute(stmt)
|
||||
|
||||
# 删除多余持仓
|
||||
if keep_keys:
|
||||
existing_positions = session.execute(
|
||||
select(StrategyPosition).where(
|
||||
and_(
|
||||
StrategyPosition.k_id == k_id,
|
||||
StrategyPosition.st_id == st_id
|
||||
)
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
to_delete_ids = []
|
||||
for existing in existing_positions:
|
||||
key = (existing.symbol, existing.side)
|
||||
if key not in keep_keys:
|
||||
to_delete_ids.append(existing.id)
|
||||
|
||||
if to_delete_ids:
|
||||
# 分块删除
|
||||
chunk_size = 100
|
||||
for i in range(0, len(to_delete_ids), chunk_size):
|
||||
chunk = to_delete_ids[i:i + chunk_size]
|
||||
session.execute(
|
||||
delete(StrategyPosition).where(
|
||||
StrategyPosition.id.in_(chunk)
|
||||
)
|
||||
)
|
||||
|
||||
success_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"同步账号 {k_id} 持仓失败: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"批量同步完成: 成功 {success_count}/{len(positions_by_account)} 个账号")
|
||||
return success_count > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"同步持仓到数据库失败: k_id={k_id}, error={e}")
|
||||
logger.error(f"批量同步持仓到数据库失败: {e}")
|
||||
return False
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def _convert_position_data(self, data: Dict) -> Dict:
|
||||
"""转换持仓数据格式"""
|
||||
# 这里实现具体的转换逻辑
|
||||
return {
|
||||
'st_id': int(data.get('st_id', 0)),
|
||||
'k_id': int(data.get('k_id', 0)),
|
||||
'asset': 'USDT',
|
||||
'symbol': data.get('symbol', ''),
|
||||
'side': data.get('side', ''),
|
||||
'price': float(data.get('price', 0)) if data.get('price') is not None else None,
|
||||
'qty': float(data.get('qty', 0)) if data.get('qty') is not None else None,
|
||||
'asset_num': float(data.get('asset_num', 0)) if data.get('asset_num') is not None else None,
|
||||
'asset_profit': float(data.get('asset_profit', 0)) if data.get('asset_profit') is not None else None,
|
||||
'leverage': int(data.get('leverage', 0)) if data.get('leverage') is not None else None,
|
||||
'uptime': int(data.get('uptime', 0)) if data.get('uptime') is not None else None,
|
||||
'profit_price': float(data.get('profit_price', 0)) if data.get('profit_price') is not None else None,
|
||||
'stop_price': float(data.get('stop_price', 0)) if data.get('stop_price') is not None else None,
|
||||
'liquidation_price': float(data.get('liquidation_price', 0)) if data.get('liquidation_price') is not None else None
|
||||
}
|
||||
# 其他方法保持不变...
|
||||
138
utils/batch_operations.py
Normal file
138
utils/batch_operations.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from typing import List, Dict, Any, Tuple
|
||||
from loguru import logger
|
||||
from sqlalchemy import text
|
||||
from .database_manager import DatabaseManager
|
||||
|
||||
class BatchOperations:
|
||||
"""批量数据库操作工具"""
|
||||
|
||||
def __init__(self):
|
||||
self.db_manager = DatabaseManager()
|
||||
|
||||
def batch_insert_update_positions(self, positions_data: List[Dict]) -> Tuple[int, int]:
|
||||
"""批量插入/更新持仓数据"""
|
||||
session = self.db_manager.get_session()
|
||||
try:
|
||||
if not positions_data:
|
||||
return 0, 0
|
||||
|
||||
# 按账号分组
|
||||
positions_by_account = {}
|
||||
for position in positions_data:
|
||||
k_id = position.get('k_id')
|
||||
if k_id not in positions_by_account:
|
||||
positions_by_account[k_id] = []
|
||||
positions_by_account[k_id].append(position)
|
||||
|
||||
total_processed = 0
|
||||
total_deleted = 0
|
||||
|
||||
with session.begin():
|
||||
for k_id, positions in positions_by_account.items():
|
||||
processed, deleted = self._process_account_positions(session, k_id, positions)
|
||||
total_processed += processed
|
||||
total_deleted += deleted
|
||||
|
||||
logger.info(f"批量处理持仓完成: 处理 {total_processed} 条,删除 {total_deleted} 条")
|
||||
return total_processed, total_deleted
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量处理持仓失败: {e}")
|
||||
return 0, 0
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def _process_account_positions(self, session, k_id: int, positions: List[Dict]) -> Tuple[int, int]:
|
||||
"""处理单个账号的持仓数据"""
|
||||
try:
|
||||
st_id = positions[0].get('st_id', 0) if positions else 0
|
||||
|
||||
# 准备数据
|
||||
insert_data = []
|
||||
keep_keys = set()
|
||||
|
||||
for pos_data in positions:
|
||||
# 转换数据
|
||||
pos_dict = self._convert_position_data(pos_data)
|
||||
if not all([pos_dict.get('symbol'), pos_dict.get('side')]):
|
||||
continue
|
||||
|
||||
# 重命名qty为sum
|
||||
if 'qty' in pos_dict:
|
||||
pos_dict['sum'] = pos_dict.pop('qty')
|
||||
|
||||
insert_data.append(pos_dict)
|
||||
keep_keys.add((pos_dict['symbol'], pos_dict['side']))
|
||||
|
||||
if not insert_data:
|
||||
# 清空该账号持仓
|
||||
result = session.execute(
|
||||
text("DELETE FROM deh_strategy_position_new WHERE k_id = :k_id AND st_id = :st_id"),
|
||||
{'k_id': k_id, 'st_id': st_id}
|
||||
)
|
||||
return 0, result.rowcount
|
||||
|
||||
# 批量插入/更新
|
||||
sql = """
|
||||
INSERT INTO deh_strategy_position_new
|
||||
(st_id, k_id, asset, symbol, side, price, `sum`,
|
||||
asset_num, asset_profit, leverage, uptime,
|
||||
profit_price, stop_price, liquidation_price)
|
||||
VALUES
|
||||
(:st_id, :k_id, :asset, :symbol, :side, :price, :sum,
|
||||
:asset_num, :asset_profit, :leverage, :uptime,
|
||||
:profit_price, :stop_price, :liquidation_price)
|
||||
ON DUPLICATE KEY UPDATE
|
||||
price = VALUES(price),
|
||||
`sum` = VALUES(`sum`),
|
||||
asset_num = VALUES(asset_num),
|
||||
asset_profit = VALUES(asset_profit),
|
||||
leverage = VALUES(leverage),
|
||||
uptime = VALUES(uptime),
|
||||
profit_price = VALUES(profit_price),
|
||||
stop_price = VALUES(stop_price),
|
||||
liquidation_price = VALUES(liquidation_price)
|
||||
"""
|
||||
|
||||
# 分块执行
|
||||
chunk_size = 500
|
||||
processed_count = 0
|
||||
|
||||
for i in range(0, len(insert_data), chunk_size):
|
||||
chunk = insert_data[i:i + chunk_size]
|
||||
session.execute(text(sql), chunk)
|
||||
processed_count += len(chunk)
|
||||
|
||||
# 删除多余持仓
|
||||
deleted_count = 0
|
||||
if keep_keys:
|
||||
# 构建删除条件
|
||||
conditions = []
|
||||
for symbol, side in keep_keys:
|
||||
safe_symbol = symbol.replace("'", "''") if symbol else ''
|
||||
safe_side = side.replace("'", "''") if side else ''
|
||||
conditions.append(f"(symbol = '{safe_symbol}' AND side = '{safe_side}')")
|
||||
|
||||
if conditions:
|
||||
conditions_str = " OR ".join(conditions)
|
||||
delete_sql = f"""
|
||||
DELETE FROM deh_strategy_position_new
|
||||
WHERE k_id = {k_id} AND st_id = {st_id}
|
||||
AND NOT ({conditions_str})
|
||||
"""
|
||||
|
||||
result = session.execute(text(delete_sql))
|
||||
deleted_count = result.rowcount
|
||||
|
||||
return processed_count, deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理账号 {k_id} 持仓失败: {e}")
|
||||
return 0, 0
|
||||
|
||||
def _convert_position_data(self, data: Dict) -> Dict:
|
||||
"""转换持仓数据格式"""
|
||||
# 转换逻辑...
|
||||
pass
|
||||
|
||||
# 类似的批量方法 for orders and account info...
|
||||
Reference in New Issue
Block a user