Coverage for addmo/util/experiment_logger.py: 67%

167 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-08-31 13:05 +0000

1import os 

2from abc import ABC, abstractmethod 

3import pickle 

4import pandas as pd 

5from pandas import ExcelWriter 

6import wandb 

7from addmo.s3_model_tuning.models.abstract_model import AbstractMLModel 

8from addmo.util.load_save import save_config_to_json 

9from addmo.util.load_save_utils import create_or_clean_directory 

10from addmo.util.load_save_utils import create_path_or_ask_to_override 

11from addmo.s3_model_tuning.models.model_factory import ModelFactory 

12 

13 

14class AbstractLogger(ABC): 

15 @staticmethod 

16 @abstractmethod 

17 def start_experiment(config: dict = None, **kwargs): 

18 pass 

19 

20 @staticmethod 

21 @abstractmethod 

22 def finish_experiment(): 

23 pass 

24 

25 @staticmethod 

26 @abstractmethod 

27 def log(log: dict): 

28 pass 

29 

30 @staticmethod 

31 @abstractmethod 

32 def log_artifact(data, name: str, art_type: str): 

33 pass 

34 

35 @staticmethod 

36 @abstractmethod 

37 def use_artifact(name: str, alias: str = "latest"): 

38 pass 

39 

40 @staticmethod 

41 def _handle_pkl(data,name,art_type, directory): 

42 filename = f"{name}.{art_type}" 

43 create_path_or_ask_to_override(filename, directory) 

44 filepath = os.path.join(directory, filename) 

45 with open(filepath, "wb") as f: 

46 pickle.dump(data, f) 

47 return "pkl", [filepath] 

48 

49 @staticmethod 

50 def _handle_model(data,name,art_type, directory): 

51 model: AbstractMLModel = data 

52 filepath = os.path.join(directory, f"{name}.{art_type}") 

53 metadata_filepath = os.path.join(directory, f"{name}_metadata.json") 

54 model.save_regressor(directory, name, art_type) 

55 return "regressor", [filepath, metadata_filepath] 

56 

57class WandbLogger(AbstractLogger): 

58 active: bool = False 

59 project = None # Name of the Weights & Biases project that you created in the browser wandb.ai 

60 directory = None # Local directory where a backup of the uploaded files is stored 

61 

62 @staticmethod 

63 def start_experiment(config=None, **kwargs): 

64 """ 

65 Starts a new experiment and logs the config to wandb. 

66 """ 

67 if WandbLogger.active: 

68 if not os.path.exists(WandbLogger.directory): 

69 os.makedirs(WandbLogger.directory) 

70 wandb.init( 

71 project=WandbLogger.project, 

72 config=config, 

73 dir=WandbLogger.directory, 

74 **kwargs, 

75 ) 

76 

77 return wandb.config 

78 

79 @staticmethod 

80 def finish_experiment(): 

81 """ 

82 Finishes the current experiment. 

83 """ 

84 if WandbLogger.active: 

85 wandb.finish() 

86 

87 @staticmethod 

88 def log(log: dict): 

89 """ 

90 Logs run data. 

91 """ 

92 if WandbLogger.active: 

93 processed_log = {} 

94 for name, data in log.items(): 

95 if isinstance(data, (pd.DataFrame, pd.Series)): 

96 data = data.reset_index() 

97 processed_log[name] = wandb.Table(dataframe=data) 

98 elif isinstance(data, list): 

99 processed_log[name] = str(data) 

100 elif isinstance(data, tuple): 

101 processed_log[name] = str(data) 

102 elif isinstance(data, dict): 

103 # flatten the dictionary 

104 for key, value in data.items(): 

105 processed_log[f"{name}.{key}"] = value 

106 else: 

107 processed_log[name] = data 

108 wandb.log(processed_log) 

109 

110 @staticmethod 

111 def log_artifact( 

112 data, 

113 name: str, 

114 art_type: str, 

115 description: str = None, 

116 metadata: dict = None 

117 ): 

118 """ 

119 Logs artifact data. 

120 """ 

121 if not WandbLogger.active: 

122 return 

123 

124 type_handlers = { 

125 "pkl": lambda d: super(WandbLogger, WandbLogger)._handle_pkl(d, name, art_type, WandbLogger.directory), 

126 "h5": lambda d: super(WandbLogger, WandbLogger)._handle_model(d, name, art_type, WandbLogger.directory), 

127 "keras": lambda d: super(WandbLogger, WandbLogger)._handle_model(d, name, art_type, WandbLogger.directory), 

128 "joblib": lambda d: super(WandbLogger, WandbLogger)._handle_model(d, name, art_type, WandbLogger.directory), 

129 "onnx": lambda d: super(WandbLogger, WandbLogger)._handle_model(d, name, art_type, WandbLogger.directory), 

130 } 

131 

132 if art_type not in type_handlers: 

133 raise ValueError(f"Unsupported artifact type: {art_type}") 

134 

135 artifact_type, files_to_add = type_handlers[art_type](data) 

136 

137 artifact = wandb.Artifact( 

138 name=name, type=artifact_type, description=description, metadata=metadata 

139 ) 

140 

141 for file in files_to_add: 

