|
|
"""
|
|
|
模型定义和训练模块
|
|
|
包含神经网络模型定义、训练和评估功能
|
|
|
"""
|
|
|
import os
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
from torch import nn
|
|
|
from torch.utils.data import DataLoader, TensorDataset
|
|
|
from sklearn.model_selection import StratifiedKFold
|
|
|
from sklearn.metrics import precision_score, recall_score, f1_score
|
|
|
from sklearn.utils.class_weight import compute_class_weight
|
|
|
import logging
|
|
|
import matplotlib.pyplot as plt
|
|
|
from matplotlib.font_manager import FontProperties
|
|
|
import seaborn as sns
|
|
|
import warnings
|
|
|
from typing import Dict, Any, Tuple, List
|
|
|
import pandas as pd
|
|
|
import datetime
|
|
|
import shutil
|
|
|
|
|
|
# 屏蔽警告
|
|
|
warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib.font_manager")
|
|
|
warnings.filterwarnings("ignore", category=UserWarning, module="seaborn.utils")
|
|
|
|
|
|
class BaseModel(nn.Module):
|
|
|
"""神经网络基类"""
|
|
|
def get_l1_loss(self) -> torch.Tensor:
|
|
|
"""计算L1正则化损失"""
|
|
|
l1_loss = torch.tensor(0., device=next(self.parameters()).device)
|
|
|
for param in self.parameters():
|
|
|
l1_loss += torch.norm(param, p=1)
|
|
|
return l1_loss
|
|
|
|
|
|
def get_l2_loss(self) -> torch.Tensor:
|
|
|
"""计算L2正则化损失"""
|
|
|
l2_loss = torch.tensor(0., device=next(self.parameters()).device)
|
|
|
for param in self.parameters():
|
|
|
l2_loss += torch.norm(param, p=2)
|
|
|
return l2_loss
|
|
|
|
|
|
class MLP(BaseModel):
|
|
|
"""多层感知机模型"""
|
|
|
def __init__(self, config: Dict[str, Any]):
|
|
|
super(MLP, self).__init__()
|
|
|
layers = []
|
|
|
input_dim = config['model']['input_dim']
|
|
|
self.dropout = nn.Dropout(p=config['training']['dropout_rate'])
|
|
|
|
|
|
for layer_cfg in config['model']['mlp']['layers']:
|
|
|
linear = nn.Linear(input_dim, layer_cfg['output_dim'])
|
|
|
nn.init.xavier_normal_(linear.weight)
|
|
|
layers.append(linear)
|
|
|
|
|
|
if layer_cfg.get('activation') == 'relu':
|
|
|
layers.append(nn.ReLU())
|
|
|
layers.append(self.dropout)
|
|
|
|
|
|
input_dim = layer_cfg['output_dim']
|
|
|
|
|
|
self.output = nn.Linear(input_dim, config['model']['mlp']['output_dim'])
|
|
|
self.model = nn.Sequential(*layers)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
"""前向传播"""
|
|
|
x = self.model(x)
|
|
|
return self.output(x)
|
|
|
|
|
|
class MLModel:
|
|
|
"""模型训练和评估类"""
|
|
|
def __init__(self, config: Dict[str, Any]):
|
|
|
self.config = config
|
|
|
self.model = None
|
|
|
# 修改设备初始化
|
|
|
self.device = torch.device(
|
|
|
self.config['system']['device']
|
|
|
if torch.cuda.is_available() and self.config['system']['device'] == 'cuda'
|
|
|
else 'cpu'
|
|
|
)
|
|
|
logging.info(f"Using device: {self.device}") # 记录使用的设备
|
|
|
print(f"Using device: {self.device}") # 打印使用的设备
|
|
|
|
|
|
# 设置字体
|
|
|
font_path = os.path.join(os.path.dirname(__file__), '..',
|
|
|
config['paths']['fonts']['chinese'])
|
|
|
self.font_prop = FontProperties(fname=font_path)
|
|
|
|
|
|
def create_model(self) -> None:
|
|
|
"""创建模型实例"""
|
|
|
self.model = MLP(self.config).to(self.device)
|
|
|
|
|
|
def load_model(self) -> None:
|
|
|
"""加载已训练的模型"""
|
|
|
self.create_model()
|
|
|
model_path = self.config['paths']['model']['train']
|
|
|
|
|
|
# 如果主路径不存在,尝试从默认路径复制
|
|
|
if not os.path.exists(model_path):
|
|
|
default_model_path = os.path.join(
|
|
|
os.path.dirname(__file__),
|
|
|
'..',
|
|
|
self.config['paths']['model']['default']
|
|
|
)
|
|
|
if os.path.exists(default_model_path):
|
|
|
# 确保目标目录存在
|
|
|
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
|
|
shutil.copyfile(default_model_path, model_path)
|
|
|
logging.info(f"Copied model from default path: {default_model_path} to {model_path}")
|
|
|
else:
|
|
|
raise FileNotFoundError(f"No model found in either path: {model_path} or {default_model_path}")
|
|
|
|
|
|
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
|
|
|
|
|
|
def plot_correlation(self, X: np.ndarray, feature_names: List[str]) -> None:
|
|
|
"""绘制特征相关性矩阵"""
|
|
|
corr_matrix = np.corrcoef(X.T)
|
|
|
plt.figure(figsize=(12, 10))
|
|
|
|
|
|
sns.heatmap(corr_matrix,
|
|
|
annot=True,
|
|
|
cmap='coolwarm',
|
|
|
xticklabels=feature_names,
|
|
|
yticklabels=feature_names,
|
|
|
fmt='.2f')
|
|
|
|
|
|
plt.xticks(rotation=45, ha='right', fontproperties=self.font_prop)
|
|
|
plt.yticks(rotation=0, fontproperties=self.font_prop)
|
|
|
plt.title('特征相关性矩阵', fontproperties=self.font_prop)
|
|
|
plt.tight_layout()
|
|
|
|
|
|
plt.savefig(self.config['correlation_matrix_path'], bbox_inches='tight', dpi=300)
|
|
|
plt.close()
|
|
|
|
|
|
def train_model(self, train_loader: DataLoader, val_loader: DataLoader,
|
|
|
criterion: nn.Module, optimizer: torch.optim.Optimizer,
|
|
|
scheduler: torch.optim.lr_scheduler._LRScheduler) -> Tuple:
|
|
|
"""训练模型"""
|
|
|
n_epochs = self.config['training']['n_epochs']
|
|
|
best_val_f1 = 0.0
|
|
|
best_model = None
|
|
|
best_epoch = -1 # 记录最佳epoch
|
|
|
patience = self.config['training']['early_stop_patience']
|
|
|
trigger_times = 0
|
|
|
|
|
|
l1_lambda = float(self.config['training']['regularization']['l1_lambda'])
|
|
|
l2_lambda = float(self.config['training']['regularization']['l2_lambda'])
|
|
|
|
|
|
train_metrics = {'loss': [], 'acc': []}
|
|
|
val_metrics = {'loss': [], 'acc': [], 'f1': [], 'precision': [], 'recall': []}
|
|
|
|
|
|
for epoch in range(n_epochs):
|
|
|
# 训练阶段
|
|
|
self.model.train()
|
|
|
train_loss, train_acc = self._train_epoch(train_loader, criterion,
|
|
|
optimizer, l1_lambda, l2_lambda)
|
|
|
|
|
|
# 验证阶段
|
|
|
val_loss, val_acc, val_f1 = self._validate_epoch(val_loader, criterion)
|
|
|
|
|
|
# 更新学习率
|
|
|
scheduler.step()
|
|
|
|
|
|
# 记录指标
|
|
|
self._update_metrics(train_metrics, val_metrics, train_loss, train_acc,
|
|
|
val_loss, val_acc, val_f1, epoch)
|
|
|
|
|
|
# 更新最佳模型信息
|
|
|
if val_f1 > best_val_f1:
|
|
|
best_val_f1 = val_f1
|
|
|
best_model = self.model.state_dict()
|
|
|
best_epoch = epoch + 1 # 记录最佳epoch
|
|
|
trigger_times = 0
|
|
|
else:
|
|
|
trigger_times += 1
|
|
|
if trigger_times >= patience:
|
|
|
log_message = (
|
|
|
f'Early stopping at epoch {epoch+1}\n'
|
|
|
f'Best model was saved at epoch {best_epoch} with F1: {best_val_f1:.4f}'
|
|
|
)
|
|
|
logging.info(log_message)
|
|
|
print(log_message)
|
|
|
break
|
|
|
|
|
|
# 打印最佳模型信息
|
|
|
log_message = f'Training completed. Best model at epoch {best_epoch} with F1: {best_val_f1:.4f}'
|
|
|
logging.info(log_message)
|
|
|
print(log_message)
|
|
|
|
|
|
return best_val_f1, best_model, best_epoch # 返回最佳epoch
|
|
|
|
|
|
def _train_epoch(self, train_loader: DataLoader, criterion: nn.Module,
|
|
|
optimizer: torch.optim.Optimizer, l1_lambda: float,
|
|
|
l2_lambda: float) -> Tuple[float, float]:
|
|
|
"""训练一个epoch"""
|
|
|
train_loss = 0
|
|
|
train_acc = 0
|
|
|
|
|
|
for inputs, targets in train_loader:
|
|
|
optimizer.zero_grad()
|
|
|
outputs = self.model(inputs)
|
|
|
|
|
|
# 计算各种损失
|
|
|
ce_loss = criterion(outputs, targets)
|
|
|
l1_loss = self.model.get_l1_loss()
|
|
|
l2_loss = self.model.get_l2_loss()
|
|
|
|
|
|
# 转换正则化系数为tensor并移动到正确的设备
|
|
|
l1_lambda_tensor = torch.tensor(l1_lambda, device=ce_loss.device)
|
|
|
l2_lambda_tensor = torch.tensor(l2_lambda, device=ce_loss.device)
|
|
|
|
|
|
# 计算总损失
|
|
|
loss = ce_loss + l1_lambda_tensor * l1_loss + l2_lambda_tensor * l2_loss
|
|
|
|
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
|
|
|
|
train_loss += loss.item() * inputs.size(0)
|
|
|
_, preds = torch.max(outputs, 1)
|
|
|
train_acc += torch.sum(preds == targets.data)
|
|
|
|
|
|
return train_loss / len(train_loader.dataset), train_acc.double() / len(train_loader.dataset)
|
|
|
|
|
|
def _validate_epoch(self, val_loader: DataLoader, criterion: nn.Module) -> Tuple[float, float, float]:
|
|
|
"""验证一个epoch"""
|
|
|
self.model.eval()
|
|
|
val_loss = 0
|
|
|
val_acc = 0
|
|
|
all_preds = []
|
|
|
all_targets = []
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for inputs, targets in val_loader:
|
|
|
outputs = self.model(inputs)
|
|
|
loss = criterion(outputs, targets)
|
|
|
|
|
|
val_loss += loss.item() * inputs.size(0)
|
|
|
_, preds = torch.max(outputs, 1)
|
|
|
val_acc += torch.sum(preds == targets.data)
|
|
|
|
|
|
all_preds.extend(preds.cpu().numpy())
|
|
|
all_targets.extend(targets.cpu().numpy())
|
|
|
|
|
|
val_f1 = f1_score(all_targets, all_preds, average='macro')
|
|
|
return (val_loss / len(val_loader.dataset),
|
|
|
val_acc.double() / len(val_loader.dataset),
|
|
|
val_f1)
|
|
|
|
|
|
def evaluate_model(self, features_data: np.ndarray, labels: np.ndarray, is_training: bool = False) -> Tuple:
|
|
|
"""
|
|
|
评估模型
|
|
|
|
|
|
Args:
|
|
|
features_data: 输入特征数据
|
|
|
labels: 真实标签
|
|
|
is_training: 是否在训练过程中调用
|
|
|
|
|
|
Returns:
|
|
|
Tuple[float, float, List[float], List[float], List[float]]:
|
|
|
(平均F1值, 错误率, 精确率列表, 召回率列表, F1值列表)
|
|
|
"""
|
|
|
# 只有在非训练过程中才加载模型
|
|
|
if not is_training:
|
|
|
self.load_model()
|
|
|
|
|
|
self.model.eval()
|
|
|
with torch.no_grad():
|
|
|
outputs = self.model(torch.from_numpy(features_data).float().to(self.device))
|
|
|
_, predictions = torch.max(outputs, 1)
|
|
|
|
|
|
predictions = predictions.cpu().numpy()
|
|
|
precision = precision_score(labels, predictions, average=None)
|
|
|
recall = recall_score(labels, predictions, average=None)
|
|
|
f1 = f1_score(labels, predictions, average=None)
|
|
|
|
|
|
wrong_count = np.sum(labels != predictions)
|
|
|
wrong_percentage = (wrong_count / len(labels)) * 100
|
|
|
|
|
|
# 打印评估结果
|
|
|
log_message = (
|
|
|
f"\nEvaluation Results:\n"
|
|
|
f"Average F1: {np.mean(f1):.4f}\n"
|
|
|
f"Wrong Percentage: {wrong_percentage:.2f}%\n"
|
|
|
f"Precision: {precision.tolist()} | {np.mean(precision):.4f}\n"
|
|
|
f"Recall: {recall.tolist()} | {np.mean(recall):.4f}\n"
|
|
|
f"F1: {f1.tolist()} | {np.mean(f1):.4f}\n"
|
|
|
f"Total samples: {len(labels)}\n"
|
|
|
f"Wrong predictions: {wrong_count}"
|
|
|
)
|
|
|
logging.info(log_message)
|
|
|
print(log_message)
|
|
|
|
|
|
return np.mean(f1), wrong_percentage, precision.tolist(), recall.tolist(), f1.tolist()
|
|
|
|
|
|
def inference_model(self, features_data: np.ndarray) -> List[int]:
|
|
|
"""
|
|
|
模型推理
|
|
|
|
|
|
Args:
|
|
|
features_data: 输入特征数据
|
|
|
|
|
|
Returns:
|
|
|
预测结果列表
|
|
|
"""
|
|
|
# 先加载模型
|
|
|
self.load_model()
|
|
|
|
|
|
self.model.eval()
|
|
|
with torch.no_grad():
|
|
|
outputs = self.model(torch.from_numpy(features_data).float().to(self.device))
|
|
|
_, predictions = torch.max(outputs, 1)
|
|
|
return predictions.cpu().numpy().tolist()
|
|
|
|
|
|
def save_model(self, path: str) -> None:
|
|
|
"""
|
|
|
保存模型
|
|
|
|
|
|
Args:
|
|
|
path: 保存路径
|
|
|
"""
|
|
|
# 确保目录存在
|
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
|
|
|
|
torch.save(self.model.state_dict(), path)
|
|
|
|
|
|
def train_detect(self) -> Tuple[float, float, List[float], List[float], List[float]]:
|
|
|
"""
|
|
|
训练和检测模型
|
|
|
|
|
|
Returns:
|
|
|
Tuple[float, float, List[float], List[float], List[float]]:
|
|
|
- average_f1: 平均F1分数
|
|
|
- wrong_percentage: 错误率
|
|
|
- precision_scores: 各类别精确率
|
|
|
- recall_scores: 各类别召回率
|
|
|
- f1_scores: 各类别F1分数
|
|
|
"""
|
|
|
# 从配置文件中读取数据路径
|
|
|
data = pd.read_excel(self.config['data_path'])
|
|
|
|
|
|
# 获取特征和标签
|
|
|
X = data[self.config['features']['feature_names']].values
|
|
|
y = data[self.config['features']['label_name']].values
|
|
|
|
|
|
# 显示特征相关性
|
|
|
self.plot_correlation(X, self.config['features']['feature_names'])
|
|
|
|
|
|
# 根据配置选择数据模式
|
|
|
if self.config['training']['data_mode'] == 'train_val':
|
|
|
# 使用 StratifiedKFold 进行划分
|
|
|
skf = StratifiedKFold(n_splits=5, shuffle=True)
|
|
|
train_idx, val_idx = next(skf.split(X, y))
|
|
|
X_train, X_val = X[train_idx], X[val_idx]
|
|
|
y_train, y_val = y[train_idx], y[val_idx]
|
|
|
else:
|
|
|
X_train, X_val = X, X
|
|
|
y_train, y_val = y, y
|
|
|
|
|
|
# 创建数据加载器
|
|
|
train_dataset = TensorDataset(
|
|
|
torch.from_numpy(X_train).float().to(self.device),
|
|
|
torch.from_numpy(y_train).long().to(self.device)
|
|
|
)
|
|
|
val_dataset = TensorDataset(
|
|
|
torch.from_numpy(X_val).float().to(self.device),
|
|
|
torch.from_numpy(y_val).long().to(self.device)
|
|
|
)
|
|
|
|
|
|
train_loader = DataLoader(
|
|
|
train_dataset,
|
|
|
batch_size=self.config['training']['batch_size'],
|
|
|
shuffle=True
|
|
|
)
|
|
|
val_loader = DataLoader(
|
|
|
val_dataset,
|
|
|
batch_size=self.config['training']['batch_size']
|
|
|
)
|
|
|
|
|
|
# 计算类别权重
|
|
|
class_weights = torch.tensor(
|
|
|
compute_class_weight('balanced',
|
|
|
classes=np.unique(y_train),
|
|
|
y=y_train),
|
|
|
dtype=torch.float32
|
|
|
).to(self.device)
|
|
|
|
|
|
if self.config['training']['experimental_mode']:
|
|
|
return self._run_experiments(X_train, y_train, X_val, y_val, class_weights)
|
|
|
else:
|
|
|
results = self._single_train_detect(X_train, y_train, X_val, y_val, class_weights)
|
|
|
# 只返回前5个值,去掉 best_epoch
|
|
|
return results[:-1]
|
|
|
|
|
|
def _update_metrics(self, train_metrics: Dict, val_metrics: Dict,
|
|
|
train_loss: float, train_acc: float,
|
|
|
val_loss: float, val_acc: float,
|
|
|
val_f1: float, epoch: int) -> None:
|
|
|
"""更新训练和验证指标"""
|
|
|
train_metrics['loss'].append(train_loss)
|
|
|
train_metrics['acc'].append(train_acc)
|
|
|
|
|
|
val_metrics['loss'].append(val_loss)
|
|
|
val_metrics['acc'].append(val_acc)
|
|
|
val_metrics['f1'].append(val_f1)
|
|
|
|
|
|
# 同时记录到日志和打印到控制台
|
|
|
log_message = (
|
|
|
f'Epoch {epoch+1:03d} | ' # 添加 epoch 信息,使用3位数字格式
|
|
|
f'Train Loss: {train_loss:.4f} | '
|
|
|
f'Train Acc: {train_acc:.4f} | '
|
|
|
f'Val Loss: {val_loss:.4f} | '
|
|
|
f'Val Acc: {val_acc:.4f} | '
|
|
|
f'Val F1: {val_f1:.4f}'
|
|
|
)
|
|
|
logging.info(log_message)
|
|
|
print(log_message) # 打印到控制台
|
|
|
|
|
|
def _run_experiments(self, X_train, y_train, X_val, y_val, class_weights):
|
|
|
"""运行多次实验"""
|
|
|
all_results = []
|
|
|
list_avg_f1 = []
|
|
|
list_wrong_percentage = []
|
|
|
list_precision = []
|
|
|
list_recall = []
|
|
|
list_f1 = []
|
|
|
|
|
|
best_experiment_f1 = 0
|
|
|
best_experiment_results = None
|
|
|
best_experiment_model = None
|
|
|
best_experiment_num = -1
|
|
|
best_experiment_epoch = -1
|
|
|
best_model_path = None
|
|
|
|
|
|
base_model_path = self.config['train_model_path']
|
|
|
base_name = os.path.splitext(base_model_path)[0]
|
|
|
ext = os.path.splitext(base_model_path)[1]
|
|
|
|
|
|
for i in range(self.config['training']['experiments_count']):
|
|
|
exp_num = i + 1
|
|
|
logging.info(f"Starting experiment {exp_num}/{self.config['training']['experiments_count']}")
|
|
|
print(f"Starting experiment {exp_num}/{self.config['training']['experiments_count']}")
|
|
|
|
|
|
# 为每次实验创建模型保存路径(基础名称+序号)
|
|
|
exp_model_path = f"{base_name}_exp{exp_num}{ext}"
|
|
|
|
|
|
results = self._single_train_detect(X_train, y_train, X_val, y_val, class_weights, exp_model_path)
|
|
|
avg_f1, wrong_percentage, precision, recall, f1, best_epoch = results
|
|
|
|
|
|
# 记录每次实验的结果
|
|
|
list_avg_f1.append(avg_f1)
|
|
|
list_wrong_percentage.append(wrong_percentage)
|
|
|
list_precision.append(precision)
|
|
|
list_recall.append(recall)
|
|
|
list_f1.append(f1)
|
|
|
|
|
|
# 如果是最佳模型
|
|
|
if avg_f1 > best_experiment_f1:
|
|
|
best_experiment_f1 = avg_f1
|
|
|
best_experiment_results = results
|
|
|
best_experiment_model = self.model.state_dict()
|
|
|
best_experiment_num = exp_num
|
|
|
best_experiment_epoch = best_epoch
|
|
|
best_model_path = exp_model_path
|
|
|
|
|
|
# 直接复制最佳实验的模型文件到不带序号的版本和目标路径
|
|
|
shutil.copyfile(best_model_path, base_model_path)
|
|
|
if self.config['training']['replace_model']:
|
|
|
shutil.copyfile(best_model_path, self.config['paths']['model']['train'])
|
|
|
|
|
|
# 打印日志信息
|
|
|
log_message = (
|
|
|
f"\nExperiments Summary:\n"
|
|
|
f"Mean F1: {np.mean(list_avg_f1):.4f} "
|
|
|
f"Mean Wrong Percentage: {np.mean(list_wrong_percentage):.2f}%\n"
|
|
|
f"Mean Precision: {[np.mean([p[i] for p in list_precision]) for i in range(len(list_precision[0]))]} | {np.mean(list_precision)}\n"
|
|
|
f"Mean Recall: {[np.mean([r[i] for r in list_recall]) for i in range(len(list_recall[0]))]} | {np.mean(list_recall)}\n"
|
|
|
f"Mean F1: {[np.mean([f1[i] for f1 in list_f1]) for i in range(len(list_f1[0]))]} | {np.mean(list_f1)}\n"
|
|
|
f"\nBest Model Details:\n"
|
|
|
f"Best Model Path: {os.path.basename(best_model_path)}\n"
|
|
|
f"From Experiment: {best_experiment_num}/{self.config['training']['experiments_count']}\n"
|
|
|
f"Best Epoch: {best_experiment_epoch}\n"
|
|
|
f"Best F1: {best_experiment_f1:.4f} "
|
|
|
f"Wrong Percentage: {best_experiment_results[1]:.2f}%\n"
|
|
|
f"Precision: {best_experiment_results[2]} | {np.mean(best_experiment_results[2])}\n"
|
|
|
f"Recall: {best_experiment_results[3]} | {np.mean(best_experiment_results[3])}\n"
|
|
|
f"F1: {best_experiment_results[4]} | {np.mean(best_experiment_results[4])}"
|
|
|
)
|
|
|
logging.info(log_message)
|
|
|
print(log_message)
|
|
|
|
|
|
return (best_experiment_results[0], # average_f1
|
|
|
best_experiment_results[1], # wrong_percentage
|
|
|
best_experiment_results[2], # precision
|
|
|
best_experiment_results[3], # recall
|
|
|
best_experiment_results[4]) # f1
|
|
|
|
|
|
def _single_train_detect(self, X_train, y_train, X_val, y_val, class_weights, model_path=None):
|
|
|
"""单次训练过程"""
|
|
|
# 创建数据加载器
|
|
|
train_dataset = TensorDataset(
|
|
|
torch.from_numpy(X_train).float().to(self.device),
|
|
|
torch.from_numpy(y_train).long().to(self.device)
|
|
|
)
|
|
|
val_dataset = TensorDataset(
|
|
|
torch.from_numpy(X_val).float().to(self.device),
|
|
|
torch.from_numpy(y_val).long().to(self.device)
|
|
|
)
|
|
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=self.config['training']['batch_size'], shuffle=True)
|
|
|
val_loader = DataLoader(val_dataset, batch_size=self.config['training']['batch_size'])
|
|
|
|
|
|
# 创建模型和优化器
|
|
|
self.create_model()
|
|
|
criterion = nn.CrossEntropyLoss(weight=class_weights)
|
|
|
optimizer = torch.optim.Adam(
|
|
|
self.model.parameters(),
|
|
|
lr=self.config['training']['learning_rate']
|
|
|
)
|
|
|
scheduler = torch.optim.lr_scheduler.StepLR(
|
|
|
optimizer,
|
|
|
self.config['training']['scheduler']['step_size'],
|
|
|
self.config['training']['scheduler']['gamma']
|
|
|
)
|
|
|
|
|
|
# 训练模型
|
|
|
best_val_f1, best_model, best_epoch = self.train_model(
|
|
|
train_loader, val_loader, criterion, optimizer, scheduler
|
|
|
)
|
|
|
|
|
|
# 加载最佳模型
|
|
|
self.model.load_state_dict(best_model)
|
|
|
|
|
|
# 保存模型
|
|
|
if model_path:
|
|
|
self.save_model(model_path)
|
|
|
else:
|
|
|
# 如果是单次训练模式,直接保存到原始路径
|
|
|
self.save_model(self.config['train_model_path'])
|
|
|
|
|
|
# 使用最佳模型进行评估
|
|
|
eval_results = self.evaluate_model(X_val, y_val, is_training=True)
|
|
|
return eval_results + (best_epoch,)
|
|
|
|
|
|
def _plot_training_process(self, train_metrics: Dict, val_metrics: Dict) -> None:
|
|
|
"""绘制训练过程图"""
|
|
|
plt.figure(figsize=(15, 5))
|
|
|
|
|
|
# 损失曲线
|
|
|
plt.subplot(131)
|
|
|
plt.plot(train_metrics['loss'], label='Train Loss')
|
|
|
plt.plot(val_metrics['loss'], label='Val Loss')
|
|
|
plt.title('Loss', fontproperties=self.font_prop)
|
|
|
plt.legend(prop=self.font_prop)
|
|
|
|
|
|
# 准确率曲线
|
|
|
plt.subplot(132)
|
|
|
plt.plot(train_metrics['acc'], label='Train Acc')
|
|
|
plt.plot(val_metrics['acc'], label='Val Acc')
|
|
|
plt.title('Accuracy', fontproperties=self.font_prop)
|
|
|
plt.legend(prop=self.font_prop)
|
|
|
|
|
|
# F1分数曲线
|
|
|
plt.subplot(133)
|
|
|
plt.plot(val_metrics['f1'], label='Val F1')
|
|
|
plt.title('F1 Score', fontproperties=self.font_prop)
|
|
|
plt.legend(prop=self.font_prop)
|
|
|
|
|
|
plt.tight_layout()
|
|
|
plt.savefig(self.config['paths']['model']['train_process'])
|
|
|
plt.close()
|