Coverage for addmo/util/experiment_logger.py: 67%
167 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-08-31 13:05 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-08-31 13:05 +0000
1import os
2from abc import ABC, abstractmethod
3import pickle
4import pandas as pd
5from pandas import ExcelWriter
6import wandb
7from addmo.s3_model_tuning.models.abstract_model import AbstractMLModel
8from addmo.util.load_save import save_config_to_json
9from addmo.util.load_save_utils import create_or_clean_directory
10from addmo.util.load_save_utils import create_path_or_ask_to_override
11from addmo.s3_model_tuning.models.model_factory import ModelFactory
14class AbstractLogger(ABC):
15 @staticmethod
16 @abstractmethod
17 def start_experiment(config: dict = None, **kwargs):
18 pass
20 @staticmethod
21 @abstractmethod
22 def finish_experiment():
23 pass
25 @staticmethod
26 @abstractmethod
27 def log(log: dict):
28 pass
30 @staticmethod
31 @abstractmethod
32 def log_artifact(data, name: str, art_type: str):
33 pass
35 @staticmethod
36 @abstractmethod
37 def use_artifact(name: str, alias: str = "latest"):
38 pass
40 @staticmethod
41 def _handle_pkl(data,name,art_type, directory):
42 filename = f"{name}.{art_type}"
43 create_path_or_ask_to_override(filename, directory)
44 filepath = os.path.join(directory, filename)
45 with open(filepath, "wb") as f:
46 pickle.dump(data, f)
47 return "pkl", [filepath]
49 @staticmethod
50 def _handle_model(data,name,art_type, directory):
51 model: AbstractMLModel = data
52 filepath = os.path.join(directory, f"{name}.{art_type}")
53 metadata_filepath = os.path.join(directory, f"{name}_metadata.json")
54 model.save_regressor(directory, name, art_type)
55 return "regressor", [filepath, metadata_filepath]
57class WandbLogger(AbstractLogger):
58 active: bool = False
59 project = None # Name of the Weights & Biases project that you created in the browser wandb.ai
60 directory = None # Local directory where a backup of the uploaded files is stored
62 @staticmethod
63 def start_experiment(config=None, **kwargs):
64 """
65 Starts a new experiment and logs the config to wandb.
66 """
67 if WandbLogger.active:
68 if not os.path.exists(WandbLogger.directory):
69 os.makedirs(WandbLogger.directory)
70 wandb.init(
71 project=WandbLogger.project,
72 config=config,
73 dir=WandbLogger.directory,
74 **kwargs,
75 )
77 return wandb.config
79 @staticmethod
80 def finish_experiment():
81 """
82 Finishes the current experiment.
83 """
84 if WandbLogger.active:
85 wandb.finish()
87 @staticmethod
88 def log(log: dict):
89 """
90 Logs run data.
91 """
92 if WandbLogger.active:
93 processed_log = {}
94 for name, data in log.items():
95 if isinstance(data, (pd.DataFrame, pd.Series)):
96 data = data.reset_index()
97 processed_log[name] = wandb.Table(dataframe=data)
98 elif isinstance(data, list):
99 processed_log[name] = str(data)
100 elif isinstance(data, tuple):
101 processed_log[name] = str(data)
102 elif isinstance(data, dict):
103 # flatten the dictionary
104 for key, value in data.items():
105 processed_log[f"{name}.{key}"] = value
106 else:
107 processed_log[name] = data
108 wandb.log(processed_log)
110 @staticmethod
111 def log_artifact(
112 data,
113 name: str,
114 art_type: str,
115 description: str = None,
116 metadata: dict = None
117 ):
118 """
119 Logs artifact data.
120 """
121 if not WandbLogger.active:
122 return
124 type_handlers = {
125 "pkl": lambda d: super(WandbLogger, WandbLogger)._handle_pkl(d, name, art_type, WandbLogger.directory),
126 "h5": lambda d: super(WandbLogger, WandbLogger)._handle_model(d, name, art_type, WandbLogger.directory),
127 "keras": lambda d: super(WandbLogger, WandbLogger)._handle_model(d, name, art_type, WandbLogger.directory),
128 "joblib": lambda d: super(WandbLogger, WandbLogger)._handle_model(d, name, art_type, WandbLogger.directory),
129 "onnx": lambda d: super(WandbLogger, WandbLogger)._handle_model(d, name, art_type, WandbLogger.directory),
130 }
132 if art_type not in type_handlers:
133 raise ValueError(f"Unsupported artifact type: {art_type}")
135 artifact_type, files_to_add = type_handlers[art_type](data)
137 artifact = wandb.Artifact(
138 name=name, type=artifact_type, description=description, metadata=metadata
139 )
141 for file in files_to_add:
142 artifact.add_file(file)
144 wandb.run.log_artifact(artifact)
145 artifact.wait()
147 @staticmethod
148 def use_artifact(name: str, alias: str = "latest"):
149 """
150 Downloads logged model artifact from wandb.
151 """
152 if WandbLogger.active:
153 artifact = wandb.use_artifact(f"{name}:{alias}")
154 artifact_dir = artifact.download()
156 # Find the model and metadata files
157 for file in os.listdir(artifact_dir):
158 if file.endswith(('.joblib', '.onnx', '.h5', '.keras', '.pkl')):
159 model_file = file
161 model_path = os.path.join(artifact_dir, model_file)
163 if model_file.endswith('.pkl'):
164 with open(model_path, "rb") as f:
165 loaded_model = pickle.load(f)
166 else:
167 loaded_model = ModelFactory().load_model(model_path)
169 return loaded_model
172class LocalLogger(AbstractLogger): #Todo: evtl. komplett löschen und auf normale speicher funktionen umstellen?
173 active: bool = False # Activate local logging
174 directory = None # Directory to store artifacts locally
175 run_time_storage = {} # Storage for the current run
177 @staticmethod
178 def start_experiment(config, **kwargs):
179 """
180 Starts a new experiment and logs the config to LocalLogger.
181 """
182 if LocalLogger.active:
183 # create_or_clean_directory(LocalLogger.directory)
184 path = os.path.join(LocalLogger.directory, "config.json")
185 save_config_to_json(config, path)
186 return config
188 @staticmethod
189 def finish_experiment():
190 """
191 Finishes the current experiment.
192 """
193 if LocalLogger.active:
194 # safe run_time_storage to disk
195 pass # Implement finish experiment logic here
197 @staticmethod
198 def log(log: dict):
199 if LocalLogger.active:
200 # safe to run_time_storage
201 pass # Implement log logic here
203 @staticmethod
204 def log_artifact(data, name: str, art_type: str):
205 """
206 Logs artifact data.
207 """
208 if LocalLogger.active:
209 if art_type == "system_data":
210 file_path = os.path.join(LocalLogger.directory, name)
211 data.to_csv(os.path.join(file_path + ".csv"))
213 return
215 type_handlers = {
216 "pkl": lambda d: super(LocalLogger, LocalLogger)._handle_pkl(d, name, art_type, LocalLogger.directory),
217 "h5": lambda d: super(LocalLogger, LocalLogger)._handle_model(d, name, art_type, LocalLogger.directory),
218 "keras": lambda d: super(LocalLogger, LocalLogger)._handle_model(d, name, art_type, LocalLogger.directory),
219 "joblib": lambda d: super(LocalLogger, LocalLogger)._handle_model(d, name, art_type, LocalLogger.directory),
220 "onnx": lambda d: super(LocalLogger, LocalLogger)._handle_model(d, name, art_type, LocalLogger.directory),
221 }
223 if art_type not in type_handlers:
224 raise ValueError(f"Unsupported artifact type: {art_type}")
226 artifact_type, files_to_add = type_handlers[art_type](data)
227 print(f"Saved {artifact_type} files: {files_to_add}")
230 @staticmethod
231 def use_artifact(name: str, alias: str = "latest"):
232 """
233 Downloads logged model artifact.
234 """
235 if LocalLogger.active:
236 filename = name + '.csv'
237 file_path = os.path.join(LocalLogger.directory, filename)
238 if os.path.exists(file_path): # Check if the file exists
239 return pd.read_csv(file_path)
240 else:
241 # If the file does not exist, return None silently
242 return None
245class ExperimentLogger(AbstractLogger):
246 """Static class to trigger the different loggers. A static class can be used throughout the
247 whole code without the need to pass it as an argument."""
249 @staticmethod
250 def start_experiment(config=None, **kwargs):
251 config_wandb = WandbLogger.start_experiment(config, **kwargs)
252 config_local = LocalLogger.start_experiment(config, **kwargs)
253 return config_wandb or config_local
255 @staticmethod
256 def finish_experiment():
257 WandbLogger.finish_experiment()
258 LocalLogger.finish_experiment()
260 @staticmethod
261 def log(log: dict):
262 WandbLogger.log(log)
263 LocalLogger.log(log)
265 @staticmethod
266 def log_artifact(data, name: str, art_type: str):
267 WandbLogger.log_artifact(data, name, art_type)
268 LocalLogger.log_artifact(data, name, art_type)
270 @staticmethod
271 def use_artifact(name: str, alias: str = "latest"):
272 data_wandb = WandbLogger.use_artifact(name, alias)
273 data_local = LocalLogger.use_artifact(name, alias)
274 return data_wandb