Coverage for agentlib/models/scipy_model.py: 87%

68 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-04-07 16:27 +0000

1"""This module contains the ScipyStateSpaceModel class.""" 

2 

3import logging 

4from typing import Union 

5 

6import numpy as np 

7from pydantic import ValidationError, model_validator 

8 

9from agentlib.core.errors import OptionalDependencyError 

10 

11try: 

12 from scipy import signal 

13 from scipy import interpolate, integrate 

14except ImportError as err: 

15 raise OptionalDependencyError( 

16 dependency_name="scipy", dependency_install="scipy", used_object="scipy-model" 

17 ) from err 

18 

19 

20from agentlib.core import Model, ModelConfig 

21 

22 

23logger = logging.getLogger(__name__) 

24 

25 

26class ScipyStateSpaceModelConfig(ModelConfig): 

27 """Customize config of Model.""" 

28 

29 system: Union[dict, list, tuple, signal.StateSpace] 

30 

31 @model_validator(mode="before") 

32 @classmethod 

33 def check_system(cls, values): 

34 """Root validator to check if the given system is valid.""" 

35 # pylint: disable=no-self-argument,no-self-use 

36 system = values.get("system") 

37 if isinstance(system, (tuple, list)): 

38 # Check correct input size 

39 assert ( 

40 len(system) == 4 

41 ), "State space representation requires exactly 4 matrices" 

42 elif isinstance(system, dict): 

43 assert "A" in system, "State space representation requires key 'A'" 

44 assert "B" in system, "State space representation requires key 'B'" 

45 assert "C" in system, "State space representation requires key 'C'" 

46 assert "D" in system, "State space representation requires key 'D'" 

47 system = [system["A"], system["B"], system["C"], system["D"]] 

48 elif isinstance(system, signal.ltisys.StateSpaceContinuous): 

49 return values 

50 else: 

51 logger.error( 

52 "Given system is of type %s but should be list, tuple or dict", 

53 type(system), 

54 ) 

55 raise ValidationError 

56 # Setup the system 

57 system = signal.StateSpace(*system) 

58 # Check dimensions with inputs, states and outputs: 

59 n_inputs = len(values.get("inputs", [])) 

60 n_outputs = len(values.get("outputs", [])) 

61 n_states = len(values.get("states", [])) 

62 assert ( 

63 system.A.shape[0] == n_states 

64 ), "Given system matrix A does not match size of states" 

65 assert ( 

66 system.A.shape[1] == n_states 

67 ), "Given system matrix A does not match size of states" 

68 assert ( 

69 system.B.shape[0] == n_states 

70 ), "Given system matrix B does not match size of states" 

71 assert ( 

72 system.B.shape[1] == n_inputs 

73 ), "Given system matrix B does not match size of inputs" 

74 assert ( 

75 system.C.shape[0] == n_outputs 

76 ), "Given system matrix C does not match size of outputs" 

77 assert ( 

78 system.C.shape[1] == n_states 

79 ), "Given system matrix C does not match size of states" 

80 assert ( 

81 system.D.shape[0] == n_outputs 

82 ), "Given system matrix D does not match size of outputs" 

83 assert ( 

84 system.D.shape[1] == n_inputs 

85 ), "Given system matrix D does not match size of inputs" 

86 values["system"] = system 

87 return values 

88 

89 

90class ScipyStateSpaceModel(Model): 

91 """ 

92 This class holds a scipy StateSpace model. 

93 It uses scipy.signal.lti as a system and the 

94 odeint as integrator. 

95 """ 

96 

97 config: ScipyStateSpaceModelConfig 

98 

99 def __init__(self, **kwargs): 

100 super().__init__(**kwargs) 

101 # Check if system was correctly set up 

102 assert isinstance(self.config.system, signal.StateSpace) 

103 

104 def do_step(self, *, t_start, t_sample=None): 

105 if t_sample is None: 

106 t_sample = self.dt 

107 t = self._create_time_samples(t_sample=t_sample) + t_start 

108 u = np.array([[inp.value for inp in self.inputs] for _ in t]) 

109 x0 = np.array([sta.value for sta in self.states]) 

110 

111 ufunc = interpolate.interp1d(t, u, kind="linear", axis=0, bounds_error=False) 

112 

113 def f_dot(x, t, sys, ufunc): 

114 """The vector field of the linear system.""" 

115 return np.dot(sys.A, x) + np.squeeze( 

116 np.dot(sys.B, np.nan_to_num(ufunc([t])).flatten()) 

117 ) 

118 

119 x = integrate.odeint(f_dot, x0, t, args=(self.config.system, ufunc)) 

120 y = np.dot(self.config.system.C, np.transpose(x)) + np.dot( 

121 self.config.system.D, np.transpose(u) 

122 ) 

123 

124 y = np.squeeze(np.transpose(y)) 

125 

126 # Set states based on shape: 

127 if len(y.shape) == 1: 

128 self._set_output_values( 

129 names=self.get_output_names(), values=[y[-1].item()] 

130 ) 

131 else: 

132 self._set_output_values( 

133 names=self.get_output_names(), values=y[-1, :].tolist() 

134 ) 

135 if len(x.shape) == 1: 

136 self._set_state_values(names=self.get_state_names(), values=[x[-1].item()]) 

137 else: 

138 self._set_state_values( 

139 names=self.get_state_names(), values=x[-1, :].tolist() 

140 ) 

141 return True 

142 

143 def initialize(self, **kwargs): 

144 pass