Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
JavaZeroo committed Aug 18, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent d8c6c3b commit 8955724
Showing 5 changed files with 238 additions and 58 deletions.
198 changes: 175 additions & 23 deletions check_model.ipynb

Large diffs are not rendered by default.

21 changes: 12 additions & 9 deletions test.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@
from rich.pretty import Pretty
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn
import time
import time as tt

from utils.Datasets import BBdataset, MNISTdataset
from utils.utils import plot_source_and_target_mnist, binary, save_gif_frame_mnist
@@ -59,7 +59,7 @@ def main():
np.random.seed(seed)

experiment_name = args.task
log_dir = Path('experiments') / experiment_name / 'test' / time.strftime("%Y-%m-%d/%H_%M_%S/")
log_dir = Path('experiments') / experiment_name / 'test' / tt.strftime("%Y-%m-%d/%H_%M_%S/")
ds_cached_dir = Path('experiments') / experiment_name / 'data'
log_dir.mkdir(parents=True, exist_ok=True)
ds_cached_dir.mkdir(parents=True, exist_ok=True)
@@ -82,7 +82,7 @@ def main_worker(args):
console.log(f"Saving to {Path.absolute(args.log_dir)}")

model, before_train, after_train = get_model_before_after(args)
# console.log(model)
if args.checkpoint is not None:
try:
model.load_state_dict(torch.load(args.checkpoint))
@@ -116,8 +116,7 @@ def main_worker(args):
pred_drift = torch.zeros_like(test_drift)

pred_bridge[0, :] = test_source
model.to(args.device)
model.eval()
# model.eval()

sigma=1
console.rule("[bold deep_sky_blue1 blink]Testing")
@@ -132,7 +131,10 @@ def main_worker(args):
for i in range(len(test_ts) - 1):
dt = (test_ts[i+1] - test_ts[i])
test_source_reshaped = test_source
test_ts_reshaped = test_ts[i].repeat(test_source.shape[0]).reshape(-1, 1, 1, 1).repeat(1, 1, 28, 28)
if args.time_expand:
test_ts_reshaped = test_ts[i].repeat(test_source.shape[0]).reshape(-1, 1, 1, 1).repeat(1, 1, 28, 28)
else:
test_ts_reshaped = torch.unsqueeze(test_ts[i], dim=0).T
pred_bridge_reshaped = pred_bridge[i]

ret = normalize_dataset_with_metadata(real_metadata, source=test_source_reshaped, ts=test_ts_reshaped, bridge=pred_bridge_reshaped)
@@ -145,11 +147,12 @@ def main_worker(args):
else:
x = torch.concat([test_source_reshaped, pred_bridge_reshaped], axis=1)
time = test_ts_reshaped.to(args.device)
x.to(args.device)

if before_train is not None:
x = before_train(x)
dydt = model(x, time) if time else model(x)

x = x.to(args.device)
model = model.to(args.device)
dydt = model(x, time) if time is not None else model(x)
dydt = dydt.cpu()
if after_train is not None:
dydt = after_train(dydt)
18 changes: 12 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@
import argparse

def check_model_task(args):
if args.task == 'gaussian2mnist':
if args.task.startswith('gaussian2mnist'):
assert args.model in ['tunet++', 'unet++', 'unet']
args.time_expand = False
else:
@@ -46,7 +46,8 @@ def main():
parser.add_argument('--batch_size', type=int, default=8000)
parser.add_argument('-n','--normalize', action='store_true')
parser.add_argument('--num_workers', type=int, default=20)

parser.add_argument('--filter_number', type=int)

args = parser.parse_args()
check_model_task(args)

@@ -57,13 +58,18 @@ def main():
np.random.seed(seed)

experiment_name = args.task
if args.change_epsilons:
experiment_name += '_change_epsilons'
if args.filter_number is not None and 'mnist' in args.task:
experiment_name += f'_filter{args.filter_number}'

