fix best model save problem in 1217

master
wangchunlin 1 year ago
parent ec18b2342c
commit 7c69775f4c

@ -11,6 +11,7 @@ from sklearn.utils.class_weight import compute_class_weight
import logging
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
import copy
# 控制是否打印的宏定义
PRINT_LOG = True
@ -207,7 +208,7 @@ class MLModel:
best_val_recall = class_recalls_m
best_val_precision = class_precisions_m
best_epoch = epoch
best_model = self.model.state_dict()
best_model = copy.deepcopy(self.model.state_dict())
trigger_times = 0
else:
trigger_times += 1
@ -215,6 +216,9 @@ class MLModel:
log_print(f'Early stopping at epoch {epoch} | Best epoch : {best_epoch + 1}')
break
# 训练完全结束后,加载最佳模型的状态
self.model.load_state_dict(best_model)
return best_val_f1, best_val_recall, best_val_precision, best_epoch, best_model
def train_detect(self):
@ -251,7 +255,10 @@ class MLModel:
best_val_f1, best_val_recall, best_val_precision, best_epoch, best_model = self.train_model(train_loader, val_loader, criterion, optimizer, scheduler)
# Save the best model
# 确保使用的是最佳模型
self.model.load_state_dict(best_model)
# 保存最佳模型
self.save_model(self.config['train_model_path'])
log_print(f"Best Validation F1 Score (Macro): {best_val_f1:.4f}")

Loading…
Cancel
Save