Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cat+BN+Relu inference is very slow on Max 1550 #3576

Open
jianyizh opened this issue Feb 28, 2025 · 2 comments
Open

Cat+BN+Relu inference is very slow on Max 1550 #3576

jianyizh opened this issue Feb 28, 2025 · 2 comments

Comments

@jianyizh
Copy link
Contributor

jianyizh commented Feb 28, 2025

The following python code is from dpn107 in pytorch dynamo benchmark. it takes 6.32ms (142.08GB/s) on max 1550. I think this is just an elementwise kernel, it should use max bandwidth.
I'm using main branch 5d23561 with pytorch main branch 91e7c7945cef141c4796e216e535ba8ac0da2609

import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

from torch._dynamo.testing import rand_strided
from torch._C import _xpu_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid

@triton_heuristics.pointwise(
    size_hints={'y': 32768, 'x': 4096}, tile_hint=TileHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*fp32', 'in_ptr7': '*fp32', 'in_ptr8': '*bf16', 'in_ptr9': '*bf16', 'in_ptr10': '*fp32', 'in_ptr11': '*fp32', 'in_ptr12': '*bf16', 'in_ptr13': '*bf16', 'out_ptr1': '*bf16', 'out_ptr2': '*bf16', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='xpu', index=0, multi_processor_count=64, cc={'architecture': 13136561920, 'driver_version': '1.6.32224+14', 'gpu_eu_count': 512, 'gpu_subslice_count': 64, 'has_atomic64': True, 'has_bfloat16_conversions': True, 'has_fp16': True, 'has_fp64': True, 'has_subgroup_2d_block_io': True, 'has_subgroup_matrix_multiply_accumulate': True, 'has_subgroup_matrix_multiply_accumulate_tensor_float32': False, 'max_compute_units': 512, 'max_num_sub_groups': 64, 'max_work_group_size': 1024, 'name': 'Intel(R) Data Center GPU Max 1550', 'platform_name': 'Intel(R) oneAPI Unified Runtime over Level-Zero', 'sub_group_sizes': [16, 32], 'total_memory': 68719476736, 'type': 'gpu', 'vendor': 'Intel(R) Corporation', 'version': '12.60.7'}, major=None, regs_per_multiprocessor=None, max_threads_per_multi_processor=None, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__native_batch_norm_legit_no_training_cat_relu_8', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 15, 'num_reduction': 0, 'backend_hash': 'B6A4CE6841FCE5A1FE328285E92D2FF3F693DDFFF2EBC3E85AE470CCA1CFEE31', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.90397984},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__native_batch_norm_legit_no_training_cat_relu_8(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, in_ptr10, in_ptr11, in_ptr12, in_ptr13, out_ptr1, out_ptr2, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 24064
    xnumel = 3136
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
    ymask = yindex < ynumel
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    y0 = (yindex % 376)
    x2 = xindex
    y1 = yindex // 376
    y3 = yindex
    tmp36 = tl.load(in_ptr6 + (y0), ymask, eviction_policy='evict_last')
    tmp38 = tl.load(in_ptr7 + (y0), ymask, eviction_policy='evict_last')
    tmp40 = tl.load(in_ptr8 + (y0), ymask, eviction_policy='evict_last').to(tl.float32)
    tmp43 = tl.load(in_ptr9 + (y0), ymask, eviction_policy='evict_last').to(tl.float32)
    tmp49 = tl.load(in_ptr10 + (y0), ymask, eviction_policy='evict_last')
    tmp51 = tl.load(in_ptr11 + (y0), ymask, eviction_policy='evict_last')
    tmp53 = tl.load(in_ptr12 + (y0), ymask, eviction_policy='evict_last').to(tl.float32)
    tmp56 = tl.load(in_ptr13 + (y0), ymask, eviction_policy='evict_last').to(tl.float32)
    tmp0 = y0
    tmp1 = tl.full([1, 1], 0, tl.int64)
    tmp2 = tmp0 >= tmp1
    tmp3 = tl.full([1, 1], 256, tl.int64)
    tmp4 = tmp0 < tmp3
    tmp5 = tl.load(in_ptr0 + (296*x2 + 928256*y1 + (y0)), ymask & xmask & tmp4, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (276*x2 + 865536*y1 + (y0)), ymask & xmask & tmp4, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp7 = tmp5 + tmp6
    tmp8 = tl.load(in_ptr2 + (276*x2 + 865536*y1 + (y0)), ymask & xmask & tmp4, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp9 = tmp7 + tmp8
    tmp10 = tl.load(in_ptr3 + (276*x2 + 865536*y1 + (y0)), ymask & xmask & tmp4, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp11 = tmp9 + tmp10
    tmp12 = tl.load(in_ptr4 + (276*x2 + 865536*y1 + (y0)), ymask & xmask & tmp4, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp13 = tmp11 + tmp12
    tmp14 = tl.full(tmp13.shape, 0.0, tmp13.dtype)
    tmp15 = tl.where(tmp4, tmp13, tmp14)
    tmp16 = tmp0 >= tmp3
    tmp17 = tl.full([1, 1], 376, tl.int64)
    tmp18 = tmp0 < tmp17
    tmp19 = tl.broadcast_to((-256) + y0, [XBLOCK, YBLOCK])
    tmp20 = tl.full([1, 1], 0, tl.int64)
    tmp21 = tmp19 >= tmp20
    tmp22 = tl.full([1, 1], 100, tl.int64)
    tmp23 = tmp19 < tmp22
    tmp24 = tmp23 & tmp16
    tmp25 = tl.load(in_ptr5 + (x2 + 3136*((-256) + y0) + 313600*y1), ymask & xmask & tmp24, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp26 = tmp19 >= tmp22
    tmp27 = tl.full([1, 1], 120, tl.int64)
    tmp28 = tmp19 < tmp27
    tmp29 = tmp26 & tmp16
    tmp30 = tl.load(in_ptr4 + (256 + 276*x2 + 865536*y1 + ((-100) + ((-256) + y0))), ymask & xmask & tmp29, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp31 = tl.where(tmp23, tmp25, tmp30)
    tmp32 = tl.full(tmp31.shape, 0.0, tmp31.dtype)
    tmp33 = tl.where(tmp16, tmp31, tmp32)
    tmp34 = tl.where(tmp4, tmp15, tmp33)
    tmp35 = tmp34.to(tl.float32)
    tmp37 = tmp35 - tmp36
    tmp39 = tmp37 * tmp38
    tmp41 = tmp40.to(tl.float32)
    tmp42 = tmp39 * tmp41
    tmp44 = tmp43.to(tl.float32)
    tmp45 = tmp42 + tmp44
    tmp46 = tmp45.to(tl.float32)
    tmp47 = tl.full([1, 1], 0, tl.int32)
    tmp48 = triton_helpers.maximum(tmp47, tmp46)
    tmp50 = tmp35 - tmp49
    tmp52 = tmp50 * tmp51
    tmp54 = tmp53.to(tl.float32)
    tmp55 = tmp52 * tmp54
    tmp57 = tmp56.to(tl.float32)
    tmp58 = tmp55 + tmp57
    tmp59 = tmp58.to(tl.float32)
    tmp60 = triton_helpers.maximum(tmp47, tmp59)
    tl.store(out_ptr1 + (y0 + 376*x2 + 1179136*y1), tmp48, ymask & xmask)
    tl.store(out_ptr2 + (y0 + 376*x2 + 1179136*y1), tmp60, ymask & xmask)


def get_args():
    arg_0 = rand_strided((64, 296, 56, 56), (928256, 1, 16576, 296), device='xpu:0', dtype=torch.bfloat16)
    arg_1 = rand_strided((64, 276, 56, 56), (865536, 1, 15456, 276), device='xpu:0', dtype=torch.bfloat16)
    arg_2 = rand_strided((64, 276, 56, 56), (865536, 1, 15456, 276), device='xpu:0', dtype=torch.bfloat16)
    arg_3 = rand_strided((64, 276, 56, 56), (865536, 1, 15456, 276), device='xpu:0', dtype=torch.bfloat16)
    arg_4 = rand_strided((64, 276, 56, 56), (865536, 1, 15456, 276), device='xpu:0', dtype=torch.bfloat16)
    arg_5 = rand_strided((64, 100, 56, 56), (313600, 3136, 56, 1), device='xpu:0', dtype=torch.bfloat16)
    arg_6 = rand_strided((376, 1, 1), (1, 1, 1), device='xpu:0', dtype=torch.float32)
    arg_7 = rand_strided((376, 1, 1), (1, 1, 1), device='xpu:0', dtype=torch.float32)
    arg_8 = rand_strided((376, 1, 1), (1, 1, 1), device='xpu:0', dtype=torch.bfloat16)
    arg_9 = rand_strided((376, 1, 1), (1, 1, 1), device='xpu:0', dtype=torch.bfloat16)
    arg_10 = rand_strided((376, 1, 1), (1, 1, 1), device='xpu:0', dtype=torch.float32)
    arg_11 = rand_strided((376, 1, 1), (1, 1, 1), device='xpu:0', dtype=torch.float32)
    arg_12 = rand_strided((376, 1, 1), (1, 1, 1), device='xpu:0', dtype=torch.bfloat16)
    arg_13 = rand_strided((376, 1, 1), (1, 1, 1), device='xpu:0', dtype=torch.bfloat16)
    arg_14 = rand_strided((64, 376, 56, 56), (1179136, 1, 21056, 376), device='xpu:0', dtype=torch.bfloat16)
    arg_15 = rand_strided((64, 376, 56, 56), (1179136, 1, 21056, 376), device='xpu:0', dtype=torch.bfloat16)
    return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8, arg_9, arg_10, arg_11, arg_12, arg_13, arg_14, arg_15,


def call(args):
    with torch.xpu._DeviceGuard(0):
        torch.xpu.set_device(0)
        stream0 = get_raw_stream(0)
        triton_poi_fused__native_batch_norm_legit_no_training_cat_relu_8.run(*args, 24064, 3136, grid=grid(24064, 3136), stream=stream0)


def benchmark_all_configs(args):
    with torch.xpu._DeviceGuard(0):
        torch.xpu.set_device(0)
        return triton_poi_fused__native_batch_norm_legit_no_training_cat_relu_8.benchmark_all_configs(*args, 24064, 3136, grid=grid(24064, 3136))


if __name__ == '__main__':
    from torch._inductor.runtime.benchmarking import benchmarker

    args = get_args()
    ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40)
    num_gb = 0.90397984
    gb_per_s = num_gb / (ms / 1e3)
    print(f"{ms:.3f}ms    {num_gb:.3f}GB    {gb_per_s:.2f}GB/s")
@jianyizh
Copy link
Contributor Author

cc @riverliuintel

@riverliuintel
Copy link

@vlad-penkin this is not must target for PT2.7, not performance regression.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants