|
|
|
@ -111,7 +111,7 @@ class MLModel:
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
raise FileNotFoundError(f"No model found in either path: {model_path} or {default_model_path}")
|
|
|
|
raise FileNotFoundError(f"No model found in either path: {model_path} or {default_model_path}")
|
|
|
|
|
|
|
|
|
|
|
|
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
|
|
|
|
self.model.load_state_dict(torch.load(model_path, map_location=self.device, weights_only=False))
|
|
|
|
|
|
|
|
|
|
|
|
def plot_correlation(self, X: np.ndarray, feature_names: List[str]) -> None:
|
|
|
|
def plot_correlation(self, X: np.ndarray, feature_names: List[str]) -> None:
|
|
|
|
"""绘制特征相关性矩阵"""
|
|
|
|
"""绘制特征相关性矩阵"""
|
|
|
|
|