Source code for agentlib_mpc.optimization_backends.backend

import abc
import os
from datetime import datetime
import logging
from pathlib import Path
from typing import Dict, Union, Callable, TypeVar, Optional

import pandas as pd
import pydantic
from agentlib.core.errors import ConfigurationError
from pydantic import ConfigDict
from pydantic_core.core_schema import FieldValidationInfo

from agentlib.utils import custom_injection
from agentlib.core import AgentVariable, Model
from agentlib_mpc.data_structures import mpc_datamodels
from agentlib_mpc.data_structures.mpc_datamodels import (
    DiscretizationOptions,
)
from agentlib_mpc.data_structures.mpc_datamodels import Results

logger = logging.getLogger(__name__)

ModelT = TypeVar("ModelT", bound=Model)


[docs]class BackendConfig(pydantic.BaseModel): model: dict discretization_options: DiscretizationOptions name: Optional[str] = None results_file: Optional[Path] = pydantic.Field(default=None) save_results: Optional[bool] = pydantic.Field(validate_default=True, default=None) overwrite_result_file: Optional[bool] = pydantic.Field( default=False, validate_default=True ) model_config = ConfigDict(extra="forbid")
[docs] @pydantic.field_validator("results_file") @classmethod def check_csv(cls, file: Path): if not file.suffix == ".csv": raise ConfigurationError( f"Results filename has to be a 'csv' file. Got {file} instead." ) return file
[docs] @pydantic.field_validator("save_results") @classmethod def disable_results_if_no_file(cls, save_results: bool, info: FieldValidationInfo): if save_results is None: # if user did not specify if results should be saved, we save them if a # file is specified. return bool(info.data["results_file"]) if save_results and info.data["results_file"] is None: raise ConfigurationError( "'save_results' was true, however there was no results file provided." ) return save_results
[docs] @pydantic.field_validator("overwrite_result_file") @classmethod def check_overwrite(cls, overwrite_result_file: bool, info: FieldValidationInfo): """Checks, whether the overwrite results sttings are valid, and deletes existing result files if applicable.""" res_file = info.data.get("results_file") if res_file and info.data["save_results"]: if overwrite_result_file: try: os.remove(res_file) os.remove(mpc_datamodels.stats_path(res_file)) except FileNotFoundError: pass else: if os.path.isfile(info.data["results_file"]): raise FileExistsError( f"Results file {res_file} already exists and will not be " f"overwritten automatically. Set 'overwrite_result_file' to " f"True to enable automatic overwrite it." ) return overwrite_result_file
[docs]class OptimizationBackend(abc.ABC): """ Base class for all optimization backends. OptimizationBackends are a plugin for the 'mpc' module. They provide means to setup and solve the underlying optimization problem of the MPC. They also can save data of the solutions. """ _supported_models: dict[str, ModelT] = {} mpc_backend_parameters = ("time_step", "prediction_horizon") config_type = BackendConfig def __init__(self, config: dict): self.logger = logger self.config = self.config_type(**config) self.model: ModelT = self.model_from_config(self.config.model) self.var_ref: Optional[mpc_datamodels.VariableReference] = None self.cost_function: Optional[Callable] = None self.stats = {} self._created_file: bool = False # flag if we checked the file location
[docs] def register_logger(self, logger: logging.Logger): """Registers a logger, can be used to use the module logger""" self.logger = logger
[docs] @abc.abstractmethod def setup_optimization(self, var_ref: mpc_datamodels.VariableReference): """ Performs all necessary steps to make the ``solve`` method usable. Args: var_ref: Variable Reference that specifies the role of each model variable in the mpc """ self.var_ref = var_ref
[docs] @abc.abstractmethod def solve( self, now: Union[float, datetime], current_vars: Dict[str, AgentVariable] ) -> Results: """ Solves the optimization problem given the current values of the corresponding AgentVariables and system time. The standardization of return values is a work in progress. Args: now: Current time used for interpolation of input trajectories. current_vars: Dict of AgentVariables holding the values relevant to the optimization problem. Keys are the names Returns: A dataframe with all optimization variables over their respective grids. Depending on discretization, can include many nan's, so care should be taken when using this, e.g. always use dropna() after accessing a column. Example: variables mDot | T_0 | slack_T time 0 0.1 | 298 | nan 230 nan | 297 | 3 470 nan | 296 | 2 588 nan | 295 | 1 700 0.05 | 294 | nan 930 nan | 294 | 0.1 """ raise NotImplementedError( "The 'OptimizationBackend' class does not implement this because " "it is individual to the subclasses" )
[docs] def update_discretization_options(self, opts: dict): """Updates the discretization options with the new dict.""" self.config.discretization_options = ( self.config.discretization_options.model_copy(update=opts) ) self.setup_optimization(var_ref=self.var_ref)
[docs] def model_from_config(self, model: dict): """Set the model to the backend.""" model = model.copy() _type = model.pop("type") custom_cls = custom_injection(config=_type) model = custom_cls(**model) if not any( ( isinstance(model, _supp_model) for _supp_model in self._supported_models.values() ) ): raise TypeError( f"Given model is of type {type(model)} but " f"should be instance of one of:" f"{', '.join(list(self._supported_models.keys()))}" ) return model
[docs] def get_lags_per_variable(self) -> dict[str, float]: """Returns the name of variables which include lags and their lag in seconds. The MPC module can use this information to save relevant past data of lagged variables""" return {}
[docs] def results_file_exists(self) -> bool: """Checks if the results file already exists, and if not, creates it with headers.""" if self._created_file: return True if self.config.results_file.is_file(): # todo, this case is weird, as it is the mistake-append self._created_file = True return True # we only check the file location once to save system calls self.config.results_file.parent.mkdir(parents=True, exist_ok=True) self._created_file = True return False
[docs] def update_model_variables(self, current_vars: Dict[str, AgentVariable]): """ Internal method to write current data_broker to model variables. Only update values, not other module_types. """ for inp in current_vars.values(): logger.debug(f"Updating model variable {inp.name}={inp.value}") self.model.set(name=inp.name, value=inp.value)
OptimizationBackendT = TypeVar("OptimizationBackendT", bound=OptimizationBackend)
[docs]class ADMMBackend(OptimizationBackend): """Base class for implementations of optimization backends for ADMM algorithms.""" @property @abc.abstractmethod def coupling_grid(self) -> list[float]: """Returns the grid on which the coupling variables are discretized.""" raise NotImplementedError