log_dir = Path('experiments') / experiment_name / 'train' / time.strftime("%Y-%m-%d/%H_%M_%S/")
ds_cached_dir = Path('experiments') / experiment_name / 'data'
log_dir.mkdir(parents=True, exist_ok=True)
ds_cached_dir.mkdir(parents=True, exist_ok=True)
args.log_dir = log_dir
args.ds_cached_dir = ds_cached_dir
if args.task == 'gaussian2mnist':
if args.task.startswith('gaussian2mnist'):
args.dim = 1
else:
args.dim = 2
@@ -86,8 +92,8 @@ def train(args, model, train_dl, optimizer, scheduler, loss_fn, before_train=Non
x = before_train(x)
x = x.to(args.device)
y = y.to(args.device)
time = time.to(args.device) if time else None
pred = model(x, time) if time else model(x)
time = time.to(args.device) if time is not None else None
pred = model(x, time) if time is not None else model(x)
if after_train is not None:
pred = after_train(pred)
loss = loss_fn(pred, y)
@@ -161,7 +167,7 @@ def main_worker(args):
progress.update(task2, visible=False)
progress.remove_task(task2)
torch.save(model.state_dict(), args.log_dir / f'model_{model.__class__.__name__}_{int(iter)}.pth')
progress.update(task1, advance=1, description="[red]Training whole dataset (lr: %2.5f) (loss=%2.5f)" % (cur_lr, now_loss))
progress.update(task1, advance=1, description="[red]Training whole dataset (l r: %2.5f) (loss=%2.5f)" % (cur_lr, now_loss))

# Draw loss curve
fig, ax = plt.subplots(figsize=(10, 5))
3 changes: 2 additions & 1 deletion utils/Models.py
Original file line number Diff line number Diff line change
@@ -46,7 +46,8 @@ def timestep_embedding(self, timesteps, dim, max_period=10000):
return embedding

def forward(self, x, timesteps):
emb = self.time_embed(self.timestep_embedding(timesteps, self.model_channels))
emb = self.timestep_embedding(timesteps, self.model_channels)
emb = self.time_embed(emb)
self.check_input_shape(x)
features = self.encoder(x)
for index, f in enumerate(features):
56 changes: 37 additions & 19 deletions utils/data_utils.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
from rich.progress import track
from utils.normalize import get_total_mean_std
from torch.utils.data import Dataset, DataLoader
import numpy as np

def gen_bridge_2d(x, y, ts, T, num_samples):
"""
@@ -139,16 +140,19 @@ def normalize_dataset_with_metadata(metadata, ts=None, bridge=None, drift=None,
return retdata


def gen_mnist_array_in_order(range=(0, 1000)):
def gen_mnist_array_in_order(range=(0, 1000), data=None):
"""
Generate MNIST array in order
"""
train_ds = torchvision.datasets.MNIST(
root="./data/",
train=True,
download=True
)
target = train_ds.data.view(-1, 1, 28, 28).float()
if data is None:
train_ds = torchvision.datasets.MNIST(
root="./data/",
train=True,
download=True
)
target = train_ds.data.view(-1, 1, 28, 28).float()
else:
target = data.view(-1, 1, 28, 28).float()

# random choice nums samples
target = target[range[0]:range[1]]
@@ -189,14 +193,12 @@ def gen_bridge(x, y, ts, T):
bridge[i+1, :] += diffusion
return bridge, drift

def gen_mnist_data_in_order(range=(0, 1000), change_epsilons=False):
def gen_mnist_data_in_order(range=(0, 1000), data=None, change_epsilons=False):
"""
Generate MNIST dataset in order
"""
source, target = gen_mnist_array_in_order(range)
epsilon = 0.001
T = 1
ts = torch.arange(0, T+epsilon, epsilon)

source, target = gen_mnist_array_in_order(range, data)

T = 1
if change_epsilons:
@@ -221,11 +223,11 @@ def gen_mnist_data(nums=100, change_epsilons=False):

T = 1
if change_epsilons:
epsilon1 = 0.001
epsilon2 = 0.0001
epsilon1 = 0.01
epsilon2 = 0.001

t1 = torch.arange(0, 0.91, epsilon1)
t2 = torch.arange(0.91, T, epsilon2)
t1 = torch.arange(0, 0.99, epsilon1)
t2 = torch.arange(0.99, T, epsilon2)
ts = torch.concatenate((t1, t2))
else:
epsilon = 0.001
@@ -237,11 +239,27 @@ def gen_mnist_data(nums=100, change_epsilons=False):


def preprocess_mnist_data(args):
if args.filter_number is not None:
train_ds = torchvision.datasets.MNIST(
root="./data/",
train=True,
download=True
)
imgs = []
for img, label in train_ds:
if label != args.filter_number:
continue
imgs.append(torch.Tensor(np.array(img)))
length = int(len(imgs)/1000)*1000
imgs = torch.stack(imgs[:length])
else:
imgs = None
length = 60000
# check data pickle file
for i in track(range(60), description="Preprocessing dataset"):
for i in track(range(int(length/1e3)), description="Preprocessing dataset"):
if (args.ds_cached_dir / f'new_ds_{i}.pkl').exists():
continue
ts, bridge, drift, source, target = gen_mnist_data_in_order((i*1000, (i+1)*1000))
ts, bridge, drift, source, target = gen_mnist_data_in_order((i*1000, (i+1)*1000), data=imgs, change_epsilons=args.change_epsilons)
_, metadata = normalize_dataset(ts, bridge, drift, source, target)
new_ds = MNISTdataset(ts, bridge, drift, source, target)
new_ds.metadata = metadata
@@ -250,7 +268,7 @@ def preprocess_mnist_data(args):
get_total_mean_std(args)

ret = {
"nums_sub_ds": 60
"nums_sub_ds": int(length/1e3)
}

return ret

0 comments on commit 8955724

Please sign in to comment.