Coverage for agentlib/modules/simulation/csv_data_source.py: 71%
84 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-04-07 16:27 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-04-07 16:27 +0000
1from pathlib import Path
2from typing import List, Optional, Union, Literal
4import numpy as np
5import pandas as pd
6from pydantic import Field, field_validator, FilePath, model_validator
8from agentlib import Agent
9from agentlib.core import BaseModule, BaseModuleConfig, AgentVariable
12class CSVDataSourceConfig(BaseModuleConfig):
13 data: Union[pd.DataFrame, FilePath] = Field(
14 title="Data",
15 default=pd.DataFrame(),
16 description="Data that should be communicated during execution. Index should "
17 "be either numeric or Datetime, numeric values are interpreted as"
18 " seconds.",
19 validate_default=True,
20 )
21 outputs: List[AgentVariable] = Field(
22 title="Outputs",
23 default=[],
24 description="Optional list of columns of data frame that should be sent. If "
25 "ommited, all datapoint in frame are sent.",
26 )
27 t_sample: Union[float, int] = Field(
28 title="t_sample",
29 default=1,
30 ge=0,
31 description="Sampling time. Data source sends an interpolated value from the "
32 "data every <t_sample> seconds. Default is 1 s.",
33 )
34 data_offset: Optional[Union[pd.Timedelta, float]] = Field(
35 title="data_offset",
36 default=0,
37 description="Offset will be subtracted from index, allowing you to start at "
38 "any point in your data. I.e. if your environment starts at 0, "
39 "and you want your data-source to start at 1000 seconds, "
40 "you should set this to 1000.",
41 )
42 extrapolation: Literal["constant", "repeat", "backwards"] = Field(
43 title="Extrapolation",
44 default="constant",
45 description="Determines what to do, when the data source runs out. 'constant' "
46 "returns the last value, 'repeat' repeats the data from the "
47 "start, and 'backwards' goes through the data backwards, bouncing "
48 "indefinitely.",
49 )
50 shared_variable_fields: List[str] = ["outputs"]
52 @field_validator("data")
53 @classmethod
54 def check_data(cls, data):
55 """Makes sure data is a data frame, and loads it if required."""
56 if isinstance(data, (str, Path)) and Path(data).is_file():
57 data = pd.read_csv(data, engine="python", index_col=0)
58 if not isinstance(data, pd.DataFrame):
59 raise ValueError(
60 f"Data {data} is not a valid DataFrame or the path is not found."
61 )
62 if data.empty:
63 raise ValueError("Provided data is empty.")
64 if len(data) < 2:
65 raise ValueError(
66 "The dataframe must contain at least two rows for interpolation."
67 )
68 return data
70 def transform_to_numeric_index(self, data: pd.DataFrame) -> pd.DataFrame:
71 """Handles the index and ensures it is numeric, with correct offset"""
72 # Convert offset to seconds if it's a Timedelta
73 offset = self.data_offset
74 if isinstance(offset, pd.Timedelta):
75 offset = offset.total_seconds()
76 # Handle different index types
77 if isinstance(data.index, pd.DatetimeIndex):
78 data.index = (data.index - data.index[0]).total_seconds()
79 else:
80 # Try to convert to numeric if it's a string
81 try:
82 data.index = pd.to_numeric(data.index)
83 data.index = data.index - data.index[0]
84 except ValueError:
85 # If conversion to numeric fails, try to convert to datetime
86 try:
87 data.index = pd.to_datetime(data.index)
88 data.index = (data.index - data.index[0]).total_seconds()
89 except ValueError:
90 raise ValueError("Unable to convert index to numeric format")
92 data.index = data.index.astype(float) - offset
94 @model_validator(mode="after")
95 def validate_data(self):
96 """Checks if outputs and data columns match, and ensures a numeric index."""
97 if self.outputs:
98 columns = set(self.data.columns)
99 output_names = set(o.name for o in self.outputs)
101 missing_columns = output_names - columns
102 if missing_columns:
103 raise ValueError(
104 f"The following output columns are not present in the dataframe: "
105 f"{', '.join(missing_columns)}"
106 )
107 self.transform_to_numeric_index(self.data)
108 return self
111class CSVDataSource(BaseModule):
113 config: CSVDataSourceConfig
115 def __init__(self, config: dict, agent: Agent):
116 super().__init__(config=config, agent=agent)
118 data = self.config.data
120 # Interpolate the dataframe
121 start_time = data.index[0]
122 end_time = data.index[-1]
123 new_index = np.arange(
124 start_time, end_time + self.config.t_sample, self.config.t_sample
125 )
126 interpolated_data = (
127 data.reindex(data.index.union(new_index))
128 .interpolate(method="index")
129 .loc[new_index]
130 )
132 # Transform to list of tuples
133 self.data_tuples = list(interpolated_data.itertuples(index=False, name=None))
134 self.data_iterator = self.create_iterator()
136 def _get_next_data(self):
137 """Yield the next data point"""
138 data = next(self.data_iterator)
139 return data
141 def create_iterator(self):
142 """Create a custom iterator based on the extrapolation method"""
143 while True:
144 for item in self.data_tuples:
145 yield item
147 if self.config.extrapolation == "constant":
148 self.logger.warning(
149 "Data source has been exhausted. Returning last value indefinitely."
150 )
151 while True:
152 yield self.data_tuples[-1]
153 elif self.config.extrapolation == "repeat":
154 self.logger.warning(
155 "Data source has been exhausted. Repeating data from the start."
156 )
157 continue # This will restart the outer loop
158 elif self.config.extrapolation == "backwards":
159 self.logger.warning(
160 "Data source has been exhausted. Going through data backwards."
161 )
162 yield from self.backwards_iterator()
164 def backwards_iterator(self):
165 """Iterator for backwards extrapolation"""
166 while True:
167 for item in reversed(
168 self.data_tuples[:-1]
169 ): # Exclude the last item to avoid repetition
170 yield item
171 for item in self.data_tuples[
172 1:
173 ]: # Exclude the first item to avoid repetition
174 yield item
176 def process(self):
177 """Write the current data values into data_broker every t_sample"""
178 while True:
179 current_data = self._get_next_data()
180 for output, value in zip(self.config.outputs, current_data):
181 self.logger.debug(
182 f"At {self.env.time}: Sending variable {output.name} with value {value} to data broker."
183 )
184 self.set(output.name, value)
185 yield self.env.timeout(self.config.t_sample)
187 def register_callbacks(self):
188 """Don't do anything as this module is not event-triggered"""