Coverage for tests/test_abstractmodel.py: 96%
53 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-08-31 13:05 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-08-31 13:05 +0000
1import unittest
2import tempfile
3import os
4import pandas as pd
5import numpy as np
6from addmo.s3_model_tuning.models.abstract_model import AbstractMLModel
7from addmo.s3_model_tuning.models.keras_models import BaseKerasModel
8from addmo.s3_model_tuning.models.scikit_learn_models import BaseScikitLearnModel
9from addmo.s3_model_tuning.models.abstract_model import ModelMetadata
10from addmo.s3_model_tuning.models.model_factory import ModelFactory
13def get_subclasses(base_class):
14 """
15 Dynamically get all the subclasses which contain the models for the given base class.
16 """
17 subclasses = set()
18 work = [base_class]
19 while work:
20 parent = work.pop()
21 for child in parent.__subclasses__():
22 if child not in subclasses:
23 subclasses.add(child)
24 work.append(child)
25 return list(subclasses)
27def train_and_check_model(self, model, x_sample, y_sample):
28 """
29 Test the fit and predict functionality of the model.
30 """
31 model.fit(x_sample, y_sample)
32 predictions = model.predict(x_sample)
33 self.assertEqual(len(predictions), len(y_sample))
34 self.assertIsInstance(predictions, np.ndarray)
36class TestBaseMLModel(unittest.TestCase):
37 """
38 Unit tests for base class models.
39 """
41 base_class = BaseScikitLearnModel # Change this to test different base classes
43 @classmethod
44 def setUpClass(cls):
45 """
46 Find all subclasses of the base class.
47 """
49 cls.temp_dir = tempfile.TemporaryDirectory()
50 cls.subclasses = get_subclasses(cls.base_class)
52 if not cls.subclasses:
53 raise ValueError(f"No subclasses found for {cls.base_class.__name__}")
55 @classmethod
56 def tearDownClass(cls):
57 """
58 Cleanup temp directory after tests.
59 """
60 cls.temp_dir.cleanup()
63 def test_all_models(self):
64 """
65 Test all registered models that are subclasses of AbstractMLModel.
66 """
68 # Ensure regressor is not None
69 for model_class in self.subclasses:
70 with self.subTest(model=model_class.__name__):
72 model = model_class() # Instantiate model
73 print(f"\n Testing model {model_class.__name__}")
75 self.assertIsNotNone(model.regressor, f"{model_class.__name__} should have a regressor")
77 x_sample = pd.DataFrame(np.random.rand(100, 2), columns=["A", "B"])
78 y_sample = pd.Series(np.random.rand(100), name = "Target")
80 train_and_check_model(self, model, x_sample, y_sample)
82 # Test model serialization
83 file_type= model.save_regressor(self.temp_dir.name, "test_model")
85 path_to_regressor= os.path.join(self.temp_dir.name, f'test_model.{file_type}')
86 test_regressor: AbstractMLModel = ModelFactory.load_model(path_to_regressor)
87 # Test metadata
88 self.assertIsInstance(model.metadata, ModelMetadata, "Meta data not defined properly")
89 self.assertEqual(model.metadata.addmo_class, model_class.__name__, "Model class mismatch")
90 self.assertIsInstance(test_regressor, type(model), "Loaded model class mismatch")
93if __name__ == "__main__":
94 unittest.main()