From 3b20e332b2d8fcb8646195c4ebff8d3d0fadabe9 Mon Sep 17 00:00:00 2001 From: wangchunlin Date: Sat, 22 Jun 2024 11:21:39 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=A8=A1=E5=9E=8B=E8=B7=AF?= =?UTF-8?q?=E5=BE=84=EF=BC=8C=E8=A7=A3=E5=86=B3docker=E9=87=8D=E5=90=AF?= =?UTF-8?q?=E6=8E=89=E6=A8=A1=E5=9E=8B=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config.yaml | 3 ++- psy_api.py | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/config/config.yaml b/config/config.yaml index 13da64e..cd0444a 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -18,7 +18,8 @@ replace_model: true # 是否替换现有模型 #---检测和推理配置---# # 检测和推理使用模型路径 -model_path: model/psychology.pth +model_path: /data/model/psychology.pth +default_model_path: model/psychology.pth #---样本特征---# # 标签名称 diff --git a/psy_api.py b/psy_api.py index dd1a636..32d7ee2 100644 --- a/psy_api.py +++ b/psy_api.py @@ -245,6 +245,14 @@ async def evaluate_model(request: Request, features_list: List[Features]): abs_model_path = os.path.abspath(os.path.join(os.path.dirname(__file__), config["model_path"])) config["model_path"] = abs_model_path + # 检查模型文件是否存在,如果不存在则复制模型文件 + if not os.path.exists(config["model_path"]): + if not os.path.isabs(config['default_model_path']): + config['default_model_path'] = os.path.join(os.path.dirname(os.path.abspath(__file__)), config['default_model_path']) + os.makedirs(os.path.dirname(config["model_path"]), exist_ok=True) + shutil.copyfile(config['default_model_path'], config["model_path"]) + log_print(f"Model file not found. Copied default model from {config['default_model_path']} to {config['model_path']}") + # 配置日志 log_path = os.path.abspath(os.path.join(static_dir, f"evaluate_log_{now}.log")) update_log_handler(log_path) @@ -324,6 +332,14 @@ async def inference_model(request: Request, features_list: List[Features]): abs_model_path = os.path.abspath(os.path.join(os.path.dirname(__file__), config["model_path"])) config["model_path"] = abs_model_path + # 检查模型文件是否存在,如果不存在则复制模型文件 + if not os.path.exists(config["model_path"]): + if not os.path.isabs(config['default_model_path']): + config['default_model_path'] = os.path.join(os.path.dirname(os.path.abspath(__file__)), config['default_model_path']) + os.makedirs(os.path.dirname(config["model_path"]), exist_ok=True) + shutil.copyfile(config['default_model_path'], config["model_path"]) + log_print(f"Model file not found. Copied default model from {config['default_model_path']} to {config['model_path']}") + # 特征和标签 X = feature_label_weighted[config['feature_names']].values