|
|
|
@ -2,10 +2,11 @@ import os
|
|
|
|
import time
|
|
|
|
import time
|
|
|
|
import datetime
|
|
|
|
import datetime
|
|
|
|
import logging
|
|
|
|
import logging
|
|
|
|
|
|
|
|
import shutil
|
|
|
|
import uvicorn
|
|
|
|
import uvicorn
|
|
|
|
import yaml
|
|
|
|
import yaml
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
|
from fastapi import FastAPI, Request
|
|
|
|
from fastapi import FastAPI, Request, File, UploadFile
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from typing import List
|
|
|
|
from typing import List
|
|
|
|
import atexit
|
|
|
|
import atexit
|
|
|
|
@ -19,6 +20,13 @@ app = FastAPI()
|
|
|
|
# 控制是否打印的宏定义
|
|
|
|
# 控制是否打印的宏定义
|
|
|
|
PRINT_LOG = True
|
|
|
|
PRINT_LOG = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 初始化配置文件
|
|
|
|
|
|
|
|
config_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "config/config.yaml"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 初始化日志配置
|
|
|
|
|
|
|
|
log_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "logfile.log"))
|
|
|
|
|
|
|
|
logging.basicConfig(filename=log_path, level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
|
|
|
|
|
|
|
|
|
|
|
def log_print(message):
|
|
|
|
def log_print(message):
|
|
|
|
logging.info(message)
|
|
|
|
logging.info(message)
|
|
|
|
if PRINT_LOG:
|
|
|
|
if PRINT_LOG:
|
|
|
|
@ -49,8 +57,17 @@ app.add_middleware(
|
|
|
|
allow_headers=["*"],
|
|
|
|
allow_headers=["*"],
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 初始化配置文件
|
|
|
|
# 定义一个函数来动态更新日志处理器
|
|
|
|
config_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "config/config.yaml"))
|
|
|
|
def update_log_handler(log_path):
|
|
|
|
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
|
|
# 移除所有现有的处理器
|
|
|
|
|
|
|
|
for handler in logger.handlers[:]:
|
|
|
|
|
|
|
|
logger.removeHandler(handler)
|
|
|
|
|
|
|
|
# 添加新的文件处理器
|
|
|
|
|
|
|
|
file_handler = logging.FileHandler(log_path)
|
|
|
|
|
|
|
|
file_handler.setLevel(logging.INFO)
|
|
|
|
|
|
|
|
file_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S'))
|
|
|
|
|
|
|
|
logger.addHandler(file_handler)
|
|
|
|
|
|
|
|
|
|
|
|
# 定义训练接口
|
|
|
|
# 定义训练接口
|
|
|
|
@app.post("/train/")
|
|
|
|
@app.post("/train/")
|
|
|
|
@ -81,11 +98,11 @@ async def train_model(request: Request, features_list: List[Features]):
|
|
|
|
|
|
|
|
|
|
|
|
# 添加模型保存路径
|
|
|
|
# 添加模型保存路径
|
|
|
|
model_path = os.path.abspath(os.path.join(static_dir, f"train_model_{now}.pth"))
|
|
|
|
model_path = os.path.abspath(os.path.join(static_dir, f"train_model_{now}.pth"))
|
|
|
|
config['model_path'] = model_path
|
|
|
|
config['train_model_path'] = model_path
|
|
|
|
|
|
|
|
|
|
|
|
# 配置日志
|
|
|
|
# 配置日志
|
|
|
|
log_path = os.path.abspath(os.path.join(static_dir, f"train_log_{now}.log"))
|
|
|
|
log_path = os.path.abspath(os.path.join(static_dir, f"train_log_{now}.log"))
|
|
|
|
logging.basicConfig(filename=log_path, level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
|
|
|
update_log_handler(log_path)
|
|
|
|
|
|
|
|
|
|
|
|
# 配置训练和验证结果图片路径
|
|
|
|
# 配置训练和验证结果图片路径
|
|
|
|
train_process_path = os.path.abspath(os.path.join(static_dir, f"train_progress_img_{now}.png"))
|
|
|
|
train_process_path = os.path.abspath(os.path.join(static_dir, f"train_progress_img_{now}.png"))
|
|
|
|
@ -103,7 +120,7 @@ async def train_model(request: Request, features_list: List[Features]):
|
|
|
|
list_precision = []
|
|
|
|
list_precision = []
|
|
|
|
list_recall = []
|
|
|
|
list_recall = []
|
|
|
|
list_f1 = []
|
|
|
|
list_f1 = []
|
|
|
|
train_times = 1 if config['data_train'] == 'all' else config["experiments_count"]
|
|
|
|
train_times = 1 if config['experimental_mode'] == False else config["experiments_count"]
|
|
|
|
for _ in range(train_times):
|
|
|
|
for _ in range(train_times):
|
|
|
|
avg_f1, wrong_percentage, precision, recall, f1 = ml_model.train_detect()
|
|
|
|
avg_f1, wrong_percentage, precision, recall, f1 = ml_model.train_detect()
|
|
|
|
list_avg_f1.append(avg_f1)
|
|
|
|
list_avg_f1.append(avg_f1)
|
|
|
|
@ -120,6 +137,11 @@ async def train_model(request: Request, features_list: List[Features]):
|
|
|
|
end_time = time.time() # 记录结束时间
|
|
|
|
end_time = time.time() # 记录结束时间
|
|
|
|
log_print("预测耗时: " + str(end_time - start_time) + " 秒") # 打印执行时间
|
|
|
|
log_print("预测耗时: " + str(end_time - start_time) + " 秒") # 打印执行时间
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 替换现有检测模型
|
|
|
|
|
|
|
|
if(config["replace_model"] == True):
|
|
|
|
|
|
|
|
shutil.copyfile(config["train_model_path"], config["model_path"])
|
|
|
|
|
|
|
|
log_print(f"Model file has been copied from {config['train_model_path']} to {config['model_path']}")
|
|
|
|
|
|
|
|
|
|
|
|
# 保证日志写到文件
|
|
|
|
# 保证日志写到文件
|
|
|
|
atexit.register(flush_log)
|
|
|
|
atexit.register(flush_log)
|
|
|
|
|
|
|
|
|
|
|
|
@ -127,6 +149,8 @@ async def train_model(request: Request, features_list: List[Features]):
|
|
|
|
model_file_url = f"{request.base_url}train_api/train_model_{now}.pth"
|
|
|
|
model_file_url = f"{request.base_url}train_api/train_model_{now}.pth"
|
|
|
|
log_file_url = f"{request.base_url}train_api/train_log_{now}.log"
|
|
|
|
log_file_url = f"{request.base_url}train_api/train_log_{now}.log"
|
|
|
|
data_file_url = f"{request.base_url}train_api/train_feature_label_weighted_{now}.xlsx"
|
|
|
|
data_file_url = f"{request.base_url}train_api/train_feature_label_weighted_{now}.xlsx"
|
|
|
|
|
|
|
|
train_process_img_url = f"{request.base_url}train_api/train_progress_img_{now}.png"
|
|
|
|
|
|
|
|
evaluate_result_img_url = f"{request.base_url}train_api/evaluate_result_img_{now}.png"
|
|
|
|
|
|
|
|
|
|
|
|
# 返回分类结果和模型文件
|
|
|
|
# 返回分类结果和模型文件
|
|
|
|
return {
|
|
|
|
return {
|
|
|
|
@ -139,7 +163,9 @@ async def train_model(request: Request, features_list: List[Features]):
|
|
|
|
"data_file": {
|
|
|
|
"data_file": {
|
|
|
|
"model_file_url": model_file_url,
|
|
|
|
"model_file_url": model_file_url,
|
|
|
|
"log_file_url": log_file_url,
|
|
|
|
"log_file_url": log_file_url,
|
|
|
|
"data_file_url": data_file_url
|
|
|
|
"data_file_url": data_file_url,
|
|
|
|
|
|
|
|
"train_process_img_url": train_process_img_url,
|
|
|
|
|
|
|
|
"evaluate_result_img_url": evaluate_result_img_url
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
@ -176,7 +202,7 @@ async def evaluate_model(request: Request, features_list: List[Features]):
|
|
|
|
|
|
|
|
|
|
|
|
# 配置日志
|
|
|
|
# 配置日志
|
|
|
|
log_path = os.path.abspath(os.path.join(static_dir, f"evaluate_log_{now}.log"))
|
|
|
|
log_path = os.path.abspath(os.path.join(static_dir, f"evaluate_log_{now}.log"))
|
|
|
|
logging.basicConfig(filename=log_path, level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
|
|
|
update_log_handler(log_path)
|
|
|
|
|
|
|
|
|
|
|
|
# 特征和标签
|
|
|
|
# 特征和标签
|
|
|
|
X = feature_label_weighted[config['feature_names']].values
|
|
|
|
X = feature_label_weighted[config['feature_names']].values
|
|
|
|
@ -243,7 +269,7 @@ async def inference_model(request: Request, features_list: List[Features]):
|
|
|
|
|
|
|
|
|
|
|
|
# 配置日志
|
|
|
|
# 配置日志
|
|
|
|
log_path = os.path.abspath(os.path.join(static_dir, f"inference_log_{now}.log"))
|
|
|
|
log_path = os.path.abspath(os.path.join(static_dir, f"inference_log_{now}.log"))
|
|
|
|
logging.basicConfig(filename=log_path, level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
|
|
|
update_log_handler(log_path)
|
|
|
|
|
|
|
|
|
|
|
|
# 特征和标签
|
|
|
|
# 特征和标签
|
|
|
|
X = feature_label_weighted[config['feature_names']].values
|
|
|
|
X = feature_label_weighted[config['feature_names']].values
|
|
|
|
@ -267,9 +293,28 @@ async def inference_model(request: Request, features_list: List[Features]):
|
|
|
|
# 返回预测结果
|
|
|
|
# 返回预测结果
|
|
|
|
return PredictionResult(predictions=predictions)
|
|
|
|
return PredictionResult(predictions=predictions)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 定义模型上传接口
|
|
|
|
|
|
|
|
@app.post("/upload_model/")
|
|
|
|
|
|
|
|
async def upload_model(file: UploadFile = File(...)):
|
|
|
|
|
|
|
|
# 创建模型存放文件夹
|
|
|
|
|
|
|
|
models_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "models"))
|
|
|
|
|
|
|
|
os.makedirs(models_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 保存模型文件
|
|
|
|
|
|
|
|
file_path = os.path.join(models_dir, "psy.pth")
|
|
|
|
|
|
|
|
with open(file_path, "wb") as buffer:
|
|
|
|
|
|
|
|
buffer.write(await file.read())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return {"message": "模型上传成功", "file_path": file_path}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 以下是fastapi启动配置
|
|
|
|
# 以下是fastapi启动配置
|
|
|
|
if __name__ == "__main__":
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
|
|
# 获取当前时间并格式化为字符串
|
|
|
|
|
|
|
|
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
|
|
|
|
|
|
# 打印程序启动时间
|
|
|
|
|
|
|
|
log_print(f"Program started at {current_time}")
|
|
|
|
|
|
|
|
|
|
|
|
name_app = os.path.basename(__file__)[0:-3] # Get the name of the script
|
|
|
|
name_app = os.path.basename(__file__)[0:-3] # Get the name of the script
|
|
|
|
log_config = {
|
|
|
|
log_config = {
|
|
|
|
"version": 1,
|
|
|
|
"version": 1,
|
|
|
|
@ -289,11 +334,14 @@ if __name__ == "__main__":
|
|
|
|
static_dir_train = os.path.abspath(os.path.join(os.path.dirname(__file__), "train_api")) # 设置模型文件和配置文件的存放目录,和本py同级
|
|
|
|
static_dir_train = os.path.abspath(os.path.join(os.path.dirname(__file__), "train_api")) # 设置模型文件和配置文件的存放目录,和本py同级
|
|
|
|
static_dir_evaluate = os.path.abspath(os.path.join(os.path.dirname(__file__), "evaluate_api"))
|
|
|
|
static_dir_evaluate = os.path.abspath(os.path.join(os.path.dirname(__file__), "evaluate_api"))
|
|
|
|
static_dir_inference = os.path.abspath(os.path.join(os.path.dirname(__file__), "inference_api"))
|
|
|
|
static_dir_inference = os.path.abspath(os.path.join(os.path.dirname(__file__), "inference_api"))
|
|
|
|
|
|
|
|
static_dir_models = os.path.abspath(os.path.join(os.path.dirname(__file__), "models"))
|
|
|
|
os.makedirs(static_dir_train, exist_ok=True)
|
|
|
|
os.makedirs(static_dir_train, exist_ok=True)
|
|
|
|
os.makedirs(static_dir_evaluate, exist_ok=True)
|
|
|
|
os.makedirs(static_dir_evaluate, exist_ok=True)
|
|
|
|
os.makedirs(static_dir_inference, exist_ok=True)
|
|
|
|
os.makedirs(static_dir_inference, exist_ok=True)
|
|
|
|
|
|
|
|
os.makedirs(static_dir_models, exist_ok=True)
|
|
|
|
# 同级目录下的static文件夹
|
|
|
|
# 同级目录下的static文件夹
|
|
|
|
app.mount("/train_api", StaticFiles(directory=static_dir_train), name="static_dir_train")
|
|
|
|
app.mount("/train_api", StaticFiles(directory=static_dir_train), name="static_dir_train")
|
|
|
|
app.mount("/evaluate_api", StaticFiles(directory=static_dir_evaluate), name="static_dir_evaluate")
|
|
|
|
app.mount("/evaluate_api", StaticFiles(directory=static_dir_evaluate), name="static_dir_evaluate")
|
|
|
|
app.mount("/inference_api", StaticFiles(directory=static_dir_inference), name="static_dir_inference")
|
|
|
|
app.mount("/inference_api", StaticFiles(directory=static_dir_inference), name="static_dir_inference")
|
|
|
|
|
|
|
|
app.mount("/models", StaticFiles(directory=static_dir_models), name="static_dir_models")
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=3397, reload=False)
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=3397, reload=False)
|