import os import sys root_path = os.getcwd() sys.path.append(root_path) import time import signal import uvicorn import pandas as pd from fastapi import FastAPI from pydantic import BaseModel from typing import List from inference import predict_with_model from fastapi.middleware.cors import CORSMiddleware import threading from http.server import SimpleHTTPRequestHandler, HTTPServer app = FastAPI() # 允许所有域名的跨域请求 app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["*"], ) # 定义请求处理类 class MyRequestHandler(SimpleHTTPRequestHandler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.directory = '.' # 设置根目录 # 定义返回结构 class PredictionResult(BaseModel): predictions: list # 定义请求体模型 class Features(BaseModel): # 10个SCL评测量(后续再处理) somatization: float obsessive_compulsive: float interpersonal_sensitivity: float depression: float anxiety: float hostility: float terror: float paranoia: float psychoticism: float other: float # 基本信息特征量 father_parenting_style: int # 温暖与理解:1;其他:0 mother_parenting_style: int # 温暖与理解:1;其他:0 self_assessed_family_economic_condition: int # 贫困:2;较差:1;其他:0 history_of_psychological_counseling: bool # 有:1;无:0 # 日常行为特征量 absenteeism_above_average: bool # 大于平均次数:1;小于等于:0 academic_warning: bool # 有预警:1;无预警:0 # 定义接口 @app.post("/classify/") async def classify_features(features_list: List[Features]): # 定义一个空的 DataFrame 用于存储所有特征 all_features = pd.DataFrame() # 遍历每个特征对象,为每个对象创建一个 DataFrame,并将其添加到 all_features 中 for features in features_list: relevant_features = { "somatization": features.somatization, "obsessive_compulsive": features.obsessive_compulsive, "interpersonal_sensitivity": features.interpersonal_sensitivity, "depression": features.depression, "anxiety": features.anxiety, "hostility": features.hostility, "terror": features.terror, "paranoia": features.paranoia, "psychoticism": features.psychoticism, "other": features.other } # 创建只有一行的 DataFrame,并直接赋值给列 df_feature = pd.DataFrame({ # 数字化特征--基本信息 '父亲教养方式数字化': [(lambda x: 0.59 if x == 1 else 0.46)(features.father_parenting_style)], '母亲教养方式数字化': [(lambda x: 0.69 if x == 1 else 0.56)(features.mother_parenting_style)], '自评家庭经济条件数字化': [(lambda x: 0.54 if x in [2, 1] else 0.47)(features.self_assessed_family_economic_condition)], '有无心理治疗(咨询)史数字化': [(lambda x: 0.21 if x else 0.09)(features.history_of_psychological_counseling)], # 数字化特征--症状因子 '强迫症状数字化': [features.obsessive_compulsive / 4], '人际关系敏感数字化': [features.interpersonal_sensitivity / 4], '抑郁数字化': [features.depression / 4], '多因子症状': [(lambda x: sum(1 for value in x.values() if value > 3.0) / 10)(relevant_features)], # 数字化特征--日常行为 '出勤情况数字化': [0.74 if features.absenteeism_above_average else 0.67], '学业情况数字化': [0.59 if features.academic_warning else 0.50] }) all_features = pd.concat([all_features, df_feature], ignore_index=True) # # 将 DataFrame 转换为字典 # df_dict = df_feature.to_dict(orient='records') # print(df_dict) # # 返回 FastAPI 响应 # return df_dict # print(all_features) start_time = time.time() # 记录开始时间 predictions = predict_with_model(all_features) end_time = time.time() # 记录结束时间 print("预测耗时:", end_time - start_time, "秒") # 打印执行时间 print("预测结果:", predictions) # 返回预测结果 return PredictionResult(predictions=predictions) # 信号处理函数 def signal_handler(sig, frame): print("Ctrl+C detected, shutting down the server...") # 在这里执行关闭服务器的操作 sys.exit(0) # 启动服务器的函数 def run_server(): port = 8080 server = HTTPServer(('0.0.0.0', port), MyRequestHandler) print(f'Web HTTP Server listening on http://0.0.0.0:{port}') server.serve_forever() if __name__ == "__main__": # 注册信号处理函数 signal.signal(signal.SIGINT, signal_handler) # 创建线程并启动前端 server_thread = threading.Thread(target=run_server) server_thread.start() name_app = os.path.basename(__file__)[0:-3] # Get the name of the script log_config = { "version": 1, "disable_existing_loggers": True, "handlers": { "file_handler": { "class": "logging.FileHandler", "filename": "logfile.log", }, }, "root": { "handlers": ["file_handler"], "level": "INFO", }, } #uvicorn.run(f'{name_app}:app', host="0.0.0.0", port=3397, reload=False,log_config=log_config) uvicorn.run(app, host="0.0.0.0", port=3397, reload=False)