From a81a0341beaa2588226f0e22de90ddcdb90f4c54 Mon Sep 17 00:00:00 2001 From: wangchunlin Date: Fri, 21 Jun 2024 18:01:24 +0800 Subject: [PATCH] =?UTF-8?q?docker=20=E5=89=8D=E6=9C=80=E5=90=8E=E4=B8=80?= =?UTF-8?q?=E7=89=88=EF=BC=8C=E8=BF=98=E6=B2=A1=E6=9C=89=E8=A7=A3=E5=86=B3?= =?UTF-8?q?=E6=97=A5=E5=BF=97=E5=92=8C=E6=96=87=E4=BB=B6=E8=BF=87=E5=A4=9A?= =?UTF-8?q?=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api_send_model.py | 15 +++++++++++ config/config.yaml | 8 +++--- psy_api.py | 66 +++++++++++++++++++++++++++++++++++++++------- train_api.yaml | 61 ------------------------------------------ utils/common.py | 2 +- 5 files changed, 78 insertions(+), 74 deletions(-) create mode 100644 api_send_model.py delete mode 100644 train_api.yaml diff --git a/api_send_model.py b/api_send_model.py new file mode 100644 index 0000000..28d74f5 --- /dev/null +++ b/api_send_model.py @@ -0,0 +1,15 @@ +import requests + +# 定义上传文件的URL +url = "http://localhost:3397/upload_model/" + +# 打开要上传的模型文件 +file_path = "train_api/train_model_20240605_172031.pth" # 替换为你要上传的模型文件路径 +files = {'file': open(file_path, 'rb')} + +# 发送POST请求上传文件 +response = requests.post(url, files=files) + +# 打印响应结果 +print(response.status_code) +print(response.json()) diff --git a/config/config.yaml b/config/config.yaml index 1cd9ba3..13da64e 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -1,6 +1,6 @@ #---设备配置---# -# device: cpu -device: cuda +device: cpu +# device: cuda #---训练配置---# n_epochs: 150 @@ -12,7 +12,9 @@ data_train: train_val early_stop_patience: 50 gamma: 0.98 step_size: 10 -experiments_count: 1 +experimental_mode: false +experiments_count: 50 +replace_model: true # 是否替换现有模型 #---检测和推理配置---# # 检测和推理使用模型路径 diff --git a/psy_api.py b/psy_api.py index 1f0b9b7..e78907a 100644 --- a/psy_api.py +++ b/psy_api.py @@ -2,10 +2,11 @@ import os import time import datetime import logging +import shutil import uvicorn import yaml import numpy as np -from fastapi import FastAPI, Request +from fastapi import FastAPI, Request, File, UploadFile from pydantic import BaseModel from typing import List import atexit @@ -19,6 +20,13 @@ app = FastAPI() # 控制是否打印的宏定义 PRINT_LOG = True +# 初始化配置文件 +config_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "config/config.yaml")) + +# 初始化日志配置 +log_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "logfile.log")) +logging.basicConfig(filename=log_path, level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S') + def log_print(message): logging.info(message) if PRINT_LOG: @@ -49,8 +57,17 @@ app.add_middleware( allow_headers=["*"], ) -# 初始化配置文件 -config_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "config/config.yaml")) +# 定义一个函数来动态更新日志处理器 +def update_log_handler(log_path): + logger = logging.getLogger() + # 移除所有现有的处理器 + for handler in logger.handlers[:]: + logger.removeHandler(handler) + # 添加新的文件处理器 + file_handler = logging.FileHandler(log_path) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S')) + logger.addHandler(file_handler) # 定义训练接口 @app.post("/train/") @@ -81,11 +98,11 @@ async def train_model(request: Request, features_list: List[Features]): # 添加模型保存路径 model_path = os.path.abspath(os.path.join(static_dir, f"train_model_{now}.pth")) - config['model_path'] = model_path + config['train_model_path'] = model_path # 配置日志 log_path = os.path.abspath(os.path.join(static_dir, f"train_log_{now}.log")) - logging.basicConfig(filename=log_path, level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S') + update_log_handler(log_path) # 配置训练和验证结果图片路径 train_process_path = os.path.abspath(os.path.join(static_dir, f"train_progress_img_{now}.png")) @@ -103,7 +120,7 @@ async def train_model(request: Request, features_list: List[Features]): list_precision = [] list_recall = [] list_f1 = [] - train_times = 1 if config['data_train'] == 'all' else config["experiments_count"] + train_times = 1 if config['experimental_mode'] == False else config["experiments_count"] for _ in range(train_times): avg_f1, wrong_percentage, precision, recall, f1 = ml_model.train_detect() list_avg_f1.append(avg_f1) @@ -120,6 +137,11 @@ async def train_model(request: Request, features_list: List[Features]): end_time = time.time() # 记录结束时间 log_print("预测耗时: " + str(end_time - start_time) + " 秒") # 打印执行时间 + # 替换现有检测模型 + if(config["replace_model"] == True): + shutil.copyfile(config["train_model_path"], config["model_path"]) + log_print(f"Model file has been copied from {config['train_model_path']} to {config['model_path']}") + # 保证日志写到文件 atexit.register(flush_log) @@ -127,6 +149,8 @@ async def train_model(request: Request, features_list: List[Features]): model_file_url = f"{request.base_url}train_api/train_model_{now}.pth" log_file_url = f"{request.base_url}train_api/train_log_{now}.log" data_file_url = f"{request.base_url}train_api/train_feature_label_weighted_{now}.xlsx" + train_process_img_url = f"{request.base_url}train_api/train_progress_img_{now}.png" + evaluate_result_img_url = f"{request.base_url}train_api/evaluate_result_img_{now}.png" # 返回分类结果和模型文件 return { @@ -139,7 +163,9 @@ async def train_model(request: Request, features_list: List[Features]): "data_file": { "model_file_url": model_file_url, "log_file_url": log_file_url, - "data_file_url": data_file_url + "data_file_url": data_file_url, + "train_process_img_url": train_process_img_url, + "evaluate_result_img_url": evaluate_result_img_url } } @@ -176,7 +202,7 @@ async def evaluate_model(request: Request, features_list: List[Features]): # 配置日志 log_path = os.path.abspath(os.path.join(static_dir, f"evaluate_log_{now}.log")) - logging.basicConfig(filename=log_path, level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S') + update_log_handler(log_path) # 特征和标签 X = feature_label_weighted[config['feature_names']].values @@ -243,7 +269,7 @@ async def inference_model(request: Request, features_list: List[Features]): # 配置日志 log_path = os.path.abspath(os.path.join(static_dir, f"inference_log_{now}.log")) - logging.basicConfig(filename=log_path, level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S') + update_log_handler(log_path) # 特征和标签 X = feature_label_weighted[config['feature_names']].values @@ -267,9 +293,28 @@ async def inference_model(request: Request, features_list: List[Features]): # 返回预测结果 return PredictionResult(predictions=predictions) +# 定义模型上传接口 +@app.post("/upload_model/") +async def upload_model(file: UploadFile = File(...)): + # 创建模型存放文件夹 + models_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "models")) + os.makedirs(models_dir, exist_ok=True) + + # 保存模型文件 + file_path = os.path.join(models_dir, "psy.pth") + with open(file_path, "wb") as buffer: + buffer.write(await file.read()) + + return {"message": "模型上传成功", "file_path": file_path} + # 以下是fastapi启动配置 if __name__ == "__main__": + # 获取当前时间并格式化为字符串 + current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + # 打印程序启动时间 + log_print(f"Program started at {current_time}") + name_app = os.path.basename(__file__)[0:-3] # Get the name of the script log_config = { "version": 1, @@ -289,11 +334,14 @@ if __name__ == "__main__": static_dir_train = os.path.abspath(os.path.join(os.path.dirname(__file__), "train_api")) # 设置模型文件和配置文件的存放目录,和本py同级 static_dir_evaluate = os.path.abspath(os.path.join(os.path.dirname(__file__), "evaluate_api")) static_dir_inference = os.path.abspath(os.path.join(os.path.dirname(__file__), "inference_api")) + static_dir_models = os.path.abspath(os.path.join(os.path.dirname(__file__), "models")) os.makedirs(static_dir_train, exist_ok=True) os.makedirs(static_dir_evaluate, exist_ok=True) os.makedirs(static_dir_inference, exist_ok=True) + os.makedirs(static_dir_models, exist_ok=True) # 同级目录下的static文件夹 app.mount("/train_api", StaticFiles(directory=static_dir_train), name="static_dir_train") app.mount("/evaluate_api", StaticFiles(directory=static_dir_evaluate), name="static_dir_evaluate") app.mount("/inference_api", StaticFiles(directory=static_dir_inference), name="static_dir_inference") + app.mount("/models", StaticFiles(directory=static_dir_models), name="static_dir_models") uvicorn.run(app, host="0.0.0.0", port=3397, reload=False) \ No newline at end of file diff --git a/train_api.yaml b/train_api.yaml deleted file mode 100644 index a548889..0000000 --- a/train_api.yaml +++ /dev/null @@ -1,61 +0,0 @@ -#---设备配置---# -device: cpu -#device: cuda - -#---训练配置---# -n_epochs: 150 -batch_size: 16 -learning_rate: 0.001 -nc: 4 -use_infer_as_val: true -#data_train: train_val # train: 只用train训练,val做验证, infer做测试;train_val: 用train和val做训练,infer做验证, infer做测试;all: 全部训练,全部验证,全部测试(数据先1/5作为infer,剩下的再1/5作为val,剩下的4/5作为训练) -data_train: all -early_stop_patience: 50 -gamma: 0.98 -step_size: 10 -experiments_count: 50 - -#---训练结果---# -# 日志路径 -log_path: results/training.log -# 训练过程统计图路径 -train_process_path: results/training_progress.png -# 训练结果统计图路径 -train_result_path: results/training_result.png -# 训练模型路径 -model_path: results/psychology.pth -# 用于测试的部分数据路径 -infer_path: results/infer.xlsx - -#---训练原始数据---# -# 训练样本数据路径配置 -data_path: data_processed/feature_label_weighted.xlsx - -#---样本特征---# -# 标签名称 -label_name: 类别 -# 特征名称 -feature_names: - - "强迫症状数字化" - - "人际关系敏感数字化" - - "抑郁数字化" - - "多因子症状" - - "母亲教养方式数字化" - - "父亲教养方式数字化" - - "自评家庭经济条件数字化" - - "有无心理治疗(咨询)史数字化" - - "学业情况数字化" - - "出勤情况数字化" -# 定义特征权重列表 -feature_weights: - - 0.135 - - 0.085 - - 0.08 - - 0.2 - - 0.09 - - 0.09 - - 0.06 - - 0.06 - - 0.08 - - 0.12 - diff --git a/utils/common.py b/utils/common.py index e591189..526219c 100644 --- a/utils/common.py +++ b/utils/common.py @@ -247,7 +247,7 @@ class MLModel: best_val_f1, best_val_recall, best_val_precision, best_epoch, best_model = self.train_model(train_loader, val_loader, criterion, optimizer, scheduler) # Save the best model - self.save_model(self.config['model_path']) + self.save_model(self.config['train_model_path']) log_print(f"Best Validation F1 Score (Macro): {best_val_f1:.4f}") log_print(f"Best Validation Recall (Macro): {best_val_recall:.4f}")