|
|
|
@ -324,7 +324,17 @@ class MLModel:
|
|
|
|
torch.save(self.model.state_dict(), path)
|
|
|
|
torch.save(self.model.state_dict(), path)
|
|
|
|
|
|
|
|
|
|
|
|
def train_detect(self) -> Tuple[float, float, List[float], List[float], List[float]]:
|
|
|
|
def train_detect(self) -> Tuple[float, float, List[float], List[float], List[float]]:
|
|
|
|
"""训练和检测模型"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
训练和检测模型
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
Tuple[float, float, List[float], List[float], List[float]]:
|
|
|
|
|
|
|
|
- average_f1: 平均F1分数
|
|
|
|
|
|
|
|
- wrong_percentage: 错误率
|
|
|
|
|
|
|
|
- precision_scores: 各类别精确率
|
|
|
|
|
|
|
|
- recall_scores: 各类别召回率
|
|
|
|
|
|
|
|
- f1_scores: 各类别F1分数
|
|
|
|
|
|
|
|
"""
|
|
|
|
# 从配置文件中读取数据路径
|
|
|
|
# 从配置文件中读取数据路径
|
|
|
|
data = pd.read_excel(self.config['data_path'])
|
|
|
|
data = pd.read_excel(self.config['data_path'])
|
|
|
|
|
|
|
|
|
|
|
|
@ -377,7 +387,9 @@ class MLModel:
|
|
|
|
if self.config['training']['experimental_mode']:
|
|
|
|
if self.config['training']['experimental_mode']:
|
|
|
|
return self._run_experiments(X_train, y_train, X_val, y_val, class_weights)
|
|
|
|
return self._run_experiments(X_train, y_train, X_val, y_val, class_weights)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
return self._single_train_detect(X_train, y_train, X_val, y_val, class_weights)
|
|
|
|
results = self._single_train_detect(X_train, y_train, X_val, y_val, class_weights)
|
|
|
|
|
|
|
|
# 只返回前5个值,去掉 best_epoch
|
|
|
|
|
|
|
|
return results[:-1]
|
|
|
|
|
|
|
|
|
|
|
|
def _update_metrics(self, train_metrics: Dict, val_metrics: Dict,
|
|
|
|
def _update_metrics(self, train_metrics: Dict, val_metrics: Dict,
|
|
|
|
train_loss: float, train_acc: float,
|
|
|
|
train_loss: float, train_acc: float,
|
|
|
|
@ -476,7 +488,11 @@ class MLModel:
|
|
|
|
logging.info(log_message)
|
|
|
|
logging.info(log_message)
|
|
|
|
print(log_message)
|
|
|
|
print(log_message)
|
|
|
|
|
|
|
|
|
|
|
|
return best_experiment_results[:-1]
|
|
|
|
return (best_experiment_results[0], # average_f1
|
|
|
|
|
|
|
|
best_experiment_results[1], # wrong_percentage
|
|
|
|
|
|
|
|
best_experiment_results[2], # precision
|
|
|
|
|
|
|
|
best_experiment_results[3], # recall
|
|
|
|
|
|
|
|
best_experiment_results[4]) # f1
|
|
|
|
|
|
|
|
|
|
|
|
def _single_train_detect(self, X_train, y_train, X_val, y_val, class_weights, model_path=None):
|
|
|
|
def _single_train_detect(self, X_train, y_train, X_val, y_val, class_weights, model_path=None):
|
|
|
|
"""单次训练过程"""
|
|
|
|
"""单次训练过程"""
|
|
|
|
|