import dataclasses
from pathlib import Path
from typing import List, Union, TypeVar, Protocol, Sequence, Iterable
from itertools import chain
import attrs
import numpy as np
import pandas as pd
import pydantic
from enum import Enum, auto
from agentlib.core import AgentVariable
from agentlib.core.module import BaseModuleConfigClass
from agentlib_mpc.data_structures.interpolation import InterpolationMethods
from pydantic import ConfigDict
[docs]class InitStatus(str, Enum):
"""Keep track of the readyness status of the MPC."""
pre_module_init = auto()
during_update = auto()
ready = auto()
[docs]class DiscretizationOptions(pydantic.BaseModel):
"""Class defining the options to discretize an MPC. Can be extended for different
optimization implementations."""
model_config = ConfigDict(extra="allow")
time_step: float = pydantic.Field(
default=60,
ge=0,
description="Time step of the MPC.",
)
prediction_horizon: int = pydantic.Field(
default=5,
ge=0,
description="Prediction horizon of the MPC.",
)
[docs]class Results(Protocol):
df: pd.DataFrame
def __getitem__(self, item: str) -> Sequence[float]: ...
[docs]@dataclasses.dataclass
class BaseVariableReference:
[docs] def all_variables(self) -> List[str]:
"""Returns a list of all variables registered in the var_ref"""
return list(chain.from_iterable(self.__dict__.values()))
[docs] @classmethod
def from_config(cls, config: BaseModuleConfigClass):
"""Creates an instance from a pydantic values dict which includes lists of
AgentVariables with the keys corresponding to 'states', 'inputs', etc.."""
def names_list(ls: List[AgentVariable]):
return [item.name for item in ls]
field_names = set(f.name for f in dataclasses.fields(cls))
variables = {
k: names_list(v) for k, v in config.__dict__.items() if k in field_names
}
return cls(**variables)
def __contains__(self, item):
all_variables = set(chain.from_iterable(self.__dict__.values()))
return item in all_variables
VariableReferenceT = TypeVar("VariableReferenceT", bound=BaseVariableReference)
[docs]@dataclasses.dataclass
class VariableReference(BaseVariableReference):
states: List[str] = dataclasses.field(default_factory=list)
controls: List[str] = dataclasses.field(default_factory=list)
inputs: List[str] = dataclasses.field(default_factory=list)
parameters: List[str] = dataclasses.field(default_factory=list)
outputs: List[str] = dataclasses.field(default_factory=list)
[docs]def r_del_u_convention(name: str) -> str:
"""Turns the name of a control variable into its weight via convention"""
return f"r_del_u_{name}"
[docs]@dataclasses.dataclass
class FullVariableReference(VariableReference):
@property
def r_del_u(self) -> List[str]:
return [r_del_u_convention(cont) for cont in self.controls]
[docs]@dataclasses.dataclass
class MINLPVariableReference(VariableReference):
binary_controls: List[str] = dataclasses.field(default_factory=list)
[docs]@dataclasses.dataclass
class MHEVariableReference(BaseVariableReference):
states: List[str] = dataclasses.field(default_factory=list)
measured_states: List[str] = dataclasses.field(default_factory=list)
weights_states: List[str] = dataclasses.field(default_factory=list)
estimated_inputs: List[str] = dataclasses.field(default_factory=list)
estimated_parameters: List[str] = dataclasses.field(default_factory=list)
known_inputs: List[str] = dataclasses.field(default_factory=list)
known_parameters: List[str] = dataclasses.field(default_factory=list)
outputs: List[str] = dataclasses.field(default_factory=list)
[docs] def all_variables(self) -> Iterable[str]:
"""Returns a list of all variables registered in the var_ref which the MHE can
get from the config with get()"""
return (
set(super().all_variables())
- set(self.measured_states)
- set(self.weights_states)
)
[docs]@attrs.define(slots=True, weakref_slot=False, kw_only=True)
class MPCVariable(AgentVariable):
"""AgentVariable used to define input variables of MPC."""
interpolation_method: InterpolationMethods = attrs.field(
default=InterpolationMethods.linear,
metadata={
"description": "Defines which method is used for interpolation of "
"boundaries or values for this variable. Default is linear.",
"title": "Interpolation Method",
},
)
MPCVariables = List[MPCVariable]
[docs]def stats_path(path: Union[Path, str]) -> Path:
res_file = Path(path)
return Path(res_file.parent, "stats_" + res_file.name)
[docs]def cia_relaxed_results_path(path: Union[Path, str]) -> Path:
res_file = Path(path)
return Path(res_file.parent, "relaxed_" + res_file.name)
MPCValue = Union[int, float, list[Union[int, float]], pd.Series, np.ndarray]