Coverage for addmo/s3_model_tuning/scoring/validation_splitting/splitter_factory.py: 70%
20 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-08-31 13:05 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-08-31 13:05 +0000
1import inspect
3from sklearn import model_selection
5from addmo.s3_model_tuning.scoring.validation_splitting import custom_splitters
6from addmo.s3_model_tuning.config.model_tuning_config import ModelTunerConfig
7from addmo.s3_model_tuning.scoring.validation_splitting.abstract_splitter import (
8 AbstractSplitter,
9)
12class SplitterFactory:
13 """
14 Factory for creating custom splitter instances.
15 """
17 @staticmethod
18 def splitter_factory(config: ModelTunerConfig) -> AbstractSplitter:
19 """Get the custom splitter instance dynamically or use scikit-learn splitters."""
21 # if splitter is custom
22 if hasattr(custom_splitters, config.validation_score_splitting):
23 custom_splitter_class = getattr(
24 custom_splitters, config.validation_score_splitting
25 )
26 if config.validation_score_splitting_kwargs is None:
27 return custom_splitter_class()
28 else:
29 return custom_splitter_class(**config.validation_score_splitting_kwargs)
31 # if splitter is from scikit-learn
32 elif hasattr(model_selection, config.validation_score_splitting):
33 scikit_learn_splitter_class = getattr(
34 model_selection, config.validation_score_splitting
35 )
36 if config.validation_score_splitting_kwargs is None:
37 return scikit_learn_splitter_class()
38 else:
39 return scikit_learn_splitter_class(
40 **config.validation_score_splitting_kwargs
41 )
43 # if splitter is not found
44 else:
45 # get the names of all custom splitters for error message
46 custom_splitter_names = [
47 name
48 for name, obj in inspect.getmembers(custom_splitters)
49 if inspect.isclass(obj)
50 and issubclass(obj, AbstractSplitter)
51 and name != "AbstractSplitter"
52 ]
54 raise ValueError(
55 f"Unknown splitter type: {config.validation_score_splitting}. "
56 f"Available custom splitter are:"
57 f" {', '.join(custom_splitter_names)}. "
58 f"You can also use any splitter from scikit-learn, like KFold, "
59 f"PredefinedSplit, TimeSeriesSplit, etc."
60 )