Coverage for tests/test_abstractmodel.py: 96%

53 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-08-31 13:05 +0000

1import unittest 

2import tempfile 

3import os 

4import pandas as pd 

5import numpy as np 

6from addmo.s3_model_tuning.models.abstract_model import AbstractMLModel 

7from addmo.s3_model_tuning.models.keras_models import BaseKerasModel 

8from addmo.s3_model_tuning.models.scikit_learn_models import BaseScikitLearnModel 

9from addmo.s3_model_tuning.models.abstract_model import ModelMetadata 

10from addmo.s3_model_tuning.models.model_factory import ModelFactory 

11 

12 

13def get_subclasses(base_class): 

14 """ 

15 Dynamically get all the subclasses which contain the models for the given base class. 

16 """ 

17 subclasses = set() 

18 work = [base_class] 

19 while work: 

20 parent = work.pop() 

21 for child in parent.__subclasses__(): 

22 if child not in subclasses: 

23 subclasses.add(child) 

24 work.append(child) 

25 return list(subclasses) 

26 

27def train_and_check_model(self, model, x_sample, y_sample): 

28 """ 

29 Test the fit and predict functionality of the model. 

30 """ 

31 model.fit(x_sample, y_sample) 

32 predictions = model.predict(x_sample) 

33 self.assertEqual(len(predictions), len(y_sample)) 

34 self.assertIsInstance(predictions, np.ndarray) 

35 

36class TestBaseMLModel(unittest.TestCase): 

37 """ 

38 Unit tests for base class models. 

39 """ 

40 

41 base_class = BaseScikitLearnModel # Change this to test different base classes 

42 

43 @classmethod 

44 def setUpClass(cls): 

45 """ 

46 Find all subclasses of the base class. 

47 """ 

48 

49 cls.temp_dir = tempfile.TemporaryDirectory() 

50 cls.subclasses = get_subclasses(cls.base_class) 

51 

52 if not cls.subclasses: 

53 raise ValueError(f"No subclasses found for {cls.base_class.__name__}") 

54 

55 @classmethod 

56 def tearDownClass(cls): 

57 """ 

58 Cleanup temp directory after tests. 

59 """ 

60 cls.temp_dir.cleanup() 

61 

62 

63 def test_all_models(self): 

64 """ 

65 Test all registered models that are subclasses of AbstractMLModel. 

66 """ 

67 

68 # Ensure regressor is not None 

69 for model_class in self.subclasses: 

70 with self.subTest(model=model_class.__name__): 

71 

72 model = model_class() # Instantiate model 

73 print(f"\n Testing model {model_class.__name__}") 

74 

75 self.assertIsNotNone(model.regressor, f"{model_class.__name__} should have a regressor") 

76 

77 x_sample = pd.DataFrame(np.random.rand(100, 2), columns=["A", "B"]) 

78 y_sample = pd.Series(np.random.rand(100), name = "Target") 

79 

80 train_and_check_model(self, model, x_sample, y_sample) 

81 

82 # Test model serialization 

83 file_type= model.save_regressor(self.temp_dir.name, "test_model") 

84 

85 path_to_regressor= os.path.join(self.temp_dir.name, f'test_model.{file_type}') 

86 test_regressor: AbstractMLModel = ModelFactory.load_model(path_to_regressor) 

87 # Test metadata 

88 self.assertIsInstance(model.metadata, ModelMetadata, "Meta data not defined properly") 

89 self.assertEqual(model.metadata.addmo_class, model_class.__name__, "Model class mismatch") 

90 self.assertIsInstance(test_regressor, type(model), "Loaded model class mismatch") 

91 

92 

93if __name__ == "__main__": 

94 unittest.main() 

95