调优,差训练过程图

v2
wangchunlin 1 year ago
parent 2412af6abe
commit 9e6f84b449

@ -18,9 +18,9 @@ def download_file(data_file_info, save_dir):
log_file_url = data_file_info["log_file_url"]
log_save_path = os.path.join(save_dir, "log.txt")
response = requests.get(log_file_url)
with open(log_save_path, 'wb') as file:
file.write(response.content)
print(f"Log file saved to: {log_save_path}")
# with open(log_save_path, 'wb') as file:
# file.write(response.content)
# print(f"Log file saved to: {log_save_path}")
else:
print("无法获取日志文件信息")
@ -29,9 +29,9 @@ def download_file(data_file_info, save_dir):
data_file_url = data_file_info["data_file_url"]
data_save_path = os.path.join(save_dir, "data.xlsx")
response = requests.get(data_file_url)
with open(data_save_path, 'wb') as file:
file.write(response.content)
print(f"Data file saved to: {data_save_path}")
# with open(data_save_path, 'wb') as file:
# file.write(response.content)
# print(f"Data file saved to: {data_save_path}")
else:
print("无法获取数据文件信息")

@ -16,9 +16,9 @@ def download_file(data_file_info, save_dir):
model_file_url = data_file_info["model_file_url"]
model_save_path = os.path.join(save_dir, "model.pth")
response = requests.get(model_file_url)
with open(model_save_path, 'wb') as file:
file.write(response.content)
print(f"Model saved to: {model_save_path}")
# with open(model_save_path, 'wb') as file:
# file.write(response.content)
# print(f"Model saved to: {model_save_path}")
else:
print("无法获取模型文件信息")
@ -27,9 +27,9 @@ def download_file(data_file_info, save_dir):
log_file_url = data_file_info["log_file_url"]
log_save_path = os.path.join(save_dir, "log.txt")
response = requests.get(log_file_url)
with open(log_save_path, 'wb') as file:
file.write(response.content)
print(f"Log file saved to: {log_save_path}")
# with open(log_save_path, 'wb') as file:
# file.write(response.content)
# print(f"Log file saved to: {log_save_path}")
else:
print("无法获取日志文件信息")
@ -38,9 +38,9 @@ def download_file(data_file_info, save_dir):
data_file_url = data_file_info["data_file_url"]
data_save_path = os.path.join(save_dir, "data.xlsx")
response = requests.get(data_file_url)
with open(data_save_path, 'wb') as file:
file.write(response.content)
print(f"Data file saved to: {data_save_path}")
# with open(data_save_path, 'wb') as file:
# file.write(response.content)
# print(f"Data file saved to: {data_save_path}")
else:
print("无法获取数据文件信息")

@ -23,21 +23,19 @@ paths:
#---训练配置---#
training:
n_epochs: 300
batch_size: 16 # CPU
learning_rate: 0.001 # CPU
# batch_size: 128 # GPU
# learning_rate: 0.001 # GPU
early_stop_patience: 50
n_epochs: 500
batch_size: 1024
learning_rate: 0.001
early_stop_patience: 100
scheduler:
gamma: 0.98
step_size: 10
step_size: 20
regularization:
l1_lambda: 1e-5
l2_lambda: 1e-4
l1_lambda: 1e-6 # 默认1e-5, 降低L1因为特征已经通过权重控制
l2_lambda: 5e-4 # 默认1e-4, 增加L2加强过拟合控制
dropout_rate: 0.2
experimental_mode: false
experiments_count: 50
experimental_mode: true
experiments_count: 10
replace_model: true
data_mode: "train_val" # 可选: "train", "train_val", "all"
@ -45,7 +43,7 @@ training:
model:
num_classes: 4 # nc
input_dim: 10
architecture: "mlp" # 可选: "mlp", "transformer"
architecture: "transformer" # 可选: "mlp", "transformer"
mlp:
layers:
- output_dim: 32

@ -324,7 +324,17 @@ class MLModel:
torch.save(self.model.state_dict(), path)
def train_detect(self) -> Tuple[float, float, List[float], List[float], List[float]]:
"""训练和检测模型"""
"""
训练和检测模型
Returns:
Tuple[float, float, List[float], List[float], List[float]]:
- average_f1: 平均F1分数
- wrong_percentage: 错误率
- precision_scores: 各类别精确率
- recall_scores: 各类别召回率
- f1_scores: 各类别F1分数
"""
# 从配置文件中读取数据路径
data = pd.read_excel(self.config['data_path'])
@ -377,7 +387,9 @@ class MLModel:
if self.config['training']['experimental_mode']:
return self._run_experiments(X_train, y_train, X_val, y_val, class_weights)
else:
return self._single_train_detect(X_train, y_train, X_val, y_val, class_weights)
results = self._single_train_detect(X_train, y_train, X_val, y_val, class_weights)
# 只返回前5个值去掉 best_epoch
return results[:-1]
def _update_metrics(self, train_metrics: Dict, val_metrics: Dict,
train_loss: float, train_acc: float,
@ -476,7 +488,11 @@ class MLModel:
logging.info(log_message)
print(log_message)
return best_experiment_results[:-1]
return (best_experiment_results[0], # average_f1
best_experiment_results[1], # wrong_percentage
best_experiment_results[2], # precision
best_experiment_results[3], # recall
best_experiment_results[4]) # f1
def _single_train_detect(self, X_train, y_train, X_val, y_val, class_weights, model_path=None):
"""单次训练过程"""

Loading…
Cancel
Save