图片打印加上中文

master
wangchunlin 2 years ago
parent d957e4f3e6
commit a2a1028f47

@ -92,11 +92,11 @@ class MLModel:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5)) fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
ax1.bar(np.arange(len(precision)), precision) ax1.bar(np.arange(len(precision)), precision)
ax1.set_title('Precision') ax1.set_title('Precision(精确率)')
ax2.bar(np.arange(len(recall)), recall) ax2.bar(np.arange(len(recall)), recall)
ax2.set_title('Recall') ax2.set_title('Recall(召回率)')
ax3.bar(np.arange(len(f1)), f1) ax3.bar(np.arange(len(f1)), f1)
ax3.set_title('F1 Score') ax3.set_title('F1 Score(F1得分)')
# 保存图片 # 保存图片
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
evaluate_result_path = os.path.join(parent_dir, self.config['evaluate_result_path']) evaluate_result_path = os.path.join(parent_dir, self.config['evaluate_result_path'])
@ -179,18 +179,18 @@ class MLModel:
# 打印训练和验证过程的可视化图片 # 打印训练和验证过程的可视化图片
plt.close('all') plt.close('all')
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5)) fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
ax1.plot(train_loss_history, label='Train Loss') ax1.plot(train_loss_history, label='Train Loss(训练损失)')
ax1.plot(val_loss_history, label='Validation Loss') ax1.plot(val_loss_history, label='Validation Loss(验证损失)')
ax1.set_title('Loss') ax1.set_title('Loss')
ax1.legend() ax1.legend()
ax2.plot(train_acc_history, label='Train Accuracy') ax2.plot(train_acc_history, label='Train Accuracy(训练正确率)')
ax2.plot(val_acc_history, label='Validation Accuracy') ax2.plot(val_acc_history, label='Validation Accuracy(验证正确率)')
ax2.set_title('Accuracy') ax2.set_title('Accuracy')
ax2.legend() ax2.legend()
ax3.plot(val_f1_history, label='Validation F1') ax3.plot(val_f1_history, label='Validation F1(验证F1得分)')
ax3.plot(val_precision_history, label='Validation Precision') ax3.plot(val_precision_history, label='Validation Precision(验证精确率)')
ax3.plot(val_recall_history, label='Validation Recall') ax3.plot(val_recall_history, label='Validation Recall(验证召回率)')
ax3.set_title('Precision Recall F1-Score (Macro Mean)') ax3.set_title('Precision Recall F1-Score (Macro Mean)(宏平均)')
ax3.legend() ax3.legend()
# 保存图片 # 保存图片
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))

Loading…
Cancel
Save