Coverage for addmo/s3_model_tuning/scoring/metrics/metric_factory.py: 76%

17 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-08-31 13:05 +0000

1import inspect 

2 

3from sklearn import metrics as sk_metrics 

4 

5from addmo.s3_model_tuning.scoring.metrics import custom_metrics 

6from addmo.s3_model_tuning.scoring.metrics.abstract_metric import AbstractMetric 

7 

8 

9class MetricFactory: 

10 """ 

11 Factory for creating metric instances. 

12 """ 

13 

14 @staticmethod 

15 def metric_factory(metric_name, metric_kwargs: dict = None): 

16 """Get the custom splitter instance dynamically or use scikit-learn splitters.""" 

17 

18 # If metric is custom 

19 if hasattr(custom_metrics, metric_name): 

20 custom_metric_class = getattr(custom_metrics, metric_name) 

21 return custom_metric_class(metric_kwargs) 

22 

23 # If metric is from scikit-learn 

24 elif metric_name in sk_metrics.get_scorer_names(): 

25 # Possible regression metrics are: 

26 # explained_variance, max_error, neg_mean_absolute_error, neg_mean_squared_error, 

27 # neg_root_mean_squared_error, neg_mean_squared_log_error, 

28 # neg_root_mean_squared_log_error, neg_median_absolute_error, 

29 # r2, neg_mean_poisson_deviance, neg_mean_gamma_deviance, 

30 # neg_mean_absolute_percentage_error, d2_absolute_error_score, 

31 # d2_pinball_score, d2_tweedie_score 

32 

33 sk_metric = sk_metrics.get_scorer(metric_name) 

34 

35 # Customize metric with additional kwargs if provided 

36 if metric_kwargs is not None: 

37 sk_metric = sk_metrics.make_scorer( 

38 sk_metric._score_func, **metric_kwargs 

39 ) 

40 return sk_metric 

41 

42 # If metric is not found 

43 else: 

44 # get the names of all custom metrics for error message 

45 custom_metric_names = [ 

46 name 

47 for name, obj in inspect.getmembers(custom_metrics) 

48 if inspect.isclass(obj) 

49 and issubclass(obj, AbstractMetric) 

50 and not inspect.isabstract(obj) 

51 ] 

52 

53 raise ValueError( 

54 f"Unknown metric type: {metric_name}. " 

55 f"Available custom metrics are:" 

56 f" {', '.join(custom_metric_names)}. " 

57 f"You can also use any metric from scikit-learn, " 

58 f"like r2, neg_mean_absolute_error, d2_pinball_score, etc." 

59 )