|
|
import os
|
|
|
import sys
|
|
|
root_path = os.getcwd()
|
|
|
sys.path.append(root_path)
|
|
|
|
|
|
import time
|
|
|
import datetime
|
|
|
import signal
|
|
|
import uvicorn
|
|
|
import pandas as pd
|
|
|
from fastapi import FastAPI, Request
|
|
|
from pydantic import BaseModel
|
|
|
from typing import List
|
|
|
from utils.common import inference_model
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
import logging
|
|
|
import matplotlib.pyplot as plt
|
|
|
import argparse
|
|
|
import numpy as np
|
|
|
import yaml
|
|
|
import threading
|
|
|
import pickle
|
|
|
from fastapi.responses import FileResponse
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
from utils.feature_process import create_feature_df, apply_feature_weights, Features
|
|
|
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
# 定义fastapi返回类
|
|
|
class PredictionResult(BaseModel):
|
|
|
predictions: list
|
|
|
|
|
|
# 允许所有域名的跨域请求
|
|
|
app.add_middleware(
|
|
|
CORSMiddleware,
|
|
|
allow_origins=["*"],
|
|
|
allow_credentials=True,
|
|
|
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
|
|
allow_headers=["*"],
|
|
|
)
|
|
|
|
|
|
# 定义接口
|
|
|
@app.post("/inference/")
|
|
|
async def classify_features(request: Request, features_list: List[Features]):
|
|
|
# 遍历每个特征对象,并将其添加到 all_features 中
|
|
|
all_features = create_feature_df(features_list)
|
|
|
|
|
|
# 读取 YAML 配置文件
|
|
|
config_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "config/config.yaml"))
|
|
|
with open(config_path, 'r') as f:
|
|
|
config = yaml.load(f, Loader=yaml.FullLoader)
|
|
|
feature_names = config['feature_names']
|
|
|
feature_weights = config['feature_weights']
|
|
|
|
|
|
# 应用特征权重
|
|
|
feature_label_weighted = apply_feature_weights(all_features, feature_names, feature_weights)
|
|
|
|
|
|
start_time = time.time() # 记录开始时间
|
|
|
|
|
|
# 创建静态文件存放文件夹
|
|
|
static_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "inference_api")) # 设置模型文件和配置文件的存放目录,和本py同级
|
|
|
os.makedirs(static_dir, exist_ok=True)
|
|
|
|
|
|
# 训练前设置
|
|
|
now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
data_path = os.path.abspath(os.path.join(os.path.dirname(__file__), static_dir, f"all_features_label_{now}.xlsx"))
|
|
|
config['data_path'] = data_path
|
|
|
feature_label_weighted.to_excel(data_path, index=False)
|
|
|
|
|
|
# 配置日志
|
|
|
log_path = os.path.abspath(os.path.join(os.path.dirname(__file__), static_dir, f"inference_{now}.log"))
|
|
|
logging.basicConfig(filename=log_path, level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
|
|
|
|
|
# 特征和标签
|
|
|
X = feature_label_weighted[config['feature_names']].values
|
|
|
y = feature_label_weighted[config['label_name']].values
|
|
|
|
|
|
predictions = inference_model(config["model_path"], X, y, config)
|
|
|
end_time = time.time() # 记录结束时间
|
|
|
print("预测耗时:", end_time - start_time, "秒") # 打印执行时间
|
|
|
|
|
|
print("预测结果:", predictions)
|
|
|
|
|
|
# 返回预测结果
|
|
|
return PredictionResult(predictions=predictions)
|
|
|
|
|
|
# 以下是fastapi启动配置
|
|
|
if __name__ == "__main__":
|
|
|
name_app = os.path.basename(__file__)[0:-3] # Get the name of the script
|
|
|
log_config = {
|
|
|
"version": 1,
|
|
|
"disable_existing_loggers": True,
|
|
|
"handlers": {
|
|
|
"file_handler": {
|
|
|
"class": "logging.FileHandler",
|
|
|
"filename": "logfile.log",
|
|
|
},
|
|
|
},
|
|
|
"root": {
|
|
|
"handlers": ["file_handler"],
|
|
|
"level": "INFO",
|
|
|
},
|
|
|
}
|
|
|
# 创建静态文件存放文件夹
|
|
|
static_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "inference_api")) # 设置模型文件和配置文件的存放目录,和本py同级
|
|
|
os.makedirs(static_dir, exist_ok=True)
|
|
|
# train_api.py同级目录下的static文件夹
|
|
|
app.mount("/inference_api", StaticFiles(directory=static_dir), name="static")
|
|
|
uvicorn.run(app, host="0.0.0.0", port=3397, reload=False)
|
|
|
|