from pathlib import Path
from typing import List, Optional, Union, Literal
import numpy as np
import pandas as pd
from pydantic import Field, field_validator, FilePath, model_validator
from agentlib import Agent
from agentlib.core import BaseModule, BaseModuleConfig, AgentVariable
[docs]class CSVDataSourceConfig(BaseModuleConfig):
    data: Union[pd.DataFrame, FilePath] = Field(
        title="Data",
        default=pd.DataFrame(),
        description="Data that should be communicated during execution. Index should "
        "be either numeric or Datetime, numeric values are interpreted as"
        " seconds.",
        validate_default=True,
    )
    outputs: List[AgentVariable] = Field(
        title="Outputs",
        default=[],
        description="Optional list of columns of data frame that should be sent. If "
        "ommited, all datapoint in frame are sent.",
    )
    t_sample: Union[float, int] = Field(
        title="t_sample",
        default=1,
        ge=0,
        description="Sampling time. Data source sends an interpolated value from the "
        "data every <t_sample> seconds. Default is 1 s.",
    )
    data_offset: Optional[Union[pd.Timedelta, float]] = Field(
        title="data_offset",
        default=0,
        description="Offset will be subtracted from index, allowing you to start at "
        "any point in your data. I.e. if your environment starts at 0, "
        "and you want your data-source to start at 1000 seconds, "
        "you should set this to 1000.",
    )
    extrapolation: Literal["constant", "repeat", "backwards"] = Field(
        title="Extrapolation",
        default="constant",
        description="Determines what to do, when the data source runs out. 'constant' "
        "returns the last value, 'repeat' repeats the data from the "
        "start, and 'backwards' goes through the data backwards, bouncing "
        "indefinitely.",
    )
    shared_variable_fields: List[str] = ["outputs"]
[docs]    @field_validator("data")
    @classmethod
    def check_data(cls, data):
        """Makes sure data is a data frame, and loads it if required."""
        if isinstance(data, (str, Path)) and Path(data).is_file():
            data = pd.read_csv(data, engine="python", index_col=0)
        if not isinstance(data, pd.DataFrame):
            raise ValueError(
                f"Data {data} is not a valid DataFrame or the path is not found."
            )
        if data.empty:
            raise ValueError("Provided data is empty.")
        if len(data) < 2:
            raise ValueError(
                "The dataframe must contain at least two rows for interpolation."
            )
        return data 
[docs]    @model_validator(mode="after")
    def validate_data(self):
        """Checks if outputs and data columns match, and ensures a numeric index."""
        if self.outputs:
            columns = set(self.data.columns)
            output_names = set(o.name for o in self.outputs)
            missing_columns = output_names - columns
            if missing_columns:
                raise ValueError(
                    f"The following output columns are not present in the dataframe: "
                    f"{', '.join(missing_columns)}"
                )
        self.transform_to_numeric_index(self.data)
        return self  
[docs]class CSVDataSource(BaseModule):
    config: CSVDataSourceConfig
    def __init__(self, config: dict, agent: Agent):
        super().__init__(config=config, agent=agent)
        data = self.config.data
        # Interpolate the dataframe
        start_time = data.index[0]
        end_time = data.index[-1]
        new_index = np.arange(
            start_time, end_time + self.config.t_sample, self.config.t_sample
        )
        interpolated_data = (
            data.reindex(data.index.union(new_index))
            .interpolate(method="index")
            .loc[new_index]
        )
        # Transform to list of tuples
        self.data_tuples = list(interpolated_data.itertuples(index=False, name=None))
        self.data_iterator = self.create_iterator()
    def _get_next_data(self):
        """Yield the next data point"""
        data = next(self.data_iterator)
        return data
[docs]    def create_iterator(self):
        """Create a custom iterator based on the extrapolation method"""
        while True:
            for item in self.data_tuples:
                yield item
            if self.config.extrapolation == "constant":
                self.logger.warning(
                    "Data source has been exhausted. Returning last value indefinitely."
                )
                while True:
                    yield self.data_tuples[-1]
            elif self.config.extrapolation == "repeat":
                self.logger.warning(
                    "Data source has been exhausted. Repeating data from the start."
                )
                continue  # This will restart the outer loop
            elif self.config.extrapolation == "backwards":
                self.logger.warning(
                    "Data source has been exhausted. Going through data backwards."
                )
                yield from self.backwards_iterator() 
[docs]    def backwards_iterator(self):
        """Iterator for backwards extrapolation"""
        while True:
            for item in reversed(
                self.data_tuples[:-1]
            ):  # Exclude the last item to avoid repetition
                yield item
            for item in self.data_tuples[
                1:
            ]:  # Exclude the first item to avoid repetition
                yield item 
[docs]    def process(self):
        """Write the current data values into data_broker every t_sample"""
        while True:
            current_data = self._get_next_data()
            for output, value in zip(self.config.outputs, current_data):
                self.logger.debug(
                    f"At {self.env.time}: Sending variable {output.name} with value {value} to data broker."
                )
                self.set(output.name, value)
            yield self.env.timeout(self.config.t_sample) 
[docs]    def register_callbacks(self):
        """Don't do anything as this module is not event-triggered"""