deep-risk/backend/scripts/check_and_migrate_database.py
2025-12-14 20:08:27 +08:00

192 lines
6.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
数据库迁移检查和执行脚本
检查并执行entity_id和entity_type字段的添加
"""
import asyncio
import sys
import os
from sqlalchemy import text
from sqlalchemy.ext.asyncio import create_async_engine
# 添加项目根目录到路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.config import settings
async def check_and_migrate():
"""检查数据库迁移状态并执行迁移"""
print("="*60)
print("数据库迁移检查和执行")
print("="*60)
# 创建数据库引擎
engine = create_async_engine(
settings.DATABASE_URL,
echo=True,
future=True,
)
async with engine.connect() as conn:
try:
# 1. 检查 sys_user 表是否存在
print("\n1. 检查 sys_user 表...")
result = await conn.execute(text("""
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = 'sys_user'
);
"""))
table_exists = result.scalar()
if not table_exists:
print("❌ sys_user 表不存在")
return False
print("✅ sys_user 表存在")
# 2. 检查 entity_id 字段是否存在
print("\n2. 检查 entity_id 字段...")
result = await conn.execute(text("""
SELECT EXISTS (
SELECT FROM information_schema.columns
WHERE table_name = 'sys_user'
AND column_name = 'entity_id'
);
"""))
has_entity_id = result.scalar()
if has_entity_id:
print("✅ entity_id 字段已存在")
else:
print("❌ entity_id 字段不存在,需要迁移")
# 执行迁移:添加 entity_id 字段
print("\n 执行迁移:添加 entity_id 字段...")
await conn.execute(text("""
ALTER TABLE sys_user
ADD COLUMN entity_id VARCHAR(50) NULL;
"""))
await conn.execute(text("""
COMMENT ON COLUMN sys_user.entity_id IS '关联的企业实体IDMCN机构或主播';
"""))
print("✅ entity_id 字段添加成功")
# 3. 检查 entity_type 字段是否存在
print("\n3. 检查 entity_type 字段...")
result = await conn.execute(text("""
SELECT EXISTS (
SELECT FROM information_schema.columns
WHERE table_name = 'sys_user'
AND column_name = 'entity_type'
);
"""))
has_entity_type = result.scalar()
if has_entity_type:
print("✅ entity_type 字段已存在")
else:
print("❌ entity_type 字段不存在,需要迁移")
# 执行迁移:添加 entity_type 字段
print("\n 执行迁移:添加 entity_type 字段...")
await conn.execute(text("""
ALTER TABLE sys_user
ADD COLUMN entity_type VARCHAR(20) NULL;
"""))
await conn.execute(text("""
COMMENT ON COLUMN sys_user.entity_type IS '关联实体类型mcn-机构streamer-主播';
"""))
print("✅ entity_type 字段添加成功")
# 4. 检查索引是否存在
print("\n4. 检查索引...")
result = await conn.execute(text("""
SELECT EXISTS (
SELECT FROM pg_indexes
WHERE tablename = 'sys_user'
AND indexname = 'ix_sys_user_entity_id'
);
"""))
has_entity_id_index = result.scalar()
if not has_entity_id_index:
print(" 创建 entity_id 索引...")
await conn.execute(text("""
CREATE INDEX ix_sys_user_entity_id ON sys_user(entity_id);
"""))
print("✅ entity_id 索引创建成功")
else:
print("✅ entity_id 索引已存在")
result = await conn.execute(text("""
SELECT EXISTS (
SELECT FROM pg_indexes
WHERE tablename = 'sys_user'
AND indexname = 'ix_sys_user_entity_type'
);
"""))
has_entity_type_index = result.scalar()
if not has_entity_type_index:
print(" 创建 entity_type 索引...")
await conn.execute(text("""
CREATE INDEX ix_sys_user_entity_type ON sys_user(entity_type);
"""))
print("✅ entity_type 索引创建成功")
else:
print("✅ entity_type 索引已存在")
# 5. 显示现有用户数据
print("\n5. 现有用户数据...")
result = await conn.execute(text("""
SELECT username, entity_id, entity_type
FROM sys_user
LIMIT 10;
"""))
users = result.fetchall()
if users:
print(f"{len(users)} 个用户:")
for user in users:
entity_info = f'entity_id={user[1]}, entity_type={user[2]}' if user[1] else '未设置'
print(f" - {user[0]}: {entity_info}")
else:
print(" ❌ 没有找到用户数据")
# 提交事务
await conn.commit()
print("\n" + "="*60)
print("✅ 数据库迁移完成!")
print("="*60)
return True
except Exception as e:
print(f"\n❌ 迁移失败: {str(e)}")
await conn.rollback()
raise
finally:
await engine.dispose()
async def main():
"""主函数"""
try:
success = await check_and_migrate()
if success:
print("\n✅ 所有迁移操作已成功完成")
sys.exit(0)
else:
print("\n❌ 迁移过程中出现问题")
sys.exit(1)
except Exception as e:
print(f"\n❌ 执行失败: {str(e)}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
asyncio.run(main())