Coverage for addmo/s3_model_tuning/models/model_factory.py: 65%

57 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-08-31 13:05 +0000

1import inspect 

2import sys 

3import json 

4import joblib 

5import os 

6from addmo.s3_model_tuning.models import scikit_learn_models 

7from addmo.s3_model_tuning.models import keras_models 

8from addmo.s3_model_tuning.models.abstract_model import AbstractMLModel 

9from addmo.s3_model_tuning.models.abstract_model import PredictorOnnx 

10from tensorflow import keras 

11from tensorflow.keras.models import load_model 

12 

13class ModelFactory: 

14 """ 

15 Creates and returns an instance of the specified machine learning model. 

16 """ 

17 

18 @staticmethod 

19 def model_factory(model_type: str, **kwargs) -> AbstractMLModel: 

20 """Get the model instance dynamically.""" 

21 

22 custom_model_class = None 

23 

24 # If model is based on scikit-learn 

25 if hasattr(scikit_learn_models, model_type): 

26 if kwargs: 

27 raise ValueError("No keyword arguments allowed for scikit-learn models.") 

28 custom_model_class = getattr(scikit_learn_models, model_type) 

29 

30 # If model is based on e.g. Keras 

31 elif hasattr(keras_models, model_type): 

32 custom_model_class = getattr(keras_models, model_type) 

33 return custom_model_class(**kwargs) 

34 

35 # Return model if found and a subclass of AbstractMLModel 

36 if (custom_model_class is not None) and (issubclass(custom_model_class, AbstractMLModel)): 

37 return custom_model_class() 

38 

39 # If model is not found 

40 else: 

41 # Get the names of all custom models for error message 

42 custom_model_names = [ 

43 name 

44 for name, obj in inspect.getmembers(scikit_learn_models) + inspect.getmembers(keras_models) 

45 if inspect.isclass(obj) 

46 and issubclass(obj, AbstractMLModel) 

47 and not inspect.isabstract(obj) 

48 ] 

49 

50 raise ValueError( 

51 f"Unknown model type: {model_type}. " 

52 f"Available custom models are: {', '.join(custom_model_names)}. " 

53 ) 

54 

55 def _load_metadata(abs_path: str): 

56 """Read metatdata file when model is loaded.""" 

57 

58 filename = os.path.splitext(abs_path)[0] 

59 metadata_path = f"{filename}_metadata.json" 

60 

61 if os.path.exists(metadata_path): 

62 with open(metadata_path) as f: 

63 metadata = json.load(f) 

64 return metadata 

65 else: 

66 raise FileNotFoundError (f'The metadata file {metadata_path} does not exist. Try saving the model before loading it or specify the path where the model is saved.') 

67 

68 

69 @staticmethod 

70 def load_model(abs_path: str) -> AbstractMLModel: 

71 """Load the model from the specified path and return the model instance.""" 

72 

73 # Load regressor from joblib file to addmo model class 

74 if abs_path.endswith('.joblib'): 

75 metadata = ModelFactory._load_metadata(abs_path) 

76 addmo_class_name = metadata.get('addmo_class') 

77 addmo_class = ModelFactory.model_factory(addmo_class_name) 

78 regressor = joblib.load(abs_path) 

79 addmo_class.load_regressor(regressor) 

80 addmo_class.metadata = metadata 

81 

82 # Load the regressor from onnx file to PredictorOnnx class 

83 elif abs_path.endswith('.onnx'): 

84 addmo_class = PredictorOnnx() 

85 addmo_class.load_regressor(abs_path) 

86 

87 # Load the regressor from keras file to addmo model class 

88 elif abs_path.endswith(('.h5', '.keras')): 

89 metadata = ModelFactory._load_metadata(abs_path) 

90 addmo_class_name = metadata.get('addmo_class') 

91 addmo_class = ModelFactory.model_factory(addmo_class_name) 

92 input_shape = len(metadata.get('features_ordered')) 

93 regressor = keras.models.load_model(abs_path) 

94 regressor.load_weights(abs_path) 

95 addmo_class.load_regressor(regressor, input_shape) 

96 addmo_class.metadata = metadata 

97 

98 else: 

99 raise FileNotFoundError( 

100 f"No model file found at {abs_path}.") 

101 

102 return addmo_class