Coverage for addmo/s3_model_tuning/scoring/metrics/metric_factory.py: 76%
17 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-08-31 13:05 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-08-31 13:05 +0000
1import inspect
3from sklearn import metrics as sk_metrics
5from addmo.s3_model_tuning.scoring.metrics import custom_metrics
6from addmo.s3_model_tuning.scoring.metrics.abstract_metric import AbstractMetric
9class MetricFactory:
10 """
11 Factory for creating metric instances.
12 """
14 @staticmethod
15 def metric_factory(metric_name, metric_kwargs: dict = None):
16 """Get the custom splitter instance dynamically or use scikit-learn splitters."""
18 # If metric is custom
19 if hasattr(custom_metrics, metric_name):
20 custom_metric_class = getattr(custom_metrics, metric_name)
21 return custom_metric_class(metric_kwargs)
23 # If metric is from scikit-learn
24 elif metric_name in sk_metrics.get_scorer_names():
25 # Possible regression metrics are:
26 # explained_variance, max_error, neg_mean_absolute_error, neg_mean_squared_error,
27 # neg_root_mean_squared_error, neg_mean_squared_log_error,
28 # neg_root_mean_squared_log_error, neg_median_absolute_error,
29 # r2, neg_mean_poisson_deviance, neg_mean_gamma_deviance,
30 # neg_mean_absolute_percentage_error, d2_absolute_error_score,
31 # d2_pinball_score, d2_tweedie_score
33 sk_metric = sk_metrics.get_scorer(metric_name)
35 # Customize metric with additional kwargs if provided
36 if metric_kwargs is not None:
37 sk_metric = sk_metrics.make_scorer(
38 sk_metric._score_func, **metric_kwargs
39 )
40 return sk_metric
42 # If metric is not found
43 else:
44 # get the names of all custom metrics for error message
45 custom_metric_names = [
46 name
47 for name, obj in inspect.getmembers(custom_metrics)
48 if inspect.isclass(obj)
49 and issubclass(obj, AbstractMetric)
50 and not inspect.isabstract(obj)
51 ]
53 raise ValueError(
54 f"Unknown metric type: {metric_name}. "
55 f"Available custom metrics are:"
56 f" {', '.join(custom_metric_names)}. "
57 f"You can also use any metric from scikit-learn, "
58 f"like r2, neg_mean_absolute_error, d2_pinball_score, etc."
59 )