Skip to content

Commit

Permalink
add logging messages and edit requirements.txt
Browse files Browse the repository at this point in the history
  • Loading branch information
Munkhtenger19 committed May 17, 2024
1 parent 129b097 commit 27bccb5
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 22 deletions.
8 changes: 0 additions & 8 deletions gnn_toolbox/custom_components/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,12 @@ def grad_with_checkpoint(outputs: Union[torch.Tensor, Sequence[torch.Tensor]],
inputs = (inputs,) if isinstance(inputs, torch.Tensor) else tuple(inputs)

for input in inputs:
# input.retain_grad()
if not input.is_leaf:
input.retain_grad()

torch.autograd.backward(outputs)

grad_outputs = []
# for input in inputs:
# if input.grad is not None:
# grad_outputs.append(input.grad.clone())
# input.grad.zero_()
# else:
# # Append zeros in the same shape as the input if no gradient was computed
# grad_outputs.append(torch.ones_like(input))

for input in inputs:
grad_outputs.append(input.grad.clone())
Expand Down
7 changes: 2 additions & 5 deletions gnn_toolbox/experiment_handler/artifact_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def save_model(self, model, params, result, is_unattacked_model):
with result_path.open('w') as file:
json.dump(result, file, indent=4)
except Exception as e:
logging.error(f"Failed to save model or results: {e}")
logging.error(f"Failed to save model or results to {model_dir}: {e}")
else:
logging.info(f'Saved the model {model_suffix} at {model_dir}')
logging.info(f'Saved the model {model_suffix} to {model_dir} for caching')

def model_exists(self, params, is_unattacked_model):
""" Check if a model with the given parameters already exists. """
Expand All @@ -51,9 +51,6 @@ def model_exists(self, params, is_unattacked_model):
hash_id = self.hash_parameters(params)
params_dir = self.cache_directory / f"{hash_id}"
if self.folder_exists(params_dir):
# params_path = params_dir / 'params.json'
# with params_path.open('r') as file:
# saved_params = json.load(file)
if is_unattacked_model:
model_name = f"{params['model']['name']}_{params['dataset']['name']}.pt"
model_path = params_dir / model_name
Expand Down
2 changes: 1 addition & 1 deletion gnn_toolbox/experiment_handler/create_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def create_transforms(configs):
transform_cls = get_from_registry("transform", name, registry)
if transform_cls is None:
raise ValueError(f"Transform '{name}' not found in the registry.")
logging.info(f"Creating transform '{name}' with parameters: {params}")
logging.debug(f"For dataset, creating transform '{name}' with parameters: {params}")
transforms.append(transform_cls(**params))

return Compose(transforms)
Expand Down
12 changes: 6 additions & 6 deletions gnn_toolbox/experiment_handler/result_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ def log_experiment_config(self):
file_path = self.log_to / "experiment_config.json"
with open(file_path, 'w') as f:
json.dump(self.experiment_config, f, indent=4)
logging.info("Experiment configuration logged successfully.")
logging.info(f"Experiment configuration logged to {file_path} successfully.")
except Exception as e:
logging.error(f"Failed to log experiment configuration: {e}")
logging.error(f"Failed to log experiment configuration to {file_path}: {e}")

def log_results(self):
try:
Expand All @@ -35,9 +35,9 @@ def log_results(self):
json.dump(value, f, indent=2)
if self.csv_save and (key=='clean_result' or self.perturbed_result2csv):
self.save_to_csv(key, value)
logging.info("Results logged successfully.")
logging.info(f"Results logged successfully to {self.log_to}.")
except Exception as e:
logging.error(f"Failed to log results: {e}")
logging.error(f"Failed to log results to {self.log_to}: {e}")

def save_to_csv(self, key, value):
try:
Expand All @@ -47,7 +47,7 @@ def save_to_csv(self, key, value):
writer.writerow(value[0].keys())
for result in value:
writer.writerow(result.values())
logging.info(f"{key} saved to CSV successfully.")
logging.info(f"{key} saved to CSV at {file_path} successfully.")
except Exception as e:
logging.error(f"Failed to save {key} to CSV: {e}")
logging.error(f"Failed to save {key} to CSV at {file_path}: {e}")

4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
numba==0.59.1
numpy==1.26.4
ogb==1.3.6
pydantic==2.7.1
PyYAML==6.0.1
scipy==1.13.0
tensorboardX==2.6.2.2
--find-links https://download.pytorch.org/whl/torch_stable.html
torch==2.2.1+cu118
torch_geometric==2.5.3
torch_scatter==2.1.2+pt22cu118
torch_sparse==0.6.18+pt22cu118
torch_geometric==2.5.3
torchtyping==0.1.4
tqdm==4.66.2
typeguard==4.2.1
Expand Down

0 comments on commit 27bccb5

Please sign in to comment.