This commit is contained in:
lz_db
2025-12-04 15:40:19 +08:00
parent f93f334256
commit f85f4ef152
25 changed files with 3182 additions and 314 deletions

14
.env
View File

@@ -1,15 +1,15 @@
# Redis配置 # Redis配置
REDIS_HOST=localhost REDIS_HOST=10.0.4.6
REDIS_PORT=6379 REDIS_PORT=6379
REDIS_PASSWORD= REDIS_PASSWORD=
REDIS_DB=0 REDIS_DB=1
# MySQL配置 # MySQL配置
DB_HOST=localhost DB_HOST=10.0.4.17
DB_PORT=3306 DB_PORT=3306
DB_USER=root DB_USER=root
DB_PASSWORD=your_password DB_PASSWORD=lz_mysqlLZ
DB_DATABASE=exchange_monitor DB_DATABASE=lz_app_test
DB_POOL_SIZE=10 DB_POOL_SIZE=10
DB_MAX_OVERFLOW=20 DB_MAX_OVERFLOW=20
DB_POOL_RECYCLE=3600 DB_POOL_RECYCLE=3600
@@ -19,7 +19,7 @@ SQLALCHEMY_ECHO=false
SQLALCHEMY_ECHO_POOL=false SQLALCHEMY_ECHO_POOL=false
# 同步配置 # 同步配置
SYNC_INTERVAL=60 SYNC_INTERVAL=20
RECENT_DAYS=3 RECENT_DAYS=3
CHUNK_SIZE=1000 CHUNK_SIZE=1000
ENABLE_POSITION_SYNC=true ENABLE_POSITION_SYNC=true
@@ -33,7 +33,7 @@ LOG_RETENTION=7 days
# 计算机名配置(支持多个) # 计算机名配置(支持多个)
COMPUTER_NAMES=lz_c01,lz_c02,lz_c03 COMPUTER_NAMES=lz_c01,lz_c02
# 或者使用模式匹配 # 或者使用模式匹配
COMPUTER_NAME_PATTERN=^lz_c\d{2}$ COMPUTER_NAME_PATTERN=^lz_c\d{2}$

Binary file not shown.

Binary file not shown.

Binary file not shown.

2727
logs/sync_2025-12-04.log Normal file

File diff suppressed because one or more lines are too long

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -14,7 +14,6 @@ class AccountSyncBatch(BaseSync):
"""批量同步所有账号的账户信息""" """批量同步所有账号的账户信息"""
try: try:
logger.info(f"开始批量同步账户信息,共 {len(accounts)} 个账号") logger.info(f"开始批量同步账户信息,共 {len(accounts)} 个账号")
# 收集所有账号的数据 # 收集所有账号的数据
all_account_data = await self._collect_all_account_data(accounts) all_account_data = await self._collect_all_account_data(accounts)

View File

@@ -16,8 +16,6 @@ class BaseSync(ABC):
def __init__(self): def __init__(self):
self.redis_client = RedisClient() self.redis_client = RedisClient()
self.db_manager = DatabaseManager() self.db_manager = DatabaseManager()
self.computer_names = self._get_computer_names()
self.computer_name_pattern = re.compile(COMPUTER_NAME_PATTERN)
self.sync_stats = { self.sync_stats = {
'total_accounts': 0, 'total_accounts': 0,
'success_count': 0, 'success_count': 0,
@@ -26,13 +24,6 @@ class BaseSync(ABC):
'avg_sync_time': 0 'avg_sync_time': 0
} }
def _get_computer_names(self) -> List[str]:
"""获取计算机名列表"""
if ',' in COMPUTER_NAMES:
names = [name.strip() for name in COMPUTER_NAMES.split(',')]
logger.info(f"使用配置的计算机名列表: {names}")
return names
return [COMPUTER_NAMES.strip()]
@abstractmethod @abstractmethod
async def sync(self): async def sync(self):
@@ -44,41 +35,6 @@ class BaseSync(ABC):
"""批量同步数据""" """批量同步数据"""
pass pass
def _safe_float(self, value: Any, default: float = 0.0) -> float:
"""安全转换为float"""
if value is None:
return default
try:
if isinstance(value, str):
value = value.strip()
if value == '':
return default
return float(value)
except (ValueError, TypeError):
return default
def _safe_int(self, value: Any, default: int = 0) -> int:
"""安全转换为int"""
if value is None:
return default
try:
if isinstance(value, str):
value = value.strip()
if value == '':
return default
return int(float(value))
except (ValueError, TypeError):
return default
def _safe_str(self, value: Any, default: str = '') -> str:
"""安全转换为str"""
if value is None:
return default
try:
result = str(value).strip()
return result if result else default
except:
return default
def _escape_sql_value(self, value: Any) -> str: def _escape_sql_value(self, value: Any) -> str:
"""转义SQL值""" """转义SQL值"""

