You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
80 lines
2.8 KiB
Python
80 lines
2.8 KiB
Python
from pydantic import BaseModel
|
|
import requests
|
|
import pandas as pd
|
|
import os
|
|
from utils.data_process import preprocess_data, convert_to_list
|
|
|
|
def download_file(data_file_info, save_dir):
|
|
"""
|
|
下载文件并保存到指定目录。
|
|
|
|
参数:
|
|
data_file_info (dict): 包含文件信息的字典,可能包含"model_file_url"、"log_file_url"和"data_file_url"等字段
|
|
save_dir (str): 文件保存的目录路径
|
|
"""
|
|
|
|
# 下载日志文件
|
|
if "log_file_url" in data_file_info:
|
|
log_file_url = data_file_info["log_file_url"]
|
|
log_save_path = os.path.join(save_dir, "log.txt")
|
|
response = requests.get(log_file_url)
|
|
with open(log_save_path, 'wb') as file:
|
|
file.write(response.content)
|
|
print(f"Log file saved to: {log_save_path}")
|
|
else:
|
|
print("无法获取日志文件信息")
|
|
|
|
# 下载数据文件
|
|
if "data_file_url" in data_file_info:
|
|
data_file_url = data_file_info["data_file_url"]
|
|
data_save_path = os.path.join(save_dir, "data.xlsx")
|
|
response = requests.get(data_file_url)
|
|
with open(data_save_path, 'wb') as file:
|
|
file.write(response.content)
|
|
print(f"Data file saved to: {data_save_path}")
|
|
else:
|
|
print("无法获取数据文件信息")
|
|
|
|
def classify_features(features_data_list):
|
|
"""
|
|
发送特征数据到服务端进行分类,并获取分类结果。
|
|
|
|
参数:
|
|
features_data_list (list): 包含特征数据的列表
|
|
|
|
返回:
|
|
dict: 包含分类结果和模型文件信息的字典
|
|
"""
|
|
response = requests.post("http://127.0.0.1:8088/evaluate/", json=features_data_list)
|
|
if response.status_code == 200:
|
|
results = response.json()
|
|
print("Precision:", results["classification_result"]["precision"])
|
|
print("Recall:", results["classification_result"]["recall"])
|
|
print("F1:", results["classification_result"]["f1"])
|
|
print("Wrong Percentage:", results["classification_result"]["wrong_percentage"])
|
|
|
|
data_file_info = results["data_file"]
|
|
data_save_dir = os.path.dirname(__file__)
|
|
download_file(data_file_info, data_save_dir)
|
|
|
|
return results
|
|
else:
|
|
print("请求失败:", response.text)
|
|
return None
|
|
|
|
if __name__ == "__main__":
|
|
# 读取原始数据表
|
|
df_src = pd.read_excel("data/data_src.xlsx")
|
|
df_leave = pd.read_excel("data_processed/Leave_Record_RES.xlsx")
|
|
df_dropout_warning = pd.read_excel("data_processed/Dropout_Warning_RES.xlsx")
|
|
|
|
# 数据预处理
|
|
df = preprocess_data(df_src, df_leave, df_dropout_warning)
|
|
|
|
# 转换成数据列表
|
|
features_data_list = convert_to_list(df)
|
|
|
|
# 发送 POST 请求并处理结果
|
|
results = classify_features(features_data_list)
|
|
if results:
|
|
print("Classification results:", results) |