From a2a5f50bfd4a0de5734749b5647d33e1c0b56753 Mon Sep 17 00:00:00 2001 From: wangchunlin Date: Sat, 22 Jun 2024 03:29:26 +0800 Subject: [PATCH] config_path --- psy_api.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/psy_api.py b/psy_api.py index c2cce49..dd1a636 100644 --- a/psy_api.py +++ b/psy_api.py @@ -106,6 +106,7 @@ def start_scheduler(): # 定义训练接口 @app.post("/train/") async def train_model(request: Request, features_list: List[Features]): + global config_path # 遍历每个特征对象,并将其添加到 all_features 中 all_features = create_feature_df(features_list) @@ -171,9 +172,6 @@ async def train_model(request: Request, features_list: List[Features]): end_time = time.time() # 记录结束时间 log_print("预测耗时: " + str(end_time - start_time) + " 秒") # 打印执行时间 - static_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "train_api")) # 设置模型文件和配置文件的存放目录,和本py同级 - config_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "config/config.yaml")) - # 替换现有检测模型 if(config["replace_model"] == True): # 如果模型路径不是绝对路径,则转换为绝对路径 @@ -213,6 +211,7 @@ async def train_model(request: Request, features_list: List[Features]): # 定义验证接口 @app.post("/evaluate/") async def evaluate_model(request: Request, features_list: List[Features]): + global config_path # 遍历每个特征对象,并将其添加到 all_features 中 all_features = create_feature_df(features_list) @@ -291,6 +290,7 @@ async def evaluate_model(request: Request, features_list: List[Features]): # 定义推理接口 @app.post("/inference/") async def inference_model(request: Request, features_list: List[Features]): + global config_path # 遍历每个特征对象,并将其添加到 all_features 中 all_features = create_feature_df(features_list) @@ -349,6 +349,7 @@ async def inference_model(request: Request, features_list: List[Features]): # 定义模型上传接口 @app.post("/upload_model/") async def upload_model(file: UploadFile = File(...)): + global config_path # 创建模型存放文件夹 models_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "models")) os.makedirs(models_dir, exist_ok=True)