Source code for agentlib.models.scipy_model

"""This module contains the ScipyStateSpaceModel class."""

import logging
from typing import Union

import numpy as np
from pydantic import ValidationError, model_validator

from agentlib.core.errors import OptionalDependencyError

try:
    from scipy import signal
    from scipy import interpolate, integrate
except ImportError as err:
    raise OptionalDependencyError(
        dependency_name="scipy", dependency_install="scipy", used_object="scipy-model"
    ) from err


from agentlib.core import Model, ModelConfig


logger = logging.getLogger(__name__)


[docs]class ScipyStateSpaceModelConfig(ModelConfig): """Customize config of Model.""" system: Union[dict, list, tuple, signal.StateSpace]
[docs] @model_validator(mode="before") @classmethod def check_system(cls, values): """Root validator to check if the given system is valid.""" # pylint: disable=no-self-argument,no-self-use system = values.get("system") if isinstance(system, (tuple, list)): # Check correct input size assert ( len(system) == 4 ), "State space representation requires exactly 4 matrices" elif isinstance(system, dict): assert "A" in system, "State space representation requires key 'A'" assert "B" in system, "State space representation requires key 'B'" assert "C" in system, "State space representation requires key 'C'" assert "D" in system, "State space representation requires key 'D'" system = [system["A"], system["B"], system["C"], system["D"]] elif isinstance(system, signal.ltisys.StateSpaceContinuous): return values else: logger.error( "Given system is of type %s but should be list, tuple or dict", type(system), ) raise ValidationError # Setup the system system = signal.StateSpace(*system) # Check dimensions with inputs, states and outputs: n_inputs = len(values.get("inputs", [])) n_outputs = len(values.get("outputs", [])) n_states = len(values.get("states", [])) assert ( system.A.shape[0] == n_states ), "Given system matrix A does not match size of states" assert ( system.A.shape[1] == n_states ), "Given system matrix A does not match size of states" assert ( system.B.shape[0] == n_states ), "Given system matrix B does not match size of states" assert ( system.B.shape[1] == n_inputs ), "Given system matrix B does not match size of inputs" assert ( system.C.shape[0] == n_outputs ), "Given system matrix C does not match size of outputs" assert ( system.C.shape[1] == n_states ), "Given system matrix C does not match size of states" assert ( system.D.shape[0] == n_outputs ), "Given system matrix D does not match size of outputs" assert ( system.D.shape[1] == n_inputs ), "Given system matrix D does not match size of inputs" values["system"] = system return values
[docs]class ScipyStateSpaceModel(Model): """ This class holds a scipy StateSpace model. It uses scipy.signal.lti as a system and the odeint as integrator. """ config: ScipyStateSpaceModelConfig def __init__(self, **kwargs): super().__init__(**kwargs) # Check if system was correctly set up assert isinstance(self.config.system, signal.StateSpace)
[docs] def do_step(self, *, t_start, t_sample=None): if t_sample is None: t_sample = self.dt t = self._create_time_samples(t_sample=t_sample) + t_start u = np.array([[inp.value for inp in self.inputs] for _ in t]) x0 = np.array([sta.value for sta in self.states]) ufunc = interpolate.interp1d(t, u, kind="linear", axis=0, bounds_error=False) def f_dot(x, t, sys, ufunc): """The vector field of the linear system.""" return np.dot(sys.A, x) + np.squeeze( np.dot(sys.B, np.nan_to_num(ufunc([t])).flatten()) ) x = integrate.odeint(f_dot, x0, t, args=(self.config.system, ufunc)) y = np.dot(self.config.system.C, np.transpose(x)) + np.dot( self.config.system.D, np.transpose(u) ) y = np.squeeze(np.transpose(y)) # Set states based on shape: if len(y.shape) == 1: self._set_output_values( names=self.get_output_names(), values=[y[-1].item()] ) else: self._set_output_values( names=self.get_output_names(), values=y[-1, :].tolist() ) if len(x.shape) == 1: self._set_state_values(names=self.get_state_names(), values=[x[-1].item()]) else: self._set_state_values( names=self.get_state_names(), values=x[-1, :].tolist() ) return True
[docs] def initialize(self, **kwargs): pass