|
|
import os
|
|
|
import sys
|
|
|
root_path = os.getcwd()
|
|
|
sys.path.append(root_path)
|
|
|
|
|
|
import time
|
|
|
import signal
|
|
|
import uvicorn
|
|
|
import pandas as pd
|
|
|
from fastapi import FastAPI
|
|
|
from pydantic import BaseModel
|
|
|
from typing import List
|
|
|
from inference import predict_with_model
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
|
import threading
|
|
|
from http.server import SimpleHTTPRequestHandler, HTTPServer
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
# 允许所有域名的跨域请求
|
|
|
app.add_middleware(
|
|
|
CORSMiddleware,
|
|
|
allow_origins=["*"],
|
|
|
allow_credentials=True,
|
|
|
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
|
|
allow_headers=["*"],
|
|
|
)
|
|
|
|
|
|
# 定义请求处理类
|
|
|
class MyRequestHandler(SimpleHTTPRequestHandler):
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
super().__init__(*args, **kwargs)
|
|
|
self.directory = '.' # 设置根目录
|
|
|
|
|
|
# 定义返回结构
|
|
|
class PredictionResult(BaseModel):
|
|
|
predictions: list
|
|
|
|
|
|
# 定义请求体模型
|
|
|
class Features(BaseModel):
|
|
|
# 10个SCL评测量(后续再处理)
|
|
|
somatization: float
|
|
|
obsessive_compulsive: float
|
|
|
interpersonal_sensitivity: float
|
|
|
depression: float
|
|
|
anxiety: float
|
|
|
hostility: float
|
|
|
terror: float
|
|
|
paranoia: float
|
|
|
psychoticism: float
|
|
|
other: float
|
|
|
# 基本信息特征量
|
|
|
father_parenting_style: int # 温暖与理解:1;其他:0
|
|
|
mother_parenting_style: int # 温暖与理解:1;其他:0
|
|
|
self_assessed_family_economic_condition: int # 贫困:2;较差:1;其他:0
|
|
|
history_of_psychological_counseling: bool # 有:1;无:0
|
|
|
# 日常行为特征量
|
|
|
absenteeism_above_average: bool # 大于平均次数:1;小于等于:0
|
|
|
academic_warning: bool # 有预警:1;无预警:0
|
|
|
|
|
|
# 定义接口
|
|
|
@app.post("/classify/")
|
|
|
async def classify_features(features_list: List[Features]):
|
|
|
|
|
|
# 定义一个空的 DataFrame 用于存储所有特征
|
|
|
all_features = pd.DataFrame()
|
|
|
|
|
|
# 遍历每个特征对象,为每个对象创建一个 DataFrame,并将其添加到 all_features 中
|
|
|
for features in features_list:
|
|
|
|
|
|
relevant_features = {
|
|
|
"somatization": features.somatization,
|
|
|
"obsessive_compulsive": features.obsessive_compulsive,
|
|
|
"interpersonal_sensitivity": features.interpersonal_sensitivity,
|
|
|
"depression": features.depression,
|
|
|
"anxiety": features.anxiety,
|
|
|
"hostility": features.hostility,
|
|
|
"terror": features.terror,
|
|
|
"paranoia": features.paranoia,
|
|
|
"psychoticism": features.psychoticism,
|
|
|
"other": features.other
|
|
|
}
|
|
|
|
|
|
# 创建只有一行的 DataFrame,并直接赋值给列
|
|
|
df_feature = pd.DataFrame({
|
|
|
# 数字化特征--基本信息
|
|
|
'父亲教养方式数字化': [(lambda x: 0.59 if x == 1 else 0.46)(features.father_parenting_style)],
|
|
|
'母亲教养方式数字化': [(lambda x: 0.69 if x == 1 else 0.56)(features.mother_parenting_style)],
|
|
|
'自评家庭经济条件数字化': [(lambda x: 0.54 if x in [2, 1] else 0.47)(features.self_assessed_family_economic_condition)],
|
|
|
'有无心理治疗(咨询)史数字化': [(lambda x: 0.21 if x else 0.09)(features.history_of_psychological_counseling)],
|
|
|
# 数字化特征--症状因子
|
|
|
'强迫症状数字化': [features.obsessive_compulsive / 4],
|
|
|
'人际关系敏感数字化': [features.interpersonal_sensitivity / 4],
|
|
|
'抑郁数字化': [features.depression / 4],
|
|
|
'多因子症状': [(lambda x: sum(1 for value in x.values() if value > 3.0) / 10)(relevant_features)],
|
|
|
# 数字化特征--日常行为
|
|
|
'出勤情况数字化': [0.74 if features.absenteeism_above_average else 0.67],
|
|
|
'学业情况数字化': [0.59 if features.academic_warning else 0.50]
|
|
|
})
|
|
|
|
|
|
all_features = pd.concat([all_features, df_feature], ignore_index=True)
|
|
|
|
|
|
# # 将 DataFrame 转换为字典
|
|
|
# df_dict = df_feature.to_dict(orient='records')
|
|
|
|
|
|
# print(df_dict)
|
|
|
# # 返回 FastAPI 响应
|
|
|
# return df_dict
|
|
|
# print(all_features)
|
|
|
|
|
|
start_time = time.time() # 记录开始时间
|
|
|
predictions = predict_with_model(all_features)
|
|
|
end_time = time.time() # 记录结束时间
|
|
|
print("预测耗时:", end_time - start_time, "秒") # 打印执行时间
|
|
|
|
|
|
print("预测结果:", predictions)
|
|
|
# 返回预测结果
|
|
|
return PredictionResult(predictions=predictions)
|
|
|
|
|
|
# 信号处理函数
|
|
|
def signal_handler(sig, frame):
|
|
|
print("Ctrl+C detected, shutting down the server...")
|
|
|
# 在这里执行关闭服务器的操作
|
|
|
sys.exit(0)
|
|
|
|
|
|
# 启动服务器的函数
|
|
|
def run_server():
|
|
|
port = 8080
|
|
|
server = HTTPServer(('0.0.0.0', port), MyRequestHandler)
|
|
|
print(f'Web HTTP Server listening on http://0.0.0.0:{port}')
|
|
|
server.serve_forever()
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
# 注册信号处理函数
|
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
|
|
|
|
# 创建线程并启动前端
|
|
|
server_thread = threading.Thread(target=run_server)
|
|
|
server_thread.start()
|
|
|
|
|
|
name_app = os.path.basename(__file__)[0:-3] # Get the name of the script
|
|
|
log_config = {
|
|
|
"version": 1,
|
|
|
"disable_existing_loggers": True,
|
|
|
"handlers": {
|
|
|
"file_handler": {
|
|
|
"class": "logging.FileHandler",
|
|
|
"filename": "logfile.log",
|
|
|
},
|
|
|
},
|
|
|
"root": {
|
|
|
"handlers": ["file_handler"],
|
|
|
"level": "INFO",
|
|
|
},
|
|
|
}
|
|
|
#uvicorn.run(f'{name_app}:app', host="0.0.0.0", port=3397, reload=False,log_config=log_config)
|
|
|
uvicorn.run(app, host="0.0.0.0", port=3397, reload=False)
|
|
|
|