Source code for agentlib_mpc.utils.plotting.basic

"""Some basic plotting utilities"""

import logging
import typing
from dataclasses import dataclass
from pathlib import Path
from typing import Tuple, Callable, TypedDict, Annotated

import matplotlib
from matplotlib import pyplot as plt
from matplotlib.ticker import AutoMinorLocator


logger = logging.getLogger(__name__)


[docs]@dataclass class ValueRange: min: float max: float
Float0to1 = Annotated[float, ValueRange(0.0, 1.0)] ColorTuple = tuple[Float0to1, Float0to1, Float0to1]
[docs]class EBCColors: dark_red: ColorTuple = (172 / 255, 43 / 255, 28 / 255) red: ColorTuple = (221 / 255, 64 / 255, 45 / 255) light_red: ColorTuple = (235 / 255, 140 / 255, 129 / 255) green: ColorTuple = (112 / 255, 173 / 255, 71 / 255) light_grey: ColorTuple = (217 / 255, 217 / 255, 217 / 255) grey: ColorTuple = (157 / 255, 158 / 255, 160 / 255) dark_grey: ColorTuple = (78 / 255, 79 / 255, 80 / 255) light_blue: ColorTuple = (157 / 255, 195 / 255, 230 / 255) blue: ColorTuple = (0 / 255, 84 / 255, 159 / 255) ebc_palette_sort_1: list[ColorTuple] = [ dark_red, red, light_red, dark_grey, grey, light_grey, blue, light_blue, green, ] ebc_palette_sort_2: list[ColorTuple] = [ red, blue, grey, green, dark_red, dark_grey, light_red, light_blue, light_grey, ]
[docs]class FontDict(TypedDict): fontsize: float
[docs]class Style: def __init__(self, use_tex: bool = False): self.font_dict: FontDict = {"fontsize": 11} self.use_tex = use_tex def __enter__(self): try: style_path = Path(Path(__file__).parent, "ebc.paper.mplstyle") plt.style.use(style_path) except OSError: logger.warning("Style Sheet could not be loaded, using default style.") if self.use_tex: matplotlib.rc("text", usetex=True) matplotlib.rcParams["text.latex.preamble"] = r"\usepackage{amsmath}" # matplotlib.rcParams.update({ # "font.family": 'serif', # 'font.serif': 'Times', # }) # # fontP = FontProperties().set_size('xx-small') return self def __exit__(self, exc_type, exc_val, exc_tb): matplotlib.rcParams.update(matplotlib.rcParamsDefault)
Customizer = Callable[[plt.Figure, plt.Axes, Style], Tuple[plt.Figure, plt.Axes]] MultiCustomizer = Callable[ [plt.Figure, tuple[plt.Axes], Style], Tuple[plt.Figure, tuple[plt.Axes]] ] @typing.overload def make_fig( style: Style, customizer: Customizer = None, rows: int = 1 ) -> tuple[plt.Figure, tuple[plt.Axes, ...]]: ... @typing.overload def make_fig( style: Style, customizer: MultiCustomizer = None ) -> tuple[plt.Figure, plt.Axes]: ... @typing.overload def make_fig(style: Style) -> tuple[plt.Figure, plt.Axes]: ...
[docs]def make_fig( style: Style, customizer: Customizer = None, rows=None ) -> Tuple[plt.Figure, tuple[plt.Axes]]: """Creates a figure and axes with an amount of rows. If rows is specified, return a tuple of axes, else only an ax""" if rows is None: _rows = 1 else: _rows = rows fig, all_ax = plt.subplots(_rows, 1, sharex=True) if rows is None: # if rows was not specified, return a single axes object ax = all_ax ax.tick_params( axis="both", which="major", labelsize=style.font_dict["fontsize"], left=False, ) if customizer: customizer(fig, all_ax, style) return fig, all_ax # if rows was specified, return a tuple if rows == 1: all_ax = (all_ax,) for ax in all_ax: ax.tick_params( axis="both", which="major", labelsize=style.font_dict["fontsize"], left=False, ) if customizer: customizer(fig, all_ax, style) return fig, all_ax
[docs]def make_grid(ax: plt.Axes): ax.xaxis.set_minor_locator(AutoMinorLocator()) ax.yaxis.set_minor_locator(AutoMinorLocator()) ax.grid( which="major", axis="both", linestyle="--", linewidth=0.5, color="black", zorder=0, ) ax.grid( which="minor", axis="both", linestyle="--", linewidth=0.5, color="0.7", zorder=0 )
[docs]def make_side_legend(ax: plt.Axes, fig: plt.Figure = None, right_position: float = 1): ax.legend(loc="center left", bbox_to_anchor=(1, 0.5), frameon=False, handlelength=1) if fig and right_position > 0: fig.subplots_adjust(right=right_position)