Coverage for tests/test_data_tuning.py: 92%

52 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-08-31 13:05 +0000

1import os 

2import unittest 

3import inspect 

4import tempfile 

5import pandas as pd 

6from unittest.mock import patch, MagicMock 

7from addmo.util.experiment_logger import ExperimentLogger,LocalLogger, WandbLogger 

8from addmo.s1_data_tuning_auto.config.data_tuning_auto_config import DataTuningAutoSetup 

9from addmo.s1_data_tuning_auto.data_tuner_auto import DataTunerAuto 

10from addmo.s2_data_tuning.config.data_tuning_config import DataTuningFixedConfig 

11from addmo.s2_data_tuning.data_tuner_fixed import DataTunerByConfig 

12from addmo.util.load_save import load_data 

13from addmo.util.data_handling import split_target_features 

14 

15 

16# Define all the tuning classes here for testing: 

17TUNER_CLASSES = [DataTunerAuto, DataTunerByConfig] 

18 

19CONFIG_MAP = { 

20 DataTunerAuto: DataTuningAutoSetup, 

21 DataTunerByConfig: DataTuningFixedConfig,} 

22class TestAllDataTuners(unittest.TestCase): 

23 """ 

24 For each tuner class, discover all methods named `tune_*` e.g. tune_auto, tune_fixed 

25 and invoke them with the “correct” arguments. Then assert the returned `tuned_x` is a non-empty DataFrame & `y` is non-empty Series. 

26 """ 

27 

28 def _instantiate_config(self, tuner_cls): 

29 config_cls = CONFIG_MAP[tuner_cls] 

30 

31 return config_cls() 

32 

33 def _make_data(self, config): 

34 raw = load_data(config.abs_path_to_data) 

35 x, y = split_target_features(config.name_of_target, raw) 

36 return raw, x, y 

37 

38 def test_tuners(self): 

39 for tuner_cls in TUNER_CLASSES: 

40 with self.subTest(tuner=tuner_cls.__name__): 

41 config = self._instantiate_config(tuner_cls) 

42 tuner = tuner_cls(config=config) 

43 raw, x, y = self._make_data(config) 

44 for name, method in inspect.getmembers(tuner, predicate=inspect.ismethod): 

45 if not name.startswith("tune_"): 

46 continue 

47 

48 with self.subTest(method=name): 

49 sig = inspect.signature(method) 

50 num_args = len(sig.parameters) 

51 

52 if num_args == 0: 

53 tuned_x = method() 

54 elif num_args == 1: 

55 tuned_x = method(raw) 

56 elif num_args == 2: 

57 tuned_x = method(x, y) 

58 else: 

59 self.skipTest(f"{name} has unexpected signature {sig}") 

60 

61 self.assertIsInstance(tuned_x, pd.DataFrame, 

62 f"{name} must return DataFrame") 

63 self.assertFalse(tuned_x.empty, 

64 f"{name} returned empty DataFrame") 

65 

66 y_out = getattr(tuner, "y", y) 

67 self.assertIsInstance(y_out, pd.Series, 

68 f"{name}: y must be Series") 

69 self.assertFalse(y_out.empty, 

70 f"{name}: y must not be empty") 

71 

72 joined = pd.concat([y_out, tuned_x], axis=1).bfill() 

73 self.assertIsInstance(joined, pd.DataFrame) 

74 self.assertFalse(joined.empty, 

75 f"{name}: joined DataFrame empty") 

76 

77 

78if __name__ == "__main__": 

79 unittest.main()