Coverage for tests/test_config_loading.py: 98%

43 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-08-31 13:05 +0000

1import unittest 

2from pydantic import ValidationError 

3from addmo.s1_data_tuning_auto.config.data_tuning_auto_config import DataTuningAutoSetup 

4from addmo.s2_data_tuning.config.data_tuning_config import DataTuningFixedConfig 

5from addmo.s3_model_tuning.config.model_tuning_config import ModelTunerConfig, ModelTuningExperimentConfig 

6 

7class TestConfigLoading(unittest.TestCase): 

8 """ 

9 Use this test if new pydantic config classes are added to ADDMo 

10 """ 

11 def test_required_fields_data_config(self): 

12 """ 

13 Check if fields necessary to run the training and plotting utilities for data tuning are added to config classes. 

14 """ 

15 required_fields = ["abs_path_to_data", "name_of_target","name_of_raw_data","name_of_tuning"] 

16 config_classes = [DataTuningAutoSetup, DataTuningFixedConfig] # Add new config classes here 

17 for ConfigClass in config_classes: 

18 with self.subTest(config=ConfigClass.__name__): 

19 config = ConfigClass() 

20 for field in required_fields: 

21 self.assertTrue( 

22 hasattr(config, field), 

23 msg=f"{ConfigClass.__name__} is missing required field '{field}'" 

24 ) 

25 def test_required_fields_model_config(self): 

26 """ 

27 Check if fields necessary to run the model tuner are added to model experiment config class. 

28 """ 

29 required_fields=["abs_path_to_data", "name_of_data_tuning_experiment","name_of_model_tuning_experiment","name_of_target"] 

30 config_classes = [ModelTuningExperimentConfig] # Add model tuning experiment classes here 

31 for ConfigClass in config_classes: 

32 with self.subTest(config=ConfigClass.__name__): 

33 config = ConfigClass() 

34 for field in required_fields: 

35 self.assertTrue( 

36 hasattr(config, field), 

37 msg=f"{ConfigClass.__name__} is missing required field '{field}'" 

38 ) 

39 

40 def test_required_fields_model_tuner(self): 

41 """ 

42 Check if fields necessary to run the model tuner are added to config class. 

43 """ 

44 required_fields = ["models","trainings_per_model"] 

45 config_classes = [ModelTunerConfig] 

46 for ConfigClass in config_classes: 

47 with self.subTest(config=ConfigClass.__name__): 

48 config = ConfigClass() 

49 for field in required_fields: 

50 self.assertTrue( 

51 hasattr(config, field), 

52 msg=f"{ConfigClass.__name__} is missing required field '{field}'" 

53 ) 

54 

55 def test_data_tuning_config_enforces_str_types(self): 

56 """ 

57 For every pydantic config that should use str fields, injecting a non-str 

58 must raise ValidationError. 

59 """ 

60 config_classes = [DataTuningAutoSetup, DataTuningFixedConfig] 

61 for ConfigClass in config_classes: 

62 with self.subTest(config=ConfigClass.__name__): 

63 valid_kwargs = { 

64 name: "dummy" 

65 for name, field in ConfigClass.model_fields.items() 

66 if field.annotation == str 

67 } 

68 # Test each str field rejects non-str 

69 for name, field in ConfigClass.model_fields.items(): 

70 if field.annotation == str: 

71 bad_kwargs = valid_kwargs.copy() 

72 bad_kwargs[name] = 123 

73 with self.assertRaises(ValidationError, msg=f"{ConfigClass.__name__}.{name} did not validate"): 

74 ConfigClass(**bad_kwargs) 

75if __name__ == "__main__": 

76 unittest.main()