Coverage for addmo/s3_model_tuning/models/model_factory.py: 65%
57 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-08-31 13:05 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-08-31 13:05 +0000
1import inspect
2import sys
3import json
4import joblib
5import os
6from addmo.s3_model_tuning.models import scikit_learn_models
7from addmo.s3_model_tuning.models import keras_models
8from addmo.s3_model_tuning.models.abstract_model import AbstractMLModel
9from addmo.s3_model_tuning.models.abstract_model import PredictorOnnx
10from tensorflow import keras
11from tensorflow.keras.models import load_model
13class ModelFactory:
14 """
15 Creates and returns an instance of the specified machine learning model.
16 """
18 @staticmethod
19 def model_factory(model_type: str, **kwargs) -> AbstractMLModel:
20 """Get the model instance dynamically."""
22 custom_model_class = None
24 # If model is based on scikit-learn
25 if hasattr(scikit_learn_models, model_type):
26 if kwargs:
27 raise ValueError("No keyword arguments allowed for scikit-learn models.")
28 custom_model_class = getattr(scikit_learn_models, model_type)
30 # If model is based on e.g. Keras
31 elif hasattr(keras_models, model_type):
32 custom_model_class = getattr(keras_models, model_type)
33 return custom_model_class(**kwargs)
35 # Return model if found and a subclass of AbstractMLModel
36 if (custom_model_class is not None) and (issubclass(custom_model_class, AbstractMLModel)):
37 return custom_model_class()
39 # If model is not found
40 else:
41 # Get the names of all custom models for error message
42 custom_model_names = [
43 name
44 for name, obj in inspect.getmembers(scikit_learn_models) + inspect.getmembers(keras_models)
45 if inspect.isclass(obj)
46 and issubclass(obj, AbstractMLModel)
47 and not inspect.isabstract(obj)
48 ]
50 raise ValueError(
51 f"Unknown model type: {model_type}. "
52 f"Available custom models are: {', '.join(custom_model_names)}. "
53 )
55 def _load_metadata(abs_path: str):
56 """Read metatdata file when model is loaded."""
58 filename = os.path.splitext(abs_path)[0]
59 metadata_path = f"{filename}_metadata.json"
61 if os.path.exists(metadata_path):
62 with open(metadata_path) as f:
63 metadata = json.load(f)
64 return metadata
65 else:
66 raise FileNotFoundError (f'The metadata file {metadata_path} does not exist. Try saving the model before loading it or specify the path where the model is saved.')
69 @staticmethod
70 def load_model(abs_path: str) -> AbstractMLModel:
71 """Load the model from the specified path and return the model instance."""
73 # Load regressor from joblib file to addmo model class
74 if abs_path.endswith('.joblib'):
75 metadata = ModelFactory._load_metadata(abs_path)
76 addmo_class_name = metadata.get('addmo_class')
77 addmo_class = ModelFactory.model_factory(addmo_class_name)
78 regressor = joblib.load(abs_path)
79 addmo_class.load_regressor(regressor)
80 addmo_class.metadata = metadata
82 # Load the regressor from onnx file to PredictorOnnx class
83 elif abs_path.endswith('.onnx'):
84 addmo_class = PredictorOnnx()
85 addmo_class.load_regressor(abs_path)
87 # Load the regressor from keras file to addmo model class
88 elif abs_path.endswith(('.h5', '.keras')):
89 metadata = ModelFactory._load_metadata(abs_path)
90 addmo_class_name = metadata.get('addmo_class')
91 addmo_class = ModelFactory.model_factory(addmo_class_name)
92 input_shape = len(metadata.get('features_ordered'))
93 regressor = keras.models.load_model(abs_path)
94 regressor.load_weights(abs_path)
95 addmo_class.load_regressor(regressor, input_shape)
96 addmo_class.metadata = metadata
98 else:
99 raise FileNotFoundError(
100 f"No model file found at {abs_path}.")
102 return addmo_class