|
|
|
|
@ -11,6 +11,7 @@ from sklearn.utils.class_weight import compute_class_weight
|
|
|
|
|
import logging
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
from matplotlib.font_manager import FontProperties
|
|
|
|
|
import copy
|
|
|
|
|
|
|
|
|
|
# 控制是否打印的宏定义
|
|
|
|
|
PRINT_LOG = True
|
|
|
|
|
@ -207,7 +208,7 @@ class MLModel:
|
|
|
|
|
best_val_recall = class_recalls_m
|
|
|
|
|
best_val_precision = class_precisions_m
|
|
|
|
|
best_epoch = epoch
|
|
|
|
|
best_model = self.model.state_dict()
|
|
|
|
|
best_model = copy.deepcopy(self.model.state_dict())
|
|
|
|
|
trigger_times = 0
|
|
|
|
|
else:
|
|
|
|
|
trigger_times += 1
|
|
|
|
|
@ -215,6 +216,9 @@ class MLModel:
|
|
|
|
|
log_print(f'Early stopping at epoch {epoch} | Best epoch : {best_epoch + 1}')
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
# 训练完全结束后,加载最佳模型的状态
|
|
|
|
|
self.model.load_state_dict(best_model)
|
|
|
|
|
|
|
|
|
|
return best_val_f1, best_val_recall, best_val_precision, best_epoch, best_model
|
|
|
|
|
|
|
|
|
|
def train_detect(self):
|
|
|
|
|
@ -251,7 +255,10 @@ class MLModel:
|
|
|
|
|
|
|
|
|
|
best_val_f1, best_val_recall, best_val_precision, best_epoch, best_model = self.train_model(train_loader, val_loader, criterion, optimizer, scheduler)
|
|
|
|
|
|
|
|
|
|
# Save the best model
|
|
|
|
|
# 确保使用的是最佳模型
|
|
|
|
|
self.model.load_state_dict(best_model)
|
|
|
|
|
|
|
|
|
|
# 保存最佳模型
|
|
|
|
|
self.save_model(self.config['train_model_path'])
|
|
|
|
|
|
|
|
|
|
log_print(f"Best Validation F1 Score (Macro): {best_val_f1:.4f}")
|
|
|
|
|
|