diff --git a/src/flag_gems/ops/dropout.py b/src/flag_gems/ops/dropout.py index 7301679d..dff36e96 100644 --- a/src/flag_gems/ops/dropout.py +++ b/src/flag_gems/ops/dropout.py @@ -6,6 +6,24 @@ from ..utils import libentry +try: + tl_rand_dtype = tl.int64 + @triton.jit + def _rand(seed, offset): + offset = offset.to(tl_rand_dtype) + z = tl.rand(seed, offset, n_rounds=6) + + _grid = (1,) + _seed, _offset = philox_cuda_seed_offset(0) + _rand[_grid](_seed, _offset) +except: + tl_rand_dtype = tl.int32 + +del _grid +del _seed +del _offset + + @libentry() @triton.autotune( configs=[ @@ -32,13 +50,15 @@ def dropout_forward_kernel( philox_offset, N_BLOCK_SIZE: tl.constexpr, ): + philox_seed = philox_seed.to(tl.int64) + philox_offset = philox_offset.to(tl_rand_dtype) pid = tl.program_id(0) * N_BLOCK_SIZE offset = pid + tl.arange(0, N_BLOCK_SIZE) mask = offset < N X_ptr = X + offset Y_ptr = Y + offset inp = tl.load(X_ptr, mask=mask, other=0.0) - philox_offset = philox_offset + offset.to(tl.uint64) + philox_offset = philox_offset + offset pmask = tl.rand(philox_seed, philox_offset, n_rounds=6) > p p = 1.0 / (1.0 - p) out = tl.where(pmask, inp * p, 0.0) @@ -71,13 +91,14 @@ def dropout_backward_kernel( philox_offset, N_BLOCK_SIZE: tl.constexpr, ): + philox_seed = philox_seed.to(tl.int64) + philox_offset = philox_offset.to(tl_rand_dtype) pid = tl.program_id(0) * N_BLOCK_SIZE offset = pid + tl.arange(0, N_BLOCK_SIZE) mask = offset < N DY_ptr = DY + offset DX_ptr = DX + offset - - philox_offset = philox_offset + offset.to(tl.uint64) + philox_offset = philox_offset + offset pmask = tl.rand(philox_seed, philox_offset, n_rounds=6) > p dy = tl.load(DY_ptr, mask=mask, other=0.0)