From 05f0edba2f328e9161173e433eaebcafd7e455ab Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Thu, 30 Jan 2025 02:52:59 +0000 Subject: [PATCH 01/32] run make checkstyle Signed-off-by: Hongpeng Guo --- examples/lightning/training.py | 2 +- examples/medusa/callback.py | 6 +-- src/liger_kernel/ops/cross_entropy.py | 6 +-- .../ops/fused_linear_cross_entropy.py | 6 +-- src/liger_kernel/ops/fused_linear_jsd.py | 6 +-- src/liger_kernel/ops/jsd.py | 6 +-- src/liger_kernel/ops/layer_norm.py | 6 +-- src/liger_kernel/ops/utils.py | 3 +- .../transformers/cross_entropy.py | 6 +-- .../fused_linear_cross_entropy.py | 6 +-- src/liger_kernel/transformers/group_norm.py | 12 ++--- src/liger_kernel/transformers/monkey_patch.py | 54 +++++++++---------- test/transformers/test_monkey_patch.py | 6 +-- test/triton/test_triton_monkey_patch.py | 6 +-- 14 files changed, 65 insertions(+), 66 deletions(-) diff --git a/examples/lightning/training.py b/examples/lightning/training.py index 8e58d8b11..ab1164878 100644 --- a/examples/lightning/training.py +++ b/examples/lightning/training.py @@ -158,7 +158,7 @@ def formatting_func(self, example): for i in range(len(example["question"])): choices = "" for j in range(len(example["choices"][i])): - choices += f"{j+1}. {example['choices'][i][j]}; " + choices += f"{j + 1}. {example['choices'][i][j]}; " s = "Below is a question and multiple choice answers, choices separated by a semicolon. Please select the best answer for the question. " s += f"{QUESTION}{example['question'][i]} " s += f"{CHOICES}{choices} " diff --git a/examples/medusa/callback.py b/examples/medusa/callback.py index 33a9d1946..673243b77 100644 --- a/examples/medusa/callback.py +++ b/examples/medusa/callback.py @@ -352,9 +352,9 @@ def _get_effective_num_gpus(): else: return world_size - assert ( - world_size != 0 - ), "WORLD_SIZE should be set to a positive integer. For single GPU training, please explicitly set WORLD_SIZE=1." + assert world_size != 0, ( + "WORLD_SIZE should be set to a positive integer. For single GPU training, please explicitly set WORLD_SIZE=1." + ) # TODO: add deepspeed support return world_size diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 9e4ab69e8..62ee98434 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -289,9 +289,9 @@ def cross_entropy_forward( weight_sum = 0.0 if weight is not None: assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}" - assert torch.is_floating_point( - weight - ), f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}" + assert torch.is_floating_point(weight), ( + f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}" + ) sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item() weight_sum = weight.sum().item() # ensure weight is contiguous diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index 4e54db473..267763560 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -58,9 +58,9 @@ def fused_linear_cross_entropy_forward( ce_weight_sum = 0.0 if ce_weight is not None: assert ce_weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}" - assert torch.is_floating_point( - ce_weight - ), f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}" + assert torch.is_floating_point(ce_weight), ( + f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}" + ) total_sum_non_ignore_ce_weight = ( torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)).sum().item() ) diff --git a/src/liger_kernel/ops/fused_linear_jsd.py b/src/liger_kernel/ops/fused_linear_jsd.py index f0c4d7bea..55c992f05 100644 --- a/src/liger_kernel/ops/fused_linear_jsd.py +++ b/src/liger_kernel/ops/fused_linear_jsd.py @@ -195,9 +195,9 @@ def forward( """ has_label = False if shift_labels is not None: - assert shift_labels.shape == ( - teacher_input.shape[0], - ), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}" + assert shift_labels.shape == (teacher_input.shape[0],), ( + f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}" + ) shift_labels = shift_labels.contiguous() has_label = True diff --git a/src/liger_kernel/ops/jsd.py b/src/liger_kernel/ops/jsd.py index 5b6fc5219..882b4099b 100644 --- a/src/liger_kernel/ops/jsd.py +++ b/src/liger_kernel/ops/jsd.py @@ -157,9 +157,9 @@ def forward( """ has_label = False if shift_labels is not None: - assert shift_labels.shape == ( - _input.shape[0], - ), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}" + assert shift_labels.shape == (_input.shape[0],), ( + f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}" + ) shift_labels = shift_labels.contiguous() has_label = True diff --git a/src/liger_kernel/ops/layer_norm.py b/src/liger_kernel/ops/layer_norm.py index 6d527c7ee..4dc97ea29 100644 --- a/src/liger_kernel/ops/layer_norm.py +++ b/src/liger_kernel/ops/layer_norm.py @@ -147,9 +147,9 @@ def layer_norm_forward(X, W, B, eps): Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device) RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device) - assert ( - X.shape[1] == W.shape[0] - ), f"Incompatible hidden size dimension between input tensor with shape[1] = {X.shape[1]} and weight tensor with shape[0] = {W.shape[0]}" + assert X.shape[1] == W.shape[0], ( + f"Incompatible hidden size dimension between input tensor with shape[1] = {X.shape[1]} and weight tensor with shape[0] = {W.shape[0]}" + ) _layer_norm_forward_kernel[(n_rows,)]( Y, diff --git a/src/liger_kernel/ops/utils.py b/src/liger_kernel/ops/utils.py index 8a15bf8d8..736990b89 100644 --- a/src/liger_kernel/ops/utils.py +++ b/src/liger_kernel/ops/utils.py @@ -49,8 +49,7 @@ def calculate_settings(n): BLOCK_SIZE = triton.next_power_of_2(n) if BLOCK_SIZE > MAX_FUSED_SIZE: raise RuntimeError( - f"Cannot launch Triton kernel since n = {n} exceeds " - f"the recommended Triton blocksize = {MAX_FUSED_SIZE}." + f"Cannot launch Triton kernel since n = {n} exceeds the recommended Triton blocksize = {MAX_FUSED_SIZE}." ) num_warps = 4 diff --git a/src/liger_kernel/transformers/cross_entropy.py b/src/liger_kernel/transformers/cross_entropy.py index ac522c2e0..f01b1f57e 100644 --- a/src/liger_kernel/transformers/cross_entropy.py +++ b/src/liger_kernel/transformers/cross_entropy.py @@ -17,9 +17,9 @@ def __init__( return_z_loss: bool = False, ): super().__init__() - assert (label_smoothing >= 0) and ( - label_smoothing <= 1 - ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}" + assert (label_smoothing >= 0) and (label_smoothing <= 1), ( + f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}" + ) assert reduction in { "mean", "sum", diff --git a/src/liger_kernel/transformers/fused_linear_cross_entropy.py b/src/liger_kernel/transformers/fused_linear_cross_entropy.py index d959a0f0b..6a9f19d7f 100644 --- a/src/liger_kernel/transformers/fused_linear_cross_entropy.py +++ b/src/liger_kernel/transformers/fused_linear_cross_entropy.py @@ -17,9 +17,9 @@ def __init__( return_z_loss: bool = False, ): super().__init__() - assert (label_smoothing >= 0) and ( - label_smoothing <= 1 - ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}" + assert (label_smoothing >= 0) and (label_smoothing <= 1), ( + f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}" + ) assert reduction in { "mean", "sum", diff --git a/src/liger_kernel/transformers/group_norm.py b/src/liger_kernel/transformers/group_norm.py index ca3d314e2..22a48a843 100644 --- a/src/liger_kernel/transformers/group_norm.py +++ b/src/liger_kernel/transformers/group_norm.py @@ -21,9 +21,9 @@ def __init__(self, num_channels, num_groups, eps=1e-6, bias=False, init_fn="ones "zeros", ], f"init_fn must be either 'ones' or 'zeros', got {init_fn}" - assert ( - num_channels % num_groups == 0 - ), f"Number of channels {num_channels} must be divisible by num_groups {num_groups}" + assert num_channels % num_groups == 0, ( + f"Number of channels {num_channels} must be divisible by num_groups {num_groups}" + ) self.num_channels = num_channels self.num_groups = num_groups self.eps = eps @@ -34,9 +34,9 @@ def __init__(self, num_channels, num_groups, eps=1e-6, bias=False, init_fn="ones def forward(self, hidden_states): # hidden_states: (batch_size, num_channels, *) assert hidden_states.dim() >= 3, f"Input must have atleast 3 dimensions, got {hidden_states.dim()}" - assert ( - hidden_states.size(1) == self.num_channels - ), f"Input tensor must have {self.num_channels} channels, got {hidden_states.size(1)}" + assert hidden_states.size(1) == self.num_channels, ( + f"Input tensor must have {self.num_channels} channels, got {hidden_states.size(1)}" + ) return LigerGroupNormFunction.apply( hidden_states, self.weight, diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index eafce145e..c23c9757e 100644 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -85,9 +85,9 @@ def apply_liger_kernel_to_llama( loaded. Default is None. """ - assert not ( - cross_entropy and fused_linear_cross_entropy - ), "cross_entropy and fused_linear_cross_entropy cannot both be True." + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) from transformers.models.llama import modeling_llama from transformers.models.llama.modeling_llama import LlamaModel @@ -159,9 +159,9 @@ def apply_liger_kernel_to_mllama( loaded. Default is None. """ - assert not ( - cross_entropy and fused_linear_cross_entropy - ), "cross_entropy and fused_linear_cross_entropy cannot both be True." + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) from transformers.models.mllama import modeling_mllama from transformers.models.mllama.modeling_mllama import MllamaForCausalLM @@ -261,9 +261,9 @@ def apply_liger_kernel_to_mistral( model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been loaded. Default is None. """ - assert not ( - cross_entropy and fused_linear_cross_entropy - ), "cross_entropy and fused_linear_cross_entropy cannot both be True." + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) from transformers.models.mistral import modeling_mistral from transformers.models.mistral.modeling_mistral import MistralModel @@ -321,9 +321,9 @@ def apply_liger_kernel_to_mixtral( loaded. Default is None. """ - assert not ( - cross_entropy and fused_linear_cross_entropy - ), "cross_entropy and fused_linear_cross_entropy cannot both be True." + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) from transformers.models.mixtral import modeling_mixtral from transformers.models.mixtral.modeling_mixtral import MixtralModel @@ -393,9 +393,9 @@ def apply_liger_kernel_to_gemma( model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been loaded. Default is None. """ - assert not ( - cross_entropy and fused_linear_cross_entropy - ), "cross_entropy and fused_linear_cross_entropy cannot both be True." + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) from transformers.models.gemma import modeling_gemma from transformers.models.gemma.modeling_gemma import GemmaModel @@ -467,9 +467,9 @@ def apply_liger_kernel_to_gemma2( model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been loaded. Default is None. """ - assert not ( - cross_entropy and fused_linear_cross_entropy - ), "cross_entropy and fused_linear_cross_entropy cannot both be True." + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) from transformers.models.gemma2 import modeling_gemma2 from transformers.models.gemma2.modeling_gemma2 import Gemma2Model @@ -544,9 +544,9 @@ def apply_liger_kernel_to_qwen2( model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been loaded. Default is None. """ - assert not ( - cross_entropy and fused_linear_cross_entropy - ), "cross_entropy and fused_linear_cross_entropy cannot both be True." + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) from transformers.models.qwen2 import modeling_qwen2 from transformers.models.qwen2.modeling_qwen2 import Qwen2Model @@ -619,9 +619,9 @@ def apply_liger_kernel_to_qwen2_vl( model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been loaded. Default is None. """ - assert not ( - cross_entropy and fused_linear_cross_entropy - ), "cross_entropy and fused_linear_cross_entropy cannot both be True." + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) from transformers.models.qwen2_vl import modeling_qwen2_vl from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel @@ -689,9 +689,9 @@ def apply_liger_kernel_to_phi3( model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been loaded. Default is None. """ - assert not ( - cross_entropy and fused_linear_cross_entropy - ), "cross_entropy and fused_linear_cross_entropy cannot both be True." + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) from transformers.models.phi3 import modeling_phi3 from transformers.models.phi3.modeling_phi3 import Phi3Model diff --git a/test/transformers/test_monkey_patch.py b/test/transformers/test_monkey_patch.py index 811cd74cc..76e937de1 100644 --- a/test/transformers/test_monkey_patch.py +++ b/test/transformers/test_monkey_patch.py @@ -216,9 +216,9 @@ def test_patching_apis_support_patching_model_instance(): for func in patching_functions: sig = inspect.signature(func) # Ensure 'model' is in the parameters - assert ( - "model" in sig.parameters - ), f"{func.__name__} does not have 'model' as an argument. All patching methods must support patching an existing model instance." + assert "model" in sig.parameters, ( + f"{func.__name__} does not have 'model' as an argument. All patching methods must support patching an existing model instance." + ) def test_apply_liger_kernel_to_instance_for_llama(): diff --git a/test/triton/test_triton_monkey_patch.py b/test/triton/test_triton_monkey_patch.py index f031c2af0..eeb8c173d 100644 --- a/test/triton/test_triton_monkey_patch.py +++ b/test/triton/test_triton_monkey_patch.py @@ -24,6 +24,6 @@ def test_import_custom_cache_manager(): cache_manager = get_cache_manager(key=random_hex_key) from liger_kernel.triton.monkey_patch import LigerTritonFileCacheManager - assert isinstance( - cache_manager, LigerTritonFileCacheManager - ), "Cache manager should have been LigerTritonFileCacheManager" + assert isinstance(cache_manager, LigerTritonFileCacheManager), ( + "Cache manager should have been LigerTritonFileCacheManager" + ) From 6a26dbbd0fb46e9b4d6071d548276a7781913a5d Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Thu, 30 Jan 2025 01:28:44 +0000 Subject: [PATCH 02/32] wip initial try test existing unitest Signed-off-by: Hongpeng Guo --- src/liger_kernel/ops/cross_entropy.py | 95 ++++++++--- .../ops/fused_linear_cross_entropy.py | 152 ++++++++++++------ .../transformers/cross_entropy.py | 14 +- src/liger_kernel/transformers/functional.py | 24 ++- .../fused_linear_cross_entropy.py | 15 +- 5 files changed, 223 insertions(+), 77 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 62ee98434..e06cfc376 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -24,12 +24,14 @@ @triton.jit def liger_cross_entropy_kernel( X_ptr, + dX_entropy_ptr, X_stride, Y_ptr, Y_stride, weight_ptr, loss_ptr, z_loss_ptr, + entropy_loss_ptr, loss_stride, n_cols, n_non_ignore, @@ -41,6 +43,7 @@ def liger_cross_entropy_kernel( reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time softcap, RETURN_Z_LOSS: tl.constexpr, + RETURN_ENTROPY_LOSS: tl.constexpr, BLOCK_SIZE: tl.constexpr, HAS_WEIGHT: tl.constexpr, HAS_SOFTCAPPING: tl.constexpr, @@ -51,6 +54,7 @@ def liger_cross_entropy_kernel( Parameters: X_ptr: Pointer to input tensor. + dX_entropy_ptr: Pointer to tensor to store the gradient of the input w.r.s. to the entropy loss X_stride (int): The stride of the input tensor. Y_ptr: Pointer to target tensor. Y_stride (int): The stride of the target tensor. @@ -68,6 +72,7 @@ def liger_cross_entropy_kernel( reduction (str): The string for the reduction to apply softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap). RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1. + RETURN_ENTROPY_LOSS (int): The boolean value to decide whether storing entropy loss to entropy_loss_ptr or not. It must be 0 or 1. BLOCK_SIZE (int): The block size for Triton operations. HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes. HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not. @@ -94,7 +99,10 @@ def liger_cross_entropy_kernel( loss_ptr += program_id * loss_stride if RETURN_Z_LOSS: z_loss_ptr += program_id * loss_stride - + if RETURN_ENTROPY_LOSS: + entropy_loss_ptr += program_id * loss_stride + dX_entropy_ptr += program_id * X_stride + if HAS_WEIGHT: weight_y = tl.load(weight_ptr + y).cast(tl.float32) @@ -104,6 +112,7 @@ def liger_cross_entropy_kernel( # 3. [Online softmax] first pass: find max + sum m = float("-inf") # m is the max value. use the notation from the paper d = 0.0 # d is the sum. use the notation from the paper + entropy_loss = 0.0 # entropy loss ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation if HAS_SOFTCAPPING: ori_X_y = softcap * tanh(ori_X_y / softcap) @@ -166,10 +175,23 @@ def liger_cross_entropy_kernel( if HAS_SOFTCAPPING: intermediate = tanh(X_block / softcap) X_block = softcap * intermediate + + # load the derivatives of the entropy loss + if RETURN_ENTROPY_LOSS: + dX_entropy_block = tl.load( + dX_entropy_ptr + X_offsets, + mask=X_offsets < n_cols, + other=0.0, + ) if not HAS_WEIGHT: # softmax(x_i) X_block = tl.exp(X_block - m) / d + if RETURN_ENTROPY_LOSS: + # entropy loss term + entropy_loss += tl.sum(-X_block * tl.log(X_block)) + # derivatives of the entropy loss term + dX_entropy_block += -(tl.log(X_block) + 1) # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i) X_block += 2 * lse_square_scale * lse * X_block # smoothing term @@ -182,6 +204,11 @@ def liger_cross_entropy_kernel( else: weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) softmax_X = tl.exp(X_block - m) / d + if RETURN_ENTROPY_LOSS: + # entropy loss term + entropy_loss += tl.sum(-softmax_X * tl.log(softmax_X)) + # derititive of the entropy loss + dX_entropy_block = - (tl.log(softmax_X) + 1) * weight_block # derivative of original_loss dloss_ori = (1 - label_smoothing) * softmax_X # specially handle dx_y @@ -197,15 +224,18 @@ def liger_cross_entropy_kernel( dloss_smooth = dloss_smooth / sum_non_ignore_weight # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. dz_loss = dz_loss / n_non_ignore + dloss_entropy = dloss_entropy / n_non_ignore # derivative of total_loss X_block = dloss_ori + dloss_smooth + dz_loss # chain rule softcapping # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap)) if HAS_SOFTCAPPING: - X_block = X_block * (1 - intermediate * intermediate) + X_block = X_block * (1 - intermediate * intermediate) tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols) + if RETURN_ENTROPY_LOSS: + tl.store(dX_entropy_ptr + X_offsets, dX_entropy_block, mask=X_offsets < n_cols) # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34 @@ -248,12 +278,16 @@ def liger_cross_entropy_kernel( loss = loss / n_non_ignore # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. z_loss = z_loss / n_non_ignore + # TODO: Implement weighted entropy loss. Currently, entropy loss is not scaled by weight. + entropy_loss = entropy_loss / n_non_ignore + loss += z_loss tl.store(loss_ptr, loss) if RETURN_Z_LOSS: tl.store(z_loss_ptr, z_loss) - + if RETURN_ENTROPY_LOSS: + tl.store(entropy_loss_ptr, entropy_loss) # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling @@ -271,8 +305,10 @@ def cross_entropy_forward( reduction, softcap, return_z_loss, + return_entropy_loss, ): assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}" + assert isinstance(return_entropy_loss, bool), f"return_entropy_loss must be True or False. Got: {return_entropy_loss}" BT, V = _input.shape n_rows = BT @@ -280,9 +316,10 @@ def cross_entropy_forward( BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) # unreduced loss + dX_entropy_2d = torch.zeros(n_rows, V, dtype=_input.dtype, device=_input.device) if return_entropy_loss else None loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None - + entropy_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_entropy_loss else None target_mask = target != ignore_index n_non_ignore = target_mask.sum().item() sum_non_ignore_weight = n_non_ignore @@ -307,12 +344,14 @@ def cross_entropy_forward( # Here we use a trick to store X_ptr gradient in X_ptr so we can save memory liger_cross_entropy_kernel[(n_rows,)]( X_ptr=_input, + dX_entropy_ptr=dX_entropy_2d, X_stride=_input.stride(-2), Y_ptr=target, Y_stride=target.stride(-1), # always 1 weight_ptr=weight, # dummy if None loss_ptr=loss_1d, z_loss_ptr=z_loss_1d, + entropy_loss_ptr=entropy_loss_1d, loss_stride=loss_1d.stride(-1), # always 1 n_cols=V, n_non_ignore=n_non_ignore, @@ -335,14 +374,29 @@ def cross_entropy_forward( if reduction == "none": loss = loss_1d z_loss = z_loss_1d if return_z_loss else None + entropy_loss = entropy_loss_1d if return_entropy_loss else None else: loss = torch.sum(loss_1d) z_loss = torch.sum(z_loss_1d) if return_z_loss else None + entropy_loss = torch.sum(entropy_loss_1d) if return_entropy_loss else None - return loss, z_loss, _input + return loss, z_loss, entropy_loss, _input, dX_entropy_2d -def cross_entropy_backward(_input, grad_output): +def cross_entropy_backward(_input, dX_entropy_2d, grad_output, grad_output_entropy): + # calculate the gradient of the input w.r.s. to the entropy loss + BT, V = _input.shape + n_rows = BT + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + element_mul_kernel[(n_rows,)]( + dX_entropy_2d, + dX_entropy_2d.stride(-2), + grad_output_entropy, + V, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): pass @@ -350,10 +404,6 @@ def cross_entropy_backward(_input, grad_output): # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. else: - BT, V = _input.shape - n_rows = BT - BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) - element_mul_kernel[(n_rows,)]( _input, _input.stride(-2), @@ -363,7 +413,7 @@ def cross_entropy_backward(_input, grad_output): num_warps=32 if not is_hip() else 16, ) - return _input + return _input + dX_entropy_2d class LigerCrossEntropyFunction(torch.autograd.Function): @@ -384,6 +434,7 @@ def forward( reduction: str = "mean", softcap: Optional[float] = None, return_z_loss: bool = False, + return_entropy_loss: bool = False, ): """ The forward pass of the Liger Cross Entropy loss. @@ -403,7 +454,7 @@ def forward( Returns: tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None. """ - loss, z_loss, _input = cross_entropy_forward( + loss, z_loss, entropy_loss, _input, dX_entropy_2d = cross_entropy_forward( _input, target, weight, @@ -413,17 +464,19 @@ def forward( reduction, softcap, return_z_loss, + return_entropy_loss, ) # TODO: investigation # If we don't detach the _input tensor, the memory will double # Not sure why but seems that there will be a time both grad and value exist but in different location - ctx.save_for_backward(_input.detach()) + ctx.save_for_backward(_input.detach(), dX_entropy_2d.detach()) ctx.return_z_loss = return_z_loss - - return loss, z_loss + ctx.return_entropy_loss = return_entropy_loss + + return loss, z_loss, entropy_loss @staticmethod - def backward(ctx, grad_output, grad_ouput2): + def backward(ctx, grad_output, grad_ouput2, grad_ouput3): """ The backward pass of the Liger Cross Entropy loss. @@ -431,14 +484,20 @@ def backward(ctx, grad_output, grad_ouput2): ctx : The context object with saved tensors. grad_output (tensor): The tensor containing the gradient of the loss with respect to the output. grad_output2 (tenosr): No use. + grad_output3 (tenosr): The tensor containing the gradient of the loss with respect to the entropy loss. Returns: tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None. """ if ctx.return_z_loss: del grad_ouput2 # z_loss is only for logging - (_input,) = ctx.saved_tensors - _input = cross_entropy_backward(_input, grad_output) + (_input, dX_entropy_2d) = ctx.saved_tensors + _input = cross_entropy_backward(_input, dX_entropy_2d, grad_output, grad_ouput3) + + # delete the tensors that are not used in remaining steps + del grad_ouput3 + del dX_entropy_2d + return ( _input, None, diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index 267763560..fb3748632 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -25,8 +25,10 @@ def fused_linear_cross_entropy_forward( reduction="mean", softcap=None, return_z_loss=False, + return_entropy_loss=False, ): assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}" + assert isinstance(return_entropy_loss, bool), f"return_entropy_loss must be True or False. Got: {return_entropy_loss}" device = _input.device # inputs have shape: BT x H @@ -44,12 +46,19 @@ def fused_linear_cross_entropy_forward( chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size + # initialize the gradients w.r.s. to the cross entropy loss grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None grad_input = torch.zeros_like(_input, device=device) grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None + # initialize the gradients w.r.s. to the entropy loss + grad_entropy_weight = torch.zeros_like(weight, device=device) if return_entropy_loss and weight.requires_grad else None + grad_entropy_input = torch.zeros_like(_input, device=device) if return_entropy_loss else None + grad_entropy_bias = torch.zeros_like(bias, device=device) if return_entropy_loss and bias is not None else None + # we use fp32 for loss accumulator loss_1d = torch.zeros(BT, dtype=torch.float32, device=device) z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None + entropy_loss_1d = torch.zeros(BT, dtype=torch.float32, device=device) if return_entropy_loss else None # TODO: evaluate how CUDA synchronization caused by .item() affects the speed target_mask = target != ignore_index @@ -77,6 +86,9 @@ def fused_linear_cross_entropy_forward( logits_chunk = _input_chunk @ weight.t() # chunk_size x V if bias is not None: logits_chunk = logits_chunk + bias + + # create a tensor to store the gradient of the input w.r.s. to the entropy loss + grad_entropy_logits_chunk = torch.zeros_like(logits_chunk, device=device) if return_entropy_loss else None target_chunk = target[start_idx:end_idx] # chunk_size, @@ -85,20 +97,23 @@ def fused_linear_cross_entropy_forward( # unreduced loss loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size, z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None - + entropy_loss_1d_slice = entropy_loss_1d[start_idx:end_idx] if return_entropy_loss else None # ensure _input and target are contiguous logits_chunk = logits_chunk.contiguous() target_chunk = target_chunk.contiguous() + grad_entropy_logits_chunk = grad_entropy_logits_chunk.contiguous() # Here we calculate the gradient of logits_chunk in place so we can save memory. liger_cross_entropy_kernel[(n_rows,)]( X_ptr=logits_chunk, + dX_entropy_ptr=grad_entropy_logits_chunk, X_stride=logits_chunk.stride(-2), Y_ptr=target_chunk, Y_stride=target_chunk.stride(-1), # always 1 weight_ptr=ce_weight, loss_ptr=loss_1d_slice, z_loss_ptr=z_loss_1d_slice, + entropy_loss_ptr=entropy_loss_1d_slice, loss_stride=loss_1d_slice.stride(-1), # always 1 n_cols=V, n_non_ignore=total_n_non_ignore, @@ -110,6 +125,7 @@ def fused_linear_cross_entropy_forward( reduction=reduction, softcap=softcap, RETURN_Z_LOSS=return_z_loss, + RETURN_ENTROPY_LOSS=return_entropy_loss, HAS_WEIGHT=True if ce_weight is not None else False, HAS_SOFTCAPPING=True if softcap is not None else False, BLOCK_SIZE=BLOCK_SIZE, @@ -119,6 +135,9 @@ def fused_linear_cross_entropy_forward( loss_1d[start_idx:end_idx] = loss_1d_slice if return_z_loss: z_loss_1d[start_idx:end_idx] = z_loss_1d_slice + if return_entropy_loss: + entropy_loss_1d[start_idx:end_idx] = entropy_loss_1d_slice + grad_logits_chunk = logits_chunk # chunk_size x V grad_input[start_idx:end_idx] = grad_logits_chunk @ weight @@ -142,62 +161,94 @@ def fused_linear_cross_entropy_forward( out=grad_bias, alpha=1.0, ) + + if return_entropy_loss: + grad_entropy_input[start_idx:end_idx] = grad_entropy_logits_chunk @ weight + if grad_weight is not None: + torch.addmm( + input=grad_entropy_weight, + mat1=grad_entropy_logits_chunk.t().to( + _input_chunk.dtype + ), + mat2=_input_chunk, + out=grad_entropy_weight, + alpha=1.0, + beta=1.0, + ) + if bias is not None: + torch.add( + input=grad_entropy_bias, + other=grad_entropy_logits_chunk.sum(dim=0), + out=grad_entropy_bias, + alpha=1.0, + ) if reduction == "none": loss = loss_1d z_loss = z_loss_1d if return_z_loss else None - + entropy_loss = entropy_loss_1d if return_entropy_loss else None else: loss = torch.sum(loss_1d) z_loss = torch.sum(z_loss_1d) if return_z_loss else None - return loss, z_loss, grad_input, grad_weight, grad_bias - - -def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias): - # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time - if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): - # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place - # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. - BT, H = grad_input.shape - n_rows = BT - BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H)) + entropy_loss = torch.sum(entropy_loss_1d) if return_entropy_loss else None + return loss, z_loss, entropy_loss, grad_input, grad_weight, grad_bias, grad_entropy_input, grad_entropy_weight, grad_entropy_bias + +def _fused_linear_backward_helper(grad_output, grad_input, grad_weight, grad_bias): + # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place + # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. + BT, H = grad_input.shape + n_rows = BT + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H)) + + element_mul_kernel[(n_rows,)]( + grad_input, + grad_input.stride(-2), + grad_output, + H, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) + + # handle grad_weight + if grad_weight is not None: + V, H = grad_weight.shape + n_rows = V element_mul_kernel[(n_rows,)]( - grad_input, - grad_input.stride(-2), + grad_weight, + grad_weight.stride(-2), grad_output, H, BLOCK_SIZE=BLOCK_SIZE, num_warps=32 if not is_hip() else 16, ) - # handle grad_weight - if grad_weight is not None: - V, H = grad_weight.shape - n_rows = V - - element_mul_kernel[(n_rows,)]( - grad_weight, - grad_weight.stride(-2), - grad_output, - H, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=32 if not is_hip() else 16, - ) + if grad_bias is not None: + V = grad_bias.shape[0] + n_rows = V - if grad_bias is not None: - V = grad_bias.shape[0] - n_rows = V - - element_mul_kernel[(n_rows,)]( - grad_bias, - grad_bias.stride(-1), - grad_output, - 1, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=32 if not is_hip() else 16, - ) + element_mul_kernel[(n_rows,)]( + grad_bias, + grad_bias.stride(-1), + grad_output, + 1, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) return grad_input, grad_weight, grad_bias + + +def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias, grad_entropy_output, grad_entropy_input, grad_entropy_weight, grad_entropy_bias): + # Calculate the gradient with respect to the entropy losses + grad_entropy_input, grad_entropy_weight, grad_entropy_bias = _fused_linear_backward_helper( + grad_entropy_output, grad_entropy_input, grad_entropy_weight, grad_entropy_bias + ) + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time + if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + grad_input, grad_weight, grad_bias = _fused_linear_backward_helper( + grad_output, grad_input, grad_weight, grad_bias + ) + return grad_input + grad_entropy_input, grad_weight + grad_entropy_weight, grad_bias + grad_entropy_bias class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function): @@ -216,6 +267,7 @@ def forward( reduction="mean", softcap=None, return_z_loss: bool = False, + return_entropy_loss: bool = False, ): """ Fusing the last linear layer with cross-entropy loss @@ -236,7 +288,7 @@ def forward( reduction: reduction to apply """ - loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward( + loss, z_loss, entropy_loss, grad_input, grad_weight, grad_bias, grad_entropy_input, grad_entropy_weight, grad_entropy_bias = fused_linear_cross_entropy_forward( _input=_input, weight=weight, target=target, @@ -248,25 +300,35 @@ def forward( reduction=reduction, softcap=softcap, return_z_loss=return_z_loss, + return_entropy_loss=return_entropy_loss, ) # downcast to dtype and store for backward ctx.save_for_backward( grad_input.detach(), grad_weight.detach() if grad_weight is not None else None, grad_bias.detach() if bias is not None else None, + grad_entropy_input.detach() if grad_entropy_input is not None else None, + grad_entropy_weight.detach() if grad_entropy_weight is not None else None, + grad_entropy_bias.detach() if grad_entropy_bias is not None else None, ) ctx.return_z_loss = return_z_loss - return loss, z_loss + return loss, z_loss, entropy_loss @staticmethod @amp_custom_bwd - def backward(ctx, grad_output, grad_output2): + def backward(ctx, grad_output, grad_output2, grad_output3): if ctx.return_z_loss: del grad_output2 # z_loss is only for logging - (grad_input, grad_weight, grad_bias) = ctx.saved_tensors + (grad_input, grad_weight, grad_bias, grad_entropy_input, grad_entropy_weight, grad_entropy_bias) = ctx.saved_tensors grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward( - grad_output, grad_input, grad_weight, grad_bias + grad_output, grad_input, grad_weight, grad_bias, grad_output3, grad_entropy_input, grad_entropy_weight, grad_entropy_bias ) + # delete the tensors that are not used in remaining steps + del grad_output3 + del grad_entropy_input + del grad_entropy_weight + del grad_entropy_bias + return ( grad_input, grad_weight, diff --git a/src/liger_kernel/transformers/cross_entropy.py b/src/liger_kernel/transformers/cross_entropy.py index f01b1f57e..a32d59f76 100644 --- a/src/liger_kernel/transformers/cross_entropy.py +++ b/src/liger_kernel/transformers/cross_entropy.py @@ -15,6 +15,7 @@ def __init__( reduction: str = "mean", softcap: Optional[float] = None, return_z_loss: bool = False, + return_entropy_loss: bool = False, ): super().__init__() assert (label_smoothing >= 0) and (label_smoothing <= 1), ( @@ -33,9 +34,10 @@ def __init__( self.reduction = reduction self.softcap = softcap self.return_z_loss = return_z_loss - + self.return_entropy_loss = return_entropy_loss + def forward(self, _input: torch.Tensor, target: torch.Tensor): - loss, z_loss = LigerCrossEntropyFunction.apply( + loss, z_loss, entropy_loss = LigerCrossEntropyFunction.apply( _input, target, self.weight, @@ -46,6 +48,10 @@ def forward(self, _input: torch.Tensor, target: torch.Tensor): self.softcap, self.return_z_loss, ) - if not self.return_z_loss: + if not self.return_z_loss and not self.return_entropy_loss: return loss - return loss, z_loss + if self.return_z_loss and not self.return_entropy_loss: + return loss, z_loss + if not self.return_z_loss and self.return_entropy_loss: + return loss, entropy_loss + return loss, z_loss, entropy_loss diff --git a/src/liger_kernel/transformers/functional.py b/src/liger_kernel/transformers/functional.py index c2f51e952..1ac4de0c1 100644 --- a/src/liger_kernel/transformers/functional.py +++ b/src/liger_kernel/transformers/functional.py @@ -28,8 +28,9 @@ def liger_cross_entropy( lse_square_scale: float = 0.0, softcap: Optional[float] = None, return_z_loss: bool = False, + return_entropy_loss: bool = False, ): - loss, z_loss = LigerCrossEntropyFunction.apply( + loss, z_loss, entropy_loss = LigerCrossEntropyFunction.apply( input, target, weight, @@ -39,10 +40,15 @@ def liger_cross_entropy( reduction, softcap, return_z_loss, + return_entropy_loss, ) - if not return_z_loss: + if not return_z_loss and not return_entropy_loss: return loss - return loss, z_loss + if return_z_loss and not return_entropy_loss: + return loss, z_loss + if not return_z_loss and return_entropy_loss: + return loss, entropy_loss + return loss, z_loss, entropy_loss def liger_fused_linear_cross_entropy( @@ -57,8 +63,9 @@ def liger_fused_linear_cross_entropy( reduction: str = "mean", softcap: Optional[float] = None, return_z_loss: bool = False, + return_entropy_loss: bool = False, ): - loss, z_loss = LigerFusedLinearCrossEntropyFunction.apply( + loss, z_loss, entropy_loss = LigerFusedLinearCrossEntropyFunction.apply( input, weight, target, @@ -70,10 +77,15 @@ def liger_fused_linear_cross_entropy( reduction, softcap, return_z_loss, + return_entropy_loss, ) - if not return_z_loss: + if not return_z_loss and not return_entropy_loss: return loss - return loss, z_loss + if return_z_loss and not return_entropy_loss: + return loss, z_loss + if not return_z_loss and return_entropy_loss: + return loss, entropy_loss + return loss, z_loss, entropy_loss def liger_fused_linear_jsd( diff --git a/src/liger_kernel/transformers/fused_linear_cross_entropy.py b/src/liger_kernel/transformers/fused_linear_cross_entropy.py index 6a9f19d7f..1a8ab23c6 100644 --- a/src/liger_kernel/transformers/fused_linear_cross_entropy.py +++ b/src/liger_kernel/transformers/fused_linear_cross_entropy.py @@ -15,6 +15,7 @@ def __init__( reduction: str = "mean", softcap: Optional[float] = None, return_z_loss: bool = False, + return_entropy_loss: bool = False, ): super().__init__() assert (label_smoothing >= 0) and (label_smoothing <= 1), ( @@ -33,9 +34,10 @@ def __init__( self.reduction = reduction self.softcap = softcap self.return_z_loss = return_z_loss - + self.return_entropy_loss = return_entropy_loss + def forward(self, lin_weight, _input, target, bias=None): - loss, z_loss = LigerFusedLinearCrossEntropyFunction.apply( + loss, z_loss, entropy_loss = LigerFusedLinearCrossEntropyFunction.apply( _input, lin_weight, target, @@ -47,7 +49,12 @@ def forward(self, lin_weight, _input, target, bias=None): self.reduction, self.softcap, self.return_z_loss, + self.return_entropy_loss, ) - if not self.return_z_loss: + if not self.return_z_loss and not self.return_entropy_loss: return loss - return loss, z_loss + if self.return_z_loss and not self.return_entropy_loss: + return loss, z_loss + if not self.return_z_loss and self.return_entropy_loss: + return loss, entropy_loss + return loss, z_loss, entropy_loss From 7dad5609b4e0f9bd2d5d4deba85f3a51d20bc343 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Thu, 30 Jan 2025 03:04:50 +0000 Subject: [PATCH 03/32] ruff style check Signed-off-by: Hongpeng Guo --- src/liger_kernel/ops/cross_entropy.py | 30 ++++---- .../ops/fused_linear_cross_entropy.py | 71 +++++++++++++++---- .../transformers/cross_entropy.py | 2 +- .../fused_linear_cross_entropy.py | 2 +- 4 files changed, 74 insertions(+), 31 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index e06cfc376..c08bdfcc4 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -102,7 +102,7 @@ def liger_cross_entropy_kernel( if RETURN_ENTROPY_LOSS: entropy_loss_ptr += program_id * loss_stride dX_entropy_ptr += program_id * X_stride - + if HAS_WEIGHT: weight_y = tl.load(weight_ptr + y).cast(tl.float32) @@ -112,7 +112,7 @@ def liger_cross_entropy_kernel( # 3. [Online softmax] first pass: find max + sum m = float("-inf") # m is the max value. use the notation from the paper d = 0.0 # d is the sum. use the notation from the paper - entropy_loss = 0.0 # entropy loss + entropy_loss = 0.0 # entropy loss ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation if HAS_SOFTCAPPING: ori_X_y = softcap * tanh(ori_X_y / softcap) @@ -175,7 +175,7 @@ def liger_cross_entropy_kernel( if HAS_SOFTCAPPING: intermediate = tanh(X_block / softcap) X_block = softcap * intermediate - + # load the derivatives of the entropy loss if RETURN_ENTROPY_LOSS: dX_entropy_block = tl.load( @@ -208,7 +208,7 @@ def liger_cross_entropy_kernel( # entropy loss term entropy_loss += tl.sum(-softmax_X * tl.log(softmax_X)) # derititive of the entropy loss - dX_entropy_block = - (tl.log(softmax_X) + 1) * weight_block + dX_entropy_block = -(tl.log(softmax_X) + 1) * weight_block # derivative of original_loss dloss_ori = (1 - label_smoothing) * softmax_X # specially handle dx_y @@ -224,14 +224,13 @@ def liger_cross_entropy_kernel( dloss_smooth = dloss_smooth / sum_non_ignore_weight # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. dz_loss = dz_loss / n_non_ignore - dloss_entropy = dloss_entropy / n_non_ignore # derivative of total_loss X_block = dloss_ori + dloss_smooth + dz_loss # chain rule softcapping # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap)) if HAS_SOFTCAPPING: - X_block = X_block * (1 - intermediate * intermediate) + X_block = X_block * (1 - intermediate * intermediate) tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols) if RETURN_ENTROPY_LOSS: @@ -280,7 +279,7 @@ def liger_cross_entropy_kernel( z_loss = z_loss / n_non_ignore # TODO: Implement weighted entropy loss. Currently, entropy loss is not scaled by weight. entropy_loss = entropy_loss / n_non_ignore - + loss += z_loss tl.store(loss_ptr, loss) @@ -289,6 +288,7 @@ def liger_cross_entropy_kernel( if RETURN_ENTROPY_LOSS: tl.store(entropy_loss_ptr, entropy_loss) + # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling # The optimal maximum block size depends on your hardware, your kernel, and your dtype @@ -308,7 +308,9 @@ def cross_entropy_forward( return_entropy_loss, ): assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}" - assert isinstance(return_entropy_loss, bool), f"return_entropy_loss must be True or False. Got: {return_entropy_loss}" + assert isinstance(return_entropy_loss, bool), ( + f"return_entropy_loss must be True or False. Got: {return_entropy_loss}" + ) BT, V = _input.shape n_rows = BT @@ -388,7 +390,7 @@ def cross_entropy_backward(_input, dX_entropy_2d, grad_output, grad_output_entro BT, V = _input.shape n_rows = BT BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) - + element_mul_kernel[(n_rows,)]( dX_entropy_2d, dX_entropy_2d.stride(-2), @@ -472,7 +474,7 @@ def forward( ctx.save_for_backward(_input.detach(), dX_entropy_2d.detach()) ctx.return_z_loss = return_z_loss ctx.return_entropy_loss = return_entropy_loss - + return loss, z_loss, entropy_loss @staticmethod @@ -493,11 +495,11 @@ def backward(ctx, grad_output, grad_ouput2, grad_ouput3): (_input, dX_entropy_2d) = ctx.saved_tensors _input = cross_entropy_backward(_input, dX_entropy_2d, grad_output, grad_ouput3) - + # delete the tensors that are not used in remaining steps - del grad_ouput3 - del dX_entropy_2d - + del grad_ouput3 + del dX_entropy_2d + return ( _input, None, diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index fb3748632..fb08f2abe 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -28,7 +28,9 @@ def fused_linear_cross_entropy_forward( return_entropy_loss=False, ): assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}" - assert isinstance(return_entropy_loss, bool), f"return_entropy_loss must be True or False. Got: {return_entropy_loss}" + assert isinstance(return_entropy_loss, bool), ( + f"return_entropy_loss must be True or False. Got: {return_entropy_loss}" + ) device = _input.device # inputs have shape: BT x H @@ -51,10 +53,12 @@ def fused_linear_cross_entropy_forward( grad_input = torch.zeros_like(_input, device=device) grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None # initialize the gradients w.r.s. to the entropy loss - grad_entropy_weight = torch.zeros_like(weight, device=device) if return_entropy_loss and weight.requires_grad else None + grad_entropy_weight = ( + torch.zeros_like(weight, device=device) if return_entropy_loss and weight.requires_grad else None + ) grad_entropy_input = torch.zeros_like(_input, device=device) if return_entropy_loss else None grad_entropy_bias = torch.zeros_like(bias, device=device) if return_entropy_loss and bias is not None else None - + # we use fp32 for loss accumulator loss_1d = torch.zeros(BT, dtype=torch.float32, device=device) z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None @@ -86,7 +90,7 @@ def fused_linear_cross_entropy_forward( logits_chunk = _input_chunk @ weight.t() # chunk_size x V if bias is not None: logits_chunk = logits_chunk + bias - + # create a tensor to store the gradient of the input w.r.s. to the entropy loss grad_entropy_logits_chunk = torch.zeros_like(logits_chunk, device=device) if return_entropy_loss else None @@ -161,15 +165,13 @@ def fused_linear_cross_entropy_forward( out=grad_bias, alpha=1.0, ) - + if return_entropy_loss: grad_entropy_input[start_idx:end_idx] = grad_entropy_logits_chunk @ weight if grad_weight is not None: torch.addmm( input=grad_entropy_weight, - mat1=grad_entropy_logits_chunk.t().to( - _input_chunk.dtype - ), + mat1=grad_entropy_logits_chunk.t().to(_input_chunk.dtype), mat2=_input_chunk, out=grad_entropy_weight, alpha=1.0, @@ -191,7 +193,18 @@ def fused_linear_cross_entropy_forward( loss = torch.sum(loss_1d) z_loss = torch.sum(z_loss_1d) if return_z_loss else None entropy_loss = torch.sum(entropy_loss_1d) if return_entropy_loss else None - return loss, z_loss, entropy_loss, grad_input, grad_weight, grad_bias, grad_entropy_input, grad_entropy_weight, grad_entropy_bias + return ( + loss, + z_loss, + entropy_loss, + grad_input, + grad_weight, + grad_bias, + grad_entropy_input, + grad_entropy_weight, + grad_entropy_bias, + ) + def _fused_linear_backward_helper(grad_output, grad_input, grad_weight, grad_bias): # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place @@ -236,9 +249,18 @@ def _fused_linear_backward_helper(grad_output, grad_input, grad_weight, grad_bia num_warps=32 if not is_hip() else 16, ) return grad_input, grad_weight, grad_bias - -def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias, grad_entropy_output, grad_entropy_input, grad_entropy_weight, grad_entropy_bias): + +def fused_linear_cross_entropy_backward( + grad_output, + grad_input, + grad_weight, + grad_bias, + grad_entropy_output, + grad_entropy_input, + grad_entropy_weight, + grad_entropy_bias, +): # Calculate the gradient with respect to the entropy losses grad_entropy_input, grad_entropy_weight, grad_entropy_bias = _fused_linear_backward_helper( grad_entropy_output, grad_entropy_input, grad_entropy_weight, grad_entropy_bias @@ -288,7 +310,17 @@ def forward( reduction: reduction to apply """ - loss, z_loss, entropy_loss, grad_input, grad_weight, grad_bias, grad_entropy_input, grad_entropy_weight, grad_entropy_bias = fused_linear_cross_entropy_forward( + ( + loss, + z_loss, + entropy_loss, + grad_input, + grad_weight, + grad_bias, + grad_entropy_input, + grad_entropy_weight, + grad_entropy_bias, + ) = fused_linear_cross_entropy_forward( _input=_input, weight=weight, target=target, @@ -319,16 +351,25 @@ def forward( def backward(ctx, grad_output, grad_output2, grad_output3): if ctx.return_z_loss: del grad_output2 # z_loss is only for logging - (grad_input, grad_weight, grad_bias, grad_entropy_input, grad_entropy_weight, grad_entropy_bias) = ctx.saved_tensors + (grad_input, grad_weight, grad_bias, grad_entropy_input, grad_entropy_weight, grad_entropy_bias) = ( + ctx.saved_tensors + ) grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward( - grad_output, grad_input, grad_weight, grad_bias, grad_output3, grad_entropy_input, grad_entropy_weight, grad_entropy_bias + grad_output, + grad_input, + grad_weight, + grad_bias, + grad_output3, + grad_entropy_input, + grad_entropy_weight, + grad_entropy_bias, ) # delete the tensors that are not used in remaining steps del grad_output3 del grad_entropy_input del grad_entropy_weight del grad_entropy_bias - + return ( grad_input, grad_weight, diff --git a/src/liger_kernel/transformers/cross_entropy.py b/src/liger_kernel/transformers/cross_entropy.py index a32d59f76..5c1d7b920 100644 --- a/src/liger_kernel/transformers/cross_entropy.py +++ b/src/liger_kernel/transformers/cross_entropy.py @@ -35,7 +35,7 @@ def __init__( self.softcap = softcap self.return_z_loss = return_z_loss self.return_entropy_loss = return_entropy_loss - + def forward(self, _input: torch.Tensor, target: torch.Tensor): loss, z_loss, entropy_loss = LigerCrossEntropyFunction.apply( _input, diff --git a/src/liger_kernel/transformers/fused_linear_cross_entropy.py b/src/liger_kernel/transformers/fused_linear_cross_entropy.py index 1a8ab23c6..0acf2be33 100644 --- a/src/liger_kernel/transformers/fused_linear_cross_entropy.py +++ b/src/liger_kernel/transformers/fused_linear_cross_entropy.py @@ -35,7 +35,7 @@ def __init__( self.softcap = softcap self.return_z_loss = return_z_loss self.return_entropy_loss = return_entropy_loss - + def forward(self, lin_weight, _input, target, bias=None): loss, z_loss, entropy_loss = LigerFusedLinearCrossEntropyFunction.apply( _input, From 1b13b2f08282ae1bac2053079a2352e5c33d6d26 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Thu, 30 Jan 2025 04:13:26 +0000 Subject: [PATCH 04/32] fix for cross_entropy Signed-off-by: Hongpeng Guo --- src/liger_kernel/ops/cross_entropy.py | 36 ++++++++++++------- .../transformers/cross_entropy.py | 1 + test/transformers/test_cross_entropy.py | 8 ++++- 3 files changed, 32 insertions(+), 13 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index c08bdfcc4..7154a71be 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -365,6 +365,7 @@ def cross_entropy_forward( reduction=reduction, softcap=softcap, RETURN_Z_LOSS=return_z_loss, + RETURN_ENTROPY_LOSS=return_entropy_loss, BLOCK_SIZE=BLOCK_SIZE, HAS_WEIGHT=True if weight is not None else False, HAS_SOFTCAPPING=True if softcap is not None else False, @@ -386,19 +387,10 @@ def cross_entropy_forward( def cross_entropy_backward(_input, dX_entropy_2d, grad_output, grad_output_entropy): - # calculate the gradient of the input w.r.s. to the entropy loss BT, V = _input.shape n_rows = BT BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) - element_mul_kernel[(n_rows,)]( - dX_entropy_2d, - dX_entropy_2d.stride(-2), - grad_output_entropy, - V, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=32 if not is_hip() else 16, - ) # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): pass @@ -414,8 +406,20 @@ def cross_entropy_backward(_input, dX_entropy_2d, grad_output, grad_output_entro BLOCK_SIZE=BLOCK_SIZE, num_warps=32 if not is_hip() else 16, ) + + # calculate the gradient of the input w.r.s. to the entropy loss + if dX_entropy_2d is not None: + element_mul_kernel[(n_rows,)]( + dX_entropy_2d, + dX_entropy_2d.stride(-2), + grad_output_entropy, + V, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) + _input += dX_entropy_2d - return _input + dX_entropy_2d + return _input class LigerCrossEntropyFunction(torch.autograd.Function): @@ -471,7 +475,10 @@ def forward( # TODO: investigation # If we don't detach the _input tensor, the memory will double # Not sure why but seems that there will be a time both grad and value exist but in different location - ctx.save_for_backward(_input.detach(), dX_entropy_2d.detach()) + if return_entropy_loss: + ctx.save_for_backward(_input.detach(), dX_entropy_2d.detach()) + else: + ctx.save_for_backward(_input.detach()) ctx.return_z_loss = return_z_loss ctx.return_entropy_loss = return_entropy_loss @@ -492,8 +499,12 @@ def backward(ctx, grad_output, grad_ouput2, grad_ouput3): """ if ctx.return_z_loss: del grad_ouput2 # z_loss is only for logging + + if ctx.return_entropy_loss: + (_input, dX_entropy_2d) = ctx.saved_tensors + else: + (_input,), dX_entropy_2d = ctx.saved_tensors, None - (_input, dX_entropy_2d) = ctx.saved_tensors _input = cross_entropy_backward(_input, dX_entropy_2d, grad_output, grad_ouput3) # delete the tensors that are not used in remaining steps @@ -510,4 +521,5 @@ def backward(ctx, grad_output, grad_ouput2, grad_ouput3): None, None, None, + None, ) diff --git a/src/liger_kernel/transformers/cross_entropy.py b/src/liger_kernel/transformers/cross_entropy.py index 5c1d7b920..abcc19930 100644 --- a/src/liger_kernel/transformers/cross_entropy.py +++ b/src/liger_kernel/transformers/cross_entropy.py @@ -47,6 +47,7 @@ def forward(self, _input: torch.Tensor, target: torch.Tensor): self.reduction, self.softcap, self.return_z_loss, + self.return_entropy_loss, ) if not self.return_z_loss and not self.return_entropy_loss: return loss diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index b88033f2a..44fbb53fa 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -860,14 +860,17 @@ def test_float32_internal(): # Run kernel for bfloat16 X_bf16 = X_init.clone() + dX_entropy_bf16 = torch.zeros_like(X_bf16) loss_bf16 = torch.zeros(batch_size, dtype=torch.float32, device=device) liger_cross_entropy_kernel[(batch_size,)]( X_ptr=X_bf16, + dX_entropy_ptr=dX_entropy_bf16, X_stride=X_bf16.stride(-2), Y_ptr=Y, Y_stride=Y.stride(-1), weight_ptr=X_bf16, # dummy ptr, not used z_loss_ptr=loss_bf16, # dummy ptr, not used + entropy_loss_ptr=loss_bf16, # dummy ptr, not used loss_ptr=loss_bf16, loss_stride=loss_bf16.stride(-1), n_cols=n_cols, @@ -880,6 +883,7 @@ def test_float32_internal(): reduction=reduction, softcap=softcap, RETURN_Z_LOSS=0, # False + RETURN_ENTROPY_LOSS=0, # False HAS_WEIGHT=False, HAS_SOFTCAPPING=False, BLOCK_SIZE=BLOCK_SIZE, @@ -888,15 +892,17 @@ def test_float32_internal(): # Run kernel for float32 X_fp32 = X_init.float() + dX_entropy_fp32 = torch.zeros_like(X_fp32) loss_fp32 = torch.zeros(batch_size, dtype=torch.float32, device=device) liger_cross_entropy_kernel[(batch_size,)]( X_ptr=X_fp32, + dX_entropy_ptr=dX_entropy_fp32, X_stride=X_fp32.stride(-2), Y_ptr=Y, Y_stride=Y.stride(-1), weight_ptr=X_fp32, # dummy ptr, not used - loss_ptr=loss_fp32, z_loss_ptr=loss_fp32, # dummy ptr, not used + entropy_loss_ptr=loss_fp32, # dummy ptr, not used loss_stride=loss_fp32.stride(-1), n_cols=n_cols, n_non_ignore=n_non_ignore, From 8a43d1e2b1c16c6ce01699e65a2ea924baef3143 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Thu, 30 Jan 2025 04:13:54 +0000 Subject: [PATCH 05/32] fix checkstyle Signed-off-by: Hongpeng Guo --- src/liger_kernel/ops/cross_entropy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 7154a71be..60ec02c1e 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -406,7 +406,7 @@ def cross_entropy_backward(_input, dX_entropy_2d, grad_output, grad_output_entro BLOCK_SIZE=BLOCK_SIZE, num_warps=32 if not is_hip() else 16, ) - + # calculate the gradient of the input w.r.s. to the entropy loss if dX_entropy_2d is not None: element_mul_kernel[(n_rows,)]( @@ -499,7 +499,7 @@ def backward(ctx, grad_output, grad_ouput2, grad_ouput3): """ if ctx.return_z_loss: del grad_ouput2 # z_loss is only for logging - + if ctx.return_entropy_loss: (_input, dX_entropy_2d) = ctx.saved_tensors else: From 82d9b55e3ef064c2e90aef2c9079162dac52aaf2 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Thu, 30 Jan 2025 04:24:17 +0000 Subject: [PATCH 06/32] wip fix flce Signed-off-by: Hongpeng Guo --- .../ops/fused_linear_cross_entropy.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index fb08f2abe..0d0edecd0 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -105,7 +105,8 @@ def fused_linear_cross_entropy_forward( # ensure _input and target are contiguous logits_chunk = logits_chunk.contiguous() target_chunk = target_chunk.contiguous() - grad_entropy_logits_chunk = grad_entropy_logits_chunk.contiguous() + if return_entropy_loss: + grad_entropy_logits_chunk = grad_entropy_logits_chunk.contiguous() # Here we calculate the gradient of logits_chunk in place so we can save memory. liger_cross_entropy_kernel[(n_rows,)]( @@ -261,16 +262,21 @@ def fused_linear_cross_entropy_backward( grad_entropy_weight, grad_entropy_bias, ): - # Calculate the gradient with respect to the entropy losses - grad_entropy_input, grad_entropy_weight, grad_entropy_bias = _fused_linear_backward_helper( - grad_entropy_output, grad_entropy_input, grad_entropy_weight, grad_entropy_bias - ) # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): grad_input, grad_weight, grad_bias = _fused_linear_backward_helper( grad_output, grad_input, grad_weight, grad_bias ) - return grad_input + grad_entropy_input, grad_weight + grad_entropy_weight, grad_bias + grad_entropy_bias + # Calculate the gradient with respect to the entropy losses + if grad_entropy_output is not None: + grad_entropy_input, grad_entropy_weight, grad_entropy_bias = _fused_linear_backward_helper( + grad_entropy_output, grad_entropy_input, grad_entropy_weight, grad_entropy_bias + ) + grad_input += grad_entropy_input + grad_weight += grad_entropy_weight + grad_bias += grad_entropy_bias + + return grad_input, grad_weight, grad_bias class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function): @@ -382,4 +388,5 @@ def backward(ctx, grad_output, grad_output2, grad_output3): None, None, None, + None, ) From 984e85fd82ec30e1ee3c34569d240db9bd4e09e9 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Thu, 30 Jan 2025 08:04:26 +0000 Subject: [PATCH 07/32] fix bugs Signed-off-by: Hongpeng Guo --- src/liger_kernel/ops/fused_linear_cross_entropy.py | 6 ++++-- test/transformers/test_fused_linear_cross_entropy.py | 8 +++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index 0d0edecd0..903e11fc2 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -273,8 +273,10 @@ def fused_linear_cross_entropy_backward( grad_entropy_output, grad_entropy_input, grad_entropy_weight, grad_entropy_bias ) grad_input += grad_entropy_input - grad_weight += grad_entropy_weight - grad_bias += grad_entropy_bias + if grad_weight is not None: + grad_weight += grad_entropy_weight + if grad_bias is not None: + grad_bias += grad_entropy_bias return grad_input, grad_weight, grad_bias diff --git a/test/transformers/test_fused_linear_cross_entropy.py b/test/transformers/test_fused_linear_cross_entropy.py index ffbe52275..d6a4343e1 100644 --- a/test/transformers/test_fused_linear_cross_entropy.py +++ b/test/transformers/test_fused_linear_cross_entropy.py @@ -256,7 +256,7 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, bias, ce_weight, atol bias = torch.randn(V, device=device, dtype=dtype) if bias else None ce_weight = torch.randn(V, device=device) if ce_weight else None - y1, z1 = liger_fused_linear_cross_entropy( + y1, z1, e1 = liger_fused_linear_cross_entropy( input=x1, weight=weight, target=target, @@ -268,13 +268,15 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, bias, ce_weight, atol reduction="mean", softcap=30.0, return_z_loss=True, + return_entropy_loss=True, ) - y2, z2 = LigerFusedLinearCrossEntropyFunction.apply( - x2, weight, target, bias, ce_weight, -100, 1e-4, 0.1, "mean", 30.0, True + y2, z2, e2 = LigerFusedLinearCrossEntropyFunction.apply( + x2, weight, target, bias, ce_weight, -100, 1e-4, 0.1, "mean", 30.0, True, True ) assert torch.allclose(y1, y2, atol=atol, rtol=rtol) assert torch.allclose(z1, z2, atol=atol, rtol=rtol) + assert torch.allclose(e1, e2, atol=atol, rtol=rtol) grad_output = torch.randn_like(y1) From eb90401b8ab6b02e173bee8ff9b95849e5e25e41 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Thu, 30 Jan 2025 08:04:54 +0000 Subject: [PATCH 08/32] fix bugs Signed-off-by: Hongpeng Guo --- src/liger_kernel/ops/fused_linear_cross_entropy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index 903e11fc2..ca14a359a 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -268,7 +268,7 @@ def fused_linear_cross_entropy_backward( grad_output, grad_input, grad_weight, grad_bias ) # Calculate the gradient with respect to the entropy losses - if grad_entropy_output is not None: + if grad_entropy_output is not None: grad_entropy_input, grad_entropy_weight, grad_entropy_bias = _fused_linear_backward_helper( grad_entropy_output, grad_entropy_input, grad_entropy_weight, grad_entropy_bias ) @@ -277,7 +277,7 @@ def fused_linear_cross_entropy_backward( grad_weight += grad_entropy_weight if grad_bias is not None: grad_bias += grad_entropy_bias - + return grad_input, grad_weight, grad_bias From 7684eed89de3b49d4576602e7d4d31f46a3d87f7 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Thu, 30 Jan 2025 08:33:53 +0000 Subject: [PATCH 09/32] fix Signed-off-by: Hongpeng Guo --- test/transformers/test_cross_entropy.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index 44fbb53fa..a66acd687 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -903,6 +903,7 @@ def test_float32_internal(): weight_ptr=X_fp32, # dummy ptr, not used z_loss_ptr=loss_fp32, # dummy ptr, not used entropy_loss_ptr=loss_fp32, # dummy ptr, not used + loss_ptr=loss_fp32, loss_stride=loss_fp32.stride(-1), n_cols=n_cols, n_non_ignore=n_non_ignore, @@ -914,6 +915,7 @@ def test_float32_internal(): reduction=reduction, softcap=softcap, RETURN_Z_LOSS=0, # False + RETURN_ENTROPY_LOSS=0, # False HAS_WEIGHT=False, HAS_SOFTCAPPING=False, BLOCK_SIZE=BLOCK_SIZE, From bed2d45539e85f759996b2627015007ef044fac0 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Thu, 30 Jan 2025 08:43:10 +0000 Subject: [PATCH 10/32] fix a unit test Signed-off-by: Hongpeng Guo --- test/transformers/test_cross_entropy.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index a66acd687..7f0a50067 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -395,7 +395,7 @@ def _test_correctness_functional( target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) - y1, y1_z = liger_cross_entropy( + y1, y1_z, y1_e = liger_cross_entropy( x1, target, None, @@ -405,8 +405,9 @@ def _test_correctness_functional( reduction="mean", softcap=30.0, return_z_loss=True, + return_entropy_loss=True, ) - y2, y2_z = LigerCrossEntropyFunction.apply(x2, target, None, 0, 1e-4, 0.1, "mean", 30.0, True) + y2, y2_z, y2_e = LigerCrossEntropyFunction.apply(x2, target, None, 0, 1e-4, 0.1, "mean", 30.0, True, True) assert torch.allclose(y1, y2, atol=atol, rtol=rtol) assert torch.allclose(y1_z, y2_z, atol=atol, rtol=rtol) From a967e65cc6d8fa4b8ef84f39e492f4faf9a4cfaf Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Mon, 3 Feb 2025 03:29:54 +0000 Subject: [PATCH 11/32] fix ce kernel, add unit test make it work Signed-off-by: Hongpeng Guo --- src/liger_kernel/ops/cross_entropy.py | 38 ++++- test/transformers/test_cross_entropy.py | 176 +++++++++++++++++++++++- 2 files changed, 200 insertions(+), 14 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 60ec02c1e..735d6a4fa 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -148,6 +148,27 @@ def liger_cross_entropy_kernel( # = log (e^(max(X)) * sum(e ^ (X_i - max(X)))) # = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d lse = m + tl.log(d) + + # 3.5 Calculate the entropy loss + if RETURN_ENTROPY_LOSS: + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load( + X_ptr + X_offsets, + mask=X_offsets < n_cols, + other=float("-inf"), + # Ensure float32 precision for softmax calculation + ).cast(tl.float32) + if HAS_SOFTCAPPING: + intermediate = tanh(X_block / softcap) + X_block = softcap * intermediate + + softmax_X = tl.exp(X_block - m) / d + # Mask for valid columns and non-zero softmax + valid_mask = (X_offsets < n_cols) & (softmax_X > 0.0) + entropy_term = tl.where(valid_mask, -softmax_X * tl.log(softmax_X), 0.0) + entropy_loss += tl.sum(entropy_term) + # 4. [Online Softmax] Second pass: compute gradients # For 'mean' reduction, gradients are normalized by number of non-ignored elements (N) @@ -188,10 +209,10 @@ def liger_cross_entropy_kernel( # softmax(x_i) X_block = tl.exp(X_block - m) / d if RETURN_ENTROPY_LOSS: - # entropy loss term - entropy_loss += tl.sum(-X_block * tl.log(X_block)) + # Mask for valid columns and non-zero softmax + valid_mask = (X_offsets < n_cols) & (X_block > 0.0) # derivatives of the entropy loss term - dX_entropy_block += -(tl.log(X_block) + 1) + dX_entropy_block = X_block * (-tl.log(X_block) - entropy_loss) # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i) X_block += 2 * lse_square_scale * lse * X_block # smoothing term @@ -201,14 +222,17 @@ def liger_cross_entropy_kernel( # reduction scale if reduction == "mean": X_block = X_block / n_non_ignore + dX_entropy_block = dX_entropy_block / n_non_ignore else: weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) softmax_X = tl.exp(X_block - m) / d if RETURN_ENTROPY_LOSS: - # entropy loss term - entropy_loss += tl.sum(-softmax_X * tl.log(softmax_X)) + # Mask for valid columns and non-zero softmax + valid_mask = (X_offsets < n_cols) & (softmax_X > 0.0) # derititive of the entropy loss - dX_entropy_block = -(tl.log(softmax_X) + 1) * weight_block + # d_entropy_term = softmax_X * (-tl.log(softmax_X) - entropy_loss) + # dX_entropy_block = tl.where(valid_mask, d_entropy_term, 0.0) + dX_entropy_block = softmax_X * (-tl.log(softmax_X) - entropy_loss) # derivative of original_loss dloss_ori = (1 - label_smoothing) * softmax_X # specially handle dx_y @@ -222,6 +246,7 @@ def liger_cross_entropy_kernel( if reduction == "mean": dloss_ori = dloss_ori / sum_non_ignore_weight dloss_smooth = dloss_smooth / sum_non_ignore_weight + dX_entropy_block = dX_entropy_block / sum_non_ignore_weight # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. dz_loss = dz_loss / n_non_ignore # derivative of total_loss @@ -406,7 +431,6 @@ def cross_entropy_backward(_input, dX_entropy_2d, grad_output, grad_output_entro BLOCK_SIZE=BLOCK_SIZE, num_warps=32 if not is_hip() else 16, ) - # calculate the gradient of the input w.r.s. to the entropy loss if dX_entropy_2d is not None: element_mul_kernel[(n_rows,)]( diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index 7f0a50067..d171e3fa1 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -18,7 +18,7 @@ set_seed(42) -class CrossEntropyWithZLoss(torch.nn.Module): +class CrossEntropyWithZLossWithEntropyLoss(torch.nn.Module): def __init__( self, weight=None, @@ -27,6 +27,7 @@ def __init__( ignore_index=-100, label_smoothing=0.0, return_z_loss=False, + return_entropy_loss=False, dtype=torch.float32, ): super().__init__() @@ -35,6 +36,7 @@ def __init__( self.reduction = reduction self.ignore_index = ignore_index self.return_z_loss = return_z_loss + self.return_entropy_loss = return_entropy_loss self.label_smoothing = label_smoothing self.dtype = dtype @@ -53,26 +55,41 @@ def forward(self, logits, targets): label_smoothing=self.label_smoothing, ignore_index=self.ignore_index, ) - # Compute log-sum-exp term lse = torch.logsumexp(logits, dim=-1) # Z-loss term z_loss = torch.where(targets != self.ignore_index, self.lse_square_scale * (lse**2), 0.0) + # Entropy loss term + entropy_loss = torch.sum(-F.softmax(logits, dim=-1) * F.log_softmax(logits, dim=-1), dim=-1) + entropy_loss = torch.where(target_mask, entropy_loss, 0.0) + if self.reduction == "mean": z_loss = z_loss.sum() / target_mask.sum() + entropy_loss = entropy_loss.sum() / target_mask.sum() elif self.reduction == "sum": z_loss = z_loss.sum() + entropy_loss = entropy_loss.sum() else: z_loss = z_loss + entropy_loss = entropy_loss + ce_loss = ce_loss.to(self.dtype) z_loss = z_loss.to(self.dtype) + entropy_loss = entropy_loss.to(self.dtype) # Final loss: cross-entropy loss + Z-loss + # Note that the entropy loss is not directly added to the total loss total_loss = ce_loss + z_loss - if self.return_z_loss: + + # Return the total loss and optionally the z_loss and entropy_loss + if self.return_z_loss and self.return_entropy_loss: + return total_loss, z_loss, entropy_loss + elif self.return_z_loss: return total_loss, z_loss + elif self.return_entropy_loss: + return total_loss, entropy_loss else: return total_loss @@ -204,7 +221,7 @@ def _test_correctness_with_z_loss_once( return_z_loss, ): torch.manual_seed(0) - torch_ce = CrossEntropyWithZLoss( + torch_ce = CrossEntropyWithZLossWithEntropyLoss( lse_square_scale=lse_square_scale, return_z_loss=return_z_loss, dtype=dtype, @@ -250,7 +267,7 @@ def _test_correctness_with_z_loss_with_other_params_once( reduction, ): torch.manual_seed(0) - torch_ce = CrossEntropyWithZLoss( + torch_ce = CrossEntropyWithZLossWithEntropyLoss( lse_square_scale=lse_square_scale, return_z_loss=return_z_loss, label_smoothing=label_smoothing, @@ -326,7 +343,7 @@ def _test_correctness_with_weight_with_other_params_once( rtol, ): torch.manual_seed(0) - torch_ce = CrossEntropyWithZLoss( + torch_ce = CrossEntropyWithZLossWithEntropyLoss( weight=weight, lse_square_scale=lse_square_scale, ignore_index=ignore_index, @@ -379,6 +396,74 @@ def _test_correctness_not_last_layer_once(target_ce, B, T, V, reduction, scalar, assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) +def _test_correctness_with_entropy_loss_with_other_params_once( + target_ce, + B, + T, + V, + scalar, + dtype, + atol, + rtol, + lse_square_scale, + return_z_loss, + return_entropy_loss, + label_smoothing, + ignore_index, + reduction, +): + torch.manual_seed(0) + torch_ce = CrossEntropyWithZLossWithEntropyLoss( + lse_square_scale=lse_square_scale, + return_z_loss=return_z_loss, + return_entropy_loss=return_entropy_loss, + label_smoothing=label_smoothing, + ignore_index=ignore_index, + reduction=reduction, + dtype=dtype, + ) + + _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar + _input = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint( + 1, B * T // 2, (1,) + ).item() # Random number of elements to set to ignore_index + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] # Randomly select indices + target[indices_to_assign] = ignore_index + + if return_z_loss and return_entropy_loss: + output, z_output, entropy_output = torch_ce(_input, target) + output2, z_output2, entropy_output2 = target_ce(_input2, target) + elif return_z_loss: + output, z_output = torch_ce(_input, target) + output2, z_output2 = target_ce(_input2, target) + elif return_entropy_loss: + output, entropy_output = torch_ce(_input, target) + output2, entropy_output2 = target_ce(_input2, target) + else: + output = torch_ce(_input, target) + output2 = target_ce(_input2, target) + + assert torch.allclose(output, output2, atol=atol, rtol=rtol) + + if return_z_loss: + assert torch.allclose(z_output, z_output2, atol=atol, rtol=rtol) + if return_entropy_loss: + assert torch.allclose(entropy_output, entropy_output2, atol=atol, rtol=rtol) + + loss1 = output + entropy_output if return_entropy_loss else output + loss2 = output2 + entropy_output2 if return_entropy_loss else output2 + + loss1.backward() + loss2.backward() + assert_verbose_allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) + + def _test_correctness_functional( B, T, @@ -815,7 +900,7 @@ def test_correctness_with_weight_with_other_params_once( @pytest.mark.parametrize( "B, T, V", [ - (2, 4096, 32000), # llama2, mistral + (2, 4096, 32000), # llama2, mistral, reduce T to 1024 to avoid OOM on torch implementation # # weird shapes (3, 423, 32000), ], @@ -839,6 +924,83 @@ def test_correctness_not_last_layer(B, T, V, reduction, scalar, dtype, atol, rto _test_correctness_not_last_layer_once(liger_ce, B, T, V, reduction, scalar, dtype, atol, rtol) +@pytest.mark.parametrize( + "B, T, V", + [ + (2, 1024, 32000), # llama2, mistral + # weird shapes + (3, 423, 32000), + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + pytest.param( + 1.0, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + ), + (1.0, torch.float32, 1e-8, 1e-6), + ], +) +@pytest.mark.parametrize( + "return_z_loss, lse_square_scale", + [ + (True, 1e-4), + (False, 1e-5), + ], +) +@pytest.mark.parametrize( + "label_smoothing, ignore_index, reduction", + [ + (0.1, 42, "mean"), + (0.2, -42, "sum"), + ], +) +@pytest.mark.parametrize("return_entropy_loss", [True]) +def test_correctness_with_entropy_loss_with_other_params_once( + B, + T, + V, + scalar, + dtype, + atol, + rtol, + lse_square_scale, + return_z_loss, + return_entropy_loss, + label_smoothing, + ignore_index, + reduction, +): + test_ce = LigerCrossEntropyLoss( + lse_square_scale=lse_square_scale, + return_z_loss=return_z_loss, + return_entropy_loss=return_entropy_loss, + label_smoothing=label_smoothing, + ignore_index=ignore_index, + reduction=reduction, + ) + _test_correctness_with_entropy_loss_with_other_params_once( + test_ce, + B, + T, + V, + scalar, + dtype, + atol, + rtol, + lse_square_scale, + return_z_loss, + return_entropy_loss, + label_smoothing, + ignore_index, + reduction, + ) + + def test_float32_internal(): """ This test validates that the internal softmax calculations occur in float32, From 068b9be8c77a3bdf58b221aa291af2b5477c140d Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Mon, 3 Feb 2025 03:30:29 +0000 Subject: [PATCH 12/32] fix style Signed-off-by: Hongpeng Guo --- src/liger_kernel/ops/cross_entropy.py | 7 +++---- test/transformers/test_cross_entropy.py | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 735d6a4fa..0bd375838 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -148,7 +148,7 @@ def liger_cross_entropy_kernel( # = log (e^(max(X)) * sum(e ^ (X_i - max(X)))) # = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d lse = m + tl.log(d) - + # 3.5 Calculate the entropy loss if RETURN_ENTROPY_LOSS: for i in range(0, n_cols, BLOCK_SIZE): @@ -162,14 +162,13 @@ def liger_cross_entropy_kernel( if HAS_SOFTCAPPING: intermediate = tanh(X_block / softcap) X_block = softcap * intermediate - + softmax_X = tl.exp(X_block - m) / d # Mask for valid columns and non-zero softmax valid_mask = (X_offsets < n_cols) & (softmax_X > 0.0) entropy_term = tl.where(valid_mask, -softmax_X * tl.log(softmax_X), 0.0) entropy_loss += tl.sum(entropy_term) - # 4. [Online Softmax] Second pass: compute gradients # For 'mean' reduction, gradients are normalized by number of non-ignored elements (N) # dx_y = (softmax(x_y) - 1) / N @@ -222,7 +221,7 @@ def liger_cross_entropy_kernel( # reduction scale if reduction == "mean": X_block = X_block / n_non_ignore - dX_entropy_block = dX_entropy_block / n_non_ignore + dX_entropy_block = dX_entropy_block / n_non_ignore else: weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) softmax_X = tl.exp(X_block - m) / d diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index d171e3fa1..697d6cd40 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -64,7 +64,7 @@ def forward(self, logits, targets): # Entropy loss term entropy_loss = torch.sum(-F.softmax(logits, dim=-1) * F.log_softmax(logits, dim=-1), dim=-1) entropy_loss = torch.where(target_mask, entropy_loss, 0.0) - + if self.reduction == "mean": z_loss = z_loss.sum() / target_mask.sum() entropy_loss = entropy_loss.sum() / target_mask.sum() @@ -455,7 +455,7 @@ def _test_correctness_with_entropy_loss_with_other_params_once( assert torch.allclose(z_output, z_output2, atol=atol, rtol=rtol) if return_entropy_loss: assert torch.allclose(entropy_output, entropy_output2, atol=atol, rtol=rtol) - + loss1 = output + entropy_output if return_entropy_loss else output loss2 = output2 + entropy_output2 if return_entropy_loss else output2 From 32ac203bfb20b9e8ed7bd1fff615fd0757321a1a Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Mon, 3 Feb 2025 04:26:06 +0000 Subject: [PATCH 13/32] add unit test to flce Signed-off-by: Hongpeng Guo --- src/liger_kernel/ops/cross_entropy.py | 8 ++--- .../test_fused_linear_cross_entropy.py | 35 ++++++++++++++----- 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 0bd375838..bf44e4fee 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -221,7 +221,8 @@ def liger_cross_entropy_kernel( # reduction scale if reduction == "mean": X_block = X_block / n_non_ignore - dX_entropy_block = dX_entropy_block / n_non_ignore + if RETURN_ENTROPY_LOSS: + dX_entropy_block = dX_entropy_block / n_non_ignore else: weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) softmax_X = tl.exp(X_block - m) / d @@ -229,8 +230,6 @@ def liger_cross_entropy_kernel( # Mask for valid columns and non-zero softmax valid_mask = (X_offsets < n_cols) & (softmax_X > 0.0) # derititive of the entropy loss - # d_entropy_term = softmax_X * (-tl.log(softmax_X) - entropy_loss) - # dX_entropy_block = tl.where(valid_mask, d_entropy_term, 0.0) dX_entropy_block = softmax_X * (-tl.log(softmax_X) - entropy_loss) # derivative of original_loss dloss_ori = (1 - label_smoothing) * softmax_X @@ -245,9 +244,10 @@ def liger_cross_entropy_kernel( if reduction == "mean": dloss_ori = dloss_ori / sum_non_ignore_weight dloss_smooth = dloss_smooth / sum_non_ignore_weight - dX_entropy_block = dX_entropy_block / sum_non_ignore_weight # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. dz_loss = dz_loss / n_non_ignore + if RETURN_ENTROPY_LOSS: + dX_entropy_block = dX_entropy_block / sum_non_ignore_weight # derivative of total_loss X_block = dloss_ori + dloss_smooth + dz_loss diff --git a/test/transformers/test_fused_linear_cross_entropy.py b/test/transformers/test_fused_linear_cross_entropy.py index d6a4343e1..99db98f47 100644 --- a/test/transformers/test_fused_linear_cross_entropy.py +++ b/test/transformers/test_fused_linear_cross_entropy.py @@ -3,7 +3,7 @@ import pytest import torch -from test.transformers.test_cross_entropy import CrossEntropyWithZLoss +from test.transformers.test_cross_entropy import CrossEntropyWithZLossWithEntropyLoss from test.utils import assert_verbose_allclose from test.utils import set_seed @@ -46,16 +46,18 @@ def __init__( reduction: str = "mean", softcap: Optional[float] = None, return_z_loss: bool = False, + return_entropy_loss: bool = False, ): super().__init__() self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) - self.ce_loss = CrossEntropyWithZLoss( + self.ce_loss = CrossEntropyWithZLossWithEntropyLoss( weight=ce_weight, ignore_index=ignore_index, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, reduction=reduction, return_z_loss=return_z_loss, + return_entropy_loss=return_entropy_loss, ) self.softcap = softcap @@ -80,6 +82,7 @@ def __init__( reduction: str = "mean", softcap: Optional[float] = None, return_z_loss: bool = False, + return_entropy_loss: bool = False, ): super().__init__() self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) @@ -91,6 +94,7 @@ def __init__( reduction=reduction, softcap=softcap, return_z_loss=return_z_loss, + return_entropy_loss=return_entropy_loss, ) def forward(self, x, y): @@ -105,7 +109,7 @@ def forward(self, x, y): @pytest.mark.parametrize( "B, T, H, V", [ - pytest.param(8, 128, 1024, 4096, marks=pytest.mark.skipif(device="xpu", reason="skip for XPU")), + # pytest.param(8, 128, 1024, 4096, marks=pytest.mark.skipif(device="xpu", reason="skip for XPU")), (4, 47, 31, 123), # random shape ], ) @@ -120,15 +124,16 @@ def forward(self, x, y): ("none", 1.0, torch.float32, 1e-3, 5e-2), ], ) -@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("bias", [False]) @pytest.mark.parametrize( "has_ce_weight, label_smoothing, ignore_index, lse_square_scale, softcap, return_z_loss", [ (False, 0, -100, 0, None, False), # Pass non-default values once to ensure all params work along - (True, 0.1, 42, 1e-4, 30.0, True), + (False, 0.1, 42, 1e-4, 30.0, True), ], ) +@pytest.mark.parametrize("return_entropy_loss", [True]) def test_correctness( B, T, @@ -144,6 +149,7 @@ def test_correctness( reduction, softcap, return_z_loss, + return_entropy_loss, atol, rtol, ): @@ -162,6 +168,7 @@ def test_correctness( reduction=reduction, softcap=softcap, return_z_loss=return_z_loss, + return_entropy_loss=return_entropy_loss, dtype=dtype, ).to(device) liger_lm_head_ce = LigerLMHeadCE( @@ -175,6 +182,7 @@ def test_correctness( reduction=reduction, softcap=softcap, return_z_loss=return_z_loss, + return_entropy_loss=return_entropy_loss, dtype=dtype, ).to(device) @@ -196,9 +204,15 @@ def test_correctness( indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] # Randomly select indices target[indices_to_assign] = ignore_index - if return_z_loss: + if return_z_loss and return_entropy_loss: + output1, z_output1, e_output1 = torch_lm_head_ce(_input1, target) + output2, z_output2, e_output2 = liger_lm_head_ce(_input2, target) + elif return_z_loss: output1, z_output1 = torch_lm_head_ce(_input1, target) output2, z_output2 = liger_lm_head_ce(_input2, target) + elif return_entropy_loss: + output1, e_output1 = torch_lm_head_ce(_input1, target) + output2, e_output2 = liger_lm_head_ce(_input2, target) else: output1 = torch_lm_head_ce(_input1, target) output2 = liger_lm_head_ce(_input2, target) @@ -206,9 +220,14 @@ def test_correctness( assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol) if return_z_loss: assert_verbose_allclose(z_output1, z_output2, atol=atol, rtol=rtol) + if return_entropy_loss: + assert_verbose_allclose(e_output1, e_output2, atol=atol, rtol=rtol) + + loss1 = output1 if not return_entropy_loss else output1 + e_output1 + loss2 = output2 if not return_entropy_loss else output2 + e_output2 - output1.backward(gradient=torch.ones_like(output1)) - output2.backward(gradient=torch.ones_like(output2)) + loss1.backward(gradient=torch.ones_like(output1)) + loss2.backward(gradient=torch.ones_like(output2)) assert_verbose_allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol) From 201f47e9ab64ac3b382c124d3a67c42de73a6e60 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Mon, 3 Feb 2025 04:40:20 +0000 Subject: [PATCH 14/32] revert the chanegs on unit tests Signed-off-by: Hongpeng Guo --- src/liger_kernel/ops/cross_entropy.py | 6 +----- test/transformers/test_fused_linear_cross_entropy.py | 8 ++++---- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index bf44e4fee..a4a3d7142 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -165,7 +165,7 @@ def liger_cross_entropy_kernel( softmax_X = tl.exp(X_block - m) / d # Mask for valid columns and non-zero softmax - valid_mask = (X_offsets < n_cols) & (softmax_X > 0.0) + valid_mask = X_offsets < n_cols entropy_term = tl.where(valid_mask, -softmax_X * tl.log(softmax_X), 0.0) entropy_loss += tl.sum(entropy_term) @@ -208,8 +208,6 @@ def liger_cross_entropy_kernel( # softmax(x_i) X_block = tl.exp(X_block - m) / d if RETURN_ENTROPY_LOSS: - # Mask for valid columns and non-zero softmax - valid_mask = (X_offsets < n_cols) & (X_block > 0.0) # derivatives of the entropy loss term dX_entropy_block = X_block * (-tl.log(X_block) - entropy_loss) # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i) @@ -227,8 +225,6 @@ def liger_cross_entropy_kernel( weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) softmax_X = tl.exp(X_block - m) / d if RETURN_ENTROPY_LOSS: - # Mask for valid columns and non-zero softmax - valid_mask = (X_offsets < n_cols) & (softmax_X > 0.0) # derititive of the entropy loss dX_entropy_block = softmax_X * (-tl.log(softmax_X) - entropy_loss) # derivative of original_loss diff --git a/test/transformers/test_fused_linear_cross_entropy.py b/test/transformers/test_fused_linear_cross_entropy.py index 99db98f47..d8963e08e 100644 --- a/test/transformers/test_fused_linear_cross_entropy.py +++ b/test/transformers/test_fused_linear_cross_entropy.py @@ -109,7 +109,7 @@ def forward(self, x, y): @pytest.mark.parametrize( "B, T, H, V", [ - # pytest.param(8, 128, 1024, 4096, marks=pytest.mark.skipif(device="xpu", reason="skip for XPU")), + pytest.param(8, 128, 1024, 4096, marks=pytest.mark.skipif(device="xpu", reason="skip for XPU")), (4, 47, 31, 123), # random shape ], ) @@ -124,16 +124,16 @@ def forward(self, x, y): ("none", 1.0, torch.float32, 1e-3, 5e-2), ], ) -@pytest.mark.parametrize("bias", [False]) +@pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize( "has_ce_weight, label_smoothing, ignore_index, lse_square_scale, softcap, return_z_loss", [ - (False, 0, -100, 0, None, False), + (True, 0, -100, 0, None, False), # Pass non-default values once to ensure all params work along (False, 0.1, 42, 1e-4, 30.0, True), ], ) -@pytest.mark.parametrize("return_entropy_loss", [True]) +@pytest.mark.parametrize("return_entropy_loss", [True, False]) def test_correctness( B, T, From 38c5d4438d8eb390df229709e3ce891eb264929d Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Mon, 3 Feb 2025 05:26:06 +0000 Subject: [PATCH 15/32] improve ce unit test Signed-off-by: Hongpeng Guo --- src/liger_kernel/ops/cross_entropy.py | 8 +++++--- test/transformers/test_cross_entropy.py | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index a4a3d7142..9b97b2f57 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -203,13 +203,15 @@ def liger_cross_entropy_kernel( mask=X_offsets < n_cols, other=0.0, ) - + # valid mask for the entropy loss + valid_mask = X_offsets < n_cols + if not HAS_WEIGHT: # softmax(x_i) X_block = tl.exp(X_block - m) / d if RETURN_ENTROPY_LOSS: # derivatives of the entropy loss term - dX_entropy_block = X_block * (-tl.log(X_block) - entropy_loss) + dX_entropy_block = tl.where(valid_mask, X_block * (-tl.log(X_block) - entropy_loss), 0.0) # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i) X_block += 2 * lse_square_scale * lse * X_block # smoothing term @@ -220,7 +222,7 @@ def liger_cross_entropy_kernel( if reduction == "mean": X_block = X_block / n_non_ignore if RETURN_ENTROPY_LOSS: - dX_entropy_block = dX_entropy_block / n_non_ignore + dX_entropy_block = tl.where(valid_mask, dX_entropy_block / n_non_ignore, 0.0) else: weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) softmax_X = tl.exp(X_block - m) / d diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index 697d6cd40..213bacfe2 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -456,8 +456,8 @@ def _test_correctness_with_entropy_loss_with_other_params_once( if return_entropy_loss: assert torch.allclose(entropy_output, entropy_output2, atol=atol, rtol=rtol) - loss1 = output + entropy_output if return_entropy_loss else output - loss2 = output2 + entropy_output2 if return_entropy_loss else output2 + loss1 = 2 * output + 3 * entropy_output if return_entropy_loss else 2 * output + loss2 = 2 * output2 + 3 * entropy_output2 if return_entropy_loss else 2 * output2 loss1.backward() loss2.backward() From 96c31920dbab13c94f2fd8453b53fa35a69b8759 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Mon, 3 Feb 2025 05:26:21 +0000 Subject: [PATCH 16/32] improve ce unit test Signed-off-by: Hongpeng Guo --- src/liger_kernel/ops/cross_entropy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 9b97b2f57..94d13bf06 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -205,7 +205,7 @@ def liger_cross_entropy_kernel( ) # valid mask for the entropy loss valid_mask = X_offsets < n_cols - + if not HAS_WEIGHT: # softmax(x_i) X_block = tl.exp(X_block - m) / d From af8488005d9c2e782268041457bd30c91f9fc5a5 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Mon, 3 Feb 2025 09:48:17 +0000 Subject: [PATCH 17/32] handle comments partial Signed-off-by: Hongpeng Guo --- src/liger_kernel/ops/cross_entropy.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 94d13bf06..c04e247c5 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -206,12 +206,13 @@ def liger_cross_entropy_kernel( # valid mask for the entropy loss valid_mask = X_offsets < n_cols + softmax_X = tl.exp(X_block - m) / d if not HAS_WEIGHT: # softmax(x_i) - X_block = tl.exp(X_block - m) / d + X_block = softmax_X if RETURN_ENTROPY_LOSS: # derivatives of the entropy loss term - dX_entropy_block = tl.where(valid_mask, X_block * (-tl.log(X_block) - entropy_loss), 0.0) + dX_entropy_block = X_block * (-tl.log(X_block) - entropy_loss) # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i) X_block += 2 * lse_square_scale * lse * X_block # smoothing term @@ -222,10 +223,9 @@ def liger_cross_entropy_kernel( if reduction == "mean": X_block = X_block / n_non_ignore if RETURN_ENTROPY_LOSS: - dX_entropy_block = tl.where(valid_mask, dX_entropy_block / n_non_ignore, 0.0) + dX_entropy_block = dX_entropy_block / n_non_ignore else: weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) - softmax_X = tl.exp(X_block - m) / d if RETURN_ENTROPY_LOSS: # derititive of the entropy loss dX_entropy_block = softmax_X * (-tl.log(softmax_X) - entropy_loss) @@ -299,8 +299,10 @@ def liger_cross_entropy_kernel( loss = loss / n_non_ignore # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. z_loss = z_loss / n_non_ignore - # TODO: Implement weighted entropy loss. Currently, entropy loss is not scaled by weight. - entropy_loss = entropy_loss / n_non_ignore + if HAS_WEIGHT: + entropy_loss = entropy_loss / sum_non_ignore_weight + else: + entropy_loss = entropy_loss / n_non_ignore loss += z_loss From 4307e37462139370ea9bdd68cec115ac569611b0 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Mon, 3 Feb 2025 01:49:40 -0800 Subject: [PATCH 18/32] Update src/liger_kernel/ops/cross_entropy.py Co-authored-by: Qingquan Song --- src/liger_kernel/ops/cross_entropy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index c04e247c5..3137da006 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -54,7 +54,7 @@ def liger_cross_entropy_kernel( Parameters: X_ptr: Pointer to input tensor. - dX_entropy_ptr: Pointer to tensor to store the gradient of the input w.r.s. to the entropy loss + dX_entropy_ptr: Pointer to tensor to store the gradient of the input w.r.t the entropy loss X_stride (int): The stride of the input tensor. Y_ptr: Pointer to target tensor. Y_stride (int): The stride of the target tensor. From 8d65866536f3b646f464fbc2d01621669738927f Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Mon, 3 Feb 2025 09:51:03 +0000 Subject: [PATCH 19/32] fix typo Signed-off-by: Hongpeng Guo --- src/liger_kernel/ops/cross_entropy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 3137da006..8cc1876c7 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -430,7 +430,7 @@ def cross_entropy_backward(_input, dX_entropy_2d, grad_output, grad_output_entro BLOCK_SIZE=BLOCK_SIZE, num_warps=32 if not is_hip() else 16, ) - # calculate the gradient of the input w.r.s. to the entropy loss + # calculate the gradient of the input w.r.t the entropy loss if dX_entropy_2d is not None: element_mul_kernel[(n_rows,)]( dX_entropy_2d, From 4c970428e01a27844bfd27e04795601da722e8b5 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Thu, 6 Feb 2025 02:21:05 +0000 Subject: [PATCH 20/32] fix bug in softcap and ce weight confusion Signed-off-by: Hongpeng Guo --- src/liger_kernel/ops/cross_entropy.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 8cc1876c7..60f4dee5b 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -206,13 +206,13 @@ def liger_cross_entropy_kernel( # valid mask for the entropy loss valid_mask = X_offsets < n_cols - softmax_X = tl.exp(X_block - m) / d + softmax_X = tl.exp(X_block - m) / d + if RETURN_ENTROPY_LOSS: + # derivatives of the entropy loss term + dX_entropy_block = softmax_X * (-tl.log(softmax_X) - entropy_loss) if not HAS_WEIGHT: # softmax(x_i) X_block = softmax_X - if RETURN_ENTROPY_LOSS: - # derivatives of the entropy loss term - dX_entropy_block = X_block * (-tl.log(X_block) - entropy_loss) # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i) X_block += 2 * lse_square_scale * lse * X_block # smoothing term @@ -226,9 +226,6 @@ def liger_cross_entropy_kernel( dX_entropy_block = dX_entropy_block / n_non_ignore else: weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) - if RETURN_ENTROPY_LOSS: - # derititive of the entropy loss - dX_entropy_block = softmax_X * (-tl.log(softmax_X) - entropy_loss) # derivative of original_loss dloss_ori = (1 - label_smoothing) * softmax_X # specially handle dx_y @@ -245,7 +242,8 @@ def liger_cross_entropy_kernel( # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. dz_loss = dz_loss / n_non_ignore if RETURN_ENTROPY_LOSS: - dX_entropy_block = dX_entropy_block / sum_non_ignore_weight + # Note that the weight is only applied to ce loss, not for entropy loss. + dX_entropy_block = dX_entropy_block / n_non_ignore # derivative of total_loss X_block = dloss_ori + dloss_smooth + dz_loss @@ -253,6 +251,8 @@ def liger_cross_entropy_kernel( # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap)) if HAS_SOFTCAPPING: X_block = X_block * (1 - intermediate * intermediate) + if RETURN_ENTROPY_LOSS: + dX_entropy_block = dX_entropy_block * (1 - intermediate * intermediate) tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols) if RETURN_ENTROPY_LOSS: @@ -299,10 +299,8 @@ def liger_cross_entropy_kernel( loss = loss / n_non_ignore # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. z_loss = z_loss / n_non_ignore - if HAS_WEIGHT: - entropy_loss = entropy_loss / sum_non_ignore_weight - else: - entropy_loss = entropy_loss / n_non_ignore + # Note that the weight is only applied to ce loss, not for entropy loss. + entropy_loss = entropy_loss / n_non_ignore loss += z_loss From 74d0f0ea7d57ec6910c923fdce5f95ff15202615 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Thu, 6 Feb 2025 02:21:16 +0000 Subject: [PATCH 21/32] fix bug in softcap and ce weight confusion Signed-off-by: Hongpeng Guo --- src/liger_kernel/ops/cross_entropy.py | 2 +- test/chunked_loss/test_dpo_loss.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 60f4dee5b..6aa0da6a5 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -206,7 +206,7 @@ def liger_cross_entropy_kernel( # valid mask for the entropy loss valid_mask = X_offsets < n_cols - softmax_X = tl.exp(X_block - m) / d + softmax_X = tl.exp(X_block - m) / d if RETURN_ENTROPY_LOSS: # derivatives of the entropy loss term dX_entropy_block = softmax_X * (-tl.log(softmax_X) - entropy_loss) diff --git a/test/chunked_loss/test_dpo_loss.py b/test/chunked_loss/test_dpo_loss.py index be2ccc36c..25bb58710 100644 --- a/test/chunked_loss/test_dpo_loss.py +++ b/test/chunked_loss/test_dpo_loss.py @@ -56,7 +56,7 @@ def alignment_loss( chosen_logratios = policy_chosen_logps - ref_chosen_logps rejected_logratios = policy_rejected_logps - ref_rejected_logps - chosen_rewards = self.beta * (policy_chosen_logps - ref_chosen_logps) + chosen_rewards = self.beta * (policy_chosen_logps - ref_chosen_logps) rejected_rewards = self.beta * (policy_rejected_logps - ref_rejected_logps) logits_diff = self.beta * (chosen_logratios - rejected_logratios) From 800599923fd39100ce966b87991342ee4b0d45bb Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Thu, 6 Feb 2025 02:34:14 +0000 Subject: [PATCH 22/32] bisec unittes to test on ci Signed-off-by: Hongpeng Guo --- test/transformers/test_cross_entropy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index 213bacfe2..1e3510cf0 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -927,7 +927,7 @@ def test_correctness_not_last_layer(B, T, V, reduction, scalar, dtype, atol, rto @pytest.mark.parametrize( "B, T, V", [ - (2, 1024, 32000), # llama2, mistral + (2, 4096, 32000), # llama2, mistral # weird shapes (3, 423, 32000), ], @@ -956,7 +956,7 @@ def test_correctness_not_last_layer(B, T, V, reduction, scalar, dtype, atol, rto "label_smoothing, ignore_index, reduction", [ (0.1, 42, "mean"), - (0.2, -42, "sum"), + (0.2, -42, "mean"), ], ) @pytest.mark.parametrize("return_entropy_loss", [True]) From e341aeabf6e0dfb4f948db316e275d0044723760 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Thu, 6 Feb 2025 02:43:03 +0000 Subject: [PATCH 23/32] refactor code Signed-off-by: Hongpeng Guo --- src/liger_kernel/ops/cross_entropy.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 6aa0da6a5..bdbc09836 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -207,9 +207,14 @@ def liger_cross_entropy_kernel( valid_mask = X_offsets < n_cols softmax_X = tl.exp(X_block - m) / d + if RETURN_ENTROPY_LOSS: # derivatives of the entropy loss term dX_entropy_block = softmax_X * (-tl.log(softmax_X) - entropy_loss) + # Note that the weight is only applied to ce loss, not for entropy loss. + if reduction == "mean": + dX_entropy_block = dX_entropy_block / n_non_ignore + if not HAS_WEIGHT: # softmax(x_i) X_block = softmax_X @@ -222,8 +227,6 @@ def liger_cross_entropy_kernel( # reduction scale if reduction == "mean": X_block = X_block / n_non_ignore - if RETURN_ENTROPY_LOSS: - dX_entropy_block = dX_entropy_block / n_non_ignore else: weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) # derivative of original_loss @@ -241,9 +244,6 @@ def liger_cross_entropy_kernel( dloss_smooth = dloss_smooth / sum_non_ignore_weight # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. dz_loss = dz_loss / n_non_ignore - if RETURN_ENTROPY_LOSS: - # Note that the weight is only applied to ce loss, not for entropy loss. - dX_entropy_block = dX_entropy_block / n_non_ignore # derivative of total_loss X_block = dloss_ori + dloss_smooth + dz_loss From ced5709b8aa46c7ed0c9bc93912f20d4376bb8d6 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Thu, 6 Feb 2025 02:56:27 +0000 Subject: [PATCH 24/32] revert changes to unit tests Signed-off-by: Hongpeng Guo --- test/transformers/test_cross_entropy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index 1e3510cf0..cf7b9e772 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -956,7 +956,7 @@ def test_correctness_not_last_layer(B, T, V, reduction, scalar, dtype, atol, rto "label_smoothing, ignore_index, reduction", [ (0.1, 42, "mean"), - (0.2, -42, "mean"), + (0.2, -42, "sum"), ], ) @pytest.mark.parametrize("return_entropy_loss", [True]) From c1d36e643d1ba52bee4b470f9c13af2052da12c0 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Thu, 6 Feb 2025 03:18:41 +0000 Subject: [PATCH 25/32] change a new way calculating entropy Signed-off-by: Hongpeng Guo --- src/liger_kernel/ops/cross_entropy.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index bdbc09836..cdaaa916b 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -151,6 +151,7 @@ def liger_cross_entropy_kernel( # 3.5 Calculate the entropy loss if RETURN_ENTROPY_LOSS: + sum_p_x = 0.0 # sum of softmax(x_i) * x_i for i in range(0, n_cols, BLOCK_SIZE): X_offsets = i + tl.arange(0, BLOCK_SIZE) X_block = tl.load( @@ -164,10 +165,10 @@ def liger_cross_entropy_kernel( X_block = softcap * intermediate softmax_X = tl.exp(X_block - m) / d - # Mask for valid columns and non-zero softmax - valid_mask = X_offsets < n_cols - entropy_term = tl.where(valid_mask, -softmax_X * tl.log(softmax_X), 0.0) - entropy_loss += tl.sum(entropy_term) + # Cumulate the sum of softmax(x_i) * x_i + sum_p_x += tl.sum(tl.where(X_offsets < n_cols, softmax_X * X_block, 0.0)) + + entropy_loss = lse - sum_p_x # 4. [Online Softmax] Second pass: compute gradients # For 'mean' reduction, gradients are normalized by number of non-ignored elements (N) @@ -203,11 +204,9 @@ def liger_cross_entropy_kernel( mask=X_offsets < n_cols, other=0.0, ) - # valid mask for the entropy loss - valid_mask = X_offsets < n_cols + # Calculate the softmax of the input softmax_X = tl.exp(X_block - m) / d - if RETURN_ENTROPY_LOSS: # derivatives of the entropy loss term dX_entropy_block = softmax_X * (-tl.log(softmax_X) - entropy_loss) From b1053a381f6d1261ef2eb2df3630ecb68eaf01cf Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Thu, 6 Feb 2025 03:35:35 +0000 Subject: [PATCH 26/32] make deriv stable Signed-off-by: Hongpeng Guo --- src/liger_kernel/ops/cross_entropy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index cdaaa916b..e58e91d8e 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -209,7 +209,7 @@ def liger_cross_entropy_kernel( softmax_X = tl.exp(X_block - m) / d if RETURN_ENTROPY_LOSS: # derivatives of the entropy loss term - dX_entropy_block = softmax_X * (-tl.log(softmax_X) - entropy_loss) + dX_entropy_block = softmax_X * (m - X_block + tl.log(d) - entropy_loss) # Note that the weight is only applied to ce loss, not for entropy loss. if reduction == "mean": dX_entropy_block = dX_entropy_block / n_non_ignore From 7af2fe360117374bd4726542ca2a9258d32d2039 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Thu, 6 Feb 2025 03:45:45 +0000 Subject: [PATCH 27/32] bisect unitets Signed-off-by: Hongpeng Guo --- test/transformers/test_cross_entropy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index cf7b9e772..d360425f6 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -938,7 +938,7 @@ def test_correctness_not_last_layer(B, T, V, reduction, scalar, dtype, atol, rto pytest.param( 1.0, torch.bfloat16, - 1e-8, + 1e-7, 5e-2, marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), From 6162e88d0484a151ba0af22da6834e9055a88ea0 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Thu, 6 Feb 2025 04:09:00 +0000 Subject: [PATCH 28/32] fix wip Signed-off-by: Hongpeng Guo --- src/liger_kernel/ops/cross_entropy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index e58e91d8e..50c0491a6 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -209,7 +209,7 @@ def liger_cross_entropy_kernel( softmax_X = tl.exp(X_block - m) / d if RETURN_ENTROPY_LOSS: # derivatives of the entropy loss term - dX_entropy_block = softmax_X * (m - X_block + tl.log(d) - entropy_loss) + dX_entropy_block = tl.where(X_offsets < n_cols, softmax_X * (m - X_block + tl.log(d) - entropy_loss), 0.0) # Note that the weight is only applied to ce loss, not for entropy loss. if reduction == "mean": dX_entropy_block = dX_entropy_block / n_non_ignore From 02fd7785bae77f6b46c150cd9fff4c49025284c3 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Thu, 6 Feb 2025 04:47:22 +0000 Subject: [PATCH 29/32] try to make it numerical stable Signed-off-by: Hongpeng Guo --- src/liger_kernel/ops/cross_entropy.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 50c0491a6..2d5cfd350 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -166,7 +166,7 @@ def liger_cross_entropy_kernel( softmax_X = tl.exp(X_block - m) / d # Cumulate the sum of softmax(x_i) * x_i - sum_p_x += tl.sum(tl.where(X_offsets < n_cols, softmax_X * X_block, 0.0)) + sum_p_x += tl.sum(tl.where(X_offsets < n_cols, tl.math.fma(softmax_X, X_block, 0.0), 0.0)) entropy_loss = lse - sum_p_x @@ -209,7 +209,8 @@ def liger_cross_entropy_kernel( softmax_X = tl.exp(X_block - m) / d if RETURN_ENTROPY_LOSS: # derivatives of the entropy loss term - dX_entropy_block = tl.where(X_offsets < n_cols, softmax_X * (m - X_block + tl.log(d) - entropy_loss), 0.0) + log_softmax_X_plus_entropy = X_block - m - tl.log(d) + entropy_loss + dX_entropy_block = tl.math.fma(softmax_X, -log_softmax_X_plus_entropy, 0.0) # Note that the weight is only applied to ce loss, not for entropy loss. if reduction == "mean": dX_entropy_block = dX_entropy_block / n_non_ignore From 7f53b5962a3953db18c20bf16c5b29f97d497dd1 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Thu, 6 Feb 2025 07:09:34 +0000 Subject: [PATCH 30/32] wip another Signed-off-by: Hongpeng Guo --- src/liger_kernel/ops/cross_entropy.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 2d5cfd350..651e1c846 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -209,8 +209,7 @@ def liger_cross_entropy_kernel( softmax_X = tl.exp(X_block - m) / d if RETURN_ENTROPY_LOSS: # derivatives of the entropy loss term - log_softmax_X_plus_entropy = X_block - m - tl.log(d) + entropy_loss - dX_entropy_block = tl.math.fma(softmax_X, -log_softmax_X_plus_entropy, 0.0) + dX_entropy_block = softmax_X * sum_p_x - softmax_X * X_block # Note that the weight is only applied to ce loss, not for entropy loss. if reduction == "mean": dX_entropy_block = dX_entropy_block / n_non_ignore From 62d2ca3449a8826298d775bb9b277f4b2a4d8366 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Thu, 6 Feb 2025 07:12:12 +0000 Subject: [PATCH 31/32] revert a unittest Signed-off-by: Hongpeng Guo --- src/liger_kernel/ops/cross_entropy.py | 2 +- test/transformers/test_cross_entropy.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 651e1c846..2efc8e088 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -203,7 +203,7 @@ def liger_cross_entropy_kernel( dX_entropy_ptr + X_offsets, mask=X_offsets < n_cols, other=0.0, - ) + ).cast(tl.float32) # Calculate the softmax of the input softmax_X = tl.exp(X_block - m) / d diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index d360425f6..cf7b9e772 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -938,7 +938,7 @@ def test_correctness_not_last_layer(B, T, V, reduction, scalar, dtype, atol, rto pytest.param( 1.0, torch.bfloat16, - 1e-7, + 1e-8, 5e-2, marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), From 0d6487c0fadf02b757261d0c9b94c0a6dbfc7fa0 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Thu, 6 Feb 2025 17:39:13 +0000 Subject: [PATCH 32/32] update unittest Signed-off-by: Hongpeng Guo --- src/liger_kernel/ops/cross_entropy.py | 31 ++++++++++++++----------- test/transformers/test_cross_entropy.py | 4 ++-- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 2efc8e088..3a863994d 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -150,8 +150,8 @@ def liger_cross_entropy_kernel( lse = m + tl.log(d) # 3.5 Calculate the entropy loss + sum_p_x = 0.0 # sum of softmax(x_i) * x_i if RETURN_ENTROPY_LOSS: - sum_p_x = 0.0 # sum of softmax(x_i) * x_i for i in range(0, n_cols, BLOCK_SIZE): X_offsets = i + tl.arange(0, BLOCK_SIZE) X_block = tl.load( @@ -166,7 +166,7 @@ def liger_cross_entropy_kernel( softmax_X = tl.exp(X_block - m) / d # Cumulate the sum of softmax(x_i) * x_i - sum_p_x += tl.sum(tl.where(X_offsets < n_cols, tl.math.fma(softmax_X, X_block, 0.0), 0.0)) + sum_p_x += tl.sum(tl.where(X_offsets < n_cols, softmax_X * X_block, 0.0)) entropy_loss = lse - sum_p_x @@ -197,6 +197,9 @@ def liger_cross_entropy_kernel( intermediate = tanh(X_block / softcap) X_block = softcap * intermediate + # Calculate the softmax of the input + softmax_X = tl.exp(X_block - m) / d + # load the derivatives of the entropy loss if RETURN_ENTROPY_LOSS: dX_entropy_block = tl.load( @@ -205,9 +208,6 @@ def liger_cross_entropy_kernel( other=0.0, ).cast(tl.float32) - # Calculate the softmax of the input - softmax_X = tl.exp(X_block - m) / d - if RETURN_ENTROPY_LOSS: # derivatives of the entropy loss term dX_entropy_block = softmax_X * sum_p_x - softmax_X * X_block # Note that the weight is only applied to ce loss, not for entropy loss. @@ -429,15 +429,16 @@ def cross_entropy_backward(_input, dX_entropy_2d, grad_output, grad_output_entro ) # calculate the gradient of the input w.r.t the entropy loss if dX_entropy_2d is not None: - element_mul_kernel[(n_rows,)]( - dX_entropy_2d, - dX_entropy_2d.stride(-2), - grad_output_entropy, - V, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=32 if not is_hip() else 16, - ) - _input += dX_entropy_2d + if not torch.equal(grad_output_entropy, torch.tensor(1.0, device=grad_output_entropy.device)): + element_mul_kernel[(n_rows,)]( + dX_entropy_2d, + dX_entropy_2d.stride(-2), + grad_output_entropy, + V, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) + _input = dX_entropy_2d + _input return _input @@ -525,6 +526,8 @@ def backward(ctx, grad_output, grad_ouput2, grad_ouput3): else: (_input,), dX_entropy_2d = ctx.saved_tensors, None + print(grad_output, grad_ouput3) + _input = cross_entropy_backward(_input, dX_entropy_2d, grad_output, grad_ouput3) # delete the tensors that are not used in remaining steps diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index cf7b9e772..1a79f4ad5 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -456,8 +456,8 @@ def _test_correctness_with_entropy_loss_with_other_params_once( if return_entropy_loss: assert torch.allclose(entropy_output, entropy_output2, atol=atol, rtol=rtol) - loss1 = 2 * output + 3 * entropy_output if return_entropy_loss else 2 * output - loss2 = 2 * output2 + 3 * entropy_output2 if return_entropy_loss else 2 * output2 + loss1 = 3.0989439 * entropy_output if return_entropy_loss else 3.0 * output + loss2 = 3.0989439 * entropy_output2 if return_entropy_loss else 3.0 * output2 loss1.backward() loss2.backward()