Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Q][GPU][BF16] torch.mul is lowered to HLO as an f32 multiply #8545

Open
apivovarov opened this issue Jan 8, 2025 · 1 comment
Open

[Q][GPU][BF16] torch.mul is lowered to HLO as an f32 multiply #8545

apivovarov opened this issue Jan 8, 2025 · 1 comment

Comments

@apivovarov
Copy link
Contributor

apivovarov commented Jan 8, 2025

❓ Questions and Help

torch 2.5.1
torch_xla 2.5.1
cuda 12.4
GPU NVIDIA L4

The following example uses torch.mul where both operands are bf16, but in the HLO graph, I see an f32 multiply operation.

export XLA_FLAGS="--xla_dump_to=/tmp/dump --xla_dump_hlo_as_text --xla_dump_hlo_pass_re=.*"
import torch
import torch_xla as xla

device = xla.device(0)

def foo(a, b):
  y = torch.mul(a, b)
  return y

a = torch.ones([5, 9216, 64], dtype=torch.bfloat16, device=device)
b = torch.ones([5, 9216, 64], dtype=torch.bfloat16, device=device)

y = foo(a, b)
print(y)

hlo: module_0000.SyncTensorsGraph.16.before_optimizations.txt

HloModule SyncTensorsGraph.16, entry_computation_layout={()->(bf16[5,9216,64]{2,1,0})}

ENTRY SyncTensorsGraph.16 {
  constant.7 = bf16[] constant(1)
  reshape.8 = bf16[1,1,1]{2,1,0} reshape(constant.7)
  broadcast.9 = bf16[1,1,1]{2,1,0} broadcast(reshape.8), dimensions={0,1,2}
  reshape.10 = bf16[] reshape(broadcast.9)
  broadcast.11 = bf16[5,9216,64]{2,1,0} broadcast(reshape.10), dimensions={}
  convert.12 = f32[5,9216,64]{2,1,0} convert(broadcast.11)
  constant.1 = bf16[] constant(1)
  reshape.2 = bf16[1,1,1]{2,1,0} reshape(constant.1)
  broadcast.3 = bf16[1,1,1]{2,1,0} broadcast(reshape.2), dimensions={0,1,2}
  reshape.4 = bf16[] reshape(broadcast.3)
  broadcast.5 = bf16[5,9216,64]{2,1,0} broadcast(reshape.4), dimensions={}
  convert.6 = f32[5,9216,64]{2,1,0} convert(broadcast.5)
  multiply.13 = f32[5,9216,64]{2,1,0} multiply(convert.12, convert.6)
  convert.14 = bf16[5,9216,64]{2,1,0} convert(multiply.13)
  ROOT tuple.15 = (bf16[5,9216,64]{2,1,0}) tuple(convert.14)
} // SyncTensorsGraph.16

I was able to achieve bf16 multiplication by setting export XLA_USE_BF16=1, but I received the following warning

XLA_USE_BF16 will be deprecated after the 2.5 release, please convert your model to bf16 directly

I'm not sure how I can enable bf16 multiplication in HLO (High-Level Optimizer) in the correct way, without using the deprecated flag.

@apivovarov apivovarov changed the title [GPU][BF16] torch.mul is lowered to hlo as f32 multiply [GPU][BF16] torch.mul is lowered to HLO as an f32 multiply Jan 8, 2025
@apivovarov apivovarov changed the title [GPU][BF16] torch.mul is lowered to HLO as an f32 multiply [Q][GPU][BF16] torch.mul is lowered to HLO as an f32 multiply Jan 8, 2025
@avizon-aws
Copy link
Collaborator

avizon-aws commented Jan 13, 2025

I tried the same thing using autocast, and it seems to be working as you expect. Below is the code to replicate.

a = torch.ones([5, 9216, 64], dtype=torch.bfloat16, device=xm.xla_device())
b = torch.ones([5, 9216, 64], dtype=torch.bfloat16, device=xm.xla_device())

with torch.autocast(device_type='xla', dtype=torch.bfloat16):
    y=torch.mul(a, b)

xm.mark_step()

Below is the HLO:

HloModule SyncTensorsGraph.16, entry_computation_layout={()->(bf16[5,9216,64]{2,1,0}, bf16[5,9216,64]{2,1,0}, bf16[5,9216,64]{2,1,0})}

ENTRY %SyncTensorsGraph.16 () -> (bf16[5,9216,64], bf16[5,9216,64], bf16[5,9216,64]) {
  %constant.1 = bf16[] constant(1)
  %broadcast.5 = bf16[5,9216,64]{2,1,0} broadcast(bf16[] %constant.1), dimensions={}
  %constant.6 = bf16[] constant(1)
  %broadcast.10 = bf16[5,9216,64]{2,1,0} broadcast(bf16[] %constant.6), dimensions={}
  %constant.7 = bf16[] constant(1)
  %broadcast.12 = bf16[5,9216,64]{2,1,0} broadcast(bf16[] %constant.7), dimensions={}
  ROOT %tuple.15 = (bf16[5,9216,64]{2,1,0}, bf16[5,9216,64]{2,1,0}, bf16[5,9216,64]{2,1,0}) tuple(bf16[5,9216,64]{2,1,0} %broadcast.5, bf16[5,9216,64]{2,1,0} %broadcast.10, bf16[5,9216,64]{2,1,0} %broadcast.12), frontend_attributes={neff_output_names="output0,output1,output2"}
}

Flags:
export XLA_DOWNCAST_BF16=0
export XLA_USE_BF16=0

Can you try replicating and confirm.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants