Coverage for addmo/s3_model_tuning/hyperparameter_tuning/hyparam_tuning_factory.py: 73%
15 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-08-31 13:05 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-08-31 13:05 +0000
1from addmo.s3_model_tuning.hyperparameter_tuning.hyperparameter_tuner import (
2 OptunaTuner,
3 NoTuningTuner,
4 GridSearchTuner,
5)
6from addmo.s3_model_tuning.models.abstract_model import AbstractMLModel
7from addmo.s3_model_tuning.config.model_tuning_config import ModelTunerConfig
8from addmo.s3_model_tuning.hyperparameter_tuning.abstract_hyparam_tuner import (
9 AbstractHyParamTuner,
10)
11from addmo.s3_model_tuning.scoring.abstract_scorer import ValidationScoring
14class HyperparameterTunerFactory:
15 """
16 Factory for creating hyperparameter tuner instances.
17 """
19 @staticmethod
20 def tuner_factory(config: ModelTunerConfig, scorer: ValidationScoring
21 ) -> AbstractHyParamTuner:
22 """
23 Creates a hyperparameter tuner based on the specified type.
24 :param tuner_type: Type of tuner to create (e.g., "grid", "none", "optuna").
25 :param model: The machine learning model for tuning.
26 :param scoring: Scoring method to use.
27 :return: Instance of a hyperparameter tuner.
28 """
30 if config.hyperparameter_tuning_type == "NoTuningTuner":
31 return NoTuningTuner(config, scorer)
32 elif config.hyperparameter_tuning_type == "OptunaTuner":
33 return OptunaTuner(config, scorer)
34 elif config.hyperparameter_tuning_type == "GridSearchTuner":
35 return GridSearchTuner(config, scorer)
36 else:
37 raise ValueError("Unknown tuner type: {}".format(config.hyperparameter_tuning_type))