未训练时使用默认模型

v2
wangchunlin 1 year ago
parent 29b0c026a2
commit 2412af6abe

@ -8,7 +8,7 @@ system:
- "evaluate_api"
host: "0.0.0.0"
port: 8088
device: "cuda" # 可选: "cpu", "cuda"
device: "cpu" # 可选: "cpu", "cuda"
#---路径配置---#
paths:
@ -24,10 +24,10 @@ paths:
#---训练配置---#
training:
n_epochs: 300
# batch_size: 16 # CPU
# learning_rate: 0.001 # CPU
batch_size: 128 # GPU
learning_rate: 0.001 # GPU
batch_size: 16 # CPU
learning_rate: 0.001 # CPU
# batch_size: 128 # GPU
# learning_rate: 0.001 # GPU
early_stop_patience: 50
scheduler:
gamma: 0.98
@ -36,7 +36,7 @@ training:
l1_lambda: 1e-5
l2_lambda: 1e-4
dropout_rate: 0.2
experimental_mode: true
experimental_mode: false
experiments_count: 50
replace_model: true
data_mode: "train_val" # 可选: "train", "train_val", "all"
@ -59,7 +59,6 @@ model:
#---特征配置---#
features:
label_name: "类别"
correlation_threshold: 0.7
groups:
核心症状:
- "强迫症状数字化"

@ -10,6 +10,7 @@ from sklearn.preprocessing import StandardScaler
import pickle
from pydantic import BaseModel, Field
import logging
import shutil
class Features(BaseModel):
"""
@ -189,6 +190,7 @@ class FeatureWeightApplier:
class FeatureNormalizer:
"""特征归一化处理类"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.feature_groups = {}
self.scalers = {}
@ -232,6 +234,21 @@ class FeatureNormalizer:
Args:
path: 加载路径
"""
# 如果主路径不存在,尝试从默认路径复制
if not os.path.exists(path):
default_path = os.path.join(
os.path.dirname(__file__),
'..',
self.config['paths']['normalizer']['default']
)
if os.path.exists(default_path):
# 确保目标目录存在
os.makedirs(os.path.dirname(path), exist_ok=True)
shutil.copyfile(default_path, path)
logging.info(f"Copied normalizer from default path: {default_path} to {path}")
else:
raise FileNotFoundError(f"No normalizer found in either path: {path} or {default_path}")
with open(path, 'rb') as f:
self.scalers = pickle.load(f)

@ -93,8 +93,24 @@ class MLModel:
def load_model(self) -> None:
"""加载已训练的模型"""
self.create_model()
self.model.load_state_dict(torch.load(self.config['paths']['model']['train'],
map_location=self.device))
model_path = self.config['paths']['model']['train']
# 如果主路径不存在,尝试从默认路径复制
if not os.path.exists(model_path):
default_model_path = os.path.join(
os.path.dirname(__file__),
'..',
self.config['paths']['model']['default']
)
if os.path.exists(default_model_path):
# 确保目标目录存在
os.makedirs(os.path.dirname(model_path), exist_ok=True)
shutil.copyfile(default_model_path, model_path)
logging.info(f"Copied model from default path: {default_model_path} to {model_path}")
else:
raise FileNotFoundError(f"No model found in either path: {model_path} or {default_model_path}")
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
def plot_correlation(self, X: np.ndarray, feature_names: List[str]) -> None:
"""绘制特征相关性矩阵"""

Loading…
Cancel
Save