Skip to content

Commit

Permalink
Add utils for config instantiation. (#250)
Browse files Browse the repository at this point in the history
* Add batch_c15n for [0,1] image input and imagenet-normalized input.

* Turn off inference mode before creating perturbations.

* Switch to training mode before running LightningModule.training_step().

* Add utils for config instantiation.

* Comment

* Clean up.
  • Loading branch information
mzweilin authored May 16, 2024
1 parent c117823 commit 07409b4
Showing 1 changed file with 29 additions and 1 deletion.
30 changes: 29 additions & 1 deletion mart/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@

import os

import hydra
from hydra import compose as hydra_compose
from hydra import initialize_config_dir
from lightning.pytorch.callbacks.callback import Callback
from omegaconf import OmegaConf

DEFAULT_VERSION_BASE = "1.2"
DEFAULT_CONFIG_DIR = "."
DEFAULT_CONFIG_NAME = "lightning.yaml"

__all__ = ["compose"]
__all__ = ["compose", "instantiate", "Instantiator", "CallbackInstantiator"]


def compose(
Expand All @@ -40,3 +43,28 @@ def compose(
cfg = cfg[key]

return cfg


def instantiate(cfg_path):
"""Instantiate an object from a Hydra yaml config file."""
config = OmegaConf.load(cfg_path)
obj = hydra.utils.instantiate(config)
return obj


class Instantiator:
def __new__(cls, cfg_path):
return instantiate(cfg_path)


class CallbackInstantiator(Callback):
"""Satisfying type checking for Lightning Callback."""

def __new__(cls, cfg_path):
obj = instantiate(cfg_path)
if isinstance(obj, Callback):
return obj
else:
raise ValueError(
f"We expect to instantiate a lightning Callback from {cfg_path}, but we get {type(obj)} instead."
)

0 comments on commit 07409b4

Please sign in to comment.