Coverage for addmo/s3_model_tuning/hyperparameter_tuning/hyparam_tuning_factory.py: 73%

15 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-08-31 13:05 +0000

1from addmo.s3_model_tuning.hyperparameter_tuning.hyperparameter_tuner import ( 

2 OptunaTuner, 

3 NoTuningTuner, 

4 GridSearchTuner, 

5) 

6from addmo.s3_model_tuning.models.abstract_model import AbstractMLModel 

7from addmo.s3_model_tuning.config.model_tuning_config import ModelTunerConfig 

8from addmo.s3_model_tuning.hyperparameter_tuning.abstract_hyparam_tuner import ( 

9 AbstractHyParamTuner, 

10) 

11from addmo.s3_model_tuning.scoring.abstract_scorer import ValidationScoring 

12 

13 

14class HyperparameterTunerFactory: 

15 """ 

16 Factory for creating hyperparameter tuner instances. 

17 """ 

18 

19 @staticmethod 

20 def tuner_factory(config: ModelTunerConfig, scorer: ValidationScoring 

21 ) -> AbstractHyParamTuner: 

22 """ 

23 Creates a hyperparameter tuner based on the specified type. 

24 :param tuner_type: Type of tuner to create (e.g., "grid", "none", "optuna"). 

25 :param model: The machine learning model for tuning. 

26 :param scoring: Scoring method to use. 

27 :return: Instance of a hyperparameter tuner. 

28 """ 

29 

30 if config.hyperparameter_tuning_type == "NoTuningTuner": 

31 return NoTuningTuner(config, scorer) 

32 elif config.hyperparameter_tuning_type == "OptunaTuner": 

33 return OptunaTuner(config, scorer) 

34 elif config.hyperparameter_tuning_type == "GridSearchTuner": 

35 return GridSearchTuner(config, scorer) 

36 else: 

37 raise ValueError("Unknown tuner type: {}".format(config.hyperparameter_tuning_type))