You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
psy/api_send_train.py

89 lines
3.3 KiB
Python

import requests
import pandas as pd
import os
from utils.data_process import preprocess_data, convert_to_list
def download_file(data_file_info, save_dir):
"""
下载文件并保存到指定目录。
参数:
data_file_info (dict): 包含文件信息的字典,可能包含"model_file_url""log_file_url""data_file_url"等字段
save_dir (str): 文件保存的目录路径
"""
# 下载模型文件
if "model_file_url" in data_file_info:
model_file_url = data_file_info["model_file_url"]
model_save_path = os.path.join(save_dir, "model.pth")
response = requests.get(model_file_url)
with open(model_save_path, 'wb') as file:
file.write(response.content)
print(f"Model saved to: {model_save_path}")
else:
print("无法获取模型文件信息")
# 下载日志文件
if "log_file_url" in data_file_info:
log_file_url = data_file_info["log_file_url"]
log_save_path = os.path.join(save_dir, "log.txt")
response = requests.get(log_file_url)
with open(log_save_path, 'wb') as file:
file.write(response.content)
print(f"Log file saved to: {log_save_path}")
else:
print("无法获取日志文件信息")
# 下载数据文件
if "data_file_url" in data_file_info:
data_file_url = data_file_info["data_file_url"]
data_save_path = os.path.join(save_dir, "data.xlsx")
response = requests.get(data_file_url)
with open(data_save_path, 'wb') as file:
file.write(response.content)
print(f"Data file saved to: {data_save_path}")
else:
print("无法获取数据文件信息")
def classify_features(features_data_list):
"""
发送特征数据到服务端进行分类,并获取分类结果。
参数:
features_data_list (list): 包含特征数据的列表
返回:
dict: 包含分类结果和模型文件信息的字典
"""
response = requests.post("http://127.0.0.1:3397/train/", json=features_data_list)
if response.status_code == 200:
results = response.json()
print("Precision:", results["classification_result"]["precision"])
print("Recall:", results["classification_result"]["recall"])
print("F1:", results["classification_result"]["f1"])
print("Wrong Percentage:", results["classification_result"]["wrong_percentage"])
data_file_info = results["data_file"]
data_save_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "client"))
download_file(data_file_info, data_save_dir)
return results
else:
print("请求失败:", response.text)
return None
if __name__ == "__main__":
# 读取原始数据表
df_src = pd.read_excel("data/data_src.xlsx")
df_leave = pd.read_excel("data_processed/Leave_Record_RES.xlsx")
df_dropout_warning = pd.read_excel("data_processed/Dropout_Warning_RES.xlsx")
# 数据预处理
df = preprocess_data(df_src, df_leave, df_dropout_warning)
# 转换成数据列表
features_data_list = convert_to_list(df)
# 发送 POST 请求并处理结果
results = classify_features(features_data_list)
if results:
print("Classification results:", results)