You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
355 lines
12 KiB
Python
355 lines
12 KiB
Python
"""
|
|
FastAPI 服务主模块
|
|
提供训练、评估和推理的 HTTP 接口
|
|
"""
|
|
import os
|
|
import time
|
|
import datetime
|
|
import logging
|
|
import shutil
|
|
import uvicorn
|
|
import schedule
|
|
import threading
|
|
import yaml
|
|
from fastapi import FastAPI, Request, File, UploadFile
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.staticfiles import StaticFiles
|
|
from pydantic import BaseModel
|
|
from typing import List
|
|
import atexit
|
|
|
|
from utils.feature_processor import (
|
|
Features, FeatureProcessor, FeatureWeightApplier, normalize_features
|
|
)
|
|
from utils.model_trainer import MLModel
|
|
|
|
# API 响应模型
|
|
class PredictionResult(BaseModel):
|
|
"""推理结果模型"""
|
|
predictions: List[int]
|
|
|
|
class ClassificationResult(BaseModel):
|
|
"""分类评估结果模型"""
|
|
precision: List[float]
|
|
recall: List[float]
|
|
f1: List[float]
|
|
wrong_percentage: float
|
|
|
|
class APIResponse(BaseModel):
|
|
"""API 通用响应模型"""
|
|
classification_result: ClassificationResult
|
|
data_file: dict
|
|
|
|
class APIConfig:
|
|
"""API 配置类"""
|
|
def __init__(self):
|
|
# 初始化配置文件路径
|
|
self.config_path = os.path.abspath(os.path.join(
|
|
os.path.dirname(__file__), "config/config.yaml"))
|
|
|
|
# 加载配置
|
|
with open(self.config_path, 'r') as f:
|
|
self.config = yaml.safe_load(f)
|
|
|
|
# 初始化日志
|
|
self.log_path = os.path.abspath(os.path.join(
|
|
os.path.dirname(__file__), "logfile.log"))
|
|
self._setup_logging()
|
|
|
|
# 初始化目录
|
|
self._setup_directories()
|
|
|
|
def _setup_logging(self):
|
|
"""配置日志"""
|
|
logging.basicConfig(
|
|
filename=self.log_path,
|
|
level=logging.INFO,
|
|
format='%(asctime)s %(levelname)s: %(message)s',
|
|
datefmt='%Y-%m-%d %H:%M:%S'
|
|
)
|
|
|
|
def _setup_directories(self):
|
|
"""创建必要的目录"""
|
|
for dir_name in self.config['system']['clean_dirs']:
|
|
dir_path = os.path.abspath(os.path.join(
|
|
os.path.dirname(__file__), dir_name))
|
|
os.makedirs(dir_path, exist_ok=True)
|
|
|
|
def update_log_handler(self, log_path: str):
|
|
"""更新日志处理器"""
|
|
logger = logging.getLogger()
|
|
for handler in logger.handlers[:]:
|
|
logger.removeHandler(handler)
|
|
file_handler = logging.FileHandler(log_path)
|
|
file_handler.setLevel(logging.INFO)
|
|
file_handler.setFormatter(logging.Formatter(
|
|
'%(asctime)s %(levelname)s: %(message)s',
|
|
datefmt='%Y-%m-%d %H:%M:%S'
|
|
))
|
|
logger.addHandler(file_handler)
|
|
|
|
class FileCleanup:
|
|
"""文件清理类"""
|
|
@staticmethod
|
|
def clean_old_files(directory: str, days: int):
|
|
"""清理指定天数前的文件"""
|
|
now = time.time()
|
|
cutoff = now - (days * 86400)
|
|
|
|
for root, _, files in os.walk(directory):
|
|
for file in files:
|
|
file_path = os.path.join(root, file)
|
|
if os.path.getmtime(file_path) < cutoff:
|
|
os.remove(file_path)
|
|
logging.info(f"Removed old file: {file_path}")
|
|
|
|
@staticmethod
|
|
def schedule_cleanup(config: dict):
|
|
"""定期清理任务"""
|
|
for directory in config['system']['clean_dirs']:
|
|
abs_directory = os.path.abspath(os.path.join(
|
|
os.path.dirname(__file__), directory))
|
|
FileCleanup.clean_old_files(
|
|
abs_directory,
|
|
config['system']['log_retention_days']
|
|
)
|
|
|
|
# 初始化 FastAPI 应用
|
|
app = FastAPI()
|
|
api_config = APIConfig()
|
|
|
|
# 配置 CORS
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# 挂载静态文件目录
|
|
for dir_name in api_config.config['system']['clean_dirs']:
|
|
dir_path = os.path.abspath(os.path.join(os.path.dirname(__file__), dir_name))
|
|
app.mount(f"/{dir_name}", StaticFiles(directory=dir_path), name=dir_name)
|
|
|
|
@app.post("/train/")
|
|
async def train_model(request: Request, features_list: List[Features]) -> APIResponse:
|
|
"""训练模型接口"""
|
|
# 创建特征处理器
|
|
processor = FeatureProcessor(api_config.config)
|
|
all_features = processor.create_feature_df(features_list)
|
|
|
|
# 应用特征权重
|
|
feature_label_weighted = FeatureWeightApplier.apply_weights(
|
|
all_features,
|
|
api_config.config['features']['feature_names'],
|
|
api_config.config['features']['feature_weights']
|
|
)
|
|
|
|
# 特征归一化
|
|
feature_label_weighted = normalize_features(
|
|
feature_label_weighted,
|
|
api_config.config['features']['feature_names'],
|
|
is_train=True,
|
|
config=api_config.config
|
|
)
|
|
|
|
# 记录开始时间
|
|
start_time = time.time()
|
|
|
|
# 设置训练相关路径
|
|
now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
train_dir = os.path.abspath(os.path.join(
|
|
os.path.dirname(__file__), "train_api"))
|
|
|
|
# 更新配置中的路径
|
|
api_config.config['paths']['model'].update({
|
|
'train_process': os.path.join(train_dir, f"train_process_{now}.png"),
|
|
'evaluate_result_path': os.path.join(train_dir, f"evaluate_result_{now}.png")
|
|
})
|
|
|
|
# 保存训练数据
|
|
data_path = os.path.join(train_dir, f"train_feature_label_weighted_{now}.xlsx")
|
|
feature_label_weighted.to_excel(data_path, index=False)
|
|
|
|
# 更新配置
|
|
api_config.config['data_path'] = data_path
|
|
api_config.config['train_model_path'] = os.path.join(
|
|
train_dir, f"train_model_{now}.pth")
|
|
# 添加相关性矩阵图路径
|
|
api_config.config['correlation_matrix_path'] = os.path.join(
|
|
train_dir, f"correlation_matrix_{now}.png")
|
|
|
|
# 配置日志
|
|
log_path = os.path.join(train_dir, f"train_log_{now}.log")
|
|
api_config.update_log_handler(log_path)
|
|
|
|
# 训练模型
|
|
ml_model = MLModel(api_config.config)
|
|
avg_f1, wrong_percentage, precision, recall, f1 = ml_model.train_detect()
|
|
|
|
# 如果配置允许,复制最佳模型到指定位置
|
|
if api_config.config['training']['replace_model']:
|
|
shutil.copyfile(
|
|
api_config.config['train_model_path'], # 不带序号的版本(最佳模型)
|
|
api_config.config['paths']['model']['train'] # 复制到目标位置
|
|
)
|
|
|
|
# 记录结束时间
|
|
end_time = time.time()
|
|
logging.info(f"训练耗时: {end_time - start_time} 秒")
|
|
|
|
# 构建响应
|
|
return APIResponse(
|
|
classification_result=ClassificationResult(
|
|
precision=precision,
|
|
recall=recall,
|
|
f1=f1,
|
|
wrong_percentage=wrong_percentage
|
|
),
|
|
data_file={
|
|
"model_file_url": f"{request.base_url}train_api/train_model_{now}.pth",
|
|
"log_file_url": f"{request.base_url}train_api/train_log_{now}.log",
|
|
"data_file_url": f"{request.base_url}train_api/train_feature_label_weighted_{now}.xlsx",
|
|
"train_process_img_url": f"{request.base_url}train_api/train_model_{now}_training_process.png",
|
|
"evaluate_result_img_url": f"{request.base_url}train_api/train_model_{now}_evaluate_result.png"
|
|
}
|
|
)
|
|
|
|
@app.post("/evaluate/")
|
|
async def evaluate_model(request: Request, features_list: List[Features]) -> APIResponse:
|
|
"""评估模型接口"""
|
|
# 特征处理
|
|
processor = FeatureProcessor(api_config.config)
|
|
all_features = processor.create_feature_df(features_list)
|
|
|
|
feature_label_weighted = FeatureWeightApplier.apply_weights(
|
|
all_features,
|
|
api_config.config['features']['feature_names'],
|
|
api_config.config['features']['feature_weights']
|
|
)
|
|
|
|
feature_label_weighted = normalize_features(
|
|
feature_label_weighted,
|
|
api_config.config['features']['feature_names'],
|
|
is_train=False,
|
|
config=api_config.config
|
|
)
|
|
|
|
# 评估
|
|
start_time = time.time()
|
|
now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
evaluate_dir = os.path.abspath(os.path.join(
|
|
os.path.dirname(__file__), "evaluate_api"))
|
|
|
|
# 更新配置中的评估结果图路径
|
|
api_config.config['paths']['model']['evaluate_result_path'] = os.path.join(
|
|
evaluate_dir, f"evaluate_result_{now}.png")
|
|
|
|
data_path = os.path.join(
|
|
evaluate_dir, f"evaluate_feature_label_weighted_{now}.xlsx")
|
|
feature_label_weighted.to_excel(data_path, index=False)
|
|
|
|
log_path = os.path.join(evaluate_dir, f"evaluate_log_{now}.log")
|
|
api_config.update_log_handler(log_path)
|
|
|
|
ml_model = MLModel(api_config.config)
|
|
avg_f1, wrong_percentage, precision, recall, f1 = ml_model.evaluate_model(
|
|
feature_label_weighted[api_config.config['features']['feature_names']].values,
|
|
feature_label_weighted[api_config.config['features']['label_name']].values
|
|
)
|
|
|
|
end_time = time.time()
|
|
logging.info(f"评估耗时: {end_time - start_time} 秒")
|
|
|
|
return APIResponse(
|
|
classification_result=ClassificationResult(
|
|
precision=precision,
|
|
recall=recall,
|
|
f1=f1,
|
|
wrong_percentage=wrong_percentage
|
|
),
|
|
data_file={
|
|
"log_file_url": f"{request.base_url}evaluate_api/evaluate_log_{now}.log",
|
|
"data_file_url": f"{request.base_url}evaluate_api/evaluate_feature_label_weighted_{now}.xlsx",
|
|
"evaluate_result_img_url": f"{request.base_url}evaluate_api/evaluate_model_{now}_evaluate_result.png"
|
|
}
|
|
)
|
|
|
|
@app.post("/inference/")
|
|
async def inference_model(request: Request, features_list: List[Features]) -> PredictionResult:
|
|
"""推理接口"""
|
|
processor = FeatureProcessor(api_config.config)
|
|
all_features = processor.create_feature_df(features_list)
|
|
|
|
feature_label_weighted = FeatureWeightApplier.apply_weights(
|
|
all_features,
|
|
api_config.config['features']['feature_names'],
|
|
api_config.config['features']['feature_weights']
|
|
)
|
|
|
|
feature_label_weighted = normalize_features(
|
|
feature_label_weighted,
|
|
api_config.config['features']['feature_names'],
|
|
is_train=False,
|
|
config=api_config.config
|
|
)
|
|
|
|
start_time = time.time()
|
|
now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
inference_dir = os.path.abspath(os.path.join(
|
|
os.path.dirname(__file__), "inference_api"))
|
|
|
|
data_path = os.path.join(
|
|
inference_dir, f"inference_feature_label_weighted_{now}.xlsx")
|
|
feature_label_weighted.to_excel(data_path, index=False)
|
|
|
|
log_path = os.path.join(inference_dir, f"inference_log_{now}.log")
|
|
api_config.update_log_handler(log_path)
|
|
|
|
ml_model = MLModel(api_config.config)
|
|
predictions = ml_model.inference_model(
|
|
feature_label_weighted[api_config.config['features']['feature_names']].values
|
|
)
|
|
|
|
end_time = time.time()
|
|
logging.info(f"推理耗时: {end_time - start_time} 秒")
|
|
|
|
return PredictionResult(predictions=predictions)
|
|
|
|
@app.post("/upload_model/")
|
|
async def upload_model(file: UploadFile = File(...)):
|
|
"""上传模型接口"""
|
|
models_dir = os.path.abspath(os.path.join(
|
|
os.path.dirname(__file__), "models"))
|
|
os.makedirs(models_dir, exist_ok=True)
|
|
|
|
file_path = os.path.join(models_dir, "psy.pth")
|
|
with open(file_path, "wb") as buffer:
|
|
buffer.write(await file.read())
|
|
|
|
return {"message": "模型上传成功", "file_path": file_path}
|
|
|
|
def start_scheduler():
|
|
"""启动定时任务"""
|
|
schedule.every(1).hours.do(
|
|
FileCleanup.schedule_cleanup,
|
|
config=api_config.config
|
|
)
|
|
while True:
|
|
schedule.run_pending()
|
|
time.sleep(1)
|
|
|
|
if __name__ == "__main__":
|
|
# 启动定时清理任务
|
|
scheduler_thread = threading.Thread(target=start_scheduler)
|
|
scheduler_thread.daemon = True
|
|
scheduler_thread.start()
|
|
|
|
# 启动服务
|
|
uvicorn.run(
|
|
app,
|
|
host=api_config.config['system']['host'],
|
|
port=api_config.config['system']['port'],
|
|
reload=False
|
|
) |