From 9e6f84b4496aedf75cffce79bfcba9d1b574235d Mon Sep 17 00:00:00 2001 From: wangchunlin Date: Sun, 1 Dec 2024 14:12:43 +0800 Subject: [PATCH] =?UTF-8?q?=E8=B0=83=E4=BC=98=EF=BC=8C=E5=B7=AE=E8=AE=AD?= =?UTF-8?q?=E7=BB=83=E8=BF=87=E7=A8=8B=E5=9B=BE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api_send_evaluate.py | 12 ++++++------ api_send_train.py | 18 +++++++++--------- config/config.yaml | 22 ++++++++++------------ api_server.py => psy_api.py | 0 utils/model_trainer.py | 22 +++++++++++++++++++--- 5 files changed, 44 insertions(+), 30 deletions(-) rename api_server.py => psy_api.py (100%) diff --git a/api_send_evaluate.py b/api_send_evaluate.py index fb877db..e879018 100644 --- a/api_send_evaluate.py +++ b/api_send_evaluate.py @@ -18,9 +18,9 @@ def download_file(data_file_info, save_dir): log_file_url = data_file_info["log_file_url"] log_save_path = os.path.join(save_dir, "log.txt") response = requests.get(log_file_url) - with open(log_save_path, 'wb') as file: - file.write(response.content) - print(f"Log file saved to: {log_save_path}") + # with open(log_save_path, 'wb') as file: + # file.write(response.content) + # print(f"Log file saved to: {log_save_path}") else: print("无法获取日志文件信息") @@ -29,9 +29,9 @@ def download_file(data_file_info, save_dir): data_file_url = data_file_info["data_file_url"] data_save_path = os.path.join(save_dir, "data.xlsx") response = requests.get(data_file_url) - with open(data_save_path, 'wb') as file: - file.write(response.content) - print(f"Data file saved to: {data_save_path}") + # with open(data_save_path, 'wb') as file: + # file.write(response.content) + # print(f"Data file saved to: {data_save_path}") else: print("无法获取数据文件信息") diff --git a/api_send_train.py b/api_send_train.py index bbe383a..1a143e6 100644 --- a/api_send_train.py +++ b/api_send_train.py @@ -16,9 +16,9 @@ def download_file(data_file_info, save_dir): model_file_url = data_file_info["model_file_url"] model_save_path = os.path.join(save_dir, "model.pth") response = requests.get(model_file_url) - with open(model_save_path, 'wb') as file: - file.write(response.content) - print(f"Model saved to: {model_save_path}") + # with open(model_save_path, 'wb') as file: + # file.write(response.content) + # print(f"Model saved to: {model_save_path}") else: print("无法获取模型文件信息") @@ -27,9 +27,9 @@ def download_file(data_file_info, save_dir): log_file_url = data_file_info["log_file_url"] log_save_path = os.path.join(save_dir, "log.txt") response = requests.get(log_file_url) - with open(log_save_path, 'wb') as file: - file.write(response.content) - print(f"Log file saved to: {log_save_path}") + # with open(log_save_path, 'wb') as file: + # file.write(response.content) + # print(f"Log file saved to: {log_save_path}") else: print("无法获取日志文件信息") @@ -38,9 +38,9 @@ def download_file(data_file_info, save_dir): data_file_url = data_file_info["data_file_url"] data_save_path = os.path.join(save_dir, "data.xlsx") response = requests.get(data_file_url) - with open(data_save_path, 'wb') as file: - file.write(response.content) - print(f"Data file saved to: {data_save_path}") + # with open(data_save_path, 'wb') as file: + # file.write(response.content) + # print(f"Data file saved to: {data_save_path}") else: print("无法获取数据文件信息") diff --git a/config/config.yaml b/config/config.yaml index ecccbd2..a200a7f 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -23,21 +23,19 @@ paths: #---训练配置---# training: - n_epochs: 300 - batch_size: 16 # CPU - learning_rate: 0.001 # CPU - # batch_size: 128 # GPU - # learning_rate: 0.001 # GPU - early_stop_patience: 50 + n_epochs: 500 + batch_size: 1024 + learning_rate: 0.001 + early_stop_patience: 100 scheduler: gamma: 0.98 - step_size: 10 + step_size: 20 regularization: - l1_lambda: 1e-5 - l2_lambda: 1e-4 + l1_lambda: 1e-6 # 默认1e-5, 降低L1,因为特征已经通过权重控制 + l2_lambda: 5e-4 # 默认1e-4, 增加L2,加强过拟合控制 dropout_rate: 0.2 - experimental_mode: false - experiments_count: 50 + experimental_mode: true + experiments_count: 10 replace_model: true data_mode: "train_val" # 可选: "train", "train_val", "all" @@ -45,7 +43,7 @@ training: model: num_classes: 4 # nc input_dim: 10 - architecture: "mlp" # 可选: "mlp", "transformer" + architecture: "transformer" # 可选: "mlp", "transformer" mlp: layers: - output_dim: 32 diff --git a/api_server.py b/psy_api.py similarity index 100% rename from api_server.py rename to psy_api.py diff --git a/utils/model_trainer.py b/utils/model_trainer.py index 62ce7b9..d1d3838 100644 --- a/utils/model_trainer.py +++ b/utils/model_trainer.py @@ -324,7 +324,17 @@ class MLModel: torch.save(self.model.state_dict(), path) def train_detect(self) -> Tuple[float, float, List[float], List[float], List[float]]: - """训练和检测模型""" + """ + 训练和检测模型 + + Returns: + Tuple[float, float, List[float], List[float], List[float]]: + - average_f1: 平均F1分数 + - wrong_percentage: 错误率 + - precision_scores: 各类别精确率 + - recall_scores: 各类别召回率 + - f1_scores: 各类别F1分数 + """ # 从配置文件中读取数据路径 data = pd.read_excel(self.config['data_path']) @@ -377,7 +387,9 @@ class MLModel: if self.config['training']['experimental_mode']: return self._run_experiments(X_train, y_train, X_val, y_val, class_weights) else: - return self._single_train_detect(X_train, y_train, X_val, y_val, class_weights) + results = self._single_train_detect(X_train, y_train, X_val, y_val, class_weights) + # 只返回前5个值,去掉 best_epoch + return results[:-1] def _update_metrics(self, train_metrics: Dict, val_metrics: Dict, train_loss: float, train_acc: float, @@ -476,7 +488,11 @@ class MLModel: logging.info(log_message) print(log_message) - return best_experiment_results[:-1] + return (best_experiment_results[0], # average_f1 + best_experiment_results[1], # wrong_percentage + best_experiment_results[2], # precision + best_experiment_results[3], # recall + best_experiment_results[4]) # f1 def _single_train_detect(self, X_train, y_train, X_val, y_val, class_weights, model_path=None): """单次训练过程"""