Source code for agentlib_mpc.modules.ml_model_simulator

"""
Module contains the MLModelSimulator, used to simulate with ML-Models. The
class inherits from the Simulator class from the agentlib core.
"""

import pydantic
from agentlib.core import AgentVariable, AgentVariables
from agentlib.core.errors import ConfigurationError
from agentlib.modules.simulation.simulator import SimulatorConfig, Simulator
from pydantic_core.core_schema import FieldValidationInfo

from agentlib_mpc.models.casadi_ml_model import CasadiMLModel
from agentlib_mpc.models.serialized_ml_model import SerializedMLModel
from pydantic import field_validator


[docs]class MLModelSimulatorConfig(SimulatorConfig): serialized_ml_models: AgentVariables = []
[docs] @field_validator("t_sample") @classmethod def check_t_sample(cls, t_sample, info: FieldValidationInfo): """Check if t_sample is smaller than stop-start time""" if "model" not in info.data: raise ConfigurationError( "Model validation failed: the 'model' field is missing or invalid in the configuration. " "Please verify your model configuration." ) dt = info.data["model"].dt if t_sample < dt or t_sample % dt != 0: raise ConfigurationError( f"Sampling Time of Simulator must be multiple of MLModel time step. Current" f" MLModel time step is {dt} and chosen sampling time is {t_sample}." ) return t_sample
[docs]class MLModelSimulator(Simulator): config: MLModelSimulatorConfig model: CasadiMLModel def _callback_update_model_input(self, inp: AgentVariable, name: str): """Set given model input value to the model""" self.logger.debug("Updating model input %s=%s", inp.name, inp.value) self.model.set_with_timestamp( name=name, value=inp.value, timestamp=inp.timestamp )
[docs] def register_callbacks(self): for ml_model_var in self.config.serialized_ml_models: self.agent.data_broker.register_callback( callback=self._update_ml_model_callback, alias=ml_model_var.alias, source=ml_model_var.source, name=ml_model_var.name, )
def _update_ml_model_callback(self, variable: AgentVariable, name: str): """Updates the MLModels of the underlying model.""" try: ml_model = SerializedMLModel.load_serialized_model_from_string( variable.value ) except pydantic.ValidationError: self.logger.error( f"Callback 'update_ml_model' got activated for variable {name} , but the " f"received AgentVariable did not contain a valid MLModel. Got " f"{variable.value} of type '{type(variable.value)} instead." ) return try: self.model.update_ml_models(ml_model, time=self.env.now) self.logger.info(f"Successfully updated MLModel for variable {name}.") except ConfigurationError as e: self.logger.error( f"Tried to update the MLModels, but new MLModels do not have matching 'dt'. " f"Error message from model: '{e}'." )