Skip to content

Commit

Permalink
add env
Browse files Browse the repository at this point in the history
  • Loading branch information
JavaZeroo committed Aug 18, 2023
1 parent 8955724 commit 562278e
Show file tree
Hide file tree
Showing 4 changed files with 338 additions and 53 deletions.
125 changes: 125 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
name: sde
channels:
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- asttokens=2.2.1=pyhd8ed1ab_0
- backcall=0.2.0=pyh9f0ad1d_0
- backports=1.0=pyhd8ed1ab_3
- backports.functools_lru_cache=1.6.5=pyhd8ed1ab_0
- bzip2=1.0.8=h7b6447c_0
- ca-certificates=2023.7.22=hbcca054_0
- comm=0.1.4=pyhd8ed1ab_0
- debugpy=1.6.7=py311h6a678d5_0
- decorator=5.1.1=pyhd8ed1ab_0
- executing=1.2.0=pyhd8ed1ab_0
- importlib-metadata=6.8.0=pyha770c72_0
- importlib_metadata=6.8.0=hd8ed1ab_0
- ipykernel=6.25.1=pyh71e2992_0
- ipython=8.14.0=pyh41d4057_0
- jedi=0.19.0=pyhd8ed1ab_0
- jupyter_client=8.3.0=pyhd8ed1ab_0
- jupyter_core=5.3.1=py311h38be061_0
- ld_impl_linux-64=2.38=h1181459_1
- libffi=3.4.4=h6a678d5_0
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libsodium=1.0.18=h36c2ea0_1
- libstdcxx-ng=11.2.0=h1234567_1
- libuuid=1.41.5=h5eee18b_0
- matplotlib-inline=0.1.6=pyhd8ed1ab_0
- ncurses=6.4=h6a678d5_0
- nest-asyncio=1.5.6=pyhd8ed1ab_0
- openssl=3.0.10=h7f8727e_0
- packaging=23.1=pyhd8ed1ab_0
- parso=0.8.3=pyhd8ed1ab_0
- pexpect=4.8.0=pyh1a96a4e_2
- pickleshare=0.7.5=py_1003
- pip=23.2.1=py311h06a4308_0
- platformdirs=3.10.0=pyhd8ed1ab_0
- prompt-toolkit=3.0.39=pyha770c72_0
- prompt_toolkit=3.0.39=hd8ed1ab_0
- psutil=5.9.0=py311h5eee18b_0
- ptyprocess=0.7.0=pyhd3deb0d_0
- pure_eval=0.2.2=pyhd8ed1ab_0
- pygments=2.16.1=pyhd8ed1ab_0
- python=3.11.4=h955ad1f_0
- python-dateutil=2.8.2=pyhd8ed1ab_0
- python_abi=3.11=2_cp311
- pyzmq=25.1.0=py311h6a678d5_0
- readline=8.2=h5eee18b_0
- setuptools=68.0.0=py311h06a4308_0
- six=1.16.0=pyh6c4a22f_0
- sqlite=3.41.2=h5eee18b_0
- stack_data=0.6.2=pyhd8ed1ab_0
- tk=8.6.12=h1ccaba5_0
- tornado=6.3.2=py311h5eee18b_0
- traitlets=5.9.0=pyhd8ed1ab_0
- typing-extensions=4.7.1=hd8ed1ab_0
- typing_extensions=4.7.1=pyha770c72_0
- tzdata=2023c=h04d1e81_0
- wcwidth=0.2.6=pyhd8ed1ab_0
- wheel=0.38.4=py311h06a4308_0
- xz=5.4.2=h5eee18b_0
- zeromq=4.3.4=h9c3ff4c_1
- zipp=3.16.2=pyhd8ed1ab_0
- zlib=1.2.13=h5eee18b_0
- pip:
- certifi==2023.7.22
- charset-normalizer==3.2.0
- cmake==3.27.2
- contourpy==1.1.0
- cycler==0.11.0
- easydict==1.10
- efficientnet-pytorch==0.7.1
- filelock==3.12.2
- fonttools==4.42.0
- fsspec==2023.6.0
- huggingface-hub==0.16.4
- idna==3.4
- imageio==2.31.1
- jinja2==3.1.2
- joblib==1.3.2
- kiwisolver==1.4.4
- lit==16.0.6
- markdown-it-py==3.0.0
- markupsafe==2.1.3
- matplotlib==3.7.2
- mdurl==0.1.2
- mpmath==1.3.0
- munch==4.0.0
- networkx==3.1
- numpy==1.25.2
- nvidia-cublas-cu11==11.10.3.66
- nvidia-cuda-cupti-cu11==11.7.101
- nvidia-cuda-nvrtc-cu11==11.7.99
- nvidia-cuda-runtime-cu11==11.7.99
- nvidia-cudnn-cu11==8.5.0.96
- nvidia-cufft-cu11==10.9.0.58
- nvidia-curand-cu11==10.2.10.91
- nvidia-cusolver-cu11==11.4.0.1
- nvidia-cusparse-cu11==11.7.4.91
- nvidia-nccl-cu11==2.14.3
- nvidia-nvtx-cu11==11.7.91
- pillow==10.0.0
- plotly==5.16.0
- pretrainedmodels==0.7.4
- pyparsing==3.0.9
- pyyaml==6.0.1
- requests==2.31.0
- rich==13.5.2
- safetensors==0.3.2
- scikit-learn==1.3.0
- scipy==1.11.1
- segmentation-models-pytorch==0.3.3
- sympy==1.12
- tenacity==8.2.3
- threadpoolctl==3.2.0
- timm==0.9.2
- torch==2.0.1
- torchvision==0.15.2
- tqdm==4.66.1
- triton==2.0.0
- urllib3==2.0.4
245 changes: 199 additions & 46 deletions gaussian2ring.ipynb

