diff --git a/utils/model_trainer.py b/utils/model_trainer.py index bb38684..6f360e1 100644 --- a/utils/model_trainer.py +++ b/utils/model_trainer.py @@ -111,7 +111,7 @@ class MLModel: else: raise FileNotFoundError(f"No model found in either path: {model_path} or {default_model_path}") - self.model.load_state_dict(torch.load(model_path, map_location=self.device)) + self.model.load_state_dict(torch.load(model_path, map_location=self.device, weights_only=False)) def plot_correlation(self, X: np.ndarray, feature_names: List[str]) -> None: """绘制特征相关性矩阵"""