Coverage for agentlib/models/scipy_model.py: 87%
68 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-04-07 16:27 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-04-07 16:27 +0000
1"""This module contains the ScipyStateSpaceModel class."""
3import logging
4from typing import Union
6import numpy as np
7from pydantic import ValidationError, model_validator
9from agentlib.core.errors import OptionalDependencyError
11try:
12 from scipy import signal
13 from scipy import interpolate, integrate
14except ImportError as err:
15 raise OptionalDependencyError(
16 dependency_name="scipy", dependency_install="scipy", used_object="scipy-model"
17 ) from err
20from agentlib.core import Model, ModelConfig
23logger = logging.getLogger(__name__)
26class ScipyStateSpaceModelConfig(ModelConfig):
27 """Customize config of Model."""
29 system: Union[dict, list, tuple, signal.StateSpace]
31 @model_validator(mode="before")
32 @classmethod
33 def check_system(cls, values):
34 """Root validator to check if the given system is valid."""
35 # pylint: disable=no-self-argument,no-self-use
36 system = values.get("system")
37 if isinstance(system, (tuple, list)):
38 # Check correct input size
39 assert (
40 len(system) == 4
41 ), "State space representation requires exactly 4 matrices"
42 elif isinstance(system, dict):
43 assert "A" in system, "State space representation requires key 'A'"
44 assert "B" in system, "State space representation requires key 'B'"
45 assert "C" in system, "State space representation requires key 'C'"
46 assert "D" in system, "State space representation requires key 'D'"
47 system = [system["A"], system["B"], system["C"], system["D"]]
48 elif isinstance(system, signal.ltisys.StateSpaceContinuous):
49 return values
50 else:
51 logger.error(
52 "Given system is of type %s but should be list, tuple or dict",
53 type(system),
54 )
55 raise ValidationError
56 # Setup the system
57 system = signal.StateSpace(*system)
58 # Check dimensions with inputs, states and outputs:
59 n_inputs = len(values.get("inputs", []))
60 n_outputs = len(values.get("outputs", []))
61 n_states = len(values.get("states", []))
62 assert (
63 system.A.shape[0] == n_states
64 ), "Given system matrix A does not match size of states"
65 assert (
66 system.A.shape[1] == n_states
67 ), "Given system matrix A does not match size of states"
68 assert (
69 system.B.shape[0] == n_states
70 ), "Given system matrix B does not match size of states"
71 assert (
72 system.B.shape[1] == n_inputs
73 ), "Given system matrix B does not match size of inputs"
74 assert (
75 system.C.shape[0] == n_outputs
76 ), "Given system matrix C does not match size of outputs"
77 assert (
78 system.C.shape[1] == n_states
79 ), "Given system matrix C does not match size of states"
80 assert (
81 system.D.shape[0] == n_outputs
82 ), "Given system matrix D does not match size of outputs"
83 assert (
84 system.D.shape[1] == n_inputs
85 ), "Given system matrix D does not match size of inputs"
86 values["system"] = system
87 return values
90class ScipyStateSpaceModel(Model):
91 """
92 This class holds a scipy StateSpace model.
93 It uses scipy.signal.lti as a system and the
94 odeint as integrator.
95 """
97 config: ScipyStateSpaceModelConfig
99 def __init__(self, **kwargs):
100 super().__init__(**kwargs)
101 # Check if system was correctly set up
102 assert isinstance(self.config.system, signal.StateSpace)
104 def do_step(self, *, t_start, t_sample=None):
105 if t_sample is None:
106 t_sample = self.dt
107 t = self._create_time_samples(t_sample=t_sample) + t_start
108 u = np.array([[inp.value for inp in self.inputs] for _ in t])
109 x0 = np.array([sta.value for sta in self.states])
111 ufunc = interpolate.interp1d(t, u, kind="linear", axis=0, bounds_error=False)
113 def f_dot(x, t, sys, ufunc):
114 """The vector field of the linear system."""
115 return np.dot(sys.A, x) + np.squeeze(
116 np.dot(sys.B, np.nan_to_num(ufunc([t])).flatten())
117 )
119 x = integrate.odeint(f_dot, x0, t, args=(self.config.system, ufunc))
120 y = np.dot(self.config.system.C, np.transpose(x)) + np.dot(
121 self.config.system.D, np.transpose(u)
122 )
124 y = np.squeeze(np.transpose(y))
126 # Set states based on shape:
127 if len(y.shape) == 1:
128 self._set_output_values(
129 names=self.get_output_names(), values=[y[-1].item()]
130 )
131 else:
132 self._set_output_values(
133 names=self.get_output_names(), values=y[-1, :].tolist()
134 )
135 if len(x.shape) == 1:
136 self._set_state_values(names=self.get_state_names(), values=[x[-1].item()])
137 else:
138 self._set_state_values(
139 names=self.get_state_names(), values=x[-1, :].tolist()
140 )
141 return True
143 def initialize(self, **kwargs):
144 pass