diff --git a/utils/common.py b/utils/common.py index a4b6b81..77c3735 100644 --- a/utils/common.py +++ b/utils/common.py @@ -11,6 +11,7 @@ from sklearn.utils.class_weight import compute_class_weight import logging import matplotlib.pyplot as plt from matplotlib.font_manager import FontProperties +import copy # 控制是否打印的宏定义 PRINT_LOG = True @@ -207,7 +208,7 @@ class MLModel: best_val_recall = class_recalls_m best_val_precision = class_precisions_m best_epoch = epoch - best_model = self.model.state_dict() + best_model = copy.deepcopy(self.model.state_dict()) trigger_times = 0 else: trigger_times += 1 @@ -215,6 +216,9 @@ class MLModel: log_print(f'Early stopping at epoch {epoch} | Best epoch : {best_epoch + 1}') break + # 训练完全结束后,加载最佳模型的状态 + self.model.load_state_dict(best_model) + return best_val_f1, best_val_recall, best_val_precision, best_epoch, best_model def train_detect(self): @@ -251,7 +255,10 @@ class MLModel: best_val_f1, best_val_recall, best_val_precision, best_epoch, best_model = self.train_model(train_loader, val_loader, criterion, optimizer, scheduler) - # Save the best model + # 确保使用的是最佳模型 + self.model.load_state_dict(best_model) + + # 保存最佳模型 self.save_model(self.config['train_model_path']) log_print(f"Best Validation F1 Score (Macro): {best_val_f1:.4f}")