Coverage for addmo/s3_model_tuning/model_tuner.py: 100%
39 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-08-31 13:05 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-08-31 13:05 +0000
1from copy import deepcopy
3from sklearn.metrics import root_mean_squared_error
5from addmo.s3_model_tuning.config.model_tuning_config import ModelTunerConfig
6from addmo.s3_model_tuning.models.model_factory import ModelFactory
7from addmo.s3_model_tuning.hyperparameter_tuning.hyparam_tuning_factory import (
8 HyperparameterTunerFactory,
9)
10from addmo.s3_model_tuning.scoring.validator_factory import ValidatorFactory
13class ModelTuner:
14 def __init__(self, config: ModelTunerConfig):
15 self.config = config
16 self.scorer = ValidatorFactory.ValidatorFactory(self.config)
17 self.tuner = HyperparameterTunerFactory.tuner_factory(self.config, self.scorer)
19 def tune_model(self, model_name: str, x_train_val, y_train_val):
20 """
21 Tune a model and return the best model fitted to training and validation system_data.
22 """
23 model = ModelFactory.model_factory(model_name)
25 best_params = self.tuner.tune(
26 model, x_train_val, y_train_val, **self.config.hyperparameter_tuning_kwargs
27 )
29 model.set_params(best_params)
31 # due to refitting on each fold, the validation score must be calculated before fitting
32 model.validation_score = self.scorer.score_validation(
33 model, x_train_val, y_train_val
34 )
36 # refit the model on the whole training and validation system_data and get best model
37 fitted_models = []
38 for i in range(self.config.trainings_per_model):
39 _model = deepcopy(model)
40 _model.fit(x_train_val, y_train_val)
41 y_pred = _model.predict(x_train_val)
42 _model.fit_error = root_mean_squared_error(y_train_val, y_pred)
43 print(f"Model training {i} fit error: {_model.fit_error}")
44 fitted_models.append(_model)
45 model = min(fitted_models, key=lambda x: x.fit_error)
47 return model
49 def tune_all_models(self, x_train_val, y_train_val):
50 """
51 Tune all models and return the best model fitted to training and validation system_data.
52 """
53 model_dict = {}
54 for model_name in self.config.models:
55 model = self.tune_model(model_name, x_train_val, y_train_val)
56 model_dict[model_name] = model
57 return model_dict
59 def get_model_validation_score(self, model_dict, model_name):
60 """
61 Get the model validation score from the model dictionary.
62 """
63 return model_dict[model_name].validation_score
65 def get_best_model_name(self, model_dict):
66 """
67 Get the best model name from the model dictionary.
68 """
69 best_model_name = max(
70 model_dict, key=lambda x: self.get_model_validation_score(model_dict, x)
71 )
72 return best_model_name
73 def get_model(self, model_dict, model_name):
74 """
75 Get the model from the model dictionary.
76 """
77 return model_dict[model_name]