You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

162 lines
5.6 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import sys
root_path = os.getcwd()
sys.path.append(root_path)
import time
import signal
import uvicorn
import pandas as pd
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
from inference import predict_with_model
from fastapi.middleware.cors import CORSMiddleware
import threading
from http.server import SimpleHTTPRequestHandler, HTTPServer
app = FastAPI()
# 允许所有域名的跨域请求
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["*"],
)
# 定义请求处理类
class MyRequestHandler(SimpleHTTPRequestHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.directory = '.' # 设置根目录
# 定义返回结构
class PredictionResult(BaseModel):
predictions: list
# 定义请求体模型
class Features(BaseModel):
# 10个SCL评测量后续再处理
somatization: float
obsessive_compulsive: float
interpersonal_sensitivity: float
depression: float
anxiety: float
hostility: float
terror: float
paranoia: float
psychoticism: float
other: float
# 基本信息特征量
father_parenting_style: int # 温暖与理解1其他0
mother_parenting_style: int # 温暖与理解1其他0
self_assessed_family_economic_condition: int # 贫困2较差1其他0
history_of_psychological_counseling: bool # 有10
# 日常行为特征量
absenteeism_above_average: bool # 大于平均次数1小于等于0
academic_warning: bool # 有预警1无预警0
# 定义接口
@app.post("/classify/")
async def classify_features(features_list: List[Features]):
# 定义一个空的 DataFrame 用于存储所有特征
all_features = pd.DataFrame()
# 遍历每个特征对象,为每个对象创建一个 DataFrame并将其添加到 all_features 中
for features in features_list:
relevant_features = {
"somatization": features.somatization,
"obsessive_compulsive": features.obsessive_compulsive,
"interpersonal_sensitivity": features.interpersonal_sensitivity,
"depression": features.depression,
"anxiety": features.anxiety,
"hostility": features.hostility,
"terror": features.terror,
"paranoia": features.paranoia,
"psychoticism": features.psychoticism,
"other": features.other
}
# 创建只有一行的 DataFrame并直接赋值给列
df_feature = pd.DataFrame({
# 数字化特征--基本信息
'父亲教养方式数字化': [(lambda x: 0.59 if x == 1 else 0.46)(features.father_parenting_style)],
'母亲教养方式数字化': [(lambda x: 0.69 if x == 1 else 0.56)(features.mother_parenting_style)],
'自评家庭经济条件数字化': [(lambda x: 0.54 if x in [2, 1] else 0.47)(features.self_assessed_family_economic_condition)],
'有无心理治疗(咨询)史数字化': [(lambda x: 0.21 if x else 0.09)(features.history_of_psychological_counseling)],
# 数字化特征--症状因子
'强迫症状数字化': [features.obsessive_compulsive / 4],
'人际关系敏感数字化': [features.interpersonal_sensitivity / 4],
'抑郁数字化': [features.depression / 4],
'多因子症状': [(lambda x: sum(1 for value in x.values() if value > 3.0) / 10)(relevant_features)],
# 数字化特征--日常行为
'出勤情况数字化': [0.74 if features.absenteeism_above_average else 0.67],
'学业情况数字化': [0.59 if features.academic_warning else 0.50]
})
all_features = pd.concat([all_features, df_feature], ignore_index=True)
# # 将 DataFrame 转换为字典
# df_dict = df_feature.to_dict(orient='records')
# print(df_dict)
# # 返回 FastAPI 响应
# return df_dict
# print(all_features)
start_time = time.time() # 记录开始时间
predictions = predict_with_model(all_features)
end_time = time.time() # 记录结束时间
print("预测耗时:", end_time - start_time, "") # 打印执行时间
print("预测结果:", predictions)
# 返回预测结果
return PredictionResult(predictions=predictions)
# 信号处理函数
def signal_handler(sig, frame):
print("Ctrl+C detected, shutting down the server...")
# 在这里执行关闭服务器的操作
sys.exit(0)
# 启动服务器的函数
def run_server():
port = 8080
server = HTTPServer(('0.0.0.0', port), MyRequestHandler)
print(f'Web HTTP Server listening on http://0.0.0.0:{port}')
server.serve_forever()
if __name__ == "__main__":
# 注册信号处理函数
signal.signal(signal.SIGINT, signal_handler)
# 创建线程并启动前端
server_thread = threading.Thread(target=run_server)
server_thread.start()
name_app = os.path.basename(__file__)[0:-3] # Get the name of the script
log_config = {
"version": 1,
"disable_existing_loggers": True,
"handlers": {
"file_handler": {
"class": "logging.FileHandler",
"filename": "logfile.log",
},
},
"root": {
"handlers": ["file_handler"],
"level": "INFO",
},
}
#uvicorn.run(f'{name_app}:app', host="0.0.0.0", port=3397, reload=False,log_config=log_config)
uvicorn.run(app, host="0.0.0.0", port=3397, reload=False)