142 artifact.add_file(file) 

143 

144 wandb.run.log_artifact(artifact) 

145 artifact.wait() 

146 

147 @staticmethod 

148 def use_artifact(name: str, alias: str = "latest"): 

149 """ 

150 Downloads logged model artifact from wandb. 

151 """ 

152 if WandbLogger.active: 

153 artifact = wandb.use_artifact(f"{name}:{alias}") 

154 artifact_dir = artifact.download() 

155 

156 # Find the model and metadata files 

157 for file in os.listdir(artifact_dir): 

158 if file.endswith(('.joblib', '.onnx', '.h5', '.keras', '.pkl')): 

159 model_file = file 

160 

161 model_path = os.path.join(artifact_dir, model_file) 

162 

163 if model_file.endswith('.pkl'): 

164 with open(model_path, "rb") as f: 

165 loaded_model = pickle.load(f) 

166 else: 

167 loaded_model = ModelFactory().load_model(model_path) 

168 

169 return loaded_model 

170 

171 

172class LocalLogger(AbstractLogger): #Todo: evtl. komplett löschen und auf normale speicher funktionen umstellen? 

173 active: bool = False # Activate local logging 

174 directory = None # Directory to store artifacts locally 

175 run_time_storage = {} # Storage for the current run 

176 

177 @staticmethod 

178 def start_experiment(config, **kwargs): 

179 """ 

180 Starts a new experiment and logs the config to LocalLogger. 

181 """ 

182 if LocalLogger.active: 

183 # create_or_clean_directory(LocalLogger.directory) 

184 path = os.path.join(LocalLogger.directory, "config.json") 

185 save_config_to_json(config, path) 

186 return config 

187 

188 @staticmethod 

189 def finish_experiment(): 

190 """ 

191 Finishes the current experiment. 

192 """ 

193 if LocalLogger.active: 

194 # safe run_time_storage to disk 

195 pass # Implement finish experiment logic here 

196 

197 @staticmethod 

198 def log(log: dict): 

199 if LocalLogger.active: 

200 # safe to run_time_storage 

201 pass # Implement log logic here 

202 

203 @staticmethod 

204 def log_artifact(data, name: str, art_type: str): 

205 """ 

206 Logs artifact data. 

207 """ 

208 if LocalLogger.active: 

209 if art_type == "system_data": 

210 file_path = os.path.join(LocalLogger.directory, name) 

211 data.to_csv(os.path.join(file_path + ".csv")) 

212 

213 return 

214 

215 type_handlers = { 

216 "pkl": lambda d: super(LocalLogger, LocalLogger)._handle_pkl(d, name, art_type, LocalLogger.directory), 

217 "h5": lambda d: super(LocalLogger, LocalLogger)._handle_model(d, name, art_type, LocalLogger.directory), 

218 "keras": lambda d: super(LocalLogger, LocalLogger)._handle_model(d, name, art_type, LocalLogger.directory), 

219 "joblib": lambda d: super(LocalLogger, LocalLogger)._handle_model(d, name, art_type, LocalLogger.directory), 

220 "onnx": lambda d: super(LocalLogger, LocalLogger)._handle_model(d, name, art_type, LocalLogger.directory), 

221 } 

222 

223 if art_type not in type_handlers: 

224 raise ValueError(f"Unsupported artifact type: {art_type}") 

225 

226 artifact_type, files_to_add = type_handlers[art_type](data) 

227 print(f"Saved {artifact_type} files: {files_to_add}") 

228 

229 

230 @staticmethod 

231 def use_artifact(name: str, alias: str = "latest"): 

232 """ 

233 Downloads logged model artifact. 

234 """ 

235 if LocalLogger.active: 

236 filename = name + '.csv' 

237 file_path = os.path.join(LocalLogger.directory, filename) 

238 if os.path.exists(file_path): # Check if the file exists 

239 return pd.read_csv(file_path) 

240 else: 

241 # If the file does not exist, return None silently 

242 return None 

243 

244 

245class ExperimentLogger(AbstractLogger): 

246 """Static class to trigger the different loggers. A static class can be used throughout the 

247 whole code without the need to pass it as an argument.""" 

248 

249 @staticmethod 

250 def start_experiment(config=None, **kwargs): 

251 config_wandb = WandbLogger.start_experiment(config, **kwargs) 

252 config_local = LocalLogger.start_experiment(config, **kwargs) 

253 return config_wandb or config_local 

254 

255 @staticmethod 

256 def finish_experiment(): 

257 WandbLogger.finish_experiment() 

258 LocalLogger.finish_experiment() 

259 

260 @staticmethod 

261 def log(log: dict): 

262 WandbLogger.log(log) 

263 LocalLogger.log(log) 

264 

265 @staticmethod 

266 def log_artifact(data, name: str, art_type: str): 

267 WandbLogger.log_artifact(data, name, art_type) 

268 LocalLogger.log_artifact(data, name, art_type) 

269 

270 @staticmethod 

271 def use_artifact(name: str, alias: str = "latest"): 

272 data_wandb = WandbLogger.use_artifact(name, alias) 

273 data_local = LocalLogger.use_artifact(name, alias) 

274 return data_wandb