未训练时使用默认模型

v2
wangchunlin 1 year ago
parent 29b0c026a2
commit 2412af6abe

@ -8,7 +8,7 @@ system:
- "evaluate_api" - "evaluate_api"
host: "0.0.0.0" host: "0.0.0.0"
port: 8088 port: 8088
device: "cuda" # 可选: "cpu", "cuda" device: "cpu" # 可选: "cpu", "cuda"
#---路径配置---# #---路径配置---#
paths: paths:
@ -24,10 +24,10 @@ paths:
#---训练配置---# #---训练配置---#
training: training:
n_epochs: 300 n_epochs: 300
# batch_size: 16 # CPU batch_size: 16 # CPU
# learning_rate: 0.001 # CPU learning_rate: 0.001 # CPU
batch_size: 128 # GPU # batch_size: 128 # GPU
learning_rate: 0.001 # GPU # learning_rate: 0.001 # GPU
early_stop_patience: 50 early_stop_patience: 50
scheduler: scheduler:
gamma: 0.98 gamma: 0.98
@ -36,7 +36,7 @@ training:
l1_lambda: 1e-5 l1_lambda: 1e-5
l2_lambda: 1e-4 l2_lambda: 1e-4
dropout_rate: 0.2 dropout_rate: 0.2
experimental_mode: true experimental_mode: false
experiments_count: 50 experiments_count: 50
replace_model: true replace_model: true
data_mode: "train_val" # 可选: "train", "train_val", "all" data_mode: "train_val" # 可选: "train", "train_val", "all"
@ -59,7 +59,6 @@ model:
#---特征配置---# #---特征配置---#
features: features:
label_name: "类别" label_name: "类别"
correlation_threshold: 0.7
groups: groups:
核心症状: 核心症状:
- "强迫症状数字化" - "强迫症状数字化"

@ -10,6 +10,7 @@ from sklearn.preprocessing import StandardScaler
import pickle import pickle
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import logging import logging
import shutil
class Features(BaseModel): class Features(BaseModel):
""" """
@ -189,6 +190,7 @@ class FeatureWeightApplier:
class FeatureNormalizer: class FeatureNormalizer:
"""特征归一化处理类""" """特征归一化处理类"""
def __init__(self, config: Dict[str, Any]): def __init__(self, config: Dict[str, Any]):
self.config = config
self.feature_groups = {} self.feature_groups = {}
self.scalers = {} self.scalers = {}
@ -232,6 +234,21 @@ class FeatureNormalizer:
Args: Args:
path: 加载路径 path: 加载路径
""" """
# 如果主路径不存在,尝试从默认路径复制
if not os.path.exists(path):
default_path = os.path.join(
os.path.dirname(__file__),
'..',
self.config['paths']['normalizer']['default']
)
if os.path.exists(default_path):
# 确保目标目录存在
os.makedirs(os.path.dirname(path), exist_ok=True)
shutil.copyfile(default_path, path)
logging.info(f"Copied normalizer from default path: {default_path} to {path}")
else:
raise FileNotFoundError(f"No normalizer found in either path: {path} or {default_path}")
with open(path, 'rb') as f: with open(path, 'rb') as f:
self.scalers = pickle.load(f) self.scalers = pickle.load(f)

@ -93,8 +93,24 @@ class MLModel:
def load_model(self) -> None: def load_model(self) -> None:
"""加载已训练的模型""" """加载已训练的模型"""
self.create_model() self.create_model()
self.model.load_state_dict(torch.load(self.config['paths']['model']['train'], model_path = self.config['paths']['model']['train']
map_location=self.device))
# 如果主路径不存在,尝试从默认路径复制
if not os.path.exists(model_path):
default_model_path = os.path.join(
os.path.dirname(__file__),
'..',
self.config['paths']['model']['default']
)
if os.path.exists(default_model_path):
# 确保目标目录存在
os.makedirs(os.path.dirname(model_path), exist_ok=True)
shutil.copyfile(default_model_path, model_path)
logging.info(f"Copied model from default path: {default_model_path} to {model_path}")
else:
raise FileNotFoundError(f"No model found in either path: {model_path} or {default_model_path}")
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
def plot_correlation(self, X: np.ndarray, feature_names: List[str]) -> None: def plot_correlation(self, X: np.ndarray, feature_names: List[str]) -> None:
"""绘制特征相关性矩阵""" """绘制特征相关性矩阵"""

Loading…
Cancel
Save