Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Softmax gives inf values when performed on 1D vector #228

Open
umalesTT opened this issue Feb 4, 2025 · 2 comments
Open

Softmax gives inf values when performed on 1D vector #228

umalesTT opened this issue Feb 4, 2025 · 2 comments
Labels
bug Something isn't working models blocker Issues blocking models bringup

Comments

@umalesTT
Copy link
Contributor

umalesTT commented Feb 4, 2025

Related to #167:

MLIR workaround tenstorrent/tt-mlir@3a0fd0e enabled current softmax tests to pass (see #215). However during MNIST training with batch_size = 1, softmax outputs turned out to be inf.

I confirmed this by adding 1D vector to test_softmax where it expectedly fails (while other 4 non-1D tests pass).

@umalesTT umalesTT added bug Something isn't working models blocker Issues blocking models bringup labels Feb 4, 2025
@umalesTT
Copy link
Contributor Author

umalesTT commented Feb 4, 2025

For more info, below are stablehlo graphs for 2 cases that fail.

x_shape = (1,32) and axis = 1

module @jit_apply_softmax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<1x32xf32>) -> (tensor<1x32xf32> {jax.result_info = ""}) {
    %cst = stablehlo.constant dense<0xFF800000> : tensor<f32>
    %0 = stablehlo.reduce(%arg0 init: %cst) applies stablehlo.maximum across dimensions = [1] : (tensor<1x32xf32>, tensor<f32>) -> tensor<1xf32>
    %cst_0 = stablehlo.constant dense<0xFF800000> : tensor<f32>
    %1 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %2 = stablehlo.maximum %1, %0 : tensor<1xf32>
    %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<1xf32>) -> tensor<1x1xf32>
    %4 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<1x1xf32>) -> tensor<1x32xf32>
    %5 = stablehlo.subtract %arg0, %4 : tensor<1x32xf32>
    %6 = stablehlo.exponential %5 : tensor<1x32xf32>
    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %7 = stablehlo.reduce(%6 init: %cst_1) applies stablehlo.add across dimensions = [1] : (tensor<1x32xf32>, tensor<f32>) -> tensor<1xf32>
    %8 = stablehlo.broadcast_in_dim %7, dims = [0] : (tensor<1xf32>) -> tensor<1x1xf32>
    %9 = stablehlo.broadcast_in_dim %8, dims = [0, 1] : (tensor<1x1xf32>) -> tensor<1x32xf32>
    %10 = stablehlo.divide %6, %9 : tensor<1x32xf32>
    return %10 : tensor<1x32xf32>
  }
}

x_shape = (32, ) and axis = 0

module @jit_apply_softmax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<32xf32>) -> (tensor<32xf32> {jax.result_info = ""}) {
    %cst = stablehlo.constant dense<0xFF800000> : tensor<f32>
    %0 = stablehlo.reduce(%arg0 init: %cst) applies stablehlo.maximum across dimensions = [0] : (tensor<32xf32>, tensor<f32>) -> tensor<f32>
    %cst_0 = stablehlo.constant dense<0xFF800000> : tensor<f32>
    %1 = stablehlo.maximum %cst_0, %0 : tensor<f32>
    %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<1xf32>) -> tensor<32xf32>
    %4 = stablehlo.subtract %arg0, %3 : tensor<32xf32>
    %5 = stablehlo.exponential %4 : tensor<32xf32>
    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %6 = stablehlo.reduce(%5 init: %cst_1) applies stablehlo.add across dimensions = [0] : (tensor<32xf32>, tensor<f32>) -> tensor<f32>
    %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %8 = stablehlo.broadcast_in_dim %7, dims = [0] : (tensor<1xf32>) -> tensor<32xf32>
    %9 = stablehlo.divide %5, %8 : tensor<32xf32>
    return %9 : tensor<32xf32>
  }
}

Hopefully, once tenstorrent/tt-metal#17270 is merged, there won't be any need for this workaround, but we should keep in mind this flaw until then.

@sgligorijevicTT
Copy link
Contributor

tenstorrent/tt-metal#17270 was merged, this should be fixed after uplifting

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working models blocker Issues blocking models bringup
Projects
None yet
Development

No branches or pull requests

2 participants