Coverage for addmo/s3_model_tuning/scoring/validation_splitting/splitter_factory.py: 70%

20 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-08-31 13:05 +0000

1import inspect 

2 

3from sklearn import model_selection 

4 

5from addmo.s3_model_tuning.scoring.validation_splitting import custom_splitters 

6from addmo.s3_model_tuning.config.model_tuning_config import ModelTunerConfig 

7from addmo.s3_model_tuning.scoring.validation_splitting.abstract_splitter import ( 

8 AbstractSplitter, 

9) 

10 

11 

12class SplitterFactory: 

13 """ 

14 Factory for creating custom splitter instances. 

15 """ 

16 

17 @staticmethod 

18 def splitter_factory(config: ModelTunerConfig) -> AbstractSplitter: 

19 """Get the custom splitter instance dynamically or use scikit-learn splitters.""" 

20 

21 # if splitter is custom 

22 if hasattr(custom_splitters, config.validation_score_splitting): 

23 custom_splitter_class = getattr( 

24 custom_splitters, config.validation_score_splitting 

25 ) 

26 if config.validation_score_splitting_kwargs is None: 

27 return custom_splitter_class() 

28 else: 

29 return custom_splitter_class(**config.validation_score_splitting_kwargs) 

30 

31 # if splitter is from scikit-learn 

32 elif hasattr(model_selection, config.validation_score_splitting): 

33 scikit_learn_splitter_class = getattr( 

34 model_selection, config.validation_score_splitting 

35 ) 

36 if config.validation_score_splitting_kwargs is None: 

37 return scikit_learn_splitter_class() 

38 else: 

39 return scikit_learn_splitter_class( 

40 **config.validation_score_splitting_kwargs 

41 ) 

42 

43 # if splitter is not found 

44 else: 

45 # get the names of all custom splitters for error message 

46 custom_splitter_names = [ 

47 name 

48 for name, obj in inspect.getmembers(custom_splitters) 

49 if inspect.isclass(obj) 

50 and issubclass(obj, AbstractSplitter) 

51 and name != "AbstractSplitter" 

52 ] 

53 

54 raise ValueError( 

55 f"Unknown splitter type: {config.validation_score_splitting}. " 

56 f"Available custom splitter are:" 

57 f" {', '.join(custom_splitter_names)}. " 

58 f"You can also use any splitter from scikit-learn, like KFold, " 

59 f"PredefinedSplit, TimeSeriesSplit, etc." 

60 )