Coverage for tests/test_data_tuning.py: 92%
52 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-08-31 13:05 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-08-31 13:05 +0000
1import os
2import unittest
3import inspect
4import tempfile
5import pandas as pd
6from unittest.mock import patch, MagicMock
7from addmo.util.experiment_logger import ExperimentLogger,LocalLogger, WandbLogger
8from addmo.s1_data_tuning_auto.config.data_tuning_auto_config import DataTuningAutoSetup
9from addmo.s1_data_tuning_auto.data_tuner_auto import DataTunerAuto
10from addmo.s2_data_tuning.config.data_tuning_config import DataTuningFixedConfig
11from addmo.s2_data_tuning.data_tuner_fixed import DataTunerByConfig
12from addmo.util.load_save import load_data
13from addmo.util.data_handling import split_target_features
16# Define all the tuning classes here for testing:
17TUNER_CLASSES = [DataTunerAuto, DataTunerByConfig]
19CONFIG_MAP = {
20 DataTunerAuto: DataTuningAutoSetup,
21 DataTunerByConfig: DataTuningFixedConfig,}
22class TestAllDataTuners(unittest.TestCase):
23 """
24 For each tuner class, discover all methods named `tune_*` e.g. tune_auto, tune_fixed
25 and invoke them with the “correct” arguments. Then assert the returned `tuned_x` is a non-empty DataFrame & `y` is non-empty Series.
26 """
28 def _instantiate_config(self, tuner_cls):
29 config_cls = CONFIG_MAP[tuner_cls]
31 return config_cls()
33 def _make_data(self, config):
34 raw = load_data(config.abs_path_to_data)
35 x, y = split_target_features(config.name_of_target, raw)
36 return raw, x, y
38 def test_tuners(self):
39 for tuner_cls in TUNER_CLASSES:
40 with self.subTest(tuner=tuner_cls.__name__):
41 config = self._instantiate_config(tuner_cls)
42 tuner = tuner_cls(config=config)
43 raw, x, y = self._make_data(config)
44 for name, method in inspect.getmembers(tuner, predicate=inspect.ismethod):
45 if not name.startswith("tune_"):
46 continue
48 with self.subTest(method=name):
49 sig = inspect.signature(method)
50 num_args = len(sig.parameters)
52 if num_args == 0:
53 tuned_x = method()
54 elif num_args == 1:
55 tuned_x = method(raw)
56 elif num_args == 2:
57 tuned_x = method(x, y)
58 else:
59 self.skipTest(f"{name} has unexpected signature {sig}")
61 self.assertIsInstance(tuned_x, pd.DataFrame,
62 f"{name} must return DataFrame")
63 self.assertFalse(tuned_x.empty,
64 f"{name} returned empty DataFrame")
66 y_out = getattr(tuner, "y", y)
67 self.assertIsInstance(y_out, pd.Series,
68 f"{name}: y must be Series")
69 self.assertFalse(y_out.empty,
70 f"{name}: y must not be empty")
72 joined = pd.concat([y_out, tuned_x], axis=1).bfill()
73 self.assertIsInstance(joined, pd.DataFrame)
74 self.assertFalse(joined.empty,
75 f"{name}: joined DataFrame empty")
78if __name__ == "__main__":
79 unittest.main()