1
This commit is contained in:
10
.env
10
.env
@@ -31,5 +31,11 @@ LOG_LEVEL=INFO
|
|||||||
LOG_ROTATION=10 MB
|
LOG_ROTATION=10 MB
|
||||||
LOG_RETENTION=7 days
|
LOG_RETENTION=7 days
|
||||||
|
|
||||||
# 计算机名(用于过滤账号)
|
|
||||||
COMPUTER_NAME=lz_c01
|
# 计算机名配置(支持多个)
|
||||||
|
COMPUTER_NAMES=lz_c01,lz_c02,lz_c03
|
||||||
|
# 或者使用模式匹配
|
||||||
|
COMPUTER_NAME_PATTERN=^lz_c\d{2}$
|
||||||
|
|
||||||
|
# 并发配置
|
||||||
|
MAX_CONCURRENT=10
|
||||||
@@ -31,5 +31,7 @@ LOG_CONFIG = {
|
|||||||
'format': '{time:YYYY-MM-DD HH:mm:ss} | {level} | {name}:{function}:{line} - {message}'
|
'format': '{time:YYYY-MM-DD HH:mm:ss} | {level} | {name}:{function}:{line} - {message}'
|
||||||
}
|
}
|
||||||
|
|
||||||
# 计算机名配置(用于过滤账号)
|
# 计算机名配置(支持多个,用逗号分隔)
|
||||||
COMPUTER_NAME = os.getenv('COMPUTER_NAME', 'lz_c01')
|
COMPUTER_NAMES = os.getenv('COMPUTER_NAMES', 'lz_c01')
|
||||||
|
# 或者使用模式匹配
|
||||||
|
COMPUTER_NAME_PATTERN = os.getenv('COMPUTER_NAME_PATTERN', r'^lz_c\d{2}$')
|
||||||
@@ -1,10 +1,12 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any, Set
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
from utils.redis_client import RedisClient
|
from utils.redis_client import RedisClient
|
||||||
from utils.database_manager import DatabaseManager
|
from utils.database_manager import DatabaseManager
|
||||||
|
from config.settings import COMPUTER_NAMES, COMPUTER_NAME_PATTERN
|
||||||
|
|
||||||
class BaseSync(ABC):
|
class BaseSync(ABC):
|
||||||
"""同步基类"""
|
"""同步基类"""
|
||||||
@@ -12,27 +14,50 @@ class BaseSync(ABC):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.redis_client = RedisClient()
|
self.redis_client = RedisClient()
|
||||||
self.db_manager = DatabaseManager()
|
self.db_manager = DatabaseManager()
|
||||||
self.computer_name = None # 从配置读取
|
self.computer_names = self._get_computer_names()
|
||||||
|
self.computer_name_pattern = re.compile(COMPUTER_NAME_PATTERN)
|
||||||
|
|
||||||
@abstractmethod
|
def _get_computer_names(self) -> List[str]:
|
||||||
async def sync(self):
|
"""获取计算机名列表"""
|
||||||
"""执行同步"""
|
if ',' in COMPUTER_NAMES:
|
||||||
pass
|
return [name.strip() for name in COMPUTER_NAMES.split(',')]
|
||||||
|
return [COMPUTER_NAMES.strip()]
|
||||||
|
|
||||||
def get_accounts_from_redis(self) -> Dict[str, Dict]:
|
def get_accounts_from_redis(self) -> Dict[str, Dict]:
|
||||||
"""从Redis获取账号配置"""
|
"""从Redis获取所有计算机名的账号配置"""
|
||||||
try:
|
try:
|
||||||
if self.computer_name is None:
|
accounts_dict = {}
|
||||||
from config.settings import COMPUTER_NAME
|
|
||||||
self.computer_name = COMPUTER_NAME
|
# 方法1:使用配置的计算机名列表
|
||||||
|
for computer_name in self.computer_names:
|
||||||
|
accounts = self._get_accounts_by_computer_name(computer_name)
|
||||||
|
accounts_dict.update(accounts)
|
||||||
|
|
||||||
|
# 方法2:自动发现所有匹配的key(备用方案)
|
||||||
|
if not accounts_dict:
|
||||||
|
accounts_dict = self._discover_all_accounts()
|
||||||
|
|
||||||
|
logger.info(f"从 {len(self.computer_names)} 个计算机名获取到 {len(accounts_dict)} 个账号")
|
||||||
|
return accounts_dict
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取账户信息失败: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _get_accounts_by_computer_name(self, computer_name: str) -> Dict[str, Dict]:
|
||||||
|
"""获取指定计算机名的账号"""
|
||||||
|
accounts_dict = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 构建key
|
||||||
|
redis_key = f"{computer_name}_strategy_api"
|
||||||
|
|
||||||
# 从Redis获取数据
|
# 从Redis获取数据
|
||||||
result = self.redis_client.client.hgetall(f"{self.computer_name}_strategy_api")
|
result = self.redis_client.client.hgetall(redis_key)
|
||||||
if not result:
|
if not result:
|
||||||
logger.warning(f"未找到 {self.computer_name} 的策略API配置")
|
logger.debug(f"未找到 {redis_key} 的策略API配置")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
accounts_dict = {}
|
|
||||||
for exchange_name, accounts_json in result.items():
|
for exchange_name, accounts_json in result.items():
|
||||||
try:
|
try:
|
||||||
accounts = json.loads(accounts_json)
|
accounts = json.loads(accounts_json)
|
||||||
@@ -45,46 +70,52 @@ class BaseSync(ABC):
|
|||||||
for account_id, account_info in accounts.items():
|
for account_id, account_info in accounts.items():
|
||||||
parsed_account = self.parse_account(exchange_id, account_id, account_info)
|
parsed_account = self.parse_account(exchange_id, account_id, account_info)
|
||||||
if parsed_account:
|
if parsed_account:
|
||||||
|
# 添加计算机名标记
|
||||||
|
parsed_account['computer_name'] = computer_name
|
||||||
accounts_dict[account_id] = parsed_account
|
accounts_dict[account_id] = parsed_account
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.error(f"解析交易所 {exchange_name} 的JSON数据失败: {e}")
|
logger.error(f"解析交易所 {exchange_name} 的JSON数据失败: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return accounts_dict
|
logger.info(f"从 {redis_key} 获取到 {len(accounts_dict)} 个账号")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取账户信息失败: {e}")
|
logger.error(f"获取计算机名 {computer_name} 的账号失败: {e}")
|
||||||
return {}
|
|
||||||
|
|
||||||
def format_exchange_id(self, key: str) -> str:
|
return accounts_dict
|
||||||
"""格式化交易所ID"""
|
|
||||||
key = key.lower().strip()
|
|
||||||
|
|
||||||
# 交易所名称映射
|
def _discover_all_accounts(self) -> Dict[str, Dict]:
|
||||||
exchange_mapping = {
|
"""自动发现所有匹配的账号key"""
|
||||||
'metatrader': 'mt5',
|
accounts_dict = {}
|
||||||
'binance_spot_test': 'binance',
|
|
||||||
'binance_spot': 'binance',
|
|
||||||
'binance': 'binance',
|
|
||||||
'gate_spot': 'gate',
|
|
||||||
'okex': 'okx'
|
|
||||||
}
|
|
||||||
|
|
||||||
return exchange_mapping.get(key, key)
|
|
||||||
|
|
||||||
def parse_account(self, exchange_id: str, account_id: str, account_info: str) -> Dict:
|
|
||||||
"""解析账号信息"""
|
|
||||||
try:
|
try:
|
||||||
source_account_info = json.loads(account_info)
|
# 获取所有匹配模式的key
|
||||||
account_data = {
|
pattern = f"*_strategy_api"
|
||||||
'exchange_id': exchange_id,
|
cursor = 0
|
||||||
'k_id': account_id,
|
|
||||||
'st_id': int(source_account_info.get('st_id', 0)),
|
|
||||||
'add_time': int(source_account_info.get('add_time', 0))
|
|
||||||
}
|
|
||||||
return {**source_account_info, **account_data}
|
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
while True:
|
||||||
logger.error(f"解析账号 {account_id} 数据失败: {e}")
|
cursor, keys = self.redis_client.client.scan(cursor, match=pattern, count=100)
|
||||||
return {}
|
|
||||||
|
for key in keys:
|
||||||
|
key_str = key.decode('utf-8') if isinstance(key, bytes) else key
|
||||||
|
|
||||||
|
# 提取计算机名
|
||||||
|
computer_name = key_str.replace('_strategy_api', '')
|
||||||
|
|
||||||
|
# 验证计算机名格式
|
||||||
|
if self.computer_name_pattern.match(computer_name):
|
||||||
|
accounts = self._get_accounts_by_computer_name(computer_name)
|
||||||
|
accounts_dict.update(accounts)
|
||||||
|
|
||||||
|
if cursor == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info(f"自动发现 {len(accounts_dict)} 个账号")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"自动发现账号失败: {e}")
|
||||||
|
|
||||||
|
return accounts_dict
|
||||||
|
|
||||||
|
# 其他方法保持不变...
|
||||||
@@ -1,8 +1,11 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import List, Dict
|
from typing import List, Dict, Optional
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
import time
|
||||||
|
from asyncio import Semaphore
|
||||||
|
|
||||||
from config.settings import SYNC_CONFIG
|
from config.settings import SYNC_CONFIG
|
||||||
from .position_sync import PositionSync
|
from .position_sync import PositionSync
|
||||||
@@ -10,14 +13,18 @@ from .order_sync import OrderSync
|
|||||||
from .account_sync import AccountSync
|
from .account_sync import AccountSync
|
||||||
|
|
||||||
class SyncManager:
|
class SyncManager:
|
||||||
"""同步管理器"""
|
"""同步管理器(支持批量并发处理)"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.is_running = True
|
self.is_running = True
|
||||||
self.sync_interval = SYNC_CONFIG['interval']
|
self.sync_interval = SYNC_CONFIG['interval']
|
||||||
|
self.max_concurrent = int(os.getenv('MAX_CONCURRENT', '10')) # 最大并发数
|
||||||
|
|
||||||
# 初始化同步器
|
# 初始化同步器
|
||||||
self.syncers = []
|
self.syncers = []
|
||||||
|
self.executor = ThreadPoolExecutor(max_workers=self.max_concurrent)
|
||||||
|
|
||||||
|
self.semaphore = Semaphore(self.max_concurrent) # 控制并发数
|
||||||
|
|
||||||
if SYNC_CONFIG['enable_position_sync']:
|
if SYNC_CONFIG['enable_position_sync']:
|
||||||
self.syncers.append(PositionSync())
|
self.syncers.append(PositionSync())
|
||||||
@@ -31,26 +38,62 @@ class SyncManager:
|
|||||||
self.syncers.append(AccountSync())
|
self.syncers.append(AccountSync())
|
||||||
logger.info("启用账户信息同步")
|
logger.info("启用账户信息同步")
|
||||||
|
|
||||||
|
# 性能统计
|
||||||
|
self.stats = {
|
||||||
|
'total_accounts': 0,
|
||||||
|
'success_count': 0,
|
||||||
|
'error_count': 0,
|
||||||
|
'last_sync_time': 0,
|
||||||
|
'avg_sync_time': 0
|
||||||
|
}
|
||||||
|
|
||||||
# 注册信号处理器
|
# 注册信号处理器
|
||||||
signal.signal(signal.SIGINT, self.signal_handler)
|
signal.signal(signal.SIGINT, self.signal_handler)
|
||||||
signal.signal(signal.SIGTERM, self.signal_handler)
|
signal.signal(signal.SIGTERM, self.signal_handler)
|
||||||
|
|
||||||
|
async def _run_syncer_with_limit(self, syncer):
|
||||||
|
"""带并发限制的运行"""
|
||||||
|
async with self.semaphore:
|
||||||
|
return await self._run_syncer(syncer)
|
||||||
|
|
||||||
def signal_handler(self, signum, frame):
|
def signal_handler(self, signum, frame):
|
||||||
"""信号处理器"""
|
"""信号处理器"""
|
||||||
logger.info(f"接收到信号 {signum},正在关闭...")
|
logger.info(f"接收到信号 {signum},正在关闭...")
|
||||||
self.is_running = False
|
self.is_running = False
|
||||||
|
|
||||||
|
def batch_process_accounts(self, accounts: Dict[str, Dict], batch_size: int = 100):
|
||||||
|
"""分批处理账号"""
|
||||||
|
account_items = list(accounts.items())
|
||||||
|
|
||||||
|
for i in range(0, len(account_items), batch_size):
|
||||||
|
batch = dict(account_items[i:i + batch_size])
|
||||||
|
# 处理这批账号
|
||||||
|
self._process_account_batch(batch)
|
||||||
|
|
||||||
|
# 批次间休息,避免数据库压力过大
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
"""启动同步服务"""
|
"""启动同步服务"""
|
||||||
logger.info(f"同步服务启动,间隔 {self.sync_interval} 秒")
|
logger.info(f"同步服务启动,间隔 {self.sync_interval} 秒,最大并发 {self.max_concurrent}")
|
||||||
|
|
||||||
while self.is_running:
|
while self.is_running:
|
||||||
try:
|
try:
|
||||||
# 执行所有同步器
|
start_time = time.time()
|
||||||
tasks = [syncer.sync() for syncer in self.syncers]
|
|
||||||
await asyncio.gather(*tasks, return_exceptions=True)
|
|
||||||
|
|
||||||
logger.debug(f"同步完成,等待 {self.sync_interval} 秒")
|
# 执行所有同步器
|
||||||
|
tasks = [self._run_syncer(syncer) for syncer in self.syncers]
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
# 更新统计
|
||||||
|
sync_time = time.time() - start_time
|
||||||
|
self.stats['last_sync_time'] = sync_time
|
||||||
|
self.stats['avg_sync_time'] = (self.stats['avg_sync_time'] * 0.9 + sync_time * 0.1)
|
||||||
|
|
||||||
|
# 打印统计信息
|
||||||
|
self._print_stats()
|
||||||
|
|
||||||
|
logger.debug(f"同步完成,耗时 {sync_time:.2f} 秒,等待 {self.sync_interval} 秒")
|
||||||
await asyncio.sleep(self.sync_interval)
|
await asyncio.sleep(self.sync_interval)
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
@@ -58,9 +101,41 @@ class SyncManager:
|
|||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"同步任务异常: {e}")
|
logger.error(f"同步任务异常: {e}")
|
||||||
|
self.stats['error_count'] += 1
|
||||||
await asyncio.sleep(30) # 出错后等待30秒
|
await asyncio.sleep(30) # 出错后等待30秒
|
||||||
|
|
||||||
|
async def _run_syncer(self, syncer):
|
||||||
|
"""运行单个同步器"""
|
||||||
|
try:
|
||||||
|
# 获取所有账号
|
||||||
|
accounts = syncer.get_accounts_from_redis()
|
||||||
|
self.stats['total_accounts'] = len(accounts)
|
||||||
|
|
||||||
|
if not accounts:
|
||||||
|
logger.warning("未获取到任何账号")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 批量处理账号
|
||||||
|
await syncer.sync_batch(accounts)
|
||||||
|
self.stats['success_count'] += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"同步器 {syncer.__class__.__name__} 执行失败: {e}")
|
||||||
|
self.stats['error_count'] += 1
|
||||||
|
|
||||||
|
def _print_stats(self):
|
||||||
|
"""打印统计信息"""
|
||||||
|
stats_str = (
|
||||||
|
f"统计: 账号数={self.stats['total_accounts']}, "
|
||||||
|
f"成功={self.stats['success_count']}, "
|
||||||
|
f"失败={self.stats['error_count']}, "
|
||||||
|
f"本次耗时={self.stats['last_sync_time']:.2f}s, "
|
||||||
|
f"平均耗时={self.stats['avg_sync_time']:.2f}s"
|
||||||
|
)
|
||||||
|
logger.info(stats_str)
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
"""停止同步服务"""
|
"""停止同步服务"""
|
||||||
self.is_running = False
|
self.is_running = False
|
||||||
|
self.executor.shutdown(wait=True)
|
||||||
logger.info("同步服务停止")
|
logger.info("同步服务停止")
|
||||||
@@ -2,173 +2,190 @@ from .base_sync import BaseSync
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
import json
|
import json
|
||||||
from datetime import datetime, timedelta
|
import asyncio
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
class PositionSync(BaseSync):
|
class PositionSync(BaseSync):
|
||||||
"""持仓数据同步器"""
|
"""持仓数据同步器(批量版本)"""
|
||||||
|
|
||||||
async def sync(self):
|
def __init__(self):
|
||||||
"""同步持仓数据"""
|
super().__init__()
|
||||||
|
self.max_concurrent = 10 # 每个同步器的最大并发数
|
||||||
|
|
||||||
|
async def sync_batch(self, accounts: Dict[str, Dict]):
|
||||||
|
"""批量同步所有账号的持仓数据"""
|
||||||
try:
|
try:
|
||||||
# 获取所有账号
|
logger.info(f"开始批量同步持仓数据,共 {len(accounts)} 个账号")
|
||||||
accounts = self.get_accounts_from_redis()
|
|
||||||
|
|
||||||
for k_id_str, account_info in accounts.items():
|
# 按账号分组
|
||||||
try:
|
account_groups = self._group_accounts_by_exchange(accounts)
|
||||||
k_id = int(k_id_str)
|
|
||||||
st_id = account_info.get('st_id', 0)
|
|
||||||
exchange_id = account_info['exchange_id']
|
|
||||||
|
|
||||||
if k_id <= 0 or st_id <= 0:
|
# 并发处理每个交易所的账号
|
||||||
continue
|
tasks = []
|
||||||
|
for exchange_id, account_list in account_groups.items():
|
||||||
|
task = self._sync_exchange_accounts(exchange_id, account_list)
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
# 从Redis获取持仓数据
|
# 等待所有任务完成
|
||||||
positions = await self._get_positions_from_redis(k_id, exchange_id)
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
# 同步到数据库
|
# 统计结果
|
||||||
if positions:
|
success_count = sum(1 for r in results if isinstance(r, bool) and r)
|
||||||
success = self._sync_positions_to_db(k_id, st_id, positions)
|
logger.info(f"持仓批量同步完成: 成功 {success_count}/{len(results)} 个交易所组")
|
||||||
if success:
|
|
||||||
logger.debug(f"持仓同步成功: k_id={k_id}, 持仓数={len(positions)}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"同步账号 {k_id_str} 持仓失败: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
logger.info("持仓数据同步完成")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"持仓同步失败: {e}")
|
logger.error(f"持仓批量同步失败: {e}")
|
||||||
|
|
||||||
async def _get_positions_from_redis(self, k_id: int, exchange_id: str) -> List[Dict]:
|
def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]:
|
||||||
"""从Redis获取持仓数据"""
|
"""按交易所分组账号"""
|
||||||
|
groups = {}
|
||||||
|
for account_id, account_info in accounts.items():
|
||||||
|
exchange_id = account_info.get('exchange_id')
|
||||||
|
if exchange_id:
|
||||||
|
if exchange_id not in groups:
|
||||||
|
groups[exchange_id] = []
|
||||||
|
groups[exchange_id].append(account_info)
|
||||||
|
return groups
|
||||||
|
|
||||||
|
async def _sync_exchange_accounts(self, exchange_id: str, account_list: List[Dict]):
|
||||||
|
"""同步某个交易所的所有账号"""
|
||||||
try:
|
try:
|
||||||
redis_key = f"{exchange_id}:positions:{k_id}"
|
# 收集所有账号的持仓数据
|
||||||
redis_data = self.redis_client.client.hget(redis_key, 'positions')
|
all_positions = []
|
||||||
|
|
||||||
if not redis_data:
|
for account_info in account_list:
|
||||||
return []
|
k_id = int(account_info['k_id'])
|
||||||
|
st_id = account_info.get('st_id', 0)
|
||||||
|
|
||||||
positions = json.loads(redis_data)
|
# 从Redis获取持仓数据
|
||||||
|
positions = await self._get_positions_from_redis(k_id, exchange_id)
|
||||||
|
|
||||||
# 添加账号信息
|
if positions:
|
||||||
for position in positions:
|
# 添加账号信息
|
||||||
position['k_id'] = k_id
|
for position in positions:
|
||||||
|
position['k_id'] = k_id
|
||||||
|
position['st_id'] = st_id
|
||||||
|
all_positions.extend(positions)
|
||||||
|
|
||||||
return positions
|
if not all_positions:
|
||||||
|
logger.debug(f"交易所 {exchange_id} 无持仓数据")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 批量同步到数据库
|
||||||
|
success = self._sync_positions_batch_to_db(all_positions)
|
||||||
|
if success:
|
||||||
|
logger.info(f"交易所 {exchange_id} 持仓同步成功: {len(all_positions)} 条持仓")
|
||||||
|
|
||||||
|
return success
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取Redis持仓数据失败: k_id={k_id}, error={e}")
|
logger.error(f"同步交易所 {exchange_id} 持仓失败: {e}")
|
||||||
return []
|
return False
|
||||||
|
|
||||||
def _sync_positions_to_db(self, k_id: int, st_id: int, positions_data: List[Dict]) -> bool:
|
def _sync_positions_batch_to_db(self, all_positions: List[Dict]) -> bool:
|
||||||
"""同步持仓数据到数据库"""
|
"""批量同步持仓数据到数据库(优化版)"""
|
||||||
session = self.db_manager.get_session()
|
session = self.db_manager.get_session()
|
||||||
try:
|
try:
|
||||||
# 使用批量优化方案
|
# 按k_id分组
|
||||||
from sqlalchemy.dialects.mysql import insert
|
positions_by_account = {}
|
||||||
|
for position in all_positions:
|
||||||
|
k_id = position['k_id']
|
||||||
|
if k_id not in positions_by_account:
|
||||||
|
positions_by_account[k_id] = []
|
||||||
|
positions_by_account[k_id].append(position)
|
||||||
|
|
||||||
# 准备数据
|
success_count = 0
|
||||||
insert_data = []
|
|
||||||
keep_keys = set() # 需要保留的(symbol, side)组合
|
|
||||||
|
|
||||||
for pos_data in positions_data:
|
|
||||||
try:
|
|
||||||
# 转换数据(这里需要实现转换逻辑)
|
|
||||||
pos_dict = self._convert_position_data(pos_data)
|
|
||||||
if not all([pos_dict.get('symbol'), pos_dict.get('side')]):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 重命名qty为sum
|
|
||||||
if 'qty' in pos_dict:
|
|
||||||
pos_dict['sum'] = pos_dict.pop('qty')
|
|
||||||
|
|
||||||
insert_data.append(pos_dict)
|
|
||||||
keep_keys.add((pos_dict['symbol'], pos_dict['side']))
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"转换持仓数据失败: {pos_data}, error={e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
with session.begin():
|
with session.begin():
|
||||||
if not insert_data:
|
for k_id, positions in positions_by_account.items():
|
||||||
# 清空该账号持仓
|
try:
|
||||||
session.execute(
|
st_id = positions[0]['st_id'] if positions else 0
|
||||||
delete(StrategyPosition).where(
|
|
||||||
and_(
|
|
||||||
StrategyPosition.k_id == k_id,
|
|
||||||
StrategyPosition.st_id == st_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|
||||||
# 批量插入/更新
|
# 准备数据
|
||||||
stmt = insert(StrategyPosition.__table__).values(insert_data)
|
insert_data = []
|
||||||
|
keep_keys = set()
|
||||||
|
|
||||||
update_dict = {
|
for pos_data in positions:
|
||||||
'price': stmt.inserted.price,
|
try:
|
||||||
'sum': stmt.inserted.sum,
|
pos_dict = self._convert_position_data(pos_data)
|
||||||
'asset_num': stmt.inserted.asset_num,
|
if not all([pos_dict.get('symbol'), pos_dict.get('side')]):
|
||||||
'asset_profit': stmt.inserted.asset_profit,
|
continue
|
||||||
'leverage': stmt.inserted.leverage,
|
|
||||||
'uptime': stmt.inserted.uptime,
|
|
||||||
'profit_price': stmt.inserted.profit_price,
|
|
||||||
'stop_price': stmt.inserted.stop_price,
|
|
||||||
'liquidation_price': stmt.inserted.liquidation_price
|
|
||||||
}
|
|
||||||
|
|
||||||
stmt = stmt.on_duplicate_key_update(**update_dict)
|
# 重命名qty为sum
|
||||||
session.execute(stmt)
|
if 'qty' in pos_dict:
|
||||||
|
pos_dict['sum'] = pos_dict.pop('qty')
|
||||||
|
|
||||||
# 删除多余持仓
|
insert_data.append(pos_dict)
|
||||||
if keep_keys:
|
keep_keys.add((pos_dict['symbol'], pos_dict['side']))
|
||||||
existing_positions = session.execute(
|
|
||||||
select(StrategyPosition).where(
|
|
||||||
and_(
|
|
||||||
StrategyPosition.k_id == k_id,
|
|
||||||
StrategyPosition.st_id == st_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
).scalars().all()
|
|
||||||
|
|
||||||
to_delete_ids = []
|
except Exception as e:
|
||||||
for existing in existing_positions:
|
logger.error(f"转换持仓数据失败: {pos_data}, error={e}")
|
||||||
key = (existing.symbol, existing.side)
|
continue
|
||||||
if key not in keep_keys:
|
|
||||||
to_delete_ids.append(existing.id)
|
|
||||||
|
|
||||||
if to_delete_ids:
|
if not insert_data:
|
||||||
session.execute(
|
continue
|
||||||
delete(StrategyPosition).where(
|
|
||||||
StrategyPosition.id.in_(to_delete_ids)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
# 批量插入/更新
|
||||||
|
from sqlalchemy.dialects.mysql import insert
|
||||||
|
|
||||||
|
stmt = insert(StrategyPosition.__table__).values(insert_data)
|
||||||
|
|
||||||
|
update_dict = {
|
||||||
|
'price': stmt.inserted.price,
|
||||||
|
'sum': stmt.inserted.sum,
|
||||||
|
'asset_num': stmt.inserted.asset_num,
|
||||||
|
'asset_profit': stmt.inserted.asset_profit,
|
||||||
|
'leverage': stmt.inserted.leverage,
|
||||||
|
'uptime': stmt.inserted.uptime,
|
||||||
|
'profit_price': stmt.inserted.profit_price,
|
||||||
|
'stop_price': stmt.inserted.stop_price,
|
||||||
|
'liquidation_price': stmt.inserted.liquidation_price
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt = stmt.on_duplicate_key_update(**update_dict)
|
||||||
|
session.execute(stmt)
|
||||||
|
|
||||||
|
# 删除多余持仓
|
||||||
|
if keep_keys:
|
||||||
|
existing_positions = session.execute(
|
||||||
|
select(StrategyPosition).where(
|
||||||
|
and_(
|
||||||
|
StrategyPosition.k_id == k_id,
|
||||||
|
StrategyPosition.st_id == st_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).scalars().all()
|
||||||
|
|
||||||
|
to_delete_ids = []
|
||||||
|
for existing in existing_positions:
|
||||||
|
key = (existing.symbol, existing.side)
|
||||||
|
if key not in keep_keys:
|
||||||
|
to_delete_ids.append(existing.id)
|
||||||
|
|
||||||
|
if to_delete_ids:
|
||||||
|
# 分块删除
|
||||||
|
chunk_size = 100
|
||||||
|
for i in range(0, len(to_delete_ids), chunk_size):
|
||||||
|
chunk = to_delete_ids[i:i + chunk_size]
|
||||||
|
session.execute(
|
||||||
|
delete(StrategyPosition).where(
|
||||||
|
StrategyPosition.id.in_(chunk)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
success_count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"同步账号 {k_id} 持仓失败: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(f"批量同步完成: 成功 {success_count}/{len(positions_by_account)} 个账号")
|
||||||
|
return success_count > 0
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"同步持仓到数据库失败: k_id={k_id}, error={e}")
|
logger.error(f"批量同步持仓到数据库失败: {e}")
|
||||||
return False
|
return False
|
||||||
finally:
|
finally:
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
def _convert_position_data(self, data: Dict) -> Dict:
|
# 其他方法保持不变...
|
||||||
"""转换持仓数据格式"""
|
|
||||||
# 这里实现具体的转换逻辑
|
|
||||||
return {
|
|
||||||
'st_id': int(data.get('st_id', 0)),
|
|
||||||
'k_id': int(data.get('k_id', 0)),
|
|
||||||
'asset': 'USDT',
|
|
||||||
'symbol': data.get('symbol', ''),
|
|
||||||
'side': data.get('side', ''),
|
|
||||||
'price': float(data.get('price', 0)) if data.get('price') is not None else None,
|
|
||||||
'qty': float(data.get('qty', 0)) if data.get('qty') is not None else None,
|
|
||||||
'asset_num': float(data.get('asset_num', 0)) if data.get('asset_num') is not None else None,
|
|
||||||
'asset_profit': float(data.get('asset_profit', 0)) if data.get('asset_profit') is not None else None,
|
|
||||||
'leverage': int(data.get('leverage', 0)) if data.get('leverage') is not None else None,
|
|
||||||
'uptime': int(data.get('uptime', 0)) if data.get('uptime') is not None else None,
|
|
||||||
'profit_price': float(data.get('profit_price', 0)) if data.get('profit_price') is not None else None,
|
|
||||||
'stop_price': float(data.get('stop_price', 0)) if data.get('stop_price') is not None else None,
|
|
||||||
'liquidation_price': float(data.get('liquidation_price', 0)) if data.get('liquidation_price') is not None else None
|
|
||||||
}
|
|
||||||
138
utils/batch_operations.py
Normal file
138
utils/batch_operations.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
from typing import List, Dict, Any, Tuple
|
||||||
|
from loguru import logger
|
||||||
|
from sqlalchemy import text
|
||||||
|
from .database_manager import DatabaseManager
|
||||||
|
|
||||||
|
class BatchOperations:
|
||||||
|
"""批量数据库操作工具"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.db_manager = DatabaseManager()
|
||||||
|
|
||||||
|
def batch_insert_update_positions(self, positions_data: List[Dict]) -> Tuple[int, int]:
|
||||||
|
"""批量插入/更新持仓数据"""
|
||||||
|
session = self.db_manager.get_session()
|
||||||
|
try:
|
||||||
|
if not positions_data:
|
||||||
|
return 0, 0
|
||||||
|
|
||||||
|
# 按账号分组
|
||||||
|
positions_by_account = {}
|
||||||
|
for position in positions_data:
|
||||||
|
k_id = position.get('k_id')
|
||||||
|
if k_id not in positions_by_account:
|
||||||
|
positions_by_account[k_id] = []
|
||||||
|
positions_by_account[k_id].append(position)
|
||||||
|
|
||||||
|
total_processed = 0
|
||||||
|
total_deleted = 0
|
||||||
|
|
||||||
|
with session.begin():
|
||||||
|
for k_id, positions in positions_by_account.items():
|
||||||
|
processed, deleted = self._process_account_positions(session, k_id, positions)
|
||||||
|
total_processed += processed
|
||||||
|
total_deleted += deleted
|
||||||
|
|
||||||
|
logger.info(f"批量处理持仓完成: 处理 {total_processed} 条,删除 {total_deleted} 条")
|
||||||
|
return total_processed, total_deleted
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"批量处理持仓失败: {e}")
|
||||||
|
return 0, 0
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
def _process_account_positions(self, session, k_id: int, positions: List[Dict]) -> Tuple[int, int]:
|
||||||
|
"""处理单个账号的持仓数据"""
|
||||||
|
try:
|
||||||
|
st_id = positions[0].get('st_id', 0) if positions else 0
|
||||||
|
|
||||||
|
# 准备数据
|
||||||
|
insert_data = []
|
||||||
|
keep_keys = set()
|
||||||
|
|
||||||
|
for pos_data in positions:
|
||||||
|
# 转换数据
|
||||||
|
pos_dict = self._convert_position_data(pos_data)
|
||||||
|
if not all([pos_dict.get('symbol'), pos_dict.get('side')]):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 重命名qty为sum
|
||||||
|
if 'qty' in pos_dict:
|
||||||
|
pos_dict['sum'] = pos_dict.pop('qty')
|
||||||
|
|
||||||
|
insert_data.append(pos_dict)
|
||||||
|
keep_keys.add((pos_dict['symbol'], pos_dict['side']))
|
||||||
|
|
||||||
|
if not insert_data:
|
||||||
|
# 清空该账号持仓
|
||||||
|
result = session.execute(
|
||||||
|
text("DELETE FROM deh_strategy_position_new WHERE k_id = :k_id AND st_id = :st_id"),
|
||||||
|
{'k_id': k_id, 'st_id': st_id}
|
||||||
|
)
|
||||||
|
return 0, result.rowcount
|
||||||
|
|
||||||
|
# 批量插入/更新
|
||||||
|
sql = """
|
||||||
|
INSERT INTO deh_strategy_position_new
|
||||||
|
(st_id, k_id, asset, symbol, side, price, `sum`,
|
||||||
|
asset_num, asset_profit, leverage, uptime,
|
||||||
|
profit_price, stop_price, liquidation_price)
|
||||||
|
VALUES
|
||||||
|
(:st_id, :k_id, :asset, :symbol, :side, :price, :sum,
|
||||||
|
:asset_num, :asset_profit, :leverage, :uptime,
|
||||||
|
:profit_price, :stop_price, :liquidation_price)
|
||||||
|
ON DUPLICATE KEY UPDATE
|
||||||
|
price = VALUES(price),
|
||||||
|
`sum` = VALUES(`sum`),
|
||||||
|
asset_num = VALUES(asset_num),
|
||||||
|
asset_profit = VALUES(asset_profit),
|
||||||
|
leverage = VALUES(leverage),
|
||||||
|
uptime = VALUES(uptime),
|
||||||
|
profit_price = VALUES(profit_price),
|
||||||
|
stop_price = VALUES(stop_price),
|
||||||
|
liquidation_price = VALUES(liquidation_price)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 分块执行
|
||||||
|
chunk_size = 500
|
||||||
|
processed_count = 0
|
||||||
|
|
||||||
|
for i in range(0, len(insert_data), chunk_size):
|
||||||
|
chunk = insert_data[i:i + chunk_size]
|
||||||
|
session.execute(text(sql), chunk)
|
||||||
|
processed_count += len(chunk)
|
||||||
|
|
||||||
|
# 删除多余持仓
|
||||||
|
deleted_count = 0
|
||||||
|
if keep_keys:
|
||||||
|
# 构建删除条件
|
||||||
|
conditions = []
|
||||||
|
for symbol, side in keep_keys:
|
||||||
|
safe_symbol = symbol.replace("'", "''") if symbol else ''
|
||||||
|
safe_side = side.replace("'", "''") if side else ''
|
||||||
|
conditions.append(f"(symbol = '{safe_symbol}' AND side = '{safe_side}')")
|
||||||
|
|
||||||
|
if conditions:
|
||||||
|
conditions_str = " OR ".join(conditions)
|
||||||
|
delete_sql = f"""
|
||||||
|
DELETE FROM deh_strategy_position_new
|
||||||
|
WHERE k_id = {k_id} AND st_id = {st_id}
|
||||||
|
AND NOT ({conditions_str})
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = session.execute(text(delete_sql))
|
||||||
|
deleted_count = result.rowcount
|
||||||
|
|
||||||
|
return processed_count, deleted_count
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理账号 {k_id} 持仓失败: {e}")
|
||||||
|
return 0, 0
|
||||||
|
|
||||||
|
def _convert_position_data(self, data: Dict) -> Dict:
|
||||||
|
"""转换持仓数据格式"""
|
||||||
|
# 转换逻辑...
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 类似的批量方法 for orders and account info...
|
||||||
Reference in New Issue
Block a user