diff --git a/train_img.py b/train_img.py index fa226348..1bba7ca9 100644 --- a/train_img.py +++ b/train_img.py @@ -103,8 +103,6 @@ def main(args): dtype = torch.bfloat16 elif args.mixed_precision == "fp16": dtype = torch.float16 - elif args.mixed_precision == "fp32": - dtype = torch.float32 else: raise ValueError(f"Unknown mixed precision {args.mixed_precision}") model: DiT = ( @@ -283,7 +281,7 @@ def main(args): parser.add_argument("--log-every", type=int, default=10) parser.add_argument("--ckpt-every", type=int, default=1000) - parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["bf16", "fp16", "fp32"]) + parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["bf16", "fp16"]) parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value") parser.add_argument("--lr", type=float, default=1e-4, help="Gradient clipping value") parser.add_argument("--grad_checkpoint", action="store_true", help="Use gradient checkpointing") diff --git a/train_video.py b/train_video.py index 49ce9942..0e55d65d 100644 --- a/train_video.py +++ b/train_video.py @@ -6,6 +6,7 @@ import colossalai import torch +import torch.distributed as dist from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.cluster import DistCoordinator @@ -49,6 +50,7 @@ def main(args): model_string_name = args.model.replace("/", "-") # Create an experiment folder experiment_dir = f"{args.outputs}/{experiment_index:03d}-{model_string_name}" + dist.barrier() if coordinator.is_master(): os.makedirs(experiment_dir, exist_ok=True) with open(f"{experiment_dir}/config.txt", "w") as f: @@ -97,7 +99,12 @@ def main(args): # Create model img_size = dataset[0][0].shape[-1] - dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 + if args.mixed_precision == "bf16": + dtype = torch.bfloat16 + elif args.mixed_precision == "fp16": + dtype = torch.float16 + else: + raise ValueError(f"Unknown mixed precision {args.mixed_precision}") model: DiT = ( DiT_models[args.model]( input_size=img_size, @@ -196,11 +203,15 @@ def main(args): # Log loss values: all_reduce_mean(loss) - if coordinator.is_master() and (step + 1) % args.log_every == 0: - pbar.set_postfix({"loss": loss.item()}) - writer.add_scalar("loss", loss.item(), epoch * num_steps_per_epoch + step) + global_step = epoch * num_steps_per_epoch + step + pbar.set_postfix({"loss": loss.item(), "step": step, "global_step": global_step}) + + # Log to tensorboard + if coordinator.is_master() and (global_step + 1) % args.log_every == 0: + writer.add_scalar("loss", loss.item(), global_step) - if args.ckpt_every > 0 and (step + 1) % args.ckpt_every == 0: + # Save checkpoint + if args.ckpt_every > 0 and (global_step + 1) % args.ckpt_every == 0: logger.info(f"Saving checkpoint") save( booster, @@ -210,12 +221,15 @@ def main(args): lr_scheduler, epoch, step + 1, + global_step + 1, args.batch_size, coordinator, experiment_dir, ema_shape_dict, ) - logger.info(f"Saved checkpoint at epoch {epoch} step {step + 1} to {experiment_dir}") + logger.info( + f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {experiment_dir}" + ) # the continue epochs are not resumed, so we need to reset the sampler start index and start step dataloader.sampler.set_start_index(0) @@ -242,7 +256,7 @@ def main(args): parser.add_argument("--batch-size", type=int, default=2) parser.add_argument("--global-seed", type=int, default=42) parser.add_argument("--num-workers", type=int, default=4) - parser.add_argument("--log-every", type=int, default=50) + parser.add_argument("--log-every", type=int, default=10) parser.add_argument("--ckpt-every", type=int, default=1000) parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["bf16", "fp16"]) parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")