View File

@@ -6,6 +6,7 @@ import time
import json import json
from typing import Dict from typing import Dict
import re import re
import utils.helpers as helpers
from utils.redis_client import RedisClient from utils.redis_client import RedisClient
from config.settings import SYNC_CONFIG from config.settings import SYNC_CONFIG
@@ -69,21 +70,21 @@ class SyncManager:
try: try:
# 获取所有账号(只获取一次) # 获取所有账号(只获取一次)
accounts = await self.get_accounts_from_redis() accounts = self.get_accounts_from_redis()
if not accounts: if not accounts:
logger.warning("未获取到任何账号,等待下次同步") logger.warning("未获取到任何账号,等待下次同步")
await asyncio.sleep(self.sync_interval) await asyncio.sleep(self.sync_interval)
continue continue
# return
self.stats['total_syncs'] += 1 self.stats['total_syncs'] += 1
sync_start = time.time() sync_start = time.time()
logger.info(f"{self.stats['total_syncs']}次同步开始,共 {len(accounts)} 个账号") logger.info(f"{self.stats['total_syncs']}次同步开始,共 {len(accounts)} 个账号")
# 执行所有同步器 # 执行所有同步器
tasks = [syncer.sync(accounts) for syncer in self.syncers] tasks = [syncer.sync_batch(accounts) for syncer in self.syncers]
await asyncio.gather(*tasks, return_exceptions=True) await asyncio.gather(*tasks, return_exceptions=True)
@@ -99,6 +100,15 @@ class SyncManager:
break break
except Exception as e: except Exception as e:
logger.error(f"同步任务异常: {e}") logger.error(f"同步任务异常: {e}")
# 获取完整的错误信息
import traceback
error_details = {
'error_type': type(e).__name__,
'error_message': str(e),
'traceback': traceback.format_exc()
}
logger.error("完整堆栈跟踪:\n{traceback}", traceback=error_details['traceback'])
await asyncio.sleep(30) await asyncio.sleep(30)
def get_accounts_from_redis(self) -> Dict[str, Dict]: def get_accounts_from_redis(self) -> Dict[str, Dict]:
@@ -118,7 +128,6 @@ class SyncManager:
logger.warning("配置的计算机名未找到数据,尝试自动发现...") logger.warning("配置的计算机名未找到数据,尝试自动发现...")
accounts_dict = self._discover_all_accounts() accounts_dict = self._discover_all_accounts()
self.sync_stats['total_accounts'] = len(accounts_dict)
logger.info(f"{len(self.computer_names)} 个计算机名获取到 {len(accounts_dict)} 个账号") logger.info(f"{len(self.computer_names)} 个计算机名获取到 {len(accounts_dict)} 个账号")
return accounts_dict return accounts_dict
@@ -303,30 +312,17 @@ class SyncManager:
"""解析账号信息""" """解析账号信息"""
try: try:
source_account_info = json.loads(account_info) source_account_info = json.loads(account_info)
# print(source_account_info)
# 基础信息 # 基础信息
account_data = { account_data = {
'exchange_id': exchange_id, 'exchange_id': exchange_id,
'k_id': account_id, 'k_id': account_id,
'st_id': self._safe_int(source_account_info.get('st_id'), 0), 'st_id': helpers.safe_int(source_account_info.get('st_id'), 0),
'add_time': self._safe_int(source_account_info.get('add_time'), 0), 'add_time': helpers.safe_int(source_account_info.get('add_time'), 0),
'account_type': source_account_info.get('account_type', 'real'),
'api_key': source_account_info.get('api_key', ''), 'api_key': source_account_info.get('api_key', ''),
'secret_key': source_account_info.get('secret_key', ''),
'password': source_account_info.get('password', ''),
'access_token': source_account_info.get('access_token', ''),
'remark': source_account_info.get('remark', '')
} }
# 合并原始信息 return account_data
result = {**source_account_info, **account_data}
# 验证必要字段
if not result.get('st_id') or not result.get('exchange_id'):
logger.warning(f"账号 {account_id} 缺少必要字段: st_id={result.get('st_id')}, exchange_id={result.get('exchange_id')}")
return None
return result
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.error(f"解析账号 {account_id} JSON数据失败: {e}, 原始数据: {account_info[:100]}...") logger.error(f"解析账号 {account_id} JSON数据失败: {e}, 原始数据: {account_info[:100]}...")

