-
-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Fused CEL #457
base: main
Are you sure you want to change the base?
[WIP] Fused CEL #457
Conversation
Thanks Jerome and fantastic work - will check this! |
This is very much a WIP. The PR description isn't rendering nicely -- here's a more readable version (located in the test folder). Aside from integrating with unsloth's custom language models, need to check how kernel interplays with grad accum and grad checkpointing. Also will likely need to rebase to tidy up the commit history. |
Additional benchmarking for Benchmark config: {
"max_steps": 50,
"dtype": "bfloat16",
"model_id": "meta-llama/Meta-Llama-3-8B",
"batch_size": 2,
"max_seq_len": 512,
"packing": true,
"grad_accum_steps": 1,
"load_in_4bit": true,
"use_lora": true,
"fused_cel_n_loop_iters": [
1,
2,
4,
8
]
} QLoRA config:
Some observations:
|
@jeromeku Thanks for testing again! Hmm weird on the training loss being noticeable higher hmmm that is really really weird I can understand why the VRAM reduction are less pronounced, but unsure on the training loss issues |
Trying to figure out why the divergence when using When using |
Efficient Fused Cross Entropy Loss
Memory-efficient cross entropy implementation that only materializes the derivatives of the language modeling head layer without storing the logits and chunks the computation of the logits such that the full logits tensor is never realized.
This is a direct adaptation of this repo.
Contents
Overview
In short:
dX
hereafter), and the derivative with respect to the logits projection weights (dW
hereafter) are computed in chunksSee the original repo for an excellent explanation of the design.
Changes
The following changes were made to the original kernel:
3-D
language modeling tensors with the required shapes of the kernel.loss
tofloat32
, which in the original kernel was initialized to the autocasted / in-feat dtype.torch.cuda.amp.{custom_fwd,custom_bwd}
to theautograd.Function
.All changes are enumerated in
unsloth/kernels/fused_cel.py
.Additionally, adapter layers and configs in
fused_cel.py
enable integration withtransformers
andunsloth
.Tests
See
tests/test_CEL.py
for correctness checks.The comments in the tests describe numerical edge cases.
Benchmarks
Following are results from preliminary benchmarking / testing on a
L4
NVIDIA GPU for a smallllama-like
model with and without thefused CEL
layer.The takeaway is that the memory efficiency claims of the original
repo
are evident, with overall memory usage lower, decreasing linearly with the number of loop iterations.Can be reproduced by passing the provided options to
benchmark_hf_test_cel.py
(run with--help
to see all options).Below is the overall config, followed by
training losses
/grad norms
and overalltraining metrics
forfloat32
andbfloat16
.Test config
:max_steps=50
model_id=hf-internal-testing/tiny-random-LlamaForCausalLM
batch_size=2
max_seq_len=256
packing=True
grad_accum_steps=1
load_in_4bit=False
use_lora=False
fused_cel_n_loop_iters=[1, 2, 4]
float32
Training metrics
forfloat32
:bfloat16
Training metrics
forbfloat16
Next Steps
FastLanguageModel
LoRA
andQLoRA
configs