"""
导入配置数据并创建默认用户
功能:
1. 从导出的 JSON 文件导入配置数据到 MongoDB
2. 创建默认管理员用户(admin/admin123)
3. 支持选择性导入集合
4. 支持覆盖或跳过已存在的数据
使用方法:
python scripts/import_config_and_create_user.py <export_file.json>
python scripts/import_config_and_create_user.py <export_file.json> --overwrite
python scripts/import_config_and_create_user.py <export_file.json> --collections system_configs users
"""
import json
import sys
import hashlib
from datetime import datetime
from pathlib import Path
from typing import List, Dict, Any, Optional
import argparse
import os
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from pymongo import MongoClient
from bson import ObjectId
def load_env_config(script_dir: Path) -> dict:
"""从 .env 文件加载配置"""
env_file = script_dir.parent / '.env'
config = {
'mongodb_port': int(os.environ.get('MONGODB_PORT', 27017)),
'mongodb_host': os.environ.get('MONGODB_HOST', 'localhost'),
'mongodb_username': os.environ.get('MONGODB_USERNAME', 'admin'),
'mongodb_password': os.environ.get('MONGODB_PASSWORD', 'tradingagents123'),
'mongodb_database': os.environ.get('MONGODB_DATABASE', 'tradingagents')
}
if env_file.exists():
try:
file_config = {}
with open(env_file, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line or line.startswith('#'): continue
if '=' in line:
key, value = line.split('=', 1)
key = key.strip()
value = value.strip()
file_config[key] = value
if 'MONGODB_PORT' not in os.environ and 'MONGODB_PORT' in file_config:
config['mongodb_port'] = int(file_config['MONGODB_PORT'])
if 'MONGODB_HOST' not in os.environ and 'MONGODB_HOST' in file_config:
config['mongodb_host'] = file_config['MONGODB_HOST']
if 'MONGODB_USERNAME' not in os.environ and 'MONGODB_USERNAME' in file_config:
config['mongodb_username'] = file_config['MONGODB_USERNAME']
if 'MONGODB_PASSWORD' not in os.environ and 'MONGODB_PASSWORD' in file_config:
config['mongodb_password'] = file_config['MONGODB_PASSWORD']
except Exception as e:
print(f"⚠️ 警告:读取 .env 文件失败:{e}")
else:
if not os.environ.get("DOCKER_CONTAINER"):
print(f"⚠️ 警告:.env 文件不存在:{env_file}")
return config
DB_NAME = "tradingagents"
DEFAULT_ADMIN = {
"username": "admin",
"password": "admin123",
"email": "[email protected]"
}
CONFIG_COLLECTIONS = [
"system_configs", "users", "llm_providers", "market_categories",
"user_tags", "datasource_groupings", "platform_configs",
"user_configs", "model_catalog"
]
def hash_password(password: str) -> str:
return hashlib.sha256(password.encode()).hexdigest()
def convert_to_bson(data: Any) -> Any:
if isinstance(data, dict):
result = {}
for key, value in data.items():
if key == "_id" or key.endswith("_id"):
if isinstance(value, str) and len(value) == 24:
try:
result[key] = ObjectId(value)
continue
except: pass
if key.endswith("_at") or key in ["created_at", "updated_at", "last_login", "added_at"]:
if isinstance(value, str):
try:
result[key] = datetime.fromisoformat(value.replace('Z', '+00:00'))
continue
except: pass
result[key] = convert_to_bson(value)
return result
elif isinstance(data, list):
return [convert_to_bson(item) for item in data]
else:
return data
def load_export_file(file_path: str) -> Dict[str, Any]:
print(f"\n📂 加载导出文件:{file_path}")
try:
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
if "export_info" not in data or "data" not in data:
print("❌ 错误:文件格式不正确,缺少 export_info 或 data 字段")
sys.exit(1)
export_info = data["export_info"]
print(f"✅ 文件加载成功")
print(f" 导出时间:{export_info.get('created_at', 'Unknown')}")
print(f" 导出格式:{export_info.get('format', 'Unknown')}")
print(f" 集合数量:{len(export_info.get('collections', []))}")
return data
except FileNotFoundError:
print(f"❌ 错误:文件不存在:{file_path}")
sys.exit(1)
except json.JSONDecodeError as e:
print(f"❌ 错误:JSON 解析失败:{e}")
sys.exit(1)
except Exception as e:
print(f"❌ 错误:加载文件失败:{e}")
sys.exit(1)
def connect_mongodb(use_docker: bool = True, config: dict = None) -> MongoClient:
if config is None:
config = {
'mongodb_port': 27017,
'mongodb_host': 'localhost',
'mongodb_username': 'admin',
'mongodb_password': 'tradingagents123',
'mongodb_database': 'tradingagents'
}
host = 'mongodb' if use_docker else config['mongodb_host']
port = config['mongodb_port']
username = config['mongodb_username']
password = config['mongodb_password']
database = config['mongodb_database']
mongo_uri = f"mongodb://{username}:{password}@{host}:{port}/{database}?authSource=admin"
env_name = "Docker 容器内" if use_docker else "宿主机"
print(f"\n🔌 连接到 MongoDB ({env_name})...")
print(f" URI: mongodb://{username}:***@{host}:{port}/{database}?authSource=admin")
try:
client = MongoClient(mongo_uri, serverSelectionTimeoutMS=5000)
client.admin.command('ping')
print(f"✅ MongoDB 连接成功")
return client
except Exception as e:
print(f"❌ 错误:MongoDB 连接失败:{e}")
if use_docker:
print(f" 请确保在 Docker 容器内运行,或使用 --host 参数在宿主机运行")
print(f" 检查容器:docker ps | grep mongodb")
else:
print(f" 请确保 MongoDB 正在运行并监听端口 {port}")
print(f" 检查端口:netstat -an | findstr {port}")
sys.exit(1)
def import_collection(db: Any, collection_name: str, documents: List[Dict[str, Any]], overwrite: bool = False) -> Dict[str, int]:
collection = db[collection_name]
converted_docs = [convert_to_bson(doc) for doc in documents]
if overwrite:
result = collection.delete_many({})
deleted_count = result.deleted_count
if converted_docs:
result = collection.insert_many(converted_docs)
inserted_count = len(result.inserted_ids)
else:
inserted_count = 0
return {"deleted": deleted_count, "inserted": inserted_count, "skipped": 0}
else:
inserted_count = 0
skipped_count = 0
for doc in converted_docs:
query = {}
if "_id" in doc:
query["_id"] = doc["_id"]
elif "username" in doc:
query["username"] = doc["username"]
elif "name" in doc:
query["name"] = doc["name"]
else:
collection.insert_one(doc)
inserted_count += 1
continue
existing = collection.find_one(query)
if existing:
skipped_count += 1
else:
collection.insert_one(doc)
inserted_count += 1
return {"deleted": 0, "inserted": inserted_count, "skipped": skipped_count}
def create_default_admin(db: Any, overwrite: bool = False) -> bool:
print(f"\n👤 创建默认管理员用户...")
users_collection = db.users
existing_user = users_collection.find_one({"username": DEFAULT_ADMIN["username"]})
if existing_user:
if not overwrite:
print(f"⚠️ 用户 '{DEFAULT_ADMIN['username']}' 已存在,跳过创建")
return False
else:
print(f"⚠️ 用户 '{DEFAULT_ADMIN['username']}' 已存在,将覆盖")
users_collection.delete_one({"username": DEFAULT_ADMIN["username"]})
user_doc = {
"username": DEFAULT_ADMIN["username"],
"email": DEFAULT_ADMIN["email"],
"hashed_password": hash_password(DEFAULT_ADMIN["password"]),
"is_active": True,
"is_verified": True,
"is_admin": True,
"created_at": datetime.utcnow(),
"updated_at": datetime.utcnow(),
"last_login": None,
"preferences": {
"default_market": "A 股",
"default_depth": "深度",
"ui_theme": "light",
"language": "zh-CN",
"notifications_enabled": True,
"email_notifications": False
},
"daily_quota": 10000,
"concurrent_limit": 10,
"total_analyses": 0,
"successful_analyses": 0,
"failed_analyses": 0,
"favorite_stocks": []
}
users_collection.insert_one(user_doc)
print(f"✅ 默认管理员用户创建成功")
print(f" 用户名:{DEFAULT_ADMIN['username']}")
print(f" 密码:{DEFAULT_ADMIN['password']}")
print(f" 邮箱:{DEFAULT_ADMIN['email']}")
print(f" 角色:管理员")
return True
def ensure_new_providers(db: Any):
print(f"\n✨ 检查并补充缺失的新 Provider...")
providers_collection = db.llm_providers
now = datetime.utcnow()
new_providers = [
{
"name": "oneapi",
"display_name": "LLM Hub",
"description": "一站式人工智能集成平台",
"website": "https://www.llmhub.com.cn",
"api_doc_url": "https://docs.llmhub.com.cn/",
"default_base_url": "https://api.llmhub.com.cn/v1",
"is_active": True,
"supported_features": ["chat", "completion", "embedding", "image", "vision", "function_calling", "streaming"],
"is_aggregator": True,
"aggregator_type": "openai_compatible",
"logo_url": "/assets/logos/oneapi.png"
}
]
added_count = 0
for provider in new_providers:
existing = providers_collection.find_one({"name": provider["name"]})
if not existing:
provider["created_at"] = now
provider["updated_at"] = now
provider["extra_config"] = {}
providers_collection.insert_one(provider)
print(f" ➕ 已自动补全:{provider['display_name']} ({provider['name']})")
added_count += 1
else:
if not existing.get("is_aggregator"):
providers_collection.update_one(
{"_id": existing["_id"]},
{"$set": {"is_aggregator": True, "aggregator_type": "openai_compatible"}}
)
print(f" 🔄 已更新以支持聚合模式:{provider['display_name']}")
if added_count == 0:
print(" ✅ 所有新 Provider 已存在,无需补充")
else:
print(f" 🎉 成功补全 {added_count} 个 Provider")
def main():
parser = argparse.ArgumentParser(
description="导入配置数据并创建默认用户",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="示例:
# 在 Docker 容器内运行(默认)
python scripts/import_config_and_create_user.py
# 在宿主机运行(连接到 localhost:27017)
python scripts/import_config_and_create_user.py --host
# 从指定文件导入(默认覆盖模式)
python scripts/import_config_and_create_user.py export.json
# 增量模式:跳过已存在的数据
python scripts/import_config_and_create_user.py --incremental
# 只导入指定的集合
python scripts/import_config_and_create_user.py --collections system_configs users
# 只创建默认用户,不导入数据
python scripts/import_config_and_create_user.py --create-user-only"
)
parser.add_argument("export_file", nargs="?", help="导出的 JSON 文件路径(默认:install/database_export_config_*.json)")
parser.add_argument("--host", action="store_true", help="在宿主机运行(连接 localhost:27017),默认在 Docker 容器内运行(连接 mongodb:27017)")
parser.add_argument("--overwrite", action="store_true", default=True, help="覆盖已存在的数据(默认:覆盖)")
parser.add_argument("--incremental", action="store_true", help="增量模式:跳过已存在的数据")
parser.add_argument("--collections", nargs="+", help="指定要导入的集合(默认:所有配置集合)")
parser.add_argument("--create-user-only", action="store_true", help="只创建默认用户,不导入数据")
parser.add_argument("--skip-user", action="store_true", help="跳过创建默认用户")
parser.add_argument("--mongodb-port", type=int, help="MongoDB 端口(覆盖 .env 配置)")
parser.add_argument("--mongodb-host", type=str, help="MongoDB 主机(覆盖 .env 配置)")
args = parser.parse_args()
if args.incremental:
args.overwrite = False
if not args.create_user_only and not args.export_file:
install_dir = project_root / "install"
if install_dir.exists():
config_files = list(install_dir.glob("database_export_config_*.json"))
if config_files:
args.export_file = str(sorted(config_files)[-1])
print(f"💡 未指定文件,使用默认配置:{args.export_file}")
else:
parser.error("install 目录中未找到配置文件 (database_export_config_*.json)")
else:
parser.error("必须提供导出文件路径,或使用 --create-user-only")
print("=" * 80)
print("📦 导入配置数据并创建默认用户")
print("=" * 80)
script_dir = Path(__file__).parent
env_config = load_env_config(script_dir)
if args.mongodb_port:
env_config['mongodb_port'] = args.mongodb_port
print(f"💡 使用命令行指定的 MongoDB 端口:{args.mongodb_port}")
if args.mongodb_host:
env_config['mongodb_host'] = args.mongodb_host
print(f"💡 使用命令行指定的 MongoDB 主机:{args.mongodb_host}")
use_docker = not args.host
client = connect_mongodb(use_docker=use_docker, config=env_config)
db = client[DB_NAME]
if not args.create_user_only:
export_data = load_export_file(args.export_file)
data = export_data["data"]
if args.collections:
collections_to_import = args.collections
else:
collections_to_import = [c for c in CONFIG_COLLECTIONS if c in data]
print(f"\n📋 准备导入 {len(collections_to_import)} 个集合:")
for col in collections_to_import:
doc_count = len(data.get(col, []))
print(f" - {col}: {doc_count} 个文档")
print(f"\n🚀 开始导入...")
print(f" 模式:{'覆盖' if args.overwrite else '增量'}")
total_stats = {"deleted": 0, "inserted": 0, "skipped": 0}
for collection_name in collections_to_import:
if collection_name not in data:
print(f"⚠️ 跳过 {collection_name}: 导出文件中不存在")
continue
documents = data[collection_name]
print(f"\n 导入 {collection_name}...")
try:
stats = import_collection(db, collection_name, documents, args.overwrite)
total_stats["deleted"] += stats["deleted"]
total_stats["inserted"] += stats["inserted"]
total_stats["skipped"] += stats["skipped"]
if args.overwrite:
print(f" ✅ 删除 {stats['deleted']} 个,插入 {stats['inserted']} 个")
else:
print(f" ✅ 插入 {stats['inserted']} 个,跳过 {stats['skipped']} 个")
except Exception as e:
print(f" ❌ 失败:{e}")
print(f"\n📊 导入统计:")
if args.overwrite:
print(f" 删除:{total_stats['deleted']} 个文档")
print(f" 插入:{total_stats['inserted']} 个文档")
if not args.overwrite:
print(f" 跳过:{total_stats['skipped']} 个文档")
if not args.skip_user:
create_default_admin(db, args.overwrite)
ensure_new_providers(db)
client.close()
print("\n" + "=" * 80)
print("✅ 操作完成!")
print("=" * 80)
if not args.skip_user:
print(f"\n🔐 登录信息:")
print(f" 用户名:{DEFAULT_ADMIN['username']}")
print(f" 密码:{DEFAULT_ADMIN['password']}")
print(f"\n📝 后续步骤:")
print(f" 1. 重启后端服务:docker restart tradingagents-backend")
print(f" 2. 访问前端并使用默认账号登录")
print(f" 3. 检查系统配置是否正确加载")
if __name__ == "__main__":
main()