Coverage for agentlib/modules/simulation/csv_data_source.py: 71%

84 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-04-07 16:27 +0000

1from pathlib import Path 

2from typing import List, Optional, Union, Literal 

3 

4import numpy as np 

5import pandas as pd 

6from pydantic import Field, field_validator, FilePath, model_validator 

7 

8from agentlib import Agent 

9from agentlib.core import BaseModule, BaseModuleConfig, AgentVariable 

10 

11 

12class CSVDataSourceConfig(BaseModuleConfig): 

13 data: Union[pd.DataFrame, FilePath] = Field( 

14 title="Data", 

15 default=pd.DataFrame(), 

16 description="Data that should be communicated during execution. Index should " 

17 "be either numeric or Datetime, numeric values are interpreted as" 

18 " seconds.", 

19 validate_default=True, 

20 ) 

21 outputs: List[AgentVariable] = Field( 

22 title="Outputs", 

23 default=[], 

24 description="Optional list of columns of data frame that should be sent. If " 

25 "ommited, all datapoint in frame are sent.", 

26 ) 

27 t_sample: Union[float, int] = Field( 

28 title="t_sample", 

29 default=1, 

30 ge=0, 

31 description="Sampling time. Data source sends an interpolated value from the " 

32 "data every <t_sample> seconds. Default is 1 s.", 

33 ) 

34 data_offset: Optional[Union[pd.Timedelta, float]] = Field( 

35 title="data_offset", 

36 default=0, 

37 description="Offset will be subtracted from index, allowing you to start at " 

38 "any point in your data. I.e. if your environment starts at 0, " 

39 "and you want your data-source to start at 1000 seconds, " 

40 "you should set this to 1000.", 

41 ) 

42 extrapolation: Literal["constant", "repeat", "backwards"] = Field( 

43 title="Extrapolation", 

44 default="constant", 

45 description="Determines what to do, when the data source runs out. 'constant' " 

46 "returns the last value, 'repeat' repeats the data from the " 

47 "start, and 'backwards' goes through the data backwards, bouncing " 

48 "indefinitely.", 

49 ) 

50 shared_variable_fields: List[str] = ["outputs"] 

51 

52 @field_validator("data") 

53 @classmethod 

54 def check_data(cls, data): 

55 """Makes sure data is a data frame, and loads it if required.""" 

56 if isinstance(data, (str, Path)) and Path(data).is_file(): 

57 data = pd.read_csv(data, engine="python", index_col=0) 

58 if not isinstance(data, pd.DataFrame): 

59 raise ValueError( 

60 f"Data {data} is not a valid DataFrame or the path is not found." 

61 ) 

62 if data.empty: 

63 raise ValueError("Provided data is empty.") 

64 if len(data) < 2: 

65 raise ValueError( 

66 "The dataframe must contain at least two rows for interpolation." 

67 ) 

68 return data 

69 

70 def transform_to_numeric_index(self, data: pd.DataFrame) -> pd.DataFrame: 

71 """Handles the index and ensures it is numeric, with correct offset""" 

72 # Convert offset to seconds if it's a Timedelta 

73 offset = self.data_offset 

74 if isinstance(offset, pd.Timedelta): 

75 offset = offset.total_seconds() 

76 # Handle different index types 

77 if isinstance(data.index, pd.DatetimeIndex): 

78 data.index = (data.index - data.index[0]).total_seconds() 

79 else: 

80 # Try to convert to numeric if it's a string 

81 try: 

82 data.index = pd.to_numeric(data.index) 

83 data.index = data.index - data.index[0] 

84 except ValueError: 

85 # If conversion to numeric fails, try to convert to datetime 

86 try: 

87 data.index = pd.to_datetime(data.index) 

88 data.index = (data.index - data.index[0]).total_seconds() 

89 except ValueError: 

90 raise ValueError("Unable to convert index to numeric format") 

91 

92 data.index = data.index.astype(float) - offset 

93 

94 @model_validator(mode="after") 

95 def validate_data(self): 

96 """Checks if outputs and data columns match, and ensures a numeric index.""" 

97 if self.outputs: 

98 columns = set(self.data.columns) 

99 output_names = set(o.name for o in self.outputs) 

100 

101 missing_columns = output_names - columns 

102 if missing_columns: 

103 raise ValueError( 

104 f"The following output columns are not present in the dataframe: " 

105 f"{', '.join(missing_columns)}" 

106 ) 

107 self.transform_to_numeric_index(self.data) 

108 return self 

109 

110 

111class CSVDataSource(BaseModule): 

112 

113 config: CSVDataSourceConfig 

114 

115 def __init__(self, config: dict, agent: Agent): 

116 super().__init__(config=config, agent=agent) 

117 

118 data = self.config.data 

119 

120 # Interpolate the dataframe 

121 start_time = data.index[0] 

122 end_time = data.index[-1] 

123 new_index = np.arange( 

124 start_time, end_time + self.config.t_sample, self.config.t_sample 

125 ) 

126 interpolated_data = ( 

127 data.reindex(data.index.union(new_index)) 

128 .interpolate(method="index") 

129 .loc[new_index] 

130 ) 

131 

132 # Transform to list of tuples 

133 self.data_tuples = list(interpolated_data.itertuples(index=False, name=None)) 

134 self.data_iterator = self.create_iterator() 

135 

136 def _get_next_data(self): 

137 """Yield the next data point""" 

138 data = next(self.data_iterator) 

139 return data 

140 

141 def create_iterator(self): 

142 """Create a custom iterator based on the extrapolation method""" 

143 while True: 

144 for item in self.data_tuples: 

145 yield item 

146 

147 if self.config.extrapolation == "constant": 

148 self.logger.warning( 

149 "Data source has been exhausted. Returning last value indefinitely." 

150 ) 

151 while True: 

152 yield self.data_tuples[-1] 

153 elif self.config.extrapolation == "repeat": 

154 self.logger.warning( 

155 "Data source has been exhausted. Repeating data from the start." 

156 ) 

157 continue # This will restart the outer loop 

158 elif self.config.extrapolation == "backwards": 

159 self.logger.warning( 

160 "Data source has been exhausted. Going through data backwards." 

161 ) 

162 yield from self.backwards_iterator() 

163 

164 def backwards_iterator(self): 

165 """Iterator for backwards extrapolation""" 

166 while True: 

167 for item in reversed( 

168 self.data_tuples[:-1] 

169 ): # Exclude the last item to avoid repetition 

170 yield item 

171 for item in self.data_tuples[ 

172 1: 

173 ]: # Exclude the first item to avoid repetition 

174 yield item 

175 

176 def process(self): 

177 """Write the current data values into data_broker every t_sample""" 

178 while True: 

179 current_data = self._get_next_data() 

180 for output, value in zip(self.config.outputs, current_data): 

181 self.logger.debug( 

182 f"At {self.env.time}: Sending variable {output.name} with value {value} to data broker." 

183 ) 

184 self.set(output.name, value) 

185 yield self.env.timeout(self.config.t_sample) 

186 

187 def register_callbacks(self): 

188 """Don't do anything as this module is not event-triggered"""