View File

@@ -20,6 +20,7 @@ class OrderSyncBatch(BaseSync):
"""批量同步所有账号的订单数据""" """批量同步所有账号的订单数据"""
try: try:
logger.info(f"开始批量同步订单数据,共 {len(accounts)} 个账号") logger.info(f"开始批量同步订单数据,共 {len(accounts)} 个账号")
return
start_time = time.time() start_time = time.time()
# 1. 收集所有账号的订单数据 # 1. 收集所有账号的订单数据

View File

@@ -3,6 +3,7 @@ from loguru import logger
from typing import List, Dict, Any, Set, Tuple from typing import List, Dict, Any, Set, Tuple
import json import json
import asyncio import asyncio
import utils.helpers as helpers
from datetime import datetime from datetime import datetime
from sqlalchemy import text, and_, select, delete from sqlalchemy import text, and_, select, delete
from models.orm_models import StrategyPosition from models.orm_models import StrategyPosition
@@ -19,6 +20,7 @@ class PositionSyncBatch(BaseSync):
"""批量同步所有账号的持仓数据""" """批量同步所有账号的持仓数据"""
try: try:
logger.info(f"开始批量同步持仓数据,共 {len(accounts)} 个账号") logger.info(f"开始批量同步持仓数据,共 {len(accounts)} 个账号")
start_time = time.time() start_time = time.time()
# 1. 收集所有账号的持仓数据 # 1. 收集所有账号的持仓数据
@@ -31,17 +33,393 @@ class PositionSyncBatch(BaseSync):
logger.info(f"收集到 {len(all_positions)} 条持仓数据") logger.info(f"收集到 {len(all_positions)} 条持仓数据")
# 2. 批量同步到数据库 # 2. 批量同步到数据库
success, stats = await self._sync_positions_batch_to_db(all_positions) success, stats = await self._sync_positions_batch_to_db_optimized_v3(all_positions)
elapsed = time.time() - start_time elapsed = time.time() - start_time
if success: if success:
logger.info(f"持仓批量同步完成: 处理 {stats['total']} 条,更新 {stats['updated']} 条," logger.info(f"持仓批量同步完成: 处理 {stats['total']} 条,受影响 {stats['affected']} 条,"
f"插入 {stats['inserted']} 条,删除 {stats['deleted']} 条,耗时 {elapsed:.2f}") f"删除 {stats['deleted']} 条,耗时 {elapsed:.2f}")
else: else:
logger.error("持仓批量同步失败") logger.error("持仓批量同步失败")
except Exception as e: except Exception as e:
logger.error(f"持仓批量同步失败: {e}") logger.error(f"持仓批量同步失败: {e}")# 获取完整的错误信息
import traceback
error_details = {
'error_type': type(e).__name__,
'error_message': str(e),
'traceback': traceback.format_exc()
}
logger.error("完整堆栈跟踪:\n{traceback}", traceback=error_details['traceback'])
async def _sync_positions_batch_to_db_optimized(self, all_positions: List[Dict]) -> Tuple[bool, Dict]:
"""
批量同步持仓数据(不使用临时表)
Args:
all_positions: 所有持仓数据列表每个持仓包含k_id账号ID等字段
Returns:
Tuple[bool, Dict]: (是否成功, 结果统计)
"""
if not all_positions:
return True, {'total': 0, 'affected': 0, 'deleted': 0, 'errors': []}
session = self.db_manager.get_session()
results = {
'total': 0,
'affected': 0,
'deleted': 0,
'errors': []
}
# 按账号分组
positions_by_account = {}
for position in all_positions:
# print(position['symbol'])
k_id = position['k_id']
if k_id not in positions_by_account:
positions_by_account[k_id] = []
positions_by_account[k_id].append(position)
logger.info(f"开始批量处理 {len(positions_by_account)} 个账号的持仓数据")
try:
# 按分组处理10个账号一组
account_ids = list(positions_by_account.keys())
for group_idx in range(0, len(account_ids), 10):
group_account_ids = account_ids[group_idx:group_idx + 10]
logger.info(f"处理第 {group_idx//10 + 1} 组账号: {group_account_ids}")
# 收集本组所有持仓数据
group_positions = []
for k_id in group_account_ids:
group_positions.extend(positions_by_account[k_id])
if not group_positions:
continue
# 处理持仓数据
processed_positions = []
account_position_keys = {} # 记录每个账号的持仓标识
for raw_position in group_positions:
try:
k_id = raw_position['k_id']
processed = self._convert_position_data(raw_position)
# 检查必要字段
if not all([processed.get('symbol'), processed.get('side')]):
continue
# 确保st_id存在
if 'st_id' not in processed:
processed['st_id'] = raw_position.get('st_id', 0)
# 确保k_id存在
if 'k_id' not in processed:
processed['k_id'] = k_id
# 重命名qty为sum如果存在
if 'qty' in processed:
processed['sum'] = processed.pop('qty')
processed_positions.append(processed)
# 记录持仓唯一标识
if k_id not in account_position_keys:
account_position_keys[k_id] = set()
position_key = f"{processed['st_id']}&{processed['symbol']}&{processed['side']}"
# print(position_key)
account_position_keys[k_id].add(position_key)
except Exception as e:
logger.error(f"处理持仓数据失败: {raw_position}, error={e}")
continue
# 批量插入或更新
if processed_positions:
try:
# 使用ON DUPLICATE KEY UPDATE批量处理
upsert_sql = text("""
INSERT INTO deh_strategy_position_new
(st_id, k_id, asset, symbol, side, price, `sum`,
asset_num, asset_profit, leverage, uptime,
profit_price, stop_price, liquidation_price)
VALUES
(:st_id, :k_id, :asset, :symbol, :side, :price, :sum,
:asset_num, :asset_profit, :leverage, :uptime,
:profit_price, :stop_price, :liquidation_price)
ON DUPLICATE KEY UPDATE
price = VALUES(price),
`sum` = VALUES(`sum`),
asset_num = VALUES(asset_num),
asset_profit = VALUES(asset_profit),
leverage = VALUES(leverage),
uptime = VALUES(uptime),
profit_price = VALUES(profit_price),
stop_price = VALUES(stop_price),
liquidation_price = VALUES(liquidation_price)
""")
result = session.execute(upsert_sql, processed_positions)
# 正确计算插入和更新的数量
total_affected = result.rowcount # 受影响的总行数
batch_size = len(processed_positions) # 本次尝试插入的数量
# 累加到总结果
results['total'] += batch_size
results['affected'] += total_affected
logger.debug(f"{group_idx//10 + 1} 组: "
f"处理 {batch_size} 条, "
f"受影响 {total_affected}")
except Exception as e:
logger.error(f"批量插入/更新失败: {e}", exc_info=True)
session.rollback()
results['errors'].append(f"批量插入/更新失败: {str(e)}")
# 继续处理下一组
continue
# 删除本组每个账号中已不存在的持仓
for k_id in group_account_ids:
try:
if k_id not in account_position_keys or not account_position_keys[k_id]:
# 如果该账号没有任何持仓,删除所有
delete_sql = text("""
DELETE FROM deh_strategy_position_new
WHERE k_id = :k_id
""")
result = session.execute(delete_sql, {'k_id': k_id})
deleted_count = result.rowcount
results['deleted'] += deleted_count
if deleted_count > 0:
logger.debug(f"账号 {k_id}: 删除所有旧持仓,共 {deleted_count}")
else:
# 构建当前持仓的条件
current_keys = account_position_keys[k_id]
# 使用多个OR条件来处理IN子句的限制
conditions = []
params = {'k_id': k_id}
for idx, key in enumerate(current_keys):
parts = key.split('&')
if len(parts) >= 3: # 确保有st_id, symbol, side三部分
st_id_val = parts[0]
symbol_val = parts[1]
side_val = parts[2]
conditions.append(f"(st_id = :st_id_{idx} AND symbol = :symbol_{idx} AND side = :side_{idx})")
params[f'st_id_{idx}'] = int(st_id_val) if st_id_val.isdigit() else st_id_val
params[f'symbol_{idx}'] = symbol_val
params[f'side_{idx}'] = side_val
if conditions:
conditions_str = " OR ".join(conditions)
# 删除不在当前持仓列表中的记录
delete_sql = text(f"""
DELETE FROM deh_strategy_position_new
WHERE k_id = :k_id
AND NOT ({conditions_str})
""")
result = session.execute(delete_sql, params)
deleted_count = result.rowcount
results['deleted'] += deleted_count
if deleted_count > 0:
logger.debug(f"账号 {k_id}: 删除 {deleted_count} 条过期持仓")
except Exception as e:
logger.error(f"删除账号 {k_id} 旧持仓失败: {e}")
# 记录错误但继续处理其他账号
results['errors'].append(f"删除账号 {k_id} 旧持仓失败: {str(e)}")
# 每组结束后提交
try:
session.commit()
logger.debug(f"{group_idx//10 + 1} 组处理完成并提交")
except Exception as e:
session.rollback()
logger.error(f"{group_idx//10 + 1} 组提交失败: {e}")
results['errors'].append(f"{group_idx//10 + 1} 组提交失败: {str(e)}")
logger.info(f"批量同步完成: "
f"总数={results['total']}, "
f"受影响={results['affected']}, "
f"删除={results['deleted']}, "
f"错误数={len(results['errors'])}")
success = len(results['errors']) == 0
return success, results
except Exception as e:
session.rollback()
logger.error(f"批量同步过程中发生错误: {e}", exc_info=True)
results['errors'].append(f"同步过程错误: {str(e)}")
return False, results
finally:
session.close()
async def _sync_positions_batch_to_db_optimized_v3(self, all_positions: List[Dict]) -> Tuple[bool, Dict]:
"""
最优化的批量同步兼容所有MySQL版本
使用策略:
1. 一次性UPSERT所有持仓数据
2. 使用UNION ALL构造虚拟表进行JOIN删除
Args:
all_positions: 所有持仓数据列表
Returns:
Tuple[bool, Dict]: (是否成功, 结果统计)
"""
if not all_positions:
return True, {'total': 0, 'affected': 0, 'deleted': 0, 'errors': []}
session = self.db_manager.get_session()
results = {
'total': 0,
'affected': 0,
'deleted': 0,
'errors': []
}
try:
session.begin()
# 准备数据
processed_positions = []
current_position_records = set() # 使用set去重避免重复
for raw_position in all_positions:
try:
processed = self._convert_position_data(raw_position)
if not all([processed.get('symbol'), processed.get('side')]):
continue
if 'qty' in processed:
processed['sum'] = processed.pop('qty')
k_id = processed.get('k_id', raw_position['k_id'])
st_id = processed.get('st_id', raw_position.get('st_id', 0))
symbol = processed.get('symbol')
side = processed.get('side')
processed_positions.append(processed)
# 去重记录当前持仓
record_key = (k_id, st_id, symbol, side)
current_position_records.add(record_key)
except Exception as e:
logger.error(f"处理持仓数据失败: {raw_position}, error={e}")
continue
if not processed_positions:
session.commit()
return True, results
results['total'] = len(processed_positions)
logger.info(f"准备同步 {results['total']} 条持仓数据,去重后 {len(current_position_records)} 条唯一记录")
# 批量UPSERT
upsert_sql = text("""
INSERT INTO deh_strategy_position_new
(st_id, k_id, asset, symbol, side, price, `sum`,
asset_num, asset_profit, leverage, uptime,
profit_price, stop_price, liquidation_price)
VALUES
(:st_id, :k_id, :asset, :symbol, :side, :price, :sum,
:asset_num, :asset_profit, :leverage, :uptime,
:profit_price, :stop_price, :liquidation_price)
ON DUPLICATE KEY UPDATE
price = VALUES(price),
`sum` = VALUES(`sum`),
asset_num = VALUES(asset_num),
asset_profit = VALUES(asset_profit),
leverage = VALUES(leverage),
uptime = VALUES(uptime),
profit_price = VALUES(profit_price),
stop_price = VALUES(stop_price),
liquidation_price = VALUES(liquidation_price)
""")
result = session.execute(upsert_sql, processed_positions)
total_affected = result.rowcount
results['affected'] =total_affected
logger.info(f"UPSERT完成: 总数 {results['total']} 条, 受影响 {results['affected']}")
# 批量删除使用UNION ALL构造虚拟表
if current_position_records:
# 构建UNION ALL查询
union_parts = []
for record in current_position_records:
k_id, st_id, symbol, side = record
# 转义单引号
symbol_escaped = symbol.replace("'", "''")
side_escaped = side.replace("'", "''")
union_parts.append(f"SELECT {k_id} as k_id, {st_id} as st_id, '{symbol_escaped}' as symbol, '{side_escaped}' as side")
if union_parts:
union_sql = " UNION ALL ".join(union_parts)
# 或者使用LEFT JOIN方式
delete_sql_join = text(f"""
DELETE p FROM deh_strategy_position_new p
LEFT JOIN (
{union_sql}
) AS current_pos ON
p.k_id = current_pos.k_id
AND p.st_id = current_pos.st_id
AND p.symbol = current_pos.symbol
AND p.side = current_pos.side
WHERE current_pos.k_id IS NULL
""")
result = session.execute(delete_sql_join)
deleted_count = result.rowcount
results['deleted'] = deleted_count
logger.info(f"删除 {deleted_count} 条过期持仓")
session.commit()
logger.info(f"批量同步V3完成: 总数={results['total']}, "
f"受影响={results['affected']}, "
f"删除={results['deleted']}")
return True, results
except Exception as e:
session.rollback()
logger.error(f"批量同步V3失败: {e}", exc_info=True)
results['errors'].append(f"同步失败: {str(e)}")
return False, results
finally:
session.close()
async def _collect_all_positions(self, accounts: Dict[str, Dict]) -> List[Dict]: async def _collect_all_positions(self, accounts: Dict[str, Dict]) -> List[Dict]:
"""收集所有账号的持仓数据""" """收集所有账号的持仓数据"""
@@ -117,6 +495,7 @@ class PositionSyncBatch(BaseSync):
# 添加账号信息 # 添加账号信息
for position in positions: for position in positions:
# print(position['symbol'])
position['k_id'] = k_id position['k_id'] = k_id
position['st_id'] = st_id position['st_id'] = st_id
position['exchange_id'] = exchange_id position['exchange_id'] = exchange_id
@@ -127,246 +506,25 @@ class PositionSyncBatch(BaseSync):
logger.error(f"获取Redis持仓数据失败: k_id={k_id}, error={e}") logger.error(f"获取Redis持仓数据失败: k_id={k_id}, error={e}")
return [] return []
async def _sync_positions_batch_to_db(self, all_positions: List[Dict]) -> Tuple[bool, Dict]:
"""批量同步持仓数据到数据库"""
try:
if not all_positions:
return True, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
# 按账号分组
positions_by_account = {}
for position in all_positions:
k_id = position['k_id']
if k_id not in positions_by_account:
positions_by_account[k_id] = []
positions_by_account[k_id].append(position)
logger.info(f"开始批量处理 {len(positions_by_account)} 个账号的持仓数据")
# 批量处理每个账号
total_stats = {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
for k_id, positions in positions_by_account.items():
st_id = positions[0]['st_id'] if positions else 0
# 处理单个账号的批量同步
success, stats = await self._sync_single_account_batch(k_id, st_id, positions)
if success:
total_stats['total'] += stats['total']
total_stats['updated'] += stats['updated']
total_stats['inserted'] += stats['inserted']
total_stats['deleted'] += stats['deleted']
return True, total_stats
except Exception as e:
logger.error(f"批量同步持仓到数据库失败: {e}")
return False, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
async def _sync_single_account_batch(self, k_id: int, st_id: int, positions: List[Dict]) -> Tuple[bool, Dict]:
"""批量同步单个账号的持仓数据"""
session = self.db_manager.get_session()
try:
# 准备数据
insert_data = []
new_positions_map = {} # (symbol, side) -> position_id (用于删除)
for position_data in positions:
try:
position_dict = self._convert_position_data(position_data)
if not all([position_dict.get('symbol'), position_dict.get('side')]):
continue
symbol = position_dict['symbol']
side = position_dict['side']
key = (symbol, side)
# 重命名qty为sum
if 'qty' in position_dict:
position_dict['sum'] = position_dict.pop('qty')
insert_data.append(position_dict)
new_positions_map[key] = position_dict.get('id') # 如果有id的话
except Exception as e:
logger.error(f"转换持仓数据失败: {position_data}, error={e}")
continue
with session.begin():
if not insert_data:
# 清空该账号所有持仓
result = session.execute(
delete(StrategyPosition).where(
and_(
StrategyPosition.k_id == k_id,
StrategyPosition.st_id == st_id
)
)
)
deleted_count = result.rowcount
return True, {
'total': 0,
'updated': 0,
'inserted': 0,
'deleted': deleted_count
}
# 1. 批量插入/更新持仓数据
processed_count = self._batch_upsert_positions(session, insert_data)
# 2. 批量删除多余持仓
deleted_count = self._batch_delete_extra_positions(session, k_id, st_id, new_positions_map)
# 注意这里无法区分插入和更新的数量processed_count是总处理数
inserted_count = processed_count # 简化处理
updated_count = 0 # 需要更复杂的逻辑来区分
stats = {
'total': len(insert_data),
'updated': updated_count,
'inserted': inserted_count,
'deleted': deleted_count
}
return True, stats
except Exception as e:
logger.error(f"批量同步账号 {k_id} 持仓失败: {e}")
return False, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
finally:
session.close()
def _batch_upsert_positions(self, session, insert_data: List[Dict]) -> int:
"""批量插入/更新持仓数据"""
try:
# 分块处理
chunk_size = self.batch_size
total_processed = 0
for i in range(0, len(insert_data), chunk_size):
chunk = insert_data[i:i + chunk_size]
values_list = []
for data in chunk:
symbol = data.get('symbol').replace("'", "''") if data.get('symbol') else ''
values = (
f"({data['st_id']}, {data['k_id']}, '{data.get('asset', 'USDT')}', "
f"'{symbol}', "
f"{data.get('price') or 'NULL'}, {data.get('sum') or 'NULL'}, "
f"{data.get('asset_num') or 'NULL'}, {data.get('asset_profit') or 'NULL'}, "
f"{data.get('leverage') or 'NULL'}, {data.get('uptime') or 'NULL'}, "
f"{data.get('profit_price') or 'NULL'}, {data.get('stop_price') or 'NULL'}, "
f"{data.get('liquidation_price') or 'NULL'})"
)
values_list.append(values)
if values_list:
values_str = ", ".join(values_list)
sql = f"""
INSERT INTO deh_strategy_position_new
(st_id, k_id, asset, symbol, side, price, `sum`,
asset_num, asset_profit, leverage, uptime,
profit_price, stop_price, liquidation_price)
VALUES {values_str}
ON DUPLICATE KEY UPDATE
price = VALUES(price),
`sum` = VALUES(`sum`),
asset_num = VALUES(asset_num),
asset_profit = VALUES(asset_profit),
leverage = VALUES(leverage),
uptime = VALUES(uptime),
profit_price = VALUES(profit_price),
stop_price = VALUES(stop_price),
liquidation_price = VALUES(liquidation_price)
"""
session.execute(text(sql))
total_processed += len(chunk)
return total_processed
except Exception as e:
logger.error(f"批量插入/更新持仓失败: {e}")
raise
def _batch_delete_extra_positions(self, session, k_id: int, st_id: int, new_positions_map: Dict) -> int:
"""批量删除多余持仓"""
try:
if not new_positions_map:
# 删除所有持仓
result = session.execute(
delete(StrategyPosition).where(
and_(
StrategyPosition.k_id == k_id,
StrategyPosition.st_id == st_id
)
)
)
return result.rowcount
# 构建保留条件
conditions = []
for (symbol, side) in new_positions_map.keys():
safe_symbol = symbol.replace("'", "''") if symbol else ''
safe_side = side.replace("'", "''") if side else ''
conditions.append(f"(symbol = '{safe_symbol}' AND side = '{safe_side}')")
if conditions:
conditions_str = " OR ".join(conditions)
sql = f"""
DELETE FROM deh_strategy_position_new
WHERE k_id = {k_id} AND st_id = {st_id}
AND NOT ({conditions_str})
"""
result = session.execute(text(sql))
return result.rowcount
return 0
except Exception as e:
logger.error(f"批量删除持仓失败: k_id={k_id}, error={e}")
return 0
def _convert_position_data(self, data: Dict) -> Dict: def _convert_position_data(self, data: Dict) -> Dict:
"""转换持仓数据格式""" """转换持仓数据格式"""
try: try:
# 安全转换函数
def safe_float(value, default=None):
if value is None:
return default
try:
return float(value)
except (ValueError, TypeError):
return default
def safe_int(value, default=None):
if value is None:
return default
try:
return int(float(value))
except (ValueError, TypeError):
return default
return { return {
'st_id': safe_int(data.get('st_id'), 0), 'st_id': helpers.safe_int(data.get('st_id'), 0),
'k_id': safe_int(data.get('k_id'), 0), 'k_id': helpers.safe_int(data.get('k_id'), 0),
'asset': data.get('asset', 'USDT'), 'asset': data.get('asset', 'USDT'),
'symbol': data.get('symbol', ''), 'symbol': data.get('symbol', ''),
'side': data.get('side', ''), 'side': data.get('side', ''),
'price': safe_float(data.get('price')), 'price': helpers.safe_float(data.get('price')),
'qty': safe_float(data.get('qty')), # 后面会重命名为sum 'qty': helpers.safe_float(data.get('qty')), # 后面会重命名为sum
'asset_num': safe_float(data.get('asset_num')), 'asset_num': helpers.safe_float(data.get('asset_num')),
'asset_profit': safe_float(data.get('asset_profit')), 'asset_profit': helpers.safe_float(data.get('asset_profit')),
'leverage': safe_int(data.get('leverage')), 'leverage': helpers.safe_int(data.get('leverage')),
'uptime': safe_int(data.get('uptime')), 'uptime': helpers.safe_int(data.get('uptime')),
'profit_price': safe_float(data.get('profit_price')), 'profit_price': helpers.safe_float(data.get('profit_price')),
'stop_price': safe_float(data.get('stop_price')), 'stop_price': helpers.safe_float(data.get('stop_price')),
'liquidation_price': safe_float(data.get('liquidation_price')) 'liquidation_price': helpers.safe_float(data.get('liquidation_price'))
} }
except Exception as e: except Exception as e:

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -56,7 +56,7 @@ class DatabaseManager:
# 创建表(如果不存在) # 创建表(如果不存在)
Base.metadata.create_all(self._engine) Base.metadata.create_all(self._engine)
logger.info("SQLAlchemy数据库引擎初始化成功") # logger.info("SQLAlchemy数据库引擎初始化成功")
except Exception as e: except Exception as e:
logger.error(f"数据库引擎初始化失败: {e}") logger.error(f"数据库引擎初始化失败: {e}")

View File

@@ -0,0 +1,31 @@
from typing import List, Dict, Optional, Any
from loguru import logger
def safe_float(value, default=0.0):
"""安全转换为float处理None和空值"""
if value is None:
return default
try:
return float(value)
except (ValueError, TypeError):
return default
def safe_int(value, default=0):
"""安全转换为int"""
if value is None:
return default
try:
return int(float(value))
except (ValueError, TypeError):
return default
def safe_str(self, value: Any, default: str = '') -> str:
"""安全转换为str"""
if value is None:
return ""
try:
return str(value)
except Exception as e:
logger.error(f"safe_str error: {e}")
return ""