Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][Verl] Add entropy loss to cross_entropy_loss and fused_linear_cross_entropy_loss #551

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/lightning/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def formatting_func(self, example):
for i in range(len(example["question"])):
choices = ""
for j in range(len(example["choices"][i])):
choices += f"{j+1}. {example['choices'][i][j]}; "
choices += f"{j + 1}. {example['choices'][i][j]}; "
s = "Below is a question and multiple choice answers, choices separated by a semicolon. Please select the best answer for the question. "
s += f"{QUESTION}{example['question'][i]} "
s += f"{CHOICES}{choices} "
Expand Down
6 changes: 3 additions & 3 deletions examples/medusa/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,9 +352,9 @@ def _get_effective_num_gpus():
else:
return world_size

assert (
world_size != 0
), "WORLD_SIZE should be set to a positive integer. For single GPU training, please explicitly set WORLD_SIZE=1."
assert world_size != 0, (
"WORLD_SIZE should be set to a positive integer. For single GPU training, please explicitly set WORLD_SIZE=1."
)

# TODO: add deepspeed support
return world_size
Expand Down
105 changes: 89 additions & 16 deletions src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@
@triton.jit
def liger_cross_entropy_kernel(
X_ptr,
dX_entropy_ptr,
X_stride,
Y_ptr,
Y_stride,
weight_ptr,
loss_ptr,
z_loss_ptr,
entropy_loss_ptr,
loss_stride,
n_cols,
n_non_ignore,
Expand All @@ -41,6 +43,7 @@ def liger_cross_entropy_kernel(
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
softcap,
RETURN_Z_LOSS: tl.constexpr,
RETURN_ENTROPY_LOSS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_SOFTCAPPING: tl.constexpr,
Expand All @@ -51,6 +54,7 @@ def liger_cross_entropy_kernel(

Parameters:
X_ptr: Pointer to input tensor.
dX_entropy_ptr: Pointer to tensor to store the gradient of the input w.r.s. to the entropy loss
X_stride (int): The stride of the input tensor.
Y_ptr: Pointer to target tensor.
Y_stride (int): The stride of the target tensor.
Expand All @@ -68,6 +72,7 @@ def liger_cross_entropy_kernel(
reduction (str): The string for the reduction to apply
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
RETURN_ENTROPY_LOSS (int): The boolean value to decide whether storing entropy loss to entropy_loss_ptr or not. It must be 0 or 1.
BLOCK_SIZE (int): The block size for Triton operations.
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
Expand All @@ -94,6 +99,9 @@ def liger_cross_entropy_kernel(
loss_ptr += program_id * loss_stride
if RETURN_Z_LOSS:
z_loss_ptr += program_id * loss_stride
if RETURN_ENTROPY_LOSS:
entropy_loss_ptr += program_id * loss_stride
dX_entropy_ptr += program_id * X_stride

if HAS_WEIGHT:
weight_y = tl.load(weight_ptr + y).cast(tl.float32)
Expand All @@ -104,6 +112,7 @@ def liger_cross_entropy_kernel(
# 3. [Online softmax] first pass: find max + sum
m = float("-inf") # m is the max value. use the notation from the paper
d = 0.0 # d is the sum. use the notation from the paper
entropy_loss = 0.0 # entropy loss
ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
if HAS_SOFTCAPPING:
ori_X_y = softcap * tanh(ori_X_y / softcap)
Expand Down Expand Up @@ -167,9 +176,22 @@ def liger_cross_entropy_kernel(
intermediate = tanh(X_block / softcap)
X_block = softcap * intermediate

# load the derivatives of the entropy loss
if RETURN_ENTROPY_LOSS:
dX_entropy_block = tl.load(
dX_entropy_ptr + X_offsets,
mask=X_offsets < n_cols,
other=0.0,
)

if not HAS_WEIGHT:
# softmax(x_i)
X_block = tl.exp(X_block - m) / d
if RETURN_ENTROPY_LOSS:
# entropy loss term
entropy_loss += tl.sum(-X_block * tl.log(X_block))
# derivatives of the entropy loss term
dX_entropy_block += -(tl.log(X_block) + 1)
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
X_block += 2 * lse_square_scale * lse * X_block
# smoothing term
Expand All @@ -182,6 +204,11 @@ def liger_cross_entropy_kernel(
else:
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
softmax_X = tl.exp(X_block - m) / d
if RETURN_ENTROPY_LOSS:
# entropy loss term
entropy_loss += tl.sum(-softmax_X * tl.log(softmax_X))
# derititive of the entropy loss
dX_entropy_block = -(tl.log(softmax_X) + 1) * weight_block
# derivative of original_loss
dloss_ori = (1 - label_smoothing) * softmax_X
# specially handle dx_y
Expand All @@ -206,6 +233,8 @@ def liger_cross_entropy_kernel(
X_block = X_block * (1 - intermediate * intermediate)

tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
if RETURN_ENTROPY_LOSS:
tl.store(dX_entropy_ptr + X_offsets, dX_entropy_block, mask=X_offsets < n_cols)

# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
Expand Down Expand Up @@ -248,11 +277,16 @@ def liger_cross_entropy_kernel(
loss = loss / n_non_ignore
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
z_loss = z_loss / n_non_ignore
# TODO: Implement weighted entropy loss. Currently, entropy loss is not scaled by weight.
entropy_loss = entropy_loss / n_non_ignore

loss += z_loss

tl.store(loss_ptr, loss)
if RETURN_Z_LOSS:
tl.store(z_loss_ptr, z_loss)
if RETURN_ENTROPY_LOSS:
tl.store(entropy_loss_ptr, entropy_loss)


# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
Expand All @@ -271,27 +305,32 @@ def cross_entropy_forward(
reduction,
softcap,
return_z_loss,
return_entropy_loss,
):
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
assert isinstance(return_entropy_loss, bool), (
f"return_entropy_loss must be True or False. Got: {return_entropy_loss}"
)

BT, V = _input.shape
n_rows = BT

BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))

# unreduced loss
dX_entropy_2d = torch.zeros(n_rows, V, dtype=_input.dtype, device=_input.device) if return_entropy_loss else None
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None

entropy_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_entropy_loss else None
target_mask = target != ignore_index
n_non_ignore = target_mask.sum().item()
sum_non_ignore_weight = n_non_ignore
weight_sum = 0.0
if weight is not None:
assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}"
assert torch.is_floating_point(
weight
), f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
assert torch.is_floating_point(weight), (
f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
)
sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item()
weight_sum = weight.sum().item()
# ensure weight is contiguous
Expand All @@ -307,12 +346,14 @@ def cross_entropy_forward(
# Here we use a trick to store X_ptr gradient in X_ptr so we can save memory
liger_cross_entropy_kernel[(n_rows,)](
X_ptr=_input,
dX_entropy_ptr=dX_entropy_2d,
X_stride=_input.stride(-2),
Y_ptr=target,
Y_stride=target.stride(-1), # always 1
weight_ptr=weight, # dummy if None
loss_ptr=loss_1d,
z_loss_ptr=z_loss_1d,
entropy_loss_ptr=entropy_loss_1d,
loss_stride=loss_1d.stride(-1), # always 1
n_cols=V,
n_non_ignore=n_non_ignore,
Expand All @@ -324,6 +365,7 @@ def cross_entropy_forward(
reduction=reduction,
softcap=softcap,
RETURN_Z_LOSS=return_z_loss,
RETURN_ENTROPY_LOSS=return_entropy_loss,
BLOCK_SIZE=BLOCK_SIZE,
HAS_WEIGHT=True if weight is not None else False,
HAS_SOFTCAPPING=True if softcap is not None else False,
Expand All @@ -335,25 +377,27 @@ def cross_entropy_forward(
if reduction == "none":
loss = loss_1d
z_loss = z_loss_1d if return_z_loss else None
entropy_loss = entropy_loss_1d if return_entropy_loss else None
else:
loss = torch.sum(loss_1d)
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
entropy_loss = torch.sum(entropy_loss_1d) if return_entropy_loss else None

return loss, z_loss, _input
return loss, z_loss, entropy_loss, _input, dX_entropy_2d


def cross_entropy_backward(_input, grad_output):
def cross_entropy_backward(_input, dX_entropy_2d, grad_output, grad_output_entropy):
BT, V = _input.shape
n_rows = BT
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))

# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
pass

# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
else:
BT, V = _input.shape
n_rows = BT
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))

element_mul_kernel[(n_rows,)](
_input,
_input.stride(-2),
Expand All @@ -363,6 +407,18 @@ def cross_entropy_backward(_input, grad_output):
num_warps=32 if not is_hip() else 16,
)

# calculate the gradient of the input w.r.s. to the entropy loss
if dX_entropy_2d is not None:
element_mul_kernel[(n_rows,)](
dX_entropy_2d,
dX_entropy_2d.stride(-2),
grad_output_entropy,
V,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32 if not is_hip() else 16,
)
_input += dX_entropy_2d

return _input


Expand All @@ -384,6 +440,7 @@ def forward(
reduction: str = "mean",
softcap: Optional[float] = None,
return_z_loss: bool = False,
return_entropy_loss: bool = False,
):
"""
The forward pass of the Liger Cross Entropy loss.
Expand All @@ -403,7 +460,7 @@ def forward(
Returns:
tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
"""
loss, z_loss, _input = cross_entropy_forward(
loss, z_loss, entropy_loss, _input, dX_entropy_2d = cross_entropy_forward(
_input,
target,
weight,
Expand All @@ -413,32 +470,47 @@ def forward(
reduction,
softcap,
return_z_loss,
return_entropy_loss,
)
# TODO: investigation
# If we don't detach the _input tensor, the memory will double
# Not sure why but seems that there will be a time both grad and value exist but in different location
ctx.save_for_backward(_input.detach())
if return_entropy_loss:
ctx.save_for_backward(_input.detach(), dX_entropy_2d.detach())
else:
ctx.save_for_backward(_input.detach())
ctx.return_z_loss = return_z_loss
ctx.return_entropy_loss = return_entropy_loss

return loss, z_loss
return loss, z_loss, entropy_loss

@staticmethod
def backward(ctx, grad_output, grad_ouput2):
def backward(ctx, grad_output, grad_ouput2, grad_ouput3):
"""
The backward pass of the Liger Cross Entropy loss.

Parameters:
ctx : The context object with saved tensors.
grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
grad_output2 (tenosr): No use.
grad_output3 (tenosr): The tensor containing the gradient of the loss with respect to the entropy loss.
Returns:
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
"""
if ctx.return_z_loss:
del grad_ouput2 # z_loss is only for logging

(_input,) = ctx.saved_tensors
_input = cross_entropy_backward(_input, grad_output)
if ctx.return_entropy_loss:
(_input, dX_entropy_2d) = ctx.saved_tensors
else:
(_input,), dX_entropy_2d = ctx.saved_tensors, None

_input = cross_entropy_backward(_input, dX_entropy_2d, grad_output, grad_ouput3)

# delete the tensors that are not used in remaining steps
del grad_ouput3
del dX_entropy_2d

return (
_input,
None,
Expand All @@ -449,4 +521,5 @@ def backward(ctx, grad_output, grad_ouput2):
None,
None,
None,
None,
)
Loading
Loading