From 2412af6abe293ead46536a4d72a7698c1db6d154 Mon Sep 17 00:00:00 2001 From: wangchunlin Date: Fri, 29 Nov 2024 16:31:11 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9C=AA=E8=AE=AD=E7=BB=83=E6=97=B6=E4=BD=BF?= =?UTF-8?q?=E7=94=A8=E9=BB=98=E8=AE=A4=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config.yaml | 13 ++++++------- utils/feature_processor.py | 17 +++++++++++++++++ utils/model_trainer.py | 20 ++++++++++++++++++-- 3 files changed, 41 insertions(+), 9 deletions(-) diff --git a/config/config.yaml b/config/config.yaml index 618cf18..ecccbd2 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -8,7 +8,7 @@ system: - "evaluate_api" host: "0.0.0.0" port: 8088 - device: "cuda" # 可选: "cpu", "cuda" + device: "cpu" # 可选: "cpu", "cuda" #---路径配置---# paths: @@ -24,10 +24,10 @@ paths: #---训练配置---# training: n_epochs: 300 - # batch_size: 16 # CPU - # learning_rate: 0.001 # CPU - batch_size: 128 # GPU - learning_rate: 0.001 # GPU + batch_size: 16 # CPU + learning_rate: 0.001 # CPU + # batch_size: 128 # GPU + # learning_rate: 0.001 # GPU early_stop_patience: 50 scheduler: gamma: 0.98 @@ -36,7 +36,7 @@ training: l1_lambda: 1e-5 l2_lambda: 1e-4 dropout_rate: 0.2 - experimental_mode: true + experimental_mode: false experiments_count: 50 replace_model: true data_mode: "train_val" # 可选: "train", "train_val", "all" @@ -59,7 +59,6 @@ model: #---特征配置---# features: label_name: "类别" - correlation_threshold: 0.7 groups: 核心症状: - "强迫症状数字化" diff --git a/utils/feature_processor.py b/utils/feature_processor.py index 91dc623..52896e7 100644 --- a/utils/feature_processor.py +++ b/utils/feature_processor.py @@ -10,6 +10,7 @@ from sklearn.preprocessing import StandardScaler import pickle from pydantic import BaseModel, Field import logging +import shutil class Features(BaseModel): """ @@ -189,6 +190,7 @@ class FeatureWeightApplier: class FeatureNormalizer: """特征归一化处理类""" def __init__(self, config: Dict[str, Any]): + self.config = config self.feature_groups = {} self.scalers = {} @@ -232,6 +234,21 @@ class FeatureNormalizer: Args: path: 加载路径 """ + # 如果主路径不存在,尝试从默认路径复制 + if not os.path.exists(path): + default_path = os.path.join( + os.path.dirname(__file__), + '..', + self.config['paths']['normalizer']['default'] + ) + if os.path.exists(default_path): + # 确保目标目录存在 + os.makedirs(os.path.dirname(path), exist_ok=True) + shutil.copyfile(default_path, path) + logging.info(f"Copied normalizer from default path: {default_path} to {path}") + else: + raise FileNotFoundError(f"No normalizer found in either path: {path} or {default_path}") + with open(path, 'rb') as f: self.scalers = pickle.load(f) diff --git a/utils/model_trainer.py b/utils/model_trainer.py index c7a33ff..62ce7b9 100644 --- a/utils/model_trainer.py +++ b/utils/model_trainer.py @@ -93,8 +93,24 @@ class MLModel: def load_model(self) -> None: """加载已训练的模型""" self.create_model() - self.model.load_state_dict(torch.load(self.config['paths']['model']['train'], - map_location=self.device)) + model_path = self.config['paths']['model']['train'] + + # 如果主路径不存在,尝试从默认路径复制 + if not os.path.exists(model_path): + default_model_path = os.path.join( + os.path.dirname(__file__), + '..', + self.config['paths']['model']['default'] + ) + if os.path.exists(default_model_path): + # 确保目标目录存在 + os.makedirs(os.path.dirname(model_path), exist_ok=True) + shutil.copyfile(default_model_path, model_path) + logging.info(f"Copied model from default path: {default_model_path} to {model_path}") + else: + raise FileNotFoundError(f"No model found in either path: {model_path} or {default_model_path}") + + self.model.load_state_dict(torch.load(model_path, map_location=self.device)) def plot_correlation(self, X: np.ndarray, feature_names: List[str]) -> None: """绘制特征相关性矩阵"""