You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
323 lines
15 KiB
Python
323 lines
15 KiB
Python
import os
|
|
import yaml
|
|
import pandas as pd
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
from torch.utils.data import DataLoader, TensorDataset
|
|
from sklearn.model_selection import StratifiedKFold
|
|
from sklearn.metrics import precision_score, recall_score, f1_score
|
|
from sklearn.utils.class_weight import compute_class_weight
|
|
import logging
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.font_manager import FontProperties
|
|
|
|
# 控制是否打印的宏定义
|
|
PRINT_LOG = True
|
|
|
|
# 指定字体路径
|
|
font_path = os.path.join(os.path.dirname(__file__), '../fonts', 'simhei.ttf')
|
|
font_prop = FontProperties(fname=font_path)
|
|
|
|
# 设置 matplotlib 支持中文
|
|
plt.rcParams['font.sans-serif'] = [font_prop.get_name()]
|
|
plt.rcParams['axes.unicode_minus'] = False
|
|
|
|
def log_print(message):
|
|
logging.info(message)
|
|
if PRINT_LOG:
|
|
print(message)
|
|
|
|
class MLModel:
|
|
def __init__(self, model_config):
|
|
self.config = model_config
|
|
self.model = None
|
|
|
|
def create_model(self):
|
|
self.model = MLP(self.config).to(self.config['device'])
|
|
# self.model = TransformerModel(self.config).to(self.config['device'])
|
|
|
|
def load_model(self):
|
|
self.create_model()
|
|
self.model.load_state_dict(torch.load(self.config['model_path'], map_location=self.config['device']))
|
|
|
|
def load_and_split_data(self):
|
|
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|
file_path = os.path.join(parent_dir, self.config['data_path'])
|
|
|
|
data = pd.read_excel(file_path)
|
|
|
|
X = data[self.config['feature_names']].values
|
|
y = data[self.config['label_name']].values
|
|
|
|
skf_outer = StratifiedKFold(n_splits=5, shuffle=True)
|
|
train_index_outer, test_index_outer = next(skf_outer.split(X, y))
|
|
X_train_val, X_infer = X[train_index_outer], X[test_index_outer]
|
|
y_train_val, y_infer = y[train_index_outer], y[test_index_outer]
|
|
|
|
skf_inner = StratifiedKFold(n_splits=5, shuffle=True)
|
|
train_index_inner, test_index_inner = next(skf_inner.split(X_train_val, y_train_val))
|
|
X_train, X_val = X_train_val[train_index_inner], X_train_val[test_index_inner]
|
|
y_train, y_val = y_train_val[train_index_inner], y_train_val[test_index_inner]
|
|
|
|
return X, y, X_train_val, y_train_val, X_train, y_train, X_val, y_val, X_infer, y_infer
|
|
|
|
def save_model(self, model_path):
|
|
torch.save(self.model.state_dict(), model_path)
|
|
|
|
def evaluate_model(self, X_infer, y_infer):
|
|
self.model.eval()
|
|
with torch.no_grad():
|
|
outputs = self.model(torch.from_numpy(X_infer).float().to(self.config['device']))
|
|
|
|
_, predictions = torch.max(outputs, 1)
|
|
precision = precision_score(y_infer, predictions.cpu().numpy(), average=None)
|
|
recall = recall_score(y_infer, predictions.cpu().numpy(), average=None)
|
|
f1 = f1_score(y_infer, predictions.cpu().numpy(), average=None)
|
|
wrong_count = len(np.where(y_infer != predictions.cpu().numpy())[0])
|
|
total_count = len(y_infer)
|
|
wrong_percentage = (wrong_count / total_count) * 100
|
|
|
|
log_print("Evaluate Result: ")
|
|
|
|
log_print(f"Prediction errors: {wrong_count}")
|
|
log_print(f"Prediction error percentage: {wrong_percentage:.2f}%")
|
|
log_print(f"Total samples: {total_count}")
|
|
|
|
avg_precision = np.mean(precision)
|
|
avg_recall = np.mean(recall)
|
|
avg_f1 = np.mean(f1)
|
|
|
|
for i in range(len(precision)):
|
|
log_print(f"Class {i} Precision: {precision[i]:.4f}, Recall: {recall[i]:.4f}, F1: {f1[i]:.4f}")
|
|
|
|
log_print("精确率:" + str(precision))
|
|
log_print("召回率:" + str(recall))
|
|
log_print("F1得分:" + str(f1))
|
|
log_print("平均精确率:" + str(avg_precision))
|
|
log_print("平均召回率:" + str(avg_recall))
|
|
log_print("平均F1得分:" + str(avg_f1))
|
|
log_print("Evaluate Result End: ")
|
|
|
|
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
|
|
ax1.bar(np.arange(len(precision)), precision)
|
|
ax1.set_title('Precision(精确率)', fontproperties=font_prop)
|
|
ax2.bar(np.arange(len(recall)), recall)
|
|
ax2.set_title('Recall(召回率)', fontproperties=font_prop)
|
|
ax3.bar(np.arange(len(f1)), f1)
|
|
ax3.set_title('F1 Score(F1得分)', fontproperties=font_prop)
|
|
# 保存图片
|
|
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|
evaluate_result_path = os.path.join(parent_dir, self.config['evaluate_result_path'])
|
|
plt.savefig(evaluate_result_path)
|
|
|
|
return np.mean(f1), wrong_percentage, precision, recall, f1
|
|
|
|
def inference_model(self, X_infer):
|
|
self.model.eval()
|
|
with torch.no_grad():
|
|
outputs = self.model(torch.from_numpy(X_infer).float().to(self.config['device']))
|
|
|
|
_, predictions = torch.max(outputs, 1)
|
|
return predictions.cpu().numpy().tolist()
|
|
|
|
def train_model(self, train_loader, val_loader, criterion, optimizer, scheduler):
|
|
n_epochs = self.config['n_epochs']
|
|
best_val_f1 = 0.0
|
|
best_val_recall = 0.0
|
|
best_val_precision = 0.0
|
|
best_epoch = -1
|
|
best_model = None
|
|
patience = self.config['early_stop_patience']
|
|
trigger_times = 0
|
|
|
|
train_loss_history, train_acc_history, val_loss_history, val_acc_history, val_f1_history, val_precision_history, val_recall_history = [[] for _ in range(7)]
|
|
|
|
plt.rcParams['figure.max_open_warning'] = 50
|
|
|
|
for epoch in range(n_epochs):
|
|
# Training phase
|
|
self.model.train()
|
|
train_loss, train_acc = 0, 0
|
|
for inputs, targets in train_loader:
|
|
optimizer.zero_grad()
|
|
outputs = self.model(inputs)
|
|
loss = criterion(outputs, targets)
|
|
loss.backward()
|
|
optimizer.step()
|
|
train_loss += loss.item() * inputs.size(0)
|
|
_, preds = torch.max(outputs, 1)
|
|
train_acc += torch.sum(preds == targets.data)
|
|
|
|
train_loss /= len(train_loader.dataset)
|
|
train_acc = train_acc.double().cpu() / len(train_loader.dataset)
|
|
|
|
# 更新学习率
|
|
scheduler.step()
|
|
|
|
# Validation phase
|
|
val_loss, val_acc, all_preds, all_targets = 0, 0, [], []
|
|
self.model.eval()
|
|
with torch.no_grad():
|
|
for inputs, targets in val_loader:
|
|
outputs = self.model(inputs)
|
|
loss = criterion(outputs, targets)
|
|
val_loss += loss.item() * inputs.size(0)
|
|
_, preds = torch.max(outputs, 1)
|
|
val_acc += torch.sum(preds == targets.data)
|
|
all_preds.extend(preds.cpu().numpy())
|
|
all_targets.extend(targets.cpu().numpy())
|
|
|
|
val_loss /= len(val_loader.dataset)
|
|
val_acc = val_acc.double().cpu() / len(val_loader.dataset)
|
|
|
|
class_precisions_m = precision_score(all_targets, all_preds, average='macro')
|
|
class_recalls_m = recall_score(all_targets, all_preds, average='macro')
|
|
class_f1_scores_m = f1_score(all_targets, all_preds, average='macro')
|
|
|
|
log_print(f'Epoch {epoch+1:0{3}d} | Train Loss: {train_loss:.4f} | Train Accuracy: {train_acc:.4f} | Validation Loss: {val_loss:.4f} | Validation Accuracy: {val_acc:.4f} | Validation Mean Precision: {class_precisions_m:.4f} | Validation Mean Recall: {class_recalls_m:.4f} | Validation Mean F1_score: {class_f1_scores_m:.4f}')
|
|
|
|
train_loss_history.append(train_loss)
|
|
train_acc_history.append(train_acc)
|
|
val_loss_history.append(val_loss)
|
|
val_acc_history.append(val_acc)
|
|
val_f1_history.append(class_f1_scores_m)
|
|
val_precision_history.append(class_precisions_m)
|
|
val_recall_history.append(class_recalls_m)
|
|
|
|
# 打印训练和验证过程的可视化图片
|
|
plt.close('all')
|
|
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
|
|
ax1.plot(train_loss_history, label='Train Loss(训练损失)')
|
|
ax1.plot(val_loss_history, label='Validation Loss(验证损失)')
|
|
ax1.set_title('Loss(损失)', fontproperties=font_prop)
|
|
ax1.legend()
|
|
ax2.plot(train_acc_history, label='Train Accuracy(训练正确率)')
|
|
ax2.plot(val_acc_history, label='Validation Accuracy(验证正确率)')
|
|
ax2.set_title('Accuracy(正确率)', fontproperties=font_prop)
|
|
ax2.legend()
|
|
ax3.plot(val_f1_history, label='Validation F1(验证F1得分)')
|
|
ax3.plot(val_precision_history, label='Validation Precision(验证精确率)')
|
|
ax3.plot(val_recall_history, label='Validation Recall(验证召回率)')
|
|
ax3.set_title('Precision Recall F1-Score (Macro Mean)(宏平均)', fontproperties=font_prop)
|
|
ax3.legend()
|
|
# 保存图片
|
|
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|
train_process_path = os.path.join(parent_dir, self.config['train_process_path'])
|
|
plt.savefig(train_process_path)
|
|
|
|
if class_f1_scores_m > best_val_f1:
|
|
best_val_f1 = class_f1_scores_m
|
|
best_val_recall = class_recalls_m
|
|
best_val_precision = class_precisions_m
|
|
best_epoch = epoch
|
|
best_model = self.model.state_dict()
|
|
trigger_times = 0
|
|
else:
|
|
trigger_times += 1
|
|
if trigger_times >= patience:
|
|
log_print(f'Early stopping at epoch {epoch} | Best epoch : {best_epoch + 1}')
|
|
break
|
|
|
|
return best_val_f1, best_val_recall, best_val_precision, best_epoch, best_model
|
|
|
|
def train_detect(self):
|
|
X, y, X_train_val, y_train_val, X_train, y_train, X_val, y_val, X_infer, y_infer = self.load_and_split_data()
|
|
|
|
if self.config['data_train'] == 'train_val':
|
|
train_dataset = TensorDataset(torch.from_numpy(X_train_val).float().to(self.config['device']), torch.from_numpy(y_train_val).long().to(self.config['device']))
|
|
val_dataset = TensorDataset(torch.from_numpy(X_infer).float().to(self.config['device']), torch.from_numpy(y_infer).long().to(self.config['device']))
|
|
class_weights = torch.tensor(compute_class_weight('balanced', classes=np.unique(y_train_val), y=y_train_val), dtype=torch.float32).to(self.config['device'])
|
|
elif self.config['data_train'] == 'train':
|
|
train_dataset = TensorDataset(torch.from_numpy(X_train).float().to(self.config['device']), torch.from_numpy(y_train).long().to(self.config['device']))
|
|
val_dataset = TensorDataset(torch.from_numpy(X_val).float().to(self.config['device']), torch.from_numpy(y_val).long().to(self.config['device']))
|
|
class_weights = torch.tensor(compute_class_weight('balanced', classes=np.unique(y_train), y=y_train), dtype=torch.float32).to(self.config['device'])
|
|
elif self.config['data_train'] == 'all':
|
|
train_dataset = TensorDataset(torch.from_numpy(X).float().to(self.config['device']), torch.from_numpy(y).long().to(self.config['device']))
|
|
val_dataset = TensorDataset(torch.from_numpy(X).float().to(self.config['device']), torch.from_numpy(y).long().to(self.config['device']))
|
|
X_infer = X
|
|
y_infer = y
|
|
class_weights = torch.tensor(compute_class_weight('balanced', classes=np.unique(y), y=y), dtype=torch.float32).to(self.config['device'])
|
|
else:
|
|
logging.error("Error: Set data_train first in yaml!")
|
|
raise ValueError("Error: Set data_train first in yaml!")
|
|
|
|
log_print(f"Class weights: {class_weights}")
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=self.config['batch_size'], shuffle=True)
|
|
val_loader = DataLoader(val_dataset, batch_size=self.config['batch_size'])
|
|
|
|
self.create_model()
|
|
|
|
criterion = nn.CrossEntropyLoss(weight=class_weights)
|
|
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config['learning_rate'])
|
|
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, self.config['step_size'], self.config['gamma'])
|
|
|
|
best_val_f1, best_val_recall, best_val_precision, best_epoch, best_model = self.train_model(train_loader, val_loader, criterion, optimizer, scheduler)
|
|
|
|
# Save the best model
|
|
self.save_model(self.config['train_model_path'])
|
|
|
|
log_print(f"Best Validation F1 Score (Macro): {best_val_f1:.4f}")
|
|
log_print(f"Best Validation Recall (Macro): {best_val_recall:.4f}")
|
|
log_print(f"Best Validation Precision (Macro): {best_val_precision:.4f}")
|
|
log_print(f"Best Epoch: {best_epoch + 1}")
|
|
|
|
avg_f1, wrong_percentage, precision, recall, f1 = self.evaluate_model(X_infer, y_infer)
|
|
|
|
return avg_f1, wrong_percentage, precision, recall, f1
|
|
|
|
|
|
# class MLP(nn.Module):
|
|
# def __init__(self, config):
|
|
# super(MLP, self).__init__()
|
|
# self.model = nn.Sequential(
|
|
# nn.Linear(len(config['feature_names']), 32),
|
|
# nn.ReLU(),
|
|
# nn.Linear(32, 128),
|
|
# nn.ReLU(),
|
|
# nn.Linear(128, 32),
|
|
# nn.ReLU(),
|
|
# nn.Linear(32, config['nc']),
|
|
# )
|
|
|
|
# def forward(self, x):
|
|
# return self.model(x)
|
|
# 20260605
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(self, config):
|
|
super(MLP, self).__init__()
|
|
layers = []
|
|
input_dim = config['mlp']['input_dim']
|
|
for layer_cfg in config['mlp']['layers']:
|
|
layers.append(nn.Linear(input_dim, layer_cfg['output_dim']))
|
|
if layer_cfg.get('activation', None) == 'relu':
|
|
layers.append(nn.ReLU())
|
|
input_dim = layer_cfg['output_dim']
|
|
layers.append(nn.Linear(input_dim, config['mlp']['output_dim']))
|
|
self.model = nn.Sequential(*layers)
|
|
|
|
def forward(self, x):
|
|
return self.model(x)
|
|
|
|
class TransformerModel(nn.Module):
|
|
def __init__(self, config):
|
|
super(TransformerModel, self).__init__()
|
|
self.embedding = nn.Linear(config['transformer']['input_dim'], config['transformer']['d_model'])
|
|
self.transformer = nn.Transformer(
|
|
d_model=config['transformer']['d_model'],
|
|
nhead=config['transformer']['nhead'],
|
|
num_encoder_layers=config['transformer']['num_encoder_layers'],
|
|
num_decoder_layers=config['transformer']['num_decoder_layers'],
|
|
dim_feedforward=config['transformer']['dim_feedforward'],
|
|
dropout=config['transformer']['dropout']
|
|
)
|
|
self.fc = nn.Linear(config['transformer']['d_model'], config['transformer']['output_dim'])
|
|
|
|
def forward(self, x):
|
|
x = self.embedding(x).unsqueeze(1) # Add sequence dimension
|
|
transformer_output = self.transformer(x, x)
|
|
output = self.fc(transformer_output.squeeze(1)) # Remove sequence dimension
|
|
return output
|