You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

355 lines
12 KiB
Python

"""
FastAPI 服务主模块
提供训练、评估和推理的 HTTP 接口
"""
import os
import time
import datetime
import logging
import shutil
import uvicorn
import schedule
import threading
import yaml
from fastapi import FastAPI, Request, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from typing import List
import atexit
from utils.feature_processor import (
Features, FeatureProcessor, FeatureWeightApplier, normalize_features
)
from utils.model_trainer import MLModel
# API 响应模型
class PredictionResult(BaseModel):
"""推理结果模型"""
predictions: List[int]
class ClassificationResult(BaseModel):
"""分类评估结果模型"""
precision: List[float]
recall: List[float]
f1: List[float]
wrong_percentage: float
class APIResponse(BaseModel):
"""API 通用响应模型"""
classification_result: ClassificationResult
data_file: dict
class APIConfig:
"""API 配置类"""
def __init__(self):
# 初始化配置文件路径
self.config_path = os.path.abspath(os.path.join(
os.path.dirname(__file__), "config/config.yaml"))
# 加载配置
with open(self.config_path, 'r') as f:
self.config = yaml.safe_load(f)
# 初始化日志
self.log_path = os.path.abspath(os.path.join(
os.path.dirname(__file__), "logfile.log"))
self._setup_logging()
# 初始化目录
self._setup_directories()
def _setup_logging(self):
"""配置日志"""
logging.basicConfig(
filename=self.log_path,
level=logging.INFO,
format='%(asctime)s %(levelname)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
def _setup_directories(self):
"""创建必要的目录"""
for dir_name in self.config['system']['clean_dirs']:
dir_path = os.path.abspath(os.path.join(
os.path.dirname(__file__), dir_name))
os.makedirs(dir_path, exist_ok=True)
def update_log_handler(self, log_path: str):
"""更新日志处理器"""
logger = logging.getLogger()
for handler in logger.handlers[:]:
logger.removeHandler(handler)
file_handler = logging.FileHandler(log_path)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(logging.Formatter(
'%(asctime)s %(levelname)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
))
logger.addHandler(file_handler)
class FileCleanup:
"""文件清理类"""
@staticmethod
def clean_old_files(directory: str, days: int):
"""清理指定天数前的文件"""
now = time.time()
cutoff = now - (days * 86400)
for root, _, files in os.walk(directory):
for file in files:
file_path = os.path.join(root, file)
if os.path.getmtime(file_path) < cutoff:
os.remove(file_path)
logging.info(f"Removed old file: {file_path}")
@staticmethod
def schedule_cleanup(config: dict):
"""定期清理任务"""
for directory in config['system']['clean_dirs']:
abs_directory = os.path.abspath(os.path.join(
os.path.dirname(__file__), directory))
FileCleanup.clean_old_files(
abs_directory,
config['system']['log_retention_days']
)
# 初始化 FastAPI 应用
app = FastAPI()
api_config = APIConfig()
# 配置 CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 挂载静态文件目录
for dir_name in api_config.config['system']['clean_dirs']:
dir_path = os.path.abspath(os.path.join(os.path.dirname(__file__), dir_name))
app.mount(f"/{dir_name}", StaticFiles(directory=dir_path), name=dir_name)
@app.post("/train/")
async def train_model(request: Request, features_list: List[Features]) -> APIResponse:
"""训练模型接口"""
# 创建特征处理器
processor = FeatureProcessor(api_config.config)
all_features = processor.create_feature_df(features_list)
# 应用特征权重
feature_label_weighted = FeatureWeightApplier.apply_weights(
all_features,
api_config.config['features']['feature_names'],
api_config.config['features']['feature_weights']
)
# 特征归一化
feature_label_weighted = normalize_features(
feature_label_weighted,
api_config.config['features']['feature_names'],
is_train=True,
config=api_config.config
)
# 记录开始时间
start_time = time.time()
# 设置训练相关路径
now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
train_dir = os.path.abspath(os.path.join(
os.path.dirname(__file__), "train_api"))
# 更新配置中的路径
api_config.config['paths']['model'].update({
'train_process': os.path.join(train_dir, f"train_process_{now}.png"),
'evaluate_result_path': os.path.join(train_dir, f"evaluate_result_{now}.png")
})
# 保存训练数据
data_path = os.path.join(train_dir, f"train_feature_label_weighted_{now}.xlsx")
feature_label_weighted.to_excel(data_path, index=False)
# 更新配置
api_config.config['data_path'] = data_path
api_config.config['train_model_path'] = os.path.join(
train_dir, f"train_model_{now}.pth")
# 添加相关性矩阵图路径
api_config.config['correlation_matrix_path'] = os.path.join(
train_dir, f"correlation_matrix_{now}.png")
# 配置日志
log_path = os.path.join(train_dir, f"train_log_{now}.log")
api_config.update_log_handler(log_path)
# 训练模型
ml_model = MLModel(api_config.config)
avg_f1, wrong_percentage, precision, recall, f1 = ml_model.train_detect()
# 如果配置允许,复制最佳模型到指定位置
if api_config.config['training']['replace_model']:
shutil.copyfile(
api_config.config['train_model_path'], # 不带序号的版本(最佳模型)
api_config.config['paths']['model']['train'] # 复制到目标位置
)
# 记录结束时间
end_time = time.time()
logging.info(f"训练耗时: {end_time - start_time}")
# 构建响应
return APIResponse(
classification_result=ClassificationResult(
precision=precision,
recall=recall,
f1=f1,
wrong_percentage=wrong_percentage
),
data_file={
"model_file_url": f"{request.base_url}train_api/train_model_{now}.pth",
"log_file_url": f"{request.base_url}train_api/train_log_{now}.log",
"data_file_url": f"{request.base_url}train_api/train_feature_label_weighted_{now}.xlsx",
"train_process_img_url": f"{request.base_url}train_api/train_model_{now}_training_process.png",
"evaluate_result_img_url": f"{request.base_url}train_api/train_model_{now}_evaluate_result.png"
}
)
@app.post("/evaluate/")
async def evaluate_model(request: Request, features_list: List[Features]) -> APIResponse:
"""评估模型接口"""
# 特征处理
processor = FeatureProcessor(api_config.config)
all_features = processor.create_feature_df(features_list)
feature_label_weighted = FeatureWeightApplier.apply_weights(
all_features,
api_config.config['features']['feature_names'],
api_config.config['features']['feature_weights']
)
feature_label_weighted = normalize_features(
feature_label_weighted,
api_config.config['features']['feature_names'],
is_train=False,
config=api_config.config
)
# 评估
start_time = time.time()
now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
evaluate_dir = os.path.abspath(os.path.join(
os.path.dirname(__file__), "evaluate_api"))
# 更新配置中的评估结果图路径
api_config.config['paths']['model']['evaluate_result_path'] = os.path.join(
evaluate_dir, f"evaluate_result_{now}.png")
data_path = os.path.join(
evaluate_dir, f"evaluate_feature_label_weighted_{now}.xlsx")
feature_label_weighted.to_excel(data_path, index=False)
log_path = os.path.join(evaluate_dir, f"evaluate_log_{now}.log")
api_config.update_log_handler(log_path)
ml_model = MLModel(api_config.config)
avg_f1, wrong_percentage, precision, recall, f1 = ml_model.evaluate_model(
feature_label_weighted[api_config.config['features']['feature_names']].values,
feature_label_weighted[api_config.config['features']['label_name']].values
)
end_time = time.time()
logging.info(f"评估耗时: {end_time - start_time}")
return APIResponse(
classification_result=ClassificationResult(
precision=precision,
recall=recall,
f1=f1,
wrong_percentage=wrong_percentage
),
data_file={
"log_file_url": f"{request.base_url}evaluate_api/evaluate_log_{now}.log",
"data_file_url": f"{request.base_url}evaluate_api/evaluate_feature_label_weighted_{now}.xlsx",
"evaluate_result_img_url": f"{request.base_url}evaluate_api/evaluate_model_{now}_evaluate_result.png"
}
)
@app.post("/inference/")
async def inference_model(request: Request, features_list: List[Features]) -> PredictionResult:
"""推理接口"""
processor = FeatureProcessor(api_config.config)
all_features = processor.create_feature_df(features_list)
feature_label_weighted = FeatureWeightApplier.apply_weights(
all_features,
api_config.config['features']['feature_names'],
api_config.config['features']['feature_weights']
)
feature_label_weighted = normalize_features(
feature_label_weighted,
api_config.config['features']['feature_names'],
is_train=False,
config=api_config.config
)
start_time = time.time()
now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
inference_dir = os.path.abspath(os.path.join(
os.path.dirname(__file__), "inference_api"))
data_path = os.path.join(
inference_dir, f"inference_feature_label_weighted_{now}.xlsx")
feature_label_weighted.to_excel(data_path, index=False)
log_path = os.path.join(inference_dir, f"inference_log_{now}.log")
api_config.update_log_handler(log_path)
ml_model = MLModel(api_config.config)
predictions = ml_model.inference_model(
feature_label_weighted[api_config.config['features']['feature_names']].values
)
end_time = time.time()
logging.info(f"推理耗时: {end_time - start_time}")
return PredictionResult(predictions=predictions)
@app.post("/upload_model/")
async def upload_model(file: UploadFile = File(...)):
"""上传模型接口"""
models_dir = os.path.abspath(os.path.join(
os.path.dirname(__file__), "models"))
os.makedirs(models_dir, exist_ok=True)
file_path = os.path.join(models_dir, "psy.pth")
with open(file_path, "wb") as buffer:
buffer.write(await file.read())
return {"message": "模型上传成功", "file_path": file_path}
def start_scheduler():
"""启动定时任务"""
schedule.every(1).hours.do(
FileCleanup.schedule_cleanup,
config=api_config.config
)
while True:
schedule.run_pending()
time.sleep(1)
if __name__ == "__main__":
# 启动定时清理任务
scheduler_thread = threading.Thread(target=start_scheduler)
scheduler_thread.daemon = True
scheduler_thread.start()
# 启动服务
uvicorn.run(
app,
host=api_config.config['system']['host'],
port=api_config.config['system']['port'],
reload=False
)