Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

train a 4step SDXL got AssertionError: Invalid: `_post_backward_hook_state #41

Open
joelulu opened this issue Aug 9, 2024 · 8 comments

Comments

@joelulu
Copy link

joelulu commented Aug 9, 2024

Thanks for excellent work!
When I try to train a 4step SDXL model.(2 nodes 16 GPUs ) I got an error:

`[rank2]: Traceback (most recent call last):
[rank2]: File "/mnt/nas/gaohl/project/DMD2-main/main/train_sd.py", line 739, in
[rank2]: trainer.train()
[rank2]: File "/mnt/nas/gaohl/project/DMD2-main/main/train_sd.py", line 633, in train
[rank2]: self.train_one_step()
[rank2]: File "/mnt/nas/gaohl/project/DMD2-main/main/train_sd.py", line 390, in train_one_step
[rank2]: self.accelerator.backward(generator_loss)
[rank2]: File "/usr/local/lib/python3.10/site-packages/accelerate/accelerator.py", line 2159, in backward
[rank2]: loss.backward(**kwargs)
[rank2]: File "/usr/local/lib/python3.10/site-packages/torch/_tensor.py", line 525, in backward
[rank2]: torch.autograd.backward(
[rank2]: File "/usr/local/lib/python3.10/site-packages/torch/autograd/init.py", line 267, in backward
[rank2]: _engine_run_backward(
[rank2]: File "/usr/local/lib/python3.10/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
[rank2]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[rank2]: File "/usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
[rank2]: return func(*args, **kwargs)
[rank2]: File "/usr/local/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1099, in _post_backward_final_callback
[rank2]: _finalize_params(fsdp_state)
[rank2]: File "/usr/local/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1168, in _finalize_params
[rank2]: _p_assert(
[rank2]: File "/usr/local/lib/python3.10/site-packages/torch/distributed/utils.py", line 146, in _p_assert
[rank2]: raise AssertionError(s)
[rank2]: AssertionError: Invalid: _post_backward_hook_state: (<torch.autograd.graph.register_multi_grad_hook..Handle object at 0x7f0bfb6a3a00>,)

and my configure file is:

accelerate launch --main_process_port $MASTER_PORT --main_process_ip $MASTER_ADDR --config_file fsdp_configs/fsdp_8node_debug_joe.yaml --machine_rank $RANK main/train_sd.py \ --generator_lr 5e-7 \ --guidance_lr 5e-7 \ --train_iters 100000000 \ --output_path $CHECKPOINT_PATH/sdxl_cond999_8node_lr5e-7_denoising4step_diffusion1000_gan5e-3_guidance8_noinit_noode_backsim_scratch \ --batch_size 2 \ --grid_size 2 \ --initialie_generator --log_iters 1000 \ --resolution 1024 \ --latent_resolution 128 \ --seed 10 \ --real_guidance_scale 8 \ --fake_guidance_scale 1.0 \ --max_grad_norm 10.0 \ --model_id "/mnt/nas/gaohl/models/models--stabilityai--stable-diffusion-xl-base-1.0/snapshots/462165984030d82259a11f4367a4eed129e94a7b/" \ --wandb_iters 100 \ --wandb_entity $WANDB_ENTITY \ --wandb_project $WANDB_PROJECT \ --wandb_name "sdxl_cond999_8node_lr5e-7_denoising4step_diffusion1000_gan5e-3_guidance8_noinit_noode_backsim_scratch" \ --log_loss \ --dfake_gen_update_ratio 5 \ --fsdp \ --sdxl \ --use_fp16 \ --max_step_percent 0.98 \ --cls_on_clean_image \ --gen_cls_loss \ --gen_cls_loss_weight 5e-3 \ --guidance_cls_loss_weight 1e-2 \ --diffusion_gan \ --diffusion_gan_max_timestep 1000 \ --denoising \ --num_denoising_step 4 \ --denoising_timestep 1000 \ --backward_simulation \ --train_prompt_path $CHECKPOINT_PATH/captions_laion_score6.25.pkl \ --real_image_path $CHECKPOINT_PATH/sdxl_vae_latents_laion_500k_lmdb/

fsdp_8node_debug_joe.yaml:
compute_environment: LOCAL_MACHINE debug: true distributed_type: FSDP downcast_bf16: 'no' fsdp_config: fsdp_auto_wrap_policy: SIZE_BASED_WRAP fsdp_backward_prefetch_policy: BACKWARD_PRE fsdp_forward_prefetch: false fsdp_min_num_params: 3000 fsdp_offload_params: false fsdp_sharding_strategy: 1 fsdp_state_dict_type: SHARDED_STATE_DICT fsdp_sync_module_states: true fsdp_use_orig_params: false machine_rank: 0 main_training_function: main mixed_precision: 'no' num_machines: 1 num_processes: 8 rdzv_backend: static same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false

@tianweiy
Copy link
Owner

I have no clue how to solve this error.

But the most likely cause for the error is mismatched torch and accelerate version. Could you double check that the torch and accelerate version match the one in the README ?

Other versions simply doesn't work unfortunately...

also related to #25 (comment)

@fire2323
Copy link

@tianweiy, do you use TorchDynamo in torch fsdp, because I found it might be related to torchdynamo compiling.

@tianweiy
Copy link
Owner

I didn't use TorchDynamo.

@BeBuBu
Copy link

BeBuBu commented Aug 12, 2024

Is there any new progress on this question? My version of accelerate and torch is as follows:
accelerate 0.25.0
pytorch 2.1.2 py3.8_cuda11.8_cudnn8.7.0_0

@tianweiy
Copy link
Owner

tianweiy commented Aug 12, 2024 via email

@joelulu
Copy link
Author

joelulu commented Aug 12, 2024

I have no clue how to solve this error.

But the most likely cause for the error is mismatched torch and accelerate version. Could you double check that the torch and accelerate version match the one in the README ?

Other versions simply doesn't work unfortunately...

also related to #25 (comment)

Thanks a lot. I solved my problem

@BeBuBu
Copy link

BeBuBu commented Aug 12, 2024

Please just try the one specified in the readme first. Other versions are not tested and likely just don't work

On Sun, Aug 11, 2024, 10:54 PM Lijian @.> wrote: Is there any new progress on this question? My version of accelerate and torch is as follows: accelerate 0.25.0 pytorch 2.1.2 py3.8_cuda11.8_cudnn8.7.0_0 — Reply to this email directly, view it on GitHub <#41 (comment)>, or unsubscribe https://github.com/notifications/unsubscribe-auth/AJFWY3T3AX2BXFYKEK7GTCDZRBE2LAVCNFSM6AAAAABMIEHMSOVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEOBTGE2TQOJWGA . You are receiving this because you were mentioned.Message ID: @.>

Thank you. Problem solved

@joelulu
Copy link
Author

joelulu commented Aug 20, 2024

Does DMD2 have plans to support StableCascade distillation ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants