Source code for agentlib_mpc.modules.dmpc.admm.admm_coordinator

"""
Defines classes that coordinate an ADMM process.
"""

import os
import time
from ast import literal_eval
from pathlib import Path
from typing import Dict, List, Optional
import queue
import logging
from dataclasses import asdict
import threading
import math

from pydantic import field_validator, Field
import numpy as np
import pandas as pd

from agentlib.core.agent import Agent
from agentlib.core.datamodels import AgentVariable, Source
from pydantic_core.core_schema import FieldValidationInfo

from agentlib_mpc.data_structures import coordinator_datatypes as cdt
from agentlib_mpc.modules.dmpc.coordinator import Coordinator, CoordinatorConfig
import agentlib_mpc.data_structures.admm_datatypes as adt

logger = logging.getLogger(__name__)


[docs]class ADMMCoordinatorConfig(CoordinatorConfig): """Hold the config for ADMMCoordinator""" penalty_factor: float = Field( title="penalty_factor", default=10, description="Penalty factor of the ADMM algorithm. Should be equal " "for all agents.", ) wait_time_on_start_iters: float = Field( title="wait_on_start_iterations", default=0.1, description="wait_on_start_iterations", ) registration_period: float = Field( title="registration_period", default=5, description="Time spent on registration before each optimization", ) admm_iter_max: int = Field( title="admm_iter_max", default=20, description="Maximum number of ADMM iterations before termination of control " "step.", ) time_step: float = Field( title="time_step", default=600, # seconds description="Sampling interval of between two control steps. Will be used in " "the discretization for MPC.", ) sampling_time: Optional[float] = Field( default=None, # seconds description="Sampling interval for control steps. If None, will be the same as" " time step. Does not affect the discretization of the MPC, " "only the interval with which there will be optimization steps.", validate_default=True, ) prediction_horizon: int = Field( title="prediction_horizon", default=10, description="Prediction horizon of participating agents.", ) abs_tol: float = Field( title="abs_tol", default=1e-3, description="Absolute stopping criterion.", ) rel_tol: float = Field( title="rel_tol", default=1e-3, description="Relative stopping criterion.", ) primal_tol: float = Field( default=1e-3, description="Absolute primal stopping criterion.", ) dual_tol: float = Field( default=1e-3, description="Absolute dual stopping criterion.", ) use_relative_tolerances: bool = Field( default=True, description="If True, use abs_tol and rel_tol, if False us prim_tol and " "dual_tol.", ) penalty_change_threshold: float = Field( default=-1, description="When the primal residual is x times higher, vary the penalty " "parameter and vice versa.", ) penalty_change_factor: float = Field( default=2, # seconds description="Factor to vary the penalty parameter with.", ) save_solve_stats: bool = Field( default=False, description="When True, saves the solve stats to a file.", ) solve_stats_file: str = Field( default="admm_stats.csv", # seconds description="File name for the solve stats.", ) save_iter_interval: int = Field( default=1000, )
[docs] @field_validator("solve_stats_file") @classmethod def solve_stats_file_is_csv(cls, file: str): assert file.endswith(".csv") return file
[docs] @field_validator("sampling_time") @classmethod def default_sampling_time(cls, samp_time, info: FieldValidationInfo): if samp_time is None: samp_time = info.data["time_step"] return samp_time
[docs]class ADMMCoordinator(Coordinator): config: ADMMCoordinatorConfig def __init__(self, *, config: dict, agent: Agent): if agent.env.config.rt: self.process = self._realtime_process self.registration_callback = self._real_time_registration_callback else: self.process = self._fast_process self.registration_callback = self._sequential_registration_callback super().__init__(config=config, agent=agent) self._coupling_variables: Dict[str, adt.ConsensusVariable] = {} self._exchange_variables: Dict[str, adt.ExchangeVariable] = {} self._agents_to_register = queue.Queue() self.agent_dict: Dict[str, adt.AgentDictEntry] = {} self._registration_queue: queue.Queue = queue.Queue() self._registration_lock: threading.Lock = threading.Lock() self.penalty_parameter = self.config.penalty_factor self._iteration_stats: pd.DataFrame = pd.DataFrame( columns=["primal_residual", "dual_residual"] ) self._primal_residuals_tracker: List[float] = [] self._dual_residuals_tracker: List[float] = [] self._penalty_tracker: List[float] = [] self._performance_tracker: List[float] = [] self.start_algorithm_at: float = 0 self._performance_counter: float = time.perf_counter() def _realtime_process(self): """Starts a thread to run next to the environment (to prevent a long blocking process). Periodically informs the thread of the next optimization.""" self._start_algorithm = threading.Event() thread_proc = threading.Thread( target=self._realtime_process_thread, name=f"{self.source}_ProcessThread", daemon=True, ) thread_proc.start() self.agent.register_thread(thread=thread_proc) thread_reg = threading.Thread( target=self._handle_registrations, name=f"{self.source}_RegistrationThread", daemon=True, ) thread_reg.start() self.agent.register_thread(thread=thread_reg) while True: self._start_algorithm.set() yield self.env.timeout(self.config.sampling_time) def _realtime_process_thread(self): while True: self._status = cdt.CoordinatorStatus.sleeping self._start_algorithm.wait() self._start_algorithm.clear() with self._registration_lock: self._realtime_step() if self._start_algorithm.isSet(): self.logger.error( "%s: Start of ADMM round was requested before " "last one finished. Skipping cycle." ) self._start_algorithm.clear() def _realtime_step(self): # ------------------ # start iteration # ------------------ self.status = cdt.CoordinatorStatus.init_iterations self.start_algorithm_at = self.env.time self._performance_counter = time.perf_counter() # maybe this will hold information instead of "True" self.set(cdt.START_ITERATION_C2A, True) # check for all_finished here time.sleep(self.config.wait_time_on_start_iters) if not list(self._agents_with_status(status=cdt.AgentStatus.ready)): self.logger.info(f"No Agents available at time {self.env.now}.") return # if no agents registered return early self._update_mean_coupling_variables() self._shift_coupling_variables() # ------------------ # iteration loop # ------------------ admm_iter = 0 for admm_iter in range(1, self.config.admm_iter_max + 1): # ------------------ # optimization # ------------------ # send self.status = cdt.CoordinatorStatus.optimization # set all agents to busy self.trigger_optimizations() # check for all finished here self._wait_for_ready() # ------------------ # perform update steps # ------------------ self.status = cdt.CoordinatorStatus.updating self._update_mean_coupling_variables() self._update_multipliers() # ------------------ # check convergence # ------------------ converged = self._check_convergence(admm_iter) if converged: self.logger.info("Converged within %s iterations. ", admm_iter) break else: self.logger.warning( "Did not converge within the maximum number of iterations " "%s. ", self.config.admm_iter_max, ) self._wrap_up_algorithm(iterations=admm_iter) self.set(cdt.START_ITERATION_C2A, False) # this signals the finish def _wait_non_rt(self): """Returns a triggered event. Cedes control to the simpy event queue for a short moment. This is required in fast-as-possible simulations, to allow other agents to react via callbacks.""" return self.env.timeout(0.001) def _fast_process(self): """Process function for use in fast-as-possible simulations. Regularly yields control back to the environment, to allow the callbacks to run.""" yield self._wait_non_rt() while True: # ------------------ # start iteration # ------------------ self.status = cdt.CoordinatorStatus.init_iterations self.start_algorithm_at = self.env.time self._performance_counter = time.perf_counter() self.set(cdt.START_ITERATION_C2A, True) yield self._wait_non_rt() if not list(self._agents_with_status(status=cdt.AgentStatus.ready)): self.logger.info(f"No Agents available at time {self.env.now}.") communication_time = self.env.time - self.start_algorithm_at yield self.env.timeout(self.config.sampling_time - communication_time) continue # if no agents registered return early self._update_mean_coupling_variables() self._shift_coupling_variables() # ------------------ # iteration loop # ------------------ admm_iter = 0 for admm_iter in range(1, self.config.admm_iter_max + 1): # ------------------ # optimization # ------------------ # send self.status = cdt.CoordinatorStatus.optimization # set all agents to busy self.trigger_optimizations() yield self._wait_non_rt() # check for all finished here self._wait_for_ready() # ------------------ # perform update steps # ------------------ self.status = cdt.CoordinatorStatus.updating self._update_mean_coupling_variables() self._update_multipliers() # ------------------ # check convergence # ------------------ converged = self._check_convergence(admm_iter) if converged: self.logger.info("Converged within %s iterations. ", admm_iter) break else: self.logger.warning( "Did not converge within the maximum number of iterations " "%s. ", self.config.admm_iter_max, ) self._wrap_up_algorithm(iterations=admm_iter) self.set(cdt.START_ITERATION_C2A, False) # this signals the finish self.status = cdt.CoordinatorStatus.sleeping time_spent_on_communication = self.env.time - self.start_algorithm_at yield self.env.timeout( self.config.sampling_time - time_spent_on_communication ) def _update_mean_coupling_variables(self): """Calculates a new mean of the coupling variables.""" active_agents = self._agents_with_status(cdt.AgentStatus.ready) for variable in self._coupling_variables.values(): variable.update_mean_trajectory(sources=active_agents) for variable in self._exchange_variables.values(): variable.update_diff_trajectories(sources=active_agents) def _shift_coupling_variables(self): """""" for variable in self._coupling_variables.values(): variable.shift_values_by_one(horizon=self.config.prediction_horizon) for variable in self._exchange_variables.values(): variable.shift_values_by_one(horizon=self.config.prediction_horizon) def _update_multipliers(self): """Performs the multiplier update for the coupling variables.""" rho = self.penalty_parameter active_agents = self._agents_with_status(cdt.AgentStatus.ready) for variable in self._coupling_variables.values(): variable.update_multipliers(rho=rho, sources=active_agents) for variable in self._exchange_variables.values(): variable.update_multiplier(rho=rho) def _agents_with_status(self, status: cdt.AgentStatus) -> List[Source]: """Returns an iterator with all agents sources that are currently on this status.""" active_agents = [s for (s, a) in self.agent_dict.items() if a.status == status] return active_agents def _check_convergence(self, iteration) -> bool: """ Checks the convergence of the algorithm. Returns True if yes, False if no. Returns: Tuple of (converged, primal residual norm, dual residual norm) """ primal_residuals = [] dual_residuals = [] active_agents = self._agents_with_status(cdt.AgentStatus.ready) flat_locals = [] flat_means = [] flat_multipliers = [] for var in self._coupling_variables.values(): prim, dual = var.get_residual(rho=self.penalty_parameter) primal_residuals.extend(prim) dual_residuals.extend(dual) locs = var.flat_locals(sources=active_agents) muls = var.flat_multipliers(active_agents) flat_locals.extend(locs) flat_multipliers.extend(muls) flat_means.extend(var.mean_trajectory) for var in self._exchange_variables.values(): prim, dual = var.get_residual(rho=self.penalty_parameter) primal_residuals.extend(prim) dual_residuals.extend(dual) locs = var.flat_locals(sources=active_agents) muls = var.multiplier flat_locals.extend(locs) flat_multipliers.extend(muls) flat_means.extend(var.mean_trajectory) # primal_residual = np.concatenate(primal_residuals) # dual_residual = np.concatenate(dual_residuals) # compute residuals prim_norm = np.linalg.norm(primal_residuals) dual_norm = np.linalg.norm(dual_residuals) self._vary_penalty_parameter(primal_residual=prim_norm, dual_residual=dual_norm) self._penalty_tracker.append(self.penalty_parameter) self._primal_residuals_tracker.append(prim_norm) self._dual_residuals_tracker.append(dual_norm) self._performance_tracker.append( time.perf_counter() - self._performance_counter ) self.logger.debug( "Finished iteration %s . \n Primal residual: %s \n Dual residual: " "%s", iteration, prim_norm, dual_norm, ) if iteration % self.config.save_iter_interval == 0: self._save_stats(iterations=iteration) if self.config.use_relative_tolerances: # scaling factors for relative criterion primal_scaling = max( np.linalg.norm(flat_locals), np.linalg.norm(flat_means), # Ax # Bz ) dual_scaling = np.linalg.norm(flat_multipliers) # compute tolerances for this iteration sqrt_p = math.sqrt(len(flat_multipliers)) sqrt_n = math.sqrt(len(flat_locals)) # not actually n, but best we can do eps_pri = ( sqrt_p * self.config.abs_tol + self.config.rel_tol * primal_scaling ) eps_dual = sqrt_n * self.config.abs_tol + self.config.rel_tol * dual_scaling converged = prim_norm < eps_pri and dual_norm < eps_dual else: converged = ( prim_norm < self.config.primal_tol and dual_norm < self.config.dual_tol ) if converged: return True return False def _save_stats(self, iterations: int) -> None: """ Args: iterations: Which iteration of the ADMM algorithm are we when this function is called? """ section_length = len(self._penalty_tracker) section_start = iterations - section_length index = [ (self.start_algorithm_at, i + section_start) for i in range(section_length) ] path = Path(self.config.solve_stats_file) header = not path.is_file() stats = pd.DataFrame( { "primal_residual": self._primal_residuals_tracker, "dual_residual": self._dual_residuals_tracker, "penalty_parameter": self._penalty_tracker, "wall_time": self._performance_tracker, }, index=index, ) self._penalty_tracker = [] self._dual_residuals_tracker = [] self._primal_residuals_tracker = [] self._performance_tracker = [] path.parent.mkdir(exist_ok=True, parents=True) stats.to_csv(path_or_buf=path, header=header, mode="a") def _vary_penalty_parameter(self, primal_residual: float, dual_residual: float): """Determines a new value for the penalty parameter based on residuals.""" mu = self.config.penalty_change_threshold tau = self.config.penalty_change_factor if mu <= 1: # do not perform varying penalty method if the threshold is set below 1 return if primal_residual > mu * dual_residual: self.penalty_parameter = self.penalty_parameter * tau elif dual_residual > mu * primal_residual: self.penalty_parameter = self.penalty_parameter / tau
[docs] def trigger_optimizations(self): """ Triggers the optimization for all agents with status ready. Returns: """ # create an iterator for all agents which are ready for this round active_agents: [str, adt.AgentDictEntry] = ( (s, a) for (s, a) in self.agent_dict.items() if a.status == cdt.AgentStatus.ready ) # aggregate and send trajectories per agent for source, agent in active_agents: # collect mean and multiplier per coupling variable mean_trajectories = {} multipliers = {} for alias in agent.coup_vars: coup_var = self._coupling_variables[alias] mean_trajectories[alias] = coup_var.mean_trajectory multipliers[alias] = coup_var.multipliers[source] diff_trajectories = {} multiplier = {} for alias in agent.exchange_vars: coup_var = self._exchange_variables[alias] diff_trajectories[alias] = coup_var.diff_trajectories[source] multiplier[alias] = coup_var.multiplier # package all coupling inputs needed for an agent coordi_to_agent = adt.CoordinatorToAgent( mean_trajectory=mean_trajectories, multiplier=multipliers, exchange_multiplier=multiplier, mean_diff_trajectory=diff_trajectories, target=source.agent_id, penalty_parameter=self.penalty_parameter, ) self.logger.debug("Sending to %s with source %s", agent.name, source) self.logger.debug("Set %s to busy.", agent.name) # send values agent.status = cdt.AgentStatus.busy self.set(cdt.OPTIMIZATION_C2A, coordi_to_agent.to_json())
[docs] def register_agent(self, variable: AgentVariable): """Registers the agent, after it sent its initial guess with correct vector length.""" value = adt.AgentToCoordinator.from_json(variable.value) src = variable.source ag_dict_entry = self.agent_dict[variable.source] # loop over coupling variables of this agent for alias, traj in value.local_trajectory.items(): coup_var = self._coupling_variables.setdefault( alias, adt.ConsensusVariable() ) # initialize Lagrange-Multipliers and local solution coup_var.multipliers[src] = [0] * len(traj) coup_var.local_trajectories[src] = traj ag_dict_entry.coup_vars.append(alias) # loop over coupling variables of this agent for alias, traj in value.local_exchange_trajectory.items(): coup_var = self._exchange_variables.setdefault( alias, adt.ExchangeVariable() ) # initialize Lagrange-Multipliers and local solution coup_var.multiplier = [0] * len(traj) coup_var.local_trajectories[src] = traj ag_dict_entry.exchange_vars.append(alias) # set agent from pending to standby ag_dict_entry.status = cdt.AgentStatus.standby self.logger.info( f"Coordinator successfully registered agent {variable.source}." )
[docs] def optim_results_callback(self, variable: AgentVariable): """ Saves the results of a local optimization. Args: variable: Returns: """ local_result = adt.AgentToCoordinator.from_json(variable.value) source = variable.source for alias, trajectory in local_result.local_trajectory.items(): coup_var = self._coupling_variables[alias] coup_var.local_trajectories[source] = trajectory for alias, trajectory in local_result.local_exchange_trajectory.items(): coup_var = self._exchange_variables[alias] coup_var.local_trajectories[source] = trajectory self.agent_dict[variable.source].status = cdt.AgentStatus.ready self.received_variable.set()
def _send_parameters_to_agent(self, variable: AgentVariable): """Sends an agent the global parameters after a signup request.""" admm_parameters = adt.ADMMParameters( prediction_horizon=self.config.prediction_horizon, time_step=self.config.time_step, penalty_factor=self.config.penalty_factor, ) message = cdt.RegistrationMessage( agent_id=variable.source.agent_id, opts=asdict(admm_parameters) ) self.set(cdt.REGISTRATION_C2A, asdict(message))
[docs] def registration_callback(self, variable: AgentVariable): self.logger.debug(f"receiving {variable.name} from {variable.source}") if not (variable.source in self.agent_dict): self.agent_dict[variable.source] = adt.AgentDictEntry( name=variable.source, status=cdt.AgentStatus.pending, ) self._send_parameters_to_agent(variable) self.logger.info( f"Coordinator got request agent {variable.source} and set to " f"'pending'." ) return # complete registration of pending agents if self.agent_dict[variable.source].status is cdt.AgentStatus.pending: self.register_agent(variable=variable)
def _sequential_registration_callback(self, variable: AgentVariable): """Handles the registration for sequential i.e. local coordinators. Variables are handled immediately.""" self.logger.debug(f"receiving {variable.name} from {variable.source}") self._initial_registration(variable) def _real_time_registration_callback(self, variable: AgentVariable): """Handles the registration for realtime coordinators. Variables are put in a queue and a thread registers them when it is safe to do so.""" self.logger.debug(f"receiving {variable.name} from {variable.source}") self._registration_queue.put(variable) def _initial_registration(self, variable: AgentVariable): """Handles initial registration of a variable. If it is unknown, add it to the agent_dict and send it the global parameters. If it is sending its confirmation with initial trajectories, refer to the actual registration function.""" if not (variable.source in self.agent_dict): self.agent_dict[variable.source] = adt.AgentDictEntry( name=variable.source, status=cdt.AgentStatus.pending, ) self._send_parameters_to_agent(variable) self.logger.info( f"Coordinator got request agent {variable.source} and set to " f"'pending'." ) # complete registration of pending agents elif self.agent_dict[variable.source].status is cdt.AgentStatus.pending: self.register_agent(variable=variable) def _handle_registrations(self): """Performs registration tasks while the algorithm is on standby.""" while True: # add new agent to dict and send them global parameters variable = self._registration_queue.get() with self._registration_lock: self._initial_registration(variable) def _wrap_up_algorithm(self, iterations): self._save_stats(iterations=iterations) self.penalty_parameter = self.config.penalty_factor
[docs] def get_results(self) -> pd.DataFrame: """Reads the results on iteration data if they were saved.""" results_file = self.config.solve_stats_file try: df = pd.read_csv(results_file, index_col=0, header=0) new_ind = [literal_eval(i) for i in df.index] df.index = pd.MultiIndex.from_tuples(new_ind) return df except FileNotFoundError: self.logger.error("Results file %s was not found.", results_file) return pd.DataFrame()
[docs] def cleanup_results(self): results_file = self.config.solve_stats_file if not results_file: return os.remove(results_file)