Coverage for addmo/s3_model_tuning/hyperparameter_tuning/hyperparameter_tuner.py: 77%

43 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-08-31 13:05 +0000

1import inspect 

2import optuna 

3import wandb 

4from sklearn.model_selection import GridSearchCV 

5from addmo.s3_model_tuning.models.abstract_model import AbstractMLModel 

6from addmo.s3_model_tuning.hyperparameter_tuning.abstract_hyparam_tuner import ( 

7 AbstractHyParamTuner, 

8) 

9from addmo.util.experiment_logger import ExperimentLogger 

10from addmo.util.experiment_logger import WandbLogger 

11 

12 

13class NoTuningTuner(AbstractHyParamTuner): 

14 """ 

15 Tuner implementation for no tuning. 

16 """ 

17 

18 def tune(self, model: AbstractMLModel, x_train_val, y_train_val, **kwargs): 

19 """ 

20 Returns the default hyperparameters without any tuning. 

21 """ 

22 

23 # if no kwargs are given, use default hyperparameters 

24 hyperparameter_set = kwargs.get("hyperparameter_set", None) 

25 if hyperparameter_set is None: 

26 hyperparameter_set = model.default_hyperparameter() 

27 print("No hyperparameter set given, will use default hyperparameters.") 

28 return hyperparameter_set 

29 

30 

31class OptunaTuner(AbstractHyParamTuner): 

32 def tune(self, model: AbstractMLModel, x_train_val, y_train_val, **kwargs): 

33 """ 

34 Perform hyperparameter tuning using Optuna. 

35 """ 

36 

37 # dirty hard-coded exceptions 

38 if model.__class__.__name__ == "ScikitSVR": 

39 n_jobs = 1 # SVR does not support parallel jobs (resulting in equal model scores across all trials) 

40 else: 

41 n_jobs = -1 

42 

43 

44 def objective(trial): 

45 hyperparameters = model.optuna_hyperparameter_suggest(trial) 

46 model.set_params(hyperparameters) 

47 score = self.scorer.score_validation(model, x_train_val, y_train_val) 

48 return score 

49 

50 study = optuna.create_study(direction="maximize") 

51 study.optimize( 

52 objective, 

53 n_trials=self.config.hyperparameter_tuning_kwargs["n_trials"], 

54 n_jobs=n_jobs, # The number of parallel jobs. If this argument is set to -1, the number 

55 # is set to CPU count. Parallel jobs may fail sequential logging to wandb. 

56 ) 

57 

58 # logging 

59 self._log_optuna_study(study, model) 

60 

61 # convert optuna params to model params 

62 best_params = model.optuna_hyperparameter_suggest(study.best_trial) 

63 return best_params 

64 

65 @staticmethod 

66 def _log_optuna_study(study, model): 

67 ExperimentLogger.log_artifact(study, "optuna_study", art_type="pkl") 

68 hyperparameter_range = {"hyperparameter_range": inspect.getsource( 

69 model.optuna_hyperparameter_suggest 

70 ) 

71 } 

72 study_results = { 

73 "optuna_study_best_params": study.best_params, 

74 "optuna_study_best_validation_score": study.best_value, 

75 } 

76 plots = { 

77 "optuna_plot_optimization_history": optuna.visualization.plot_optimization_history( 

78 study 

79 ), 

80 "optuna_plot_parallel_coordinate": optuna.visualization.plot_parallel_coordinate( 

81 study 

82 ), 

83 "optuna_plot_contour": optuna.visualization.plot_contour(study), 

84 "optuna_plot_param_importances": optuna.visualization.plot_param_importances( 

85 study 

86 ), 

87 } 

88 ExperimentLogger.log({**hyperparameter_range, **study_results, **plots}) 

89 

90 

91class GridSearchTuner(AbstractHyParamTuner): 

92 """ 

93 Tuner implementation using Grid Search. 

94 """ 

95 

96 def tune(self, model: AbstractMLModel, x_train_val, y_train_val, **kwargs): 

97 """ 

98 Perform hyperparameter tuning using Grid Search. 

99 """ 

100 

101 grid_search = GridSearchCV( 

102 model, 

103 model.grid_search_hyperparameter(), 

104 cv=self.scorer.splitter, 

105 scoring=self.scorer.metric, 

106 ) 

107 grid_search.fit(x_train_val, y_train_val) 

108 

109 best_params = grid_search.best_params_ 

110 return best_params