Large diffs are not rendered by default.

11 changes: 9 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def main():
parser.add_argument('--batch_size', type=int, default=8000)
parser.add_argument('-n','--normalize', action='store_true')
parser.add_argument('--tarined_data', action='store_true')

parser.add_argument('--filter_number', type=int)

args = parser.parse_args()
check_model_task(args)

Expand All @@ -58,7 +59,13 @@ def main():
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)

experiment_name = args.task
experiment_name = args.task
if args.change_epsilons:
experiment_name += '_change_epsilons'
if args.filter_number is not None and 'mnist' in args.task:
experiment_name += f'_filter{args.filter_number}'


log_dir = Path('experiments') / experiment_name / 'test' / tt.strftime("%Y-%m-%d/%H_%M_%S/")
ds_cached_dir = Path('experiments') / experiment_name / 'data'
log_dir.mkdir(parents=True, exist_ok=True)
Expand Down
10 changes: 5 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,12 @@ def main_worker(args):
TimeElapsedColumn(),
transient=False,
) as progress:
task1 = progress.add_task("[gold]Training whole dataset (lr: X) (loss=X)", total=ds_info['nums_sub_ds']*args.epoch_nums)
task1 = progress.add_task("[red]Training whole dataset (lr: X) (loss=X)", total=ds_info['nums_sub_ds']*args.epoch_nums)
while not progress.finished:
if ds_info['nums_sub_ds'] == 1:
new_dl = read_ds_from_pkl(args,
real_metadata,
args.ds_cached_dir / f"new_ds_{int(iter%ds_info['nums_sub_ds'])}.pkl"
args.ds_cached_dir / f"new_ds_0.pkl"
)
for iter in range(ds_info['nums_sub_ds']*args.epoch_nums):
if ds_info['nums_sub_ds'] > 1:
Expand All @@ -158,7 +158,7 @@ def main_worker(args):
args.ds_cached_dir / f"new_ds_{int(iter%ds_info['nums_sub_ds'])}.pkl"
)

task2 = progress.add_task(f"[green]Training sub dataset {int(iter%ds_info['nums_sub_ds'])}", total=args.iter_nums)
task2 = progress.add_task(f"[dark_orange]Training sub dataset {int(iter%ds_info['nums_sub_ds'])}", total=args.iter_nums)
for _ in range(args.iter_nums):
now_loss = train(args, model ,new_dl, optimizer, scheduler, loss_fn, before_train, after_train)
loss_list.append(now_loss)
Expand All @@ -167,8 +167,8 @@ def main_worker(args):
progress.update(task2, visible=False)
progress.remove_task(task2)
torch.save(model.state_dict(), args.log_dir / f'model_{model.__class__.__name__}_{int(iter)}.pth')
progress.update(task1, advance=1, description="[red]Training whole dataset (l r: %2.5f) (loss=%2.5f)" % (cur_lr, now_loss))

progress.update(task1, advance=1, description="[red]Training whole dataset (lr: %2.5f) (loss=%2.5f)" % (cur_lr, now_loss))
progress.log(f"[green]sub dataset {int(iter%ds_info['nums_sub_ds'])} finished; Loss: {now_loss}")
# Draw loss curve
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(loss_list)
Expand Down

0 comments on commit 562278e

Please sign in to comment.