You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

127 lines
3.9 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
文件名: detect_num.py
推理部分代码
作者: 王春林
创建日期: 2023年7月14日
最后修改日期: 2023年7月18日
版本号: 1.0.0
"""
import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
# 读取特征和标签
data = pd.read_excel('feature_label.xlsx')
# 获取编号列
sample_ids = data['编号'].values
# 以下是你的特征名
feature_names = ["躯体化", "强迫症状", "人际关系敏感", "抑郁", "焦虑", "敌对", "恐怖", "偏执", "精神病性", "其他", "父亲教养方式数字化", "母亲教养方式数字化", "自评家庭经济条件数字化", "有无心理治疗(咨询)史数字化", "出勤情况数字化", "学业情况数字化", "权重数字化值"]
# 将特征和标签分开,并做归一化处理
X = data[feature_names].values
y = data['label'].values - 1 # 将标签从1-4转换为0-3
scaler = MinMaxScaler()
X = scaler.fit_transform(X)
# 定义 MLP 网络
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.model = nn.Sequential(
nn.Linear(17, 32), # 输入层
nn.ReLU(), # 激活函数
nn.Linear(32, 128), # 隐藏层
nn.ReLU(), # 激活函数
nn.Linear(128, 32), # 隐藏层
nn.ReLU(), # 激活函数
nn.Linear(32, 4), # 输出层4个类别
)
def forward(self, x):
return self.model(x)
# 加载模型
model = MLP().to(device)
model.load_state_dict(torch.load('Psychological_Classification_4Classes.pth'))
model.eval()
# 创建数据加载器
dataset = TensorDataset(torch.from_numpy(X).float().to(device), torch.from_numpy(y).long().to(device))
loader = DataLoader(dataset, batch_size=32)
# 推理
corrects = 0
sample_index = 0
for inputs, targets in loader:
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
corrects += torch.sum(preds == targets.data)
# 打印每个样本的推理结果
for i in range(len(inputs)):
print(f'Sample ID: {sample_ids[sample_index]} | Target: {targets[i]} | Prediction: {preds[i]} (-1 in excel)')
sample_index += 1
# 计算整体推理的正确率
accuracy = corrects.double().cpu() / len(loader.dataset)
print(f'Overall Accuracy: {accuracy:.4f}')
# ...(之前的代码)
# 创建文件来存储预测结果为0和1的学号
file_1st_warning = open("一级预警.txt", "w", encoding="utf-8")
file_2nd_warning = open("二级预警.txt", "w", encoding="utf-8")
# 初始化每个类别的计数器
class_counts = {0: 0, 1: 0, 2: 0, 3: 0}
class_corrects = {0: 0, 1: 0, 2: 0, 3: 0}
# 进行推理
corrects = 0
sample_index = 0
for inputs, targets in loader:
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
corrects += torch.sum(preds == targets.data)
# 记录预测结果为0和1的学号
for i in range(len(inputs)):
if preds[i] == 0:
file_1st_warning.write(f"{sample_ids[sample_index]}\n")
elif preds[i] == 1:
file_2nd_warning.write(f"{sample_ids[sample_index]}\n")
sample_index += 1
# 更新每个类别的计数器和正确计数器
for i in range(len(targets)):
class_counts[targets[i].item()] += 1
if preds[i] == targets[i]:
class_corrects[targets[i].item()] += 1
# 关闭文件
file_1st_warning.close()
file_2nd_warning.close()
# 计算整体准确率
accuracy = corrects.double().cpu() / len(loader.dataset)
print(f'整体准确率: {accuracy:.4f}')
# 打印每个类别的信息
for class_idx in range(4):
class_accuracy = class_corrects[class_idx] / class_counts[class_idx]
print(f'类别 {class_idx + 1} | 预测数量: {class_counts[class_idx]} | 准确率: {class_accuracy:.4f}')