"""This module contains the ScipyStateSpaceModel class."""
import logging
from typing import Union
import numpy as np
from pydantic import ValidationError, model_validator
from agentlib.core.errors import OptionalDependencyError
try:
    from scipy import signal
    from scipy import interpolate, integrate
except ImportError as err:
    raise OptionalDependencyError(
        dependency_name="scipy", dependency_install="scipy", used_object="scipy-model"
    ) from err
from agentlib.core import Model, ModelConfig
logger = logging.getLogger(__name__)
[docs]class ScipyStateSpaceModelConfig(ModelConfig):
    """Customize config of Model."""
    system: Union[dict, list, tuple, signal.StateSpace]
[docs]    @model_validator(mode="before")
    @classmethod
    def check_system(cls, values):
        """Root validator to check if the given system is valid."""
        # pylint: disable=no-self-argument,no-self-use
        system = values.get("system")
        if isinstance(system, (tuple, list)):
            # Check correct input size
            assert (
                len(system) == 4
            ), "State space representation requires exactly 4 matrices"
        elif isinstance(system, dict):
            assert "A" in system, "State space representation requires key 'A'"
            assert "B" in system, "State space representation requires key 'B'"
            assert "C" in system, "State space representation requires key 'C'"
            assert "D" in system, "State space representation requires key 'D'"
            system = [system["A"], system["B"], system["C"], system["D"]]
        elif isinstance(system, signal.ltisys.StateSpaceContinuous):
            return values
        else:
            logger.error(
                "Given system is of type %s but should be list, tuple or dict",
                type(system),
            )
            raise ValidationError
        # Setup the system
        system = signal.StateSpace(*system)
        # Check dimensions with inputs, states and outputs:
        n_inputs = len(values.get("inputs", []))
        n_outputs = len(values.get("outputs", []))
        n_states = len(values.get("states", []))
        assert (
            system.A.shape[0] == n_states
        ), "Given system matrix A does not match size of states"
        assert (
            system.A.shape[1] == n_states
        ), "Given system matrix A does not match size of states"
        assert (
            system.B.shape[0] == n_states
        ), "Given system matrix B does not match size of states"
        assert (
            system.B.shape[1] == n_inputs
        ), "Given system matrix B does not match size of inputs"
        assert (
            system.C.shape[0] == n_outputs
        ), "Given system matrix C does not match size of outputs"
        assert (
            system.C.shape[1] == n_states
        ), "Given system matrix C does not match size of states"
        assert (
            system.D.shape[0] == n_outputs
        ), "Given system matrix D does not match size of outputs"
        assert (
            system.D.shape[1] == n_inputs
        ), "Given system matrix D does not match size of inputs"
        values["system"] = system
        return values  
[docs]class ScipyStateSpaceModel(Model):
    """
    This class holds a scipy StateSpace model.
    It uses scipy.signal.lti as a system and the
    odeint as integrator.
    """
    config: ScipyStateSpaceModelConfig
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # Check if system was correctly set up
        assert isinstance(self.config.system, signal.StateSpace)
[docs]    def do_step(self, *, t_start, t_sample=None):
        if t_sample is None:
            t_sample = self.dt
        t = self._create_time_samples(t_sample=t_sample) + t_start
        u = np.array([[inp.value for inp in self.inputs] for _ in t])
        x0 = np.array([sta.value for sta in self.states])
        ufunc = interpolate.interp1d(t, u, kind="linear", axis=0, bounds_error=False)
        def f_dot(x, t, sys, ufunc):
            """The vector field of the linear system."""
            return np.dot(sys.A, x) + np.squeeze(
                np.dot(sys.B, np.nan_to_num(ufunc([t])).flatten())
            )
        x = integrate.odeint(f_dot, x0, t, args=(self.config.system, ufunc))
        y = np.dot(self.config.system.C, np.transpose(x)) + np.dot(
            self.config.system.D, np.transpose(u)
        )
        y = np.squeeze(np.transpose(y))
        # Set states based on shape:
        if len(y.shape) == 1:
            self._set_output_values(
                names=self.get_output_names(), values=[y[-1].item()]
            )
        else:
            self._set_output_values(
                names=self.get_output_names(), values=y[-1, :].tolist()
            )
        if len(x.shape) == 1:
            self._set_state_values(names=self.get_state_names(), values=[x[-1].item()])
        else:
            self._set_state_values(
                names=self.get_state_names(), values=x[-1, :].tolist()
            )
        return True 
[docs]    def initialize(self, **kwargs):
        pass