Coverage for aixcalibuha/sensitivity_analysis/pawn.py: 95%
41 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-01-27 10:48 +0000
« prev ^ index » next coverage.py v7.4.1, created at 2024-01-27 10:48 +0000
1"""
2Adds the PAWNAnalyzer to the available
3classes of sensitivity analysis.
4"""
6from SALib.sample import sobol
7from SALib.sample import morris
8from SALib.sample import fast_sampler as fast
9from SALib.analyze import pawn as analyze_pawn
10import numpy as np
11from aixcalibuha.sensitivity_analysis import SenAnalyzer
12from aixcalibuha import CalibrationClass
15class PAWNAnalyzer(SenAnalyzer):
16 """
17 PAWN method from SALib https://salib.readthedocs.io/en/latest/api.html#pawn-sensitivity-analysis
18 Density-based method which computes the PAWN index at 'min', 'max', 'mean',
19 'median' and coefficient of variation 'cv'.
21 Additional arguments:
23 :keyword bool calc_second_order:
24 Default True, used for the sampler of the sobol-method
25 :keyword int s:
26 Default 10, used for the pawn-method.
27 :keyword str sampler:
28 Which sampler should be used. Default sobol.
29 Choose between 'sobol', 'morris' and 'fast'.
30 :keyword int num_levels:
31 Default num_samples, used for the sampler of the morris-method.
32 :keyword optimal_trajectories:
33 Used for the sampler of the morris-method.
34 :keyword bool local_optimization:
35 Default True, used for the sampler of the morris-method.
36 :keyword int M:
37 Default 4, used for the sampler of the fast-method.
38 """
40 def __init__(self, sim_api, **kwargs):
41 super().__init__(
42 sim_api=sim_api,
43 **kwargs)
44 # Set additional kwargs
45 self.calc_second_order = kwargs.pop("calc_second_order", True)
46 self.s = kwargs.pop("s", 10)
47 self.sampler = kwargs.pop("sampler", 'sobol')
48 self.num_levels = kwargs.pop("num_levels", self.num_samples)
49 self.optimal_trajectories = kwargs.pop("optimal_trajectories", None)
50 self.local_optimization = kwargs.pop("local_optimization", True)
51 self.M = kwargs.pop("M", 4)
53 @property
54 def analysis_variables(self):
55 """The analysis variables of the PAWN method"""
56 return ['minimum', 'mean', 'median', 'maximum', 'CV']
58 def analysis_function(self, x, y):
59 """
60 Use the SALib.analyze.pawn method to analyze the simulation results.
62 :param np.array x:
63 placeholder for the `X` parameter of the morris method not used for sobol
64 :param np.array y:
65 The NumPy array containing the model outputs
66 :return:
67 returns the result of the SALib.analyze.pawn method (from the documentation:
68 This implementation reports the PAWN index at the min, mean, median, and
69 max across the slides/conditioning intervals as well as the coefficient of
70 variation (``CV``). The median value is the typically reported value. As
71 the ``CV`` is (standard deviation / mean), it indicates the level of
72 variability across the slides, with values closer to zero indicating lower
73 variation.)
74 """
75 return analyze_pawn.analyze(self.problem, x, y,
76 S=self.s)
78 def create_sampler_demand(self):
79 """
80 Function to create the sampler parameters for the sampler method
81 """
82 if self.sampler == 'sobol':
83 return {'calc_second_order': self.calc_second_order}
84 if self.sampler == 'morris':
85 return {'num_levels': self.num_levels,
86 'optimal_trajectories': self.optimal_trajectories,
87 'local_optimization': self.local_optimization}
88 if self.sampler == 'fast':
89 return {'M': self.M}
90 raise NotImplementedError(f'{self.sampler} is not implemented yet')
92 def generate_samples(self):
93 """
94 Run the sampler for the selected sampler and return the results.
96 :return:
97 The list of samples generated as a NumPy array with one row per sample
98 and each row containing one value for each variable name in `problem['names']`.
99 :rtype: np.ndarray
100 """
101 if self.sampler == 'sobol':
102 return sobol.sample(self.problem,
103 N=self.num_samples,
104 **self.create_sampler_demand())
105 if self.sampler == 'morris':
106 return morris.sample(self.problem,
107 N=self.num_samples,
108 **self.create_sampler_demand())
109 if self.sampler == 'fast':
110 return fast.sample(self.problem,
111 N=self.num_samples,
112 **self.create_sampler_demand())
113 raise NotImplementedError(f'{self.sampler} is not implemented yet')
115 def _get_res_dict(self, result: dict, cal_class: CalibrationClass, analysis_variable: str):
116 """
117 Convert the result object to a dict with the key
118 being the variable name and the value being the result
119 associated to self.analysis_variable.
120 """
121 if result is None:
122 names = cal_class.tuner_paras.get_names()
123 return dict(zip(names, np.zeros(len(names))))
124 return dict(zip(result['names'], result[analysis_variable]))