Skip to content

Commit

Permalink
Add wandb helper.
Browse files Browse the repository at this point in the history
  • Loading branch information
lxuechen committed Aug 13, 2022
1 parent 0b834c1 commit 85d6cf1
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 1 deletion.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,5 @@ pip install git+https://github.com/lxuechen/ml-swissknife.git
- [x] Confidence interval utils
- Data
- [ ] UTKFaces
- wandb
- [ ] Project-based helper for downloading files
53 changes: 53 additions & 0 deletions ml_swissknife/wandb_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import logging
from typing import Sequence

import wandb

from . import utils


class WandbHelper(object):
"""Project-based wandb helper."""

def __init__(self, project_name, user='lxuechen'):
super(WandbHelper, self).__init__()
self._user = user
self._api = wandb.Api()
self._name_to_run_map = self._create_name_to_run_map(project_name=project_name)

def _create_name_to_run_map(self, project_name):
base_dir = utils.join(self._user, project_name)
runs = self._api.runs(base_dir)

name_to_run_map = dict()
for run in runs:
if run.name in name_to_run_map:
logging.warning(f"Observed repeated run name in {base_dir}; old value will be overridden.")
name_to_run_map[run.name] = run
return name_to_run_map

def name_to_run(self, name: str) -> wandb.apis.public.Run:
"""Retrieve the run based on name.
Note that `wandb.Api().run(<user>/<project>/<run_id>)` requires the `run_id` to retrieve, which is by default
a random hash. This makes finding runs very inconvenient.
Example usage to retrieve the run `<user>/<project>/example_run`:
wbhelper = WandbHelper(...).name_to_run('example_run')
For reference, the API for wandb.apis.public.Run:
https://docs.wandb.ai/ref/python/public-api/run
"""
return self._name_to_run_map[name]

def download_run(self, name, root='.', replace=False):
"""Download all files associated with a run."""
base_dir = utils.join(root, name)
run = self.name_to_run(name)
for file in run.files():
file.download(root=base_dir, replace=replace)

def download_runs(self, names=Sequence[str], root='.', replace=False):
"""Download all files associated with multiple runs."""
for name in names:
self.download_run(name, root=root, replace=replace)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
url="https://github.com/lxuechen/ml-swissknife",
install_requires=[
'torch', 'torchvision', 'spacy', 'tqdm', 'numpy', 'scipy', 'gputil', 'fire', 'requests', 'nltk', 'transformers',
'datasets', 'gdown>=4.4.0', 'pandas', 'pytest', 'matplotlib', 'seaborn', 'cvxpy', 'imageio'
'datasets', 'gdown>=4.4.0', 'pandas', 'pytest', 'matplotlib', 'seaborn', 'cvxpy', 'imageio', 'wandb'
],
extras_require=extras_require,
python_requires='~=3.7',
Expand Down

0 comments on commit 85d6cf1

Please sign in to comment.