Coverage for agentlib_flexquant/optimization_backends/constrained_cia.py: 75%
76 statements
« prev ^ index » next coverage.py v7.4.4, created at 2026-03-26 09:43 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2026-03-26 09:43 +0000
1import pydantic
2import numpy as np
3from agentlib.core.errors import OptionalDependencyError
4from agentlib_mpc.optimization_backends.casadi_.minlp_cia import CasADiCIABackend
5from agentlib_mpc.optimization_backends.casadi_.core.casadi_backend import CasadiBackendConfig
6from agentlib_mpc.data_structures.mpc_datamodels import MINLPVariableReference, MPCVariable
7from agentlib_flexquant.data_structures.globals import full_trajectory_suffix
8from agentlib_mpc.optimization_backends.casadi_.core.discretization import Results
10try:
11 import pycombina
12except ImportError:
13 raise OptionalDependencyError(
14 used_object="Pycombina",
15 dependency_install=".\ after cloning pycombina. Instructions: "
16 "https://pycombina.readthedocs.io/en/latest/install.html#",
17 )
20class ConstrainedCIABackendConfig(CasadiBackendConfig):
21 market_time: int = pydantic.Field(
22 default=900,
23 ge=0,
24 unit="s",
25 description="Time for market interaction",
26 )
27 use_rounding: bool = pydantic.Field(
28 default=False,
29 description="If True, CIA is skipped and plain rounding is used.",
30 )
31 full_controls_dict: dict = pydantic.Field(
32 default={},
33 description="Holds a key value pair for each full control of the Baseline",
34 )
36 class Config:
37 # Explicitly set this to allow additional fields in the derived class
38 extra = "forbid"
41class ConstrainedCasADiCIABackend(CasADiCIABackend):
42 var_ref: MINLPVariableReference
43 config_type = ConstrainedCIABackendConfig
45 def __init__(self, *args, **kwargs):
46 super().__init__(*args, **kwargs)
48 def solve(self, now: float, current_vars: dict[str, MPCVariable]) -> Results:
49 # collect and format inputs
50 mpc_inputs = self._get_current_mpc_inputs(agent_variables=current_vars, now=now)
52 # solve NLP with relaxed binaries
53 relaxed_results = self.discretization.solve(mpc_inputs)
55 if self.config.use_rounding:
56 b_rel = [relaxed_results[var] for var in self.var_ref.binary_controls]
57 b_rel_np = np.transpose(np.vstack(b_rel))
59 # List to collect all rounded/overwritten binary arrays
60 binary_arrays = []
62 # constrain shadow MPCs to values of baseline for time<market_time
63 for bin_con, bin_rel in zip(self.var_ref.binary_controls, b_rel_np):
64 cons = self.get_baseline_binary_solution(bin_con)
65 # round binaries
66 binary_array = np.round(bin_rel)
67 if cons is not None:
68 # Determine how many sample times are before the market time
69 sample_time = self.config.discretization_options.time_step
70 market_time = self.config.market_time
71 num_samples_before_market = int(market_time / sample_time)
72 # Overwrite the market time entries with baseline values
73 for i in range(num_samples_before_market):
74 time_point = i * sample_time
75 if time_point in cons.index:
76 binary_array[i] = cons.loc[time_point]
77 binary_arrays.append(binary_array)
78 # Stack all binary arrays back together (transpose to match expected shape)
79 binary_array = np.transpose(np.vstack(binary_arrays))
81 else:
82 relaxed_binary_array = self.make_binary_array(full_results=relaxed_results)
83 binary_array = self.do_pycombina(b_rel=relaxed_binary_array)
85 mpc_inputs_new = self.constrain_binary_inputs(
86 mpc_inputs_old=mpc_inputs,
87 binary_array=binary_array,
88 )
89 # solve NLP with fixed binaries
90 full_results_final = self.discretization.solve(mpc_inputs_new)
92 self.save_rel_result_df(relaxed_results, now=now)
93 self.save_result_df(full_results_final, now=now)
95 return full_results_final
97 def do_pycombina(self, b_rel: np.array) -> np.array:
99 grid = self.discretization.grid(self.system.binary_controls).copy()
100 grid.append(grid[-1] + self.config.discretization_options.time_step)
102 binapprox = pycombina.BinApprox(
103 t=grid,
104 b_rel=b_rel,
105 )
107 # constrain shadow MPCs to values of baseline for time<market_time
108 for bin_con in self.var_ref.binary_controls:
109 cons = self.get_baseline_binary_solution(bin_con)
110 if cons is not None:
111 last_idx = 0
112 for idx, value in cons.items():
113 # constrain every timestep before market_time
114 # with values of baseline
115 binapprox.set_valid_controls_for_interval(
116 (last_idx, idx), [value, 1 - value]
117 )
118 last_idx = idx
120 bnb = pycombina.CombinaBnB(binapprox)
121 bnb.solve(
122 use_warm_start=False,
123 max_cpu_time=15,
124 verbosity=0,
125 )
126 b_bin = binapprox.b_bin
128 # if there is only one mode, we created a dummy mode which we remove now
129 if len(self.var_ref.binary_controls) == 1:
130 b_bin = b_bin[0, :].reshape(1, -1)
132 return b_bin
134 def get_baseline_binary_solution(self, bin_con):
135 # check for baseline or shadow MPC
136 if not self.config.full_controls_dict:
137 # if baseline, return
138 return None
139 name = bin_con + full_trajectory_suffix
140 # if shadow MPC, get current value send by baseline
141 if self.config.full_controls_dict[name] is not None:
142 cons = self.config.full_controls_dict[name]
143 # the index of constraints starts at the absolute current environment
144 # time, while the market time is relative time on mpc horizon
145 cons.index -= cons.index[0]
146 # get the constraints in the market time
147 cons = cons[cons.index <= self.config.market_time]
148 return cons