From a33dfbc883a80f3875ec534a6bb5e7d3b159543e Mon Sep 17 00:00:00 2001 From: neuron-code-sharing-robot <163452788+neuron-code-sharing-robot@users.noreply.github.com> Date: Fri, 13 Dec 2024 13:29:52 -0800 Subject: [PATCH 1/3] NeuronSDK 2.21: NKI-Samples Update (#6) Update nki-samples for NeuronSDK 2.21 Beta Release. 1. Use new nki.jit decorator for kernels. 2. Added samples for the new direct allocation feature. 3. Misc tests and documentation improvements. Co-authored-by: aws-qieqingy <122939906+aws-qieqingy@users.noreply.github.com> --- CONTRIBUTING.md | 69 +---- LICENSE | 16 - LICENSE.txt | 1 + README.md | 93 ++++-- src/reference/__init__.py | 22 +- src/reference/allocated_attention.py | 283 ++++++++++++++++++ src/reference/allocated_fused_linear.py | 114 +++++++ src/reference/attention.py | 213 ++++++++----- src/reference/tutorial.py | 22 +- src/reference/vision.py | 21 +- .../average_pool2d/average_pool2d_jax.py | 24 +- .../average_pool2d_nki_kernels.py | 49 ++- .../average_pool2d/average_pool2d_torch.py | 11 +- .../fused_mamba/mamba_nki_kernels.py | 38 ++- src/tutorials/fused_mamba/mamba_torch.py | 7 +- .../layernorm/layernorm_nki_kernel.py | 143 +++++---- src/tutorials/layernorm/layernorm_torch.py | 18 +- .../matrix_multiplication_nki_kernels.py | 52 +++- .../matrix_multiplication_torch.py | 13 +- src/tutorials/rmsnorm/rmsnorm_jax.py | 9 +- src/tutorials/rmsnorm/rmsnorm_nki_kernels.py | 14 +- src/tutorials/rmsnorm/rmsnorm_torch.py | 8 +- .../sd_attention/sd_attention_nki_kernels.py | 101 ++++--- .../sd_attention/sd_attention_torch.py | 9 +- .../tensor_addition/tensor_addition_jax.py | 36 +-- .../tensor_addition_nki_kernels.py | 28 +- .../tensor_addition/tensor_addition_torch.py | 32 +- src/tutorials/transpose2d/transpose2d_jax.py | 16 +- .../transpose2d/transpose2d_nki_kernels.py | 13 +- .../transpose2d/transpose2d_torch.py | 8 +- .../flash_attention_benchmark.py | 2 + .../sd2_512_benchmark.py | 2 + .../sd2_inpainting_936_624_benchmark.py | 2 + test/unit/test_SD_attention_small_head.py | 19 +- .../test_allocated_SD_attention_small_head.py | 67 +++++ test/unit/test_flash_attn_bwd.py | 31 +- test/unit/test_flash_attn_fwd.py | 69 +++-- test/unit/test_resize_nearest.py | 16 +- test/unit/test_rmsnorm_qkv.py | 65 ++++ test/unit/test_select_and_scatter.py | 32 +- 40 files changed, 1202 insertions(+), 586 deletions(-) delete mode 100644 LICENSE create mode 100644 LICENSE.txt create mode 100644 src/reference/allocated_attention.py create mode 100644 src/reference/allocated_fused_linear.py create mode 100644 test/unit/test_allocated_SD_attention_small_head.py create mode 100644 test/unit/test_rmsnorm_qkv.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4c16260..32ce44e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,6 +1,6 @@ # Contributing Guidelines -Thank you for your interest in contributing to our project. Whether it's a new NKI kernel, improving existing kernel code, bug fix, new feature, correction, or additional +Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional documentation, we greatly value feedback and contributions from our community. Please read through this document before submitting any issues or pull requests to ensure we have all the necessary @@ -24,13 +24,14 @@ reported the issue. Please try to include as much information as you can. Detail Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 1. You are working against the latest source on the *main* branch. -2. You check existing open, and recently merged pull requests to make sure someone else hasn't addressed the problem already. +2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. +3. You open an issue to discuss any significant work - we would hate for your time to be wasted. To send us a pull request, please: 1. Fork the repository. -2. Modify the source; please focus on the specific changes you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. -3. Please ensure your change satisfies the requirements listed in [Testing Requirements](#testing-requirements) and [Coding Guidelines](#coding-guidelines) +2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. +3. Ensure local tests pass. 4. Commit to your fork using clear commit messages. 5. Send us a pull request, answering any default questions in the pull request interface. 6. Wait for a repository collaborator to look at your pull request, run the automated tests, and review. If additional changes or discussion is needed, a collaborator will get back to you, so please stay involved in the conversation. @@ -39,64 +40,8 @@ To send us a pull request, please: GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). -### Testing Requirements -Running the binaries for a NKI kernel require Neuron devices on an AWS EC2 instance from trn1, trn1n, or inf2 instance families. -Details on setting up an instance can be found in [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-setup.html). -If you would like to test your kernel without requiring a Neuron device, you can use `nki.simulate()` to run your kernel using `NumPy` input/output tensors and types. -An example can be found in the [layernorm tutorial test](test/unit/test_tutorials_layernorm.py). However, kernels with _only_ simulation tests will not be accepted. - -#### Requirements for Kernels Targeting `src/reference/` - -All kernels located in this folder need to have the following tests. - -1. Numeric accuracy tests with `nki.baremetal`. The output from the kernel -must be validated against a CPU reference implementation. See `test_flash_attn_fwd_numerical` in [test_flash_attn_fwd.py](test/unit/test_flash_attn_fwd.py) as an example. Documentation for `nki.baremetal` is available at [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/api/generated/nki.baremetal.html). - -2. Performance benchmark tests with `nki.benchmark`. The unit test must have performance checks. At a minimum, put an assertion to verify p99 latency meets a certain threshold. See `test_flash_attn_fwd_perf` in [test_flash_attn_fwd.py](test/unit/test_flash_attn_fwd.py) as an example. Documentation for `nki.benchmark` is available at [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/api/generated/nki.benchmark.html) - -3. End-to-End integration tests that use your kernel in a model. - - a. Each test should be in its own separate folder. - - b. Each Test must have a `run.sh` script, that accepts an argument \. See [run.sh of FlashAttention](test/integration/flash_attention/run.sh) as an example. - - c. The test scripts must produce benchmark results with the `benchmark` function, located in [LatencyCollector.py](test/integration/perf_utils/LatencyCollector.py). The `benchmark` function will write the latency of your E2E model to the `test_result.json`. - - d. Register your test target in [run_integration.sh](test/integration/run_integration.sh). - - -### Coding Guidelines -Most guidelines are covered by a **PEP-8** check on all newly submitted code, which covers aspects such as code layout and basic Python naming conventions. -In addition to PEP-8, we use the following NKI specific style guidelines: - -1. **Abbreviations** - * Importing NKI modules should use consistent names. For example, - ``` - import neuronxcc.nki as nki - import neuronxcc.nki.isa as nisa - import neuronxcc.nki.language as nl - import neuronxcc.nki.typing as nt - import numpy as np - ``` -2. Variable Names - * Indexing should specify partition and free dimensions along with the variable they are used for. For example: - The index for the partition dimension for tile `a` would be - ``` - i_p_a = nl.arange(128)[:, None] - ``` - while the index for the free dimension for tile `b` would be - ``` - i_f_b = nl.arange(512)[None, :] - ``` - * Name loop variables, indices, and buffers consistently, and specify their intended use in the name. - -3. Documentation - * New kernels should containing inline docstrings that describe the semantics of the kernel, and provide information on the IO layout. - Upon release, we generate the documentation for our kernels and merge them into the NKI API documentation which will appear in the official AWS NKI documentation. - - -## Finding Contributions to Work on +## Finding contributions to work on Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any ['help wanted'](https://github.com/aws-neuron/nki-samples/labels/help%20wanted) issues is a great place to start. @@ -106,7 +51,7 @@ For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of opensource-codeofconduct@amazon.com with any additional questions or comments. -## Security Issue Notifications +## Security issue notifications If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. diff --git a/LICENSE b/LICENSE deleted file mode 100644 index 3b1fad4..0000000 --- a/LICENSE +++ /dev/null @@ -1,16 +0,0 @@ -MIT No Attribution - -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - -Permission is hereby granted, free of charge, to any person obtaining a copy of -this software and associated documentation files (the "Software"), to deal in -the Software without restriction, including without limitation the rights to -use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of -the Software, and to permit persons to whom the Software is furnished to do so. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS -FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR -COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER -IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN -CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..e7f39e2 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1 @@ +TODO: Fill LICENSE after it is finalized \ No newline at end of file diff --git a/README.md b/README.md index 60602a9..2d97f6b 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ At the core of the Neuron SDK is the Neuron Compiler, which takes computation gr them into highly optimized machine code. [NKI](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki) is a Python-based programming environment designed for the compiler which -adopts commonly used NumPy and Triton-like syntax along with tile-level semantics. +adopts commonly used NumPy andTriton-like syntax along with tile-level semantics. NKI also interoperates with the Neuron Profiler, providing insights into performance bottlenecks and instruction latencies. It offers tensor printing support, standard error messaging, and built-in kernel simulation capabilities for efficient debugging purposes. NKI offers two types of programming interfaces: @@ -16,25 +16,31 @@ enabling bare-metal access to the chip for full control. ![alt "High-level flow of NKI in the Neuron Compiler. NKI emits IR immediately before the backend-IR compilation stage"](doc_assets/high-level-nki-flow.png#center "High-Level NKI Flow") -## Documentation -The latest NKI documentation can be found on the AWS Documentation site, [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/). -Documentation for NKI kernels are both inline (docstring) and available on the documentation site's -[kernel API reference page](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/api/nki.kernels.html). +### nki.language +**nki.language** enables precise control over computation and data movement on NeuronCores-- the processing units within AWS Inferentia and Trainium chips. +Developers can control data movement between device memory and on-chip memory explicitly using `nl.load()` and `nl.store()` operations. +Developers can then perform the desired computations on loaded tensors, such as element-wise operations or tensor contractions, +providing crucial performance improvements. Additionally, developers can control how computation is performed on different compute engines inside NeuronCores. +nki.language APIs are considered high-level APIs and are designed for "ease of use" for ML practitioners. +To achieve the best performance, developers can enlist the nki.isa APIs. + +![alt "Diagram of the NeuronCore Architecture. It shows 4 engines: tensor, vector, scalar, and GPSIMD, connected to SBUF memory. The tensor, vector, and scalar engines are also connected to a high-speed PSUM memory bank that supports accumulate on write. Lastly the HBM (DRAM) is connected to both SBUF and PSUM memory banks."](doc_assets/pm-nc.png#scale_50#center "NeuronCore Architecture") + +### nki.isa + +**nki.isa** provides direct access to chip instructions to offer flexibility and fine-grained control over instruction usage and performance optimizations. +Developers can utilize various `nki.isa` instructions using the Tensor, Vector, Scalar, GP-SIMD, and DMA engines. +For example, developers can use `nki.isa.nc_matmul()` to compute a matrix multiplication using Tensor Engine. +Alternatively, developers can use `nki.isa.activation()` to apply an activation function on every element of the input tile using Scalar Engine. ## Repository Structure ### src #### reference -This folder contains the source code of the `neuronxcc.nki.kernels`, and they are optimized kernels from the Neuron Team serving as samples. - -All kernels located in this folder have numeric accuracy tests +The [reference kernels](src/reference/) are optimized reference kernels. All kernels located in this folder must have all of numeric accuracy tests and performance benchmarks defined in the [test](test/) directory. We also demonstrate using these kernels end-to-end in our [integration tests](test/integration/). -Note that these kernels are already being deployed as part of the Neuron stack. With flash attention as an example, -[compiling Llama models with transformers-neuronx](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/transformers-neuronx/transformers-neuronx-developer-guide.html) -will automatically invoke the `flash_fwd` kernel in [attention.py](src/reference/attention.py). Therefore, replacing the framework operators with these NKI kernels likely won't result in extra performance benefit. - #### tutorials The [tutorial kernels](src/tutorials/) are for educational purpose and include the kernels that are used in NKI guides. @@ -52,16 +58,65 @@ verify the numeric accuracy of the operation, and publish performance results to The [integration tests](tests/integration) folder contains integration tests of (selected) kernels. They verify the numeric accuracy of the model’s output, and publish end-to-end performance results into the [integration benchmarks](docs/benchmarks/integration) folder. -## Maintenance Policy -NKI is currently released as **beta** while we gather feedback from our users and integrate it into the API. NKI API follow the [Neuron SDK Maintenance Policy](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/sdk-policy.html). +## Documentation +The latest NKI documentation can be found on the AWS Documentation site, [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/). +Documentation for NKI kernels are both inline (docstring) and available on the documentation site's +[kernel API reference page](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/api/nki.kernels.html). -## Getting Help -Have a look at the GitHub issues for this repository where you will find past issues customers have encountered with workarounds and clarifications. -If you cannot find a suitable issue for your use-case feel free to [file an issue](https://github.com/aws-neuron/nki-samples/issues/new) to ask for assistance or to suggest improvements. Please read [CONTRIBUTING.md](CONTRIBUTING.md) for detailed information on submitting issues. +## Versioning +NKI is currently released as **beta** while we gather feedback from our users and integrate it into the API. We will also be updating the NKI API as needed +to support new Neuron and Neuron Compiler features. While NKI is in beta we may need to make backwards-incompatible changes to incorporate feedback from +our users or to support new use-cases of NKI on Neuron devices. Upon releasing NKI as generally available (GA), we will commit to not making backwards +incompatible changes to the NKI API for any supported version of the Neuron compiler. ## Contributing -We invite you to join the NKI community! If you'd like to share kernels you create with the community, we welcome your contributions to this repository via -GitHub pull-requests as well as through filed issues discussing features, bug fixes, new use-cases, and API improvements. Please see [CONTRIBUTING.md](CONTRIBUTING.md) for more information +We invite you to join the NKI community! If you'd like to share kernels you create with the community, we welcome your contributions to this repository via. +GitHub pull-requests as well as through filed issues discussing features, bug fixes, new use-cases, and API improvements. + +### Getting Help +Have a look at the GitHub issues for this repository where you will find past issues customers have encountered with workarounds and clarifications. +If you cannot find a suitable issue for your use-case feel free to file an issue asking for assistance or to suggest improvements. + +In addition, extensive NKI documentation can be found [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki). + +### Testing and Merging +Running the binaries for a NKI kernel require Neuron devices on an AWS EC2 instance from trn1, trn1n, or inf2 instance families. +Details on setting up an instance can be found in [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-setup.html). + +Before merging, the Neuron team will need to internally test and verify kernels work as expected. If the change is accepted, +we will manually merge your changes, and it will be merged here upon the next release. + +If you would like to test your kernel without a requiring a Neuron device, you can use `nki.simulate()` to run your kernel using `NumPy` tensors and types. +An example can be found in the [layernorm tutorial test](test/unit/test_tutorials_layernorm.py). + +### Coding Guidelines +Most guidelines are covered by a **PEP-8** check on all newly submitted code, which covers aspects such as code layout and basic Python naming conventions. +In addition to PEP-8, we use the following NKI specific style guidelines: + +1. **Abbreviations** + * Importing NKI modules should use consistent names. For example, + ``` + import neuronxcc.nki as nki + import neuronxcc.nki.isa as nisa + import neuronxcc.nki.language as nl + import neuronxcc.nki.typing as nt + import numpy as np + ``` +2. Variable Names + * Indexing should specify partition and free dimensions along with the variable they are used for. For example: + The index for the partition dimension for tile `a` would be + ``` + i_p_a = nl.arange(128)[:, None] + ``` + while the index for the free dimension for tile `b` would be + ``` + i_f_b = nl.arange(512)[None, :] + ``` + * Name loop variables, indices, and buffers consistently, and specify their intended use in the name. + +3. Documentation + * New kernels should containing inline docstrings that describe the semantics of the kernel, and provide information on the IO layout. + Upon release, we generate the documentation for our kernels and merge them into the NKI API documentation which will appear in the official AWS NKI documentation. ## Licensing This repository is licensed under the terms of the [MIT-0 License](LICENSE.txt) \ No newline at end of file diff --git a/src/reference/__init__.py b/src/reference/__init__.py index ad4a18a..922dd83 100644 --- a/src/reference/__init__.py +++ b/src/reference/__init__.py @@ -6,7 +6,27 @@ Kernels here are the same to the ones available in the NKI Github Sample Repo. -TODO: Insert link to Github Repo when available +https://github.com/aws-neuron/nki-samples """ from neuronxcc.nki.kernels.attention import fused_self_attn_for_SD_small_head_size, flash_attn_bwd, flash_fwd from neuronxcc.nki.kernels.vision import resize_nearest_fixed_dma_kernel, select_and_scatter_kernel +from neuronxcc.nki.kernels.tutorial import add_kernel_nx8x128x512 +from neuronxcc.nki.kernels.allocated_attention import allocated_fused_self_attn_for_SD_small_head_size +from neuronxcc.nki.kernels.allocated_fused_linear import allocated_fused_rms_norm_qkv + +from neuronxcc.nki._private_kernels.legacy.attention import \ + (fused_self_attn_for_SD_small_head_size as _fused_self_attn_for_SD_small_head_size, + flash_attn_bwd as _flash_attn_bwd, flash_fwd as _flash_fwd) +from neuronxcc.nki._private_kernels.legacy.vision import ( + resize_nearest_fixed_dma_kernel as _resize_nearest_fixed_dma_kernel, + select_and_scatter_kernel as _select_and_scatter_kernel) +from neuronxcc.nki._private_kernels.legacy.tutorial import add_kernel_nx8x128x512 as _add_kernel_nx8x128x512 +from neuronxcc.nki._private_kernels.legacy.allocated_fused_linear import _allocated_fused_rms_norm_qkv + +fused_self_attn_for_SD_small_head_size._legacy_func = _fused_self_attn_for_SD_small_head_size +flash_attn_bwd._legacy_func = _flash_attn_bwd +flash_fwd._legacy_func = _flash_fwd +resize_nearest_fixed_dma_kernel._legacy_func = _resize_nearest_fixed_dma_kernel +select_and_scatter_kernel._legacy_func = _select_and_scatter_kernel +add_kernel_nx8x128x512._legacy_func = _add_kernel_nx8x128x512 +allocated_fused_rms_norm_qkv._legacy_func = _allocated_fused_rms_norm_qkv diff --git a/src/reference/allocated_attention.py b/src/reference/allocated_attention.py new file mode 100644 index 0000000..564412c --- /dev/null +++ b/src/reference/allocated_attention.py @@ -0,0 +1,283 @@ +import functools +import neuronxcc.nki as nki +import neuronxcc.nki.language as nl +import neuronxcc.nki.isa as nisa +import neuronxcc.nki.compiler as ncc +from neuronxcc.nki.language import par_dim +import numpy as np + +@nki.jit +def allocated_fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, + use_causal_mask=False, + mixed_percision=True): + """ + Allocated fused self attention kernel for small head size Stable Diffusion workload. + + Computes (softmax(Q.T@K)V).T. The wired layout is choosen to avoid transpose as + much as possible to simplify the debug. The kernel uses the direct allocation API, + and implements double buffering to achive better performance than automatic allocation. + As of NeuronSDK 2.21, it achieves 18% better performance than auto allocated equivalent. + To see the performance gap, you can use ``force_auto_alloc`` decorator to override + manual allocation and benchmark the performance difference. + + This kernel is designed to be used for Stable Diffusion models where the + n_heads is equal to 128. Seqlen must be divisible by 1024, and smaller than 5120. + Assertion is thrown if ``n_heads`` or sequence length does not satisfy the requirement. + These restrictions are to simplify the address calculation in allocations. + + IO tensor layouts: + - q_ptr: shape (bs, d_heads, seq_q) + - k_ptr: shape (bs, d_heads, seq_k) + - v_ptr: shape (bs, seq_v, n_heads) + - out_ptr: shape (bs, d_heads, seq_q) + - We use seq_q and seq_k just for clarity, this kernel requires seq_q == seq_k + + IO tensor dtypes: + - This kernel assumes all IO tensors have the same dtype + - If mixed_percision is True, then all Tensor Engine operation will be performed in + bfloat16 and accumulation will be performed in float32. Otherwise the intermediates + will be in the same type as the inputs. + """ + # Use q_ref dtype as the intermediate tensor dtype + # Assume all IO tensors have the same dtype + kernel_dtype = np.float32 + pe_in_dt = nl.bfloat16 if mixed_percision else kernel_dtype + + kernel_dtype_itemsize = np.dtype(kernel_dtype).itemsize + pe_in_dt_itemsize = np.dtype(pe_in_dt).itemsize + assert q_ref.dtype == k_ref.dtype == v_ref.dtype + + # Shape checking + bs, d_head, seqlen = q_ref.shape + assert d_head <= 128, "Cannot use this kernel for d_head > 128" + assert tuple(q_ref.shape) == (bs, d_head, seqlen), 'Input shape mismatch!' + assert tuple(k_ref.shape) == (bs, d_head, seqlen), 'Input shape mismatch!' + assert tuple(v_ref.shape) == (bs, seqlen, + d_head), f'Input shape mismatch! Expected: {(bs, seqlen, d_head)} Actual: {tuple(v_ref.shape)}' + out_ref = nl.ndarray((bs, d_head, seqlen), dtype=q_ref.dtype, buffer=nl.shared_hbm) + + assert d_head == 128 + + cur_addr = 0 + + id0 = nl.arange(0, 128)[:, None] + id1 = nl.arange(0, 128)[None, :] + identity = nl.shared_constant(np.identity(128, dtype=np.int8), dtype=nl.bfloat16) + identity_load = nl.ndarray((par_dim(128), 128), dtype=pe_in_dt, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr)) + cur_addr += 128 * pe_in_dt_itemsize + identity_load[id0, id1] = nl.load(identity) + + identity_load_fp32 = nl.ndarray((par_dim(128), 128), dtype=np.float32, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr)) + cur_addr += 128 * np.dtype(np.float32).itemsize + identity_load_fp32[id0, id1] = nl.load(identity) + + # Softmax scaling factor, multiplied onto Q + softmax_scale = 0.125 + + # Different batch samples/attention heads have independent attention + batch_id = nl.program_id(axis=0) + + q_seq_n_tiles, q_seq_tile_size = seqlen // 128, 128 + k_seq_n_tiles, k_seq_tile_size = seqlen // 512, 512 + # No tiling on d_head dimension since the number of d_head fits in SB + d_head_tile_size = d_head + v_seq_n_tiles, v_seq_tile_size = seqlen // 128, 128 + + ################################### + # Step 1. preload tensors + ################################### + v_local = nl.ndarray((v_seq_n_tiles, par_dim(v_seq_tile_size), d_head), dtype=pe_in_dt, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(v_seq_n_tiles, ))) # 8kb + cur_addr += v_seq_n_tiles * d_head * pe_in_dt_itemsize + + for i_v_seq_tile in nl.affine_range(v_seq_n_tiles): + ip_v = nl.arange(v_seq_tile_size)[:, None] + if_v = nl.arange(d_head_tile_size)[None, :] + v_local[i_v_seq_tile, ip_v, if_v] = nl.load( + v_ref[batch_id, i_v_seq_tile * v_seq_tile_size + ip_v, if_v], + dtype=pe_in_dt) + + q_local = nl.ndarray((q_seq_n_tiles, par_dim(d_head_tile_size), q_seq_tile_size), dtype=pe_in_dt, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(q_seq_n_tiles, ))) # 8kb + cur_addr += q_seq_n_tiles * q_seq_tile_size * pe_in_dt_itemsize + ip_q = nl.arange(d_head_tile_size)[:, None] + if_q = nl.arange(q_seq_tile_size)[None, :] + for i_q_seq_tile in nl.affine_range(q_seq_n_tiles): + q_local[i_q_seq_tile, ip_q, if_q] = nl.load( + q_ref[batch_id, ip_q, i_q_seq_tile * q_seq_tile_size + if_q], + dtype=pe_in_dt) + q_local[i_q_seq_tile, ip_q, if_q] = nl.multiply(q_local[i_q_seq_tile, ip_q, if_q], softmax_scale) + + k_local = nl.ndarray((k_seq_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size), dtype=pe_in_dt, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(k_seq_n_tiles, ))) # 8kb + cur_addr += k_seq_n_tiles * k_seq_tile_size * pe_in_dt_itemsize + ip_k = nl.arange(d_head_tile_size)[:, None] + if_k = nl.arange(k_seq_tile_size)[None, :] + for i_k_seq_tile in nl.affine_range(k_seq_n_tiles): + k_local[i_k_seq_tile, ip_k, if_k] = nl.load( + k_ref[batch_id, + ip_k, + i_k_seq_tile * k_seq_tile_size + if_k + ], + dtype=pe_in_dt) + + for i_q_seq_tile in nl.affine_range(q_seq_n_tiles//2): # indent = 2 + # perform activation and reduction in softmax in larger tile to amortize instruction overhead + reduction_size = 1024 + reduction_tiles = seqlen // reduction_size + + # =================================== SBUF Allocation Starts =================================== + + # The num_free_tiles is intentionally set to (1, ) to disable double buffering on the first matmul. + # From the profile, when the first matmul is double buffered, the tensor_scalar_reduce instruction that writes to this buffer + # spends long time waiting for the matmul it depends on to be executed. The instruction scheduler made a bad decision and + # clogged the pipeline when double buffering is on. This is a workaround to hint the scheduler. + qk_res_buf = nl.ndarray((2, par_dim(q_seq_tile_size), seqlen), dtype=kernel_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(1, ))) # 32 k + cur_addr += seqlen * kernel_dtype_itemsize + exp_res = nl.ndarray((2, par_dim(q_seq_tile_size), seqlen),dtype=pe_in_dt, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 16 kb + cur_addr += seqlen * 2 * pe_in_dt_itemsize + trans_softmax_res = nl.ndarray( + (2, par_dim(v_seq_tile_size), seqlen), name='trans_softmax_res', + dtype=pe_in_dt, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 16kb + cur_addr += seqlen * 2 * pe_in_dt_itemsize + + sum_divisor = nl.ndarray((2, par_dim(q_seq_tile_size), d_head_tile_size), dtype=kernel_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 1kb + cur_addr += 2 * d_head_tile_size * kernel_dtype_itemsize + sum_reciprocal_broadcast = nl.ndarray((2, par_dim(q_seq_tile_size), d_head_tile_size), dtype=kernel_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 1kb + cur_addr += 2 * d_head_tile_size * kernel_dtype_itemsize + + attn_res_sbuf = nl.ndarray((2, par_dim(d_head_tile_size), q_seq_tile_size), dtype=kernel_dtype, + buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, )), name="attn_res_sbuf") # 1kb + cur_addr += 2 * q_seq_tile_size * kernel_dtype_itemsize + attn_res_div = nl.ndarray((2, par_dim(q_seq_tile_size), d_head_tile_size), dtype=kernel_dtype, + buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2,))) # 1kb + cur_addr += 2 * d_head_tile_size * kernel_dtype_itemsize + + neg_max_res = nl.ndarray((2, par_dim(q_seq_tile_size), k_seq_n_tiles), dtype=kernel_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 64b + cur_addr += 2 * k_seq_n_tiles * kernel_dtype_itemsize + partial_sum_res = nl.ndarray((2, par_dim(q_seq_tile_size), reduction_tiles), dtype=kernel_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 32b + cur_addr += 2 * reduction_tiles * kernel_dtype_itemsize + neg_max_res_final = nl.ndarray((2, par_dim(q_seq_tile_size), 1), dtype=kernel_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 8b + cur_addr += 2 * 1 * kernel_dtype_itemsize + sum_res = nl.ndarray((2, par_dim(q_seq_tile_size), 1), dtype=kernel_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 8b + cur_addr += 2 * 1 * kernel_dtype_itemsize + sum_reciprocal = nl.ndarray((2, par_dim(q_seq_tile_size), 1), dtype=kernel_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 8b + cur_addr += 2 * 1 * kernel_dtype_itemsize + + # =================================== SBUF Allocation End =================================== + + qk_psum = nl.ndarray((2, k_seq_n_tiles, par_dim(q_seq_tile_size), k_seq_tile_size), + dtype=np.float32, buffer=ncc.psum.mod_alloc(base_bank=0, num_bank_tiles=(2, 4))) + + assert k_seq_tile_size == 4 * v_seq_tile_size + local_tp_buf = nl.ndarray((2, k_seq_n_tiles, par_dim(q_seq_tile_size), k_seq_tile_size), dtype=np.float32, + buffer=ncc.psum.mod_alloc(base_bank=0, num_bank_tiles=(2, 4))) + + def psum_addr(bank_map, idx, pdim_size, fdim_size): + return (bank_map[idx], 0, 0) + + # Result psum buffer has the hidden dim as P + # qk_psum is using 0, 1, 2, 3 for fisrt interleave group, and 4, 5, 6, 7 for the second. + # assign 1 and 5 avoid bank collision between groups + attn_res_psum = nl.ndarray((2, par_dim(d_head_tile_size), q_seq_tile_size), + dtype=np.float32, buffer=ncc.psum.alloc(functools.partial(psum_addr, bank_map={(0, ): 1, (1, ): 5}))) + + sum_local_tp_buf = nl.ndarray((2, par_dim(q_seq_tile_size), k_seq_tile_size), dtype=np.float32, + buffer=ncc.psum.alloc(functools.partial(psum_addr, bank_map={(0, ): 2, (1, ): 7}))) + + for i_interleave_grp in nl.affine_range(2): + # A SBUF buffer tile for an independent softmax tile + ip_max = nl.arange(q_seq_tile_size)[:, None] + if_max = nl.arange(k_seq_n_tiles)[None, :] + + # Loop over RHS free of matmul(stationary=tensor_q, moving=tensor_k, contract=d_head) + for i_k_seq_tile in nl.affine_range(k_seq_n_tiles): # indent = 4 + + # Tensor indices for accessing qk result in k_seq_tile_size + ip_qk = nl.arange(q_seq_tile_size)[:, None] + if_qk = nl.arange(k_seq_tile_size)[None, :] + + ############################################################## + # Step 2. matmul(stationary=tensor_q, moving=tensor_k, contract=d_head) + ############################################################## + qk_psum[i_interleave_grp, i_k_seq_tile, ip_qk, if_qk] = nisa.nc_matmul(moving=k_local[i_k_seq_tile, ip_k, if_k], + stationary=q_local[i_q_seq_tile*2+i_interleave_grp, ip_q, if_q]) + + ################################### + # Step 3. Apply optional causal mask + ################################### + if use_causal_mask: + assert not use_causal_mask, "Causal mask not supported yet!" + # Magic number -9984.0 to replace -inf similar to what Tensorizer uses + qk_res_buf[i_interleave_grp, ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk] = nisa.affine_select( + pred=(i_q_seq_tile * q_seq_tile_size + ip_qk >= i_k_seq_tile * k_seq_tile_size + if_qk), + on_true_tile=qk_psum[i_interleave_grp, i_k_seq_tile,ip_qk, if_qk], on_false_value=-9984.0, dtype=kernel_dtype) + else: + # Copy result to SBUF and find partial maximum for softmax + qk_res_buf[i_interleave_grp, ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk] = nisa.tensor_scalar_reduce(qk_psum[i_interleave_grp, i_k_seq_tile,ip_qk, if_qk], np.add, 1.0, + reduce_op=np.max, reduce_res=neg_max_res[i_interleave_grp, ip_max, i_k_seq_tile], dtype=kernel_dtype) + + # Find global max from tiles + neg_max_res_final[i_interleave_grp, ip_max, 0] = nisa.tensor_reduce( + np.max, data=neg_max_res[i_interleave_grp, ip_max, if_max], + axis=(1,), dtype=kernel_dtype, negate=True) + + ip_softmax = nl.arange(q_seq_tile_size)[:, None] + if_softmax = nl.arange(seqlen)[None, :] + ip_sum_res = nl.arange(q_seq_tile_size)[:, None] + if_sum_res = nl.arange(d_head_tile_size)[None, :] + + if_reduction = nl.arange(reduction_size)[None, :] + for i_exp in nl.affine_range(reduction_tiles): + exp_res[i_interleave_grp, ip_softmax, i_exp*reduction_size + if_reduction] = nisa.activation_reduce(np.exp, + data=qk_res_buf[i_interleave_grp, ip_softmax, i_exp * reduction_size + if_reduction], + reduce_op=np.sum, reduce_res=partial_sum_res[i_interleave_grp, ip_softmax, i_exp], + bias=neg_max_res_final[i_interleave_grp, ip_max, 0], scale=1.0, + ) + + sum_res[i_interleave_grp, ip_softmax, 0] = nisa.tensor_reduce(np.add, data=partial_sum_res[i_interleave_grp, :, :], axis=(1,), + dtype=kernel_dtype) + + sum_reciprocal[i_interleave_grp, ip_softmax, 0] = nl.divide(1.0, sum_res[i_interleave_grp, ip_softmax, 0]) + sum_reciprocal_broadcast[i_interleave_grp, ip_softmax, if_sum_res] = sum_reciprocal[i_interleave_grp, ip_softmax, 0].broadcast_to((q_seq_tile_size, d_head_tile_size)) + sum_divisor[i_interleave_grp, ip_sum_res, if_sum_res] = nl.copy(sum_reciprocal_broadcast[i_interleave_grp, ip_softmax, if_sum_res], dtype=kernel_dtype) + + ################################### + # Step 5. transpose(softmax_res) + ################################### + ip_scores_t = nl.arange(v_seq_tile_size)[:, None] + if_scores_t = nl.arange(v_seq_tile_size)[None, :] + # Loop over matmul_1 contraction + for i_v_seq_tile in nl.affine_range(v_seq_n_tiles // 4): + for i_offset in nl.affine_range(4): + ip_scores = nl.arange(v_seq_tile_size)[:, None] + if_scores = nl.arange(v_seq_tile_size)[None, :] + + local_tp_buf[i_interleave_grp, i_v_seq_tile, ip_scores, i_offset*v_seq_tile_size + if_scores] = nisa.nc_matmul( + exp_res[i_interleave_grp, ip_scores, (i_v_seq_tile*4+i_offset) * v_seq_tile_size + if_scores], + identity_load) + + if_batch = nl.arange(k_seq_tile_size)[None, :] + trans_softmax_res[i_interleave_grp, ip_scores_t, i_v_seq_tile*k_seq_tile_size + if_batch] = nl.copy(local_tp_buf[i_interleave_grp, i_v_seq_tile, ip_scores, if_batch]) + + ip_out = nl.arange(d_head_tile_size)[:, None] + if_out = nl.arange(q_seq_tile_size)[None, :] + + for i_v_seq_tile in nl.affine_range(v_seq_n_tiles): + ###################################################################### + # Step 6. matmul_1(stationary=v_local, moving=trans_softmax_res, contract=seqlen_v=seqlen_k) + ###################################################################### + ip_v_t = nl.arange(v_seq_tile_size)[:, None] + if_v_t = nl.arange(d_head_tile_size)[None, :] + attn_res_psum[i_interleave_grp, ip_out, if_out] += \ + nisa.nc_matmul(moving=trans_softmax_res[i_interleave_grp, ip_scores_t, i_v_seq_tile*v_seq_tile_size+if_scores_t], + stationary=v_local[i_v_seq_tile, ip_v_t, if_v_t]) + + attn_res_sbuf[i_interleave_grp, ip_out, if_out] = nisa.tensor_copy(attn_res_psum[i_interleave_grp, ip_out, if_out], + dtype=kernel_dtype, engine=nisa.vector_engine) + + sum_local_tp_buf[i_interleave_grp, ip_sum_res, if_sum_res] = nisa.nc_matmul(sum_divisor[i_interleave_grp, ip_sum_res, if_sum_res], identity_load_fp32) + attn_res_div[i_interleave_grp, ip_sum_res, if_sum_res] = attn_res_sbuf[i_interleave_grp, :, :] * sum_local_tp_buf[i_interleave_grp, ip_sum_res, if_sum_res] + + nl.store( + out_ref[batch_id, ip_out, (i_q_seq_tile*2+i_interleave_grp) * q_seq_tile_size + if_out], + value=attn_res_div[i_interleave_grp, :, :]) + + return out_ref \ No newline at end of file diff --git a/src/reference/allocated_fused_linear.py b/src/reference/allocated_fused_linear.py new file mode 100644 index 0000000..21e32af --- /dev/null +++ b/src/reference/allocated_fused_linear.py @@ -0,0 +1,114 @@ +""" +Copyright (c) 2024, Amazon.com. All Rights Reserved + +kernels - Fused normalization with linear layers + +""" + +import neuronxcc.nki.language as nl +import neuronxcc.nki.isa as nisa +import neuronxcc.nki.compiler as ncc +import math +import numpy as np +from neuronxcc import nki +from neuronxcc.nki.language import par_dim + +@nki.jit +def allocated_fused_rms_norm_qkv(hidden, weights, norm_dtype=nl.float32, eps=1e-6): + """ + Allocated kernel that computes RMSNorm(hidden) @ wQKV. This kernel is designed to only handle fp16/bf16 tensor types. + Internally, normalizations are cast to fp32 to avoid NaN errors. + + Args: + hidden (_type_): Input tensor of the attention block in BSH layout + weights (_type_): Fused QKV linear weights, assumed to be eltwise-multiplied with RMS norm weight vector (gamma) + out_tensor (_type_): Output tensor + norm_dtype (_type_, optional): Data type for RMS norm, should be f32 to avoid NaN. Defaults to nl.float32. + eps (_type_, optional): RMS norm epsilon term. Defaults to 1e-6. + """ + # Hidden should be in BSH layout. + batch, batchless_shape = hidden.shape[0], hidden.shape[1:] + seqlen, dim = batchless_shape + _dim, head_dim = weights.shape + + assert dim <= 8192 and dim & 128 == 0, "Unsupported hidden dimension" + assert _dim == dim, "Reduction dimension must match" + assert head_dim <= 512, "Head dimension must be 512 or less" + + out_tensor = nl.ndarray((batch, seqlen, head_dim), dtype=hidden.dtype, buffer=nl.shared_hbm) + + pmax, fmax = nl.tile_size.pmax, nl.tile_size.psum_fmax # 128, 512 + ix, iy = nl.mgrid[0:pmax, 0:dim] + i_lhs = nl.mgrid[0:pmax, 0:pmax] + i_rhs = nl.mgrid[0:pmax, 0:fmax] + i_res = nl.mgrid[0:pmax, 0:fmax] + M = math.ceil(dim / pmax) + NUM_TRANSP_TILES = math.ceil(dim / fmax) + NUM_TILES = math.ceil(seqlen / pmax) + TILES_INT = math.ceil(NUM_TILES / 2) + scale = 1 / dim + + iden_x, iden_y = nl.mgrid[0:pmax, 0:128] + + identity_a = nl.shared_constant(np.identity(n=128, dtype=np.int8), dtype=hidden.dtype) + identity_tensor = nl.ndarray((par_dim(pmax), 128), dtype=weights.dtype, buffer=ncc.sbuf.mod_alloc(base_addr=0)) + identity_tensor[iden_x, iden_y] = nl.load(identity_a, dtype=weights.dtype) + bias_placeholder = nl.ndarray((par_dim(pmax), 1), dtype=np.float32, buffer=ncc.sbuf.mod_alloc(base_addr=128*2)) + bias_placeholder[...] = 0 + + for b in nl.affine_range(batch): + weights_buffer = nl.ndarray((M, par_dim(pmax), fmax), dtype=weights.dtype, + buffer=ncc.sbuf.mod_alloc(base_addr=260+(3*dim+fmax)*2+(dim+1)*4, num_free_tiles=(M,))) + # Preload the entire weights tensor. everything fits in SBUF for LLaMA 3.1 70B + for m in nl.affine_range(M): + weights_buffer[m, i_rhs.p, i_rhs.x] = nl.load(weights[m*pmax+i_rhs.p, i_rhs.x], + mask=(m*pmax+i_rhs.p= k_pos - # For tiles on and on the right of the diagonal, need to do affine_select. + # For tiles on and to the right of the diagonal, need to do affine_select. # Magic number -9984.0 to replace -inf similar to what Tensorizer uses qk_res_buf[i_q_p, k_i * B_F_SIZE + i_q_f] = nisa.affine_select( pred=pred, @@ -154,9 +165,9 @@ def _flash_attention_core(q_local_tile, k, v, for k_d_i in nl.sequential_range(REDUCTION_TILE // B_F_SIZE): offset = k_d_i + k_r_i * (REDUCTION_TILE // B_F_SIZE) \ + global_k_large_tile_idx * (LARGE_TILE_SZ // B_F_SIZE) \ - + q_tile_idx * (seq_len // B_F_SIZE) \ - + (head_id * q_h_per_k_h + gqa_head_idx) * (seq_len // B_F_SIZE) * seq_q_num_tiles \ - + batch_id * nl.num_programs(1) * (seq_len // B_F_SIZE) * seq_q_num_tiles + + q_tile_idx * seq_k_num_tiles \ + + (head_id * q_h_per_k_h + gqa_head_idx) * seq_k_num_tiles * seq_q_num_tiles \ + + batch_id * nheads * seq_k_num_tiles * seq_q_num_tiles offset_seed = nl.add(seed_tensor[0, 0], offset, mask=forward_mask) nl.random_seed(seed=offset_seed, mask=forward_mask) softmax_dropout = nl.dropout(p_local[i_q_p, k_r_i * REDUCTION_TILE + k_d_i * B_F_SIZE + i_q_f], @@ -172,9 +183,9 @@ def _flash_attention_core(q_local_tile, k, v, for i_p_t in nl.affine_range(LARGE_TILE_SZ // 512): p_local_t_tmp = nl.ndarray((par_dim(B_P_SIZE), 512), buffer=nl.psum, dtype=np.float32) for i_p_t_local in nl.affine_range(512//128): - p_local_t_tmp[i_q_p, i_p_t_local*128 + i_f_128] = nisa.nc_transpose(p_local[i_q_p, i_p_t*512+i_p_t_local * B_P_SIZE + i_f_128]) + p_local_t_tmp[i_q_p, i_p_t_local*128 + i_f_128] = nisa.nc_transpose(p_local[i_q_p, i_p_t*512+i_p_t_local * B_P_SIZE + i_f_128], mask=forward_mask) i_f_512 = nl.arange(512)[None, :] - p_local_transposed[i_q_p, i_p_t * 512 + i_f_512 ] = nl.copy(p_local_t_tmp[i_q_p, i_f_512], dtype=kernel_dtype) + p_local_transposed[i_q_p, i_p_t * 512 + i_f_512 ] = nl.copy(p_local_t_tmp[i_q_p, i_f_512], dtype=kernel_dtype, mask=forward_mask) ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type, mask=forward_mask) pv_psum = nl.zeros((par_dim(B_P_SIZE), B_D_SIZE), dtype=np.float32, buffer=nl.psum) @@ -199,7 +210,8 @@ def _flash_attention_core(q_local_tile, k, v, l_buffer[olm_buffer_idx, i_q_p, 0] = nl.copy(l_buffer[olm_buffer_idx-1, i_q_p, 0], mask=negation_mask) -def flash_fwd(q, k, v, seed, o, lse=None, +@nki.jit +def flash_fwd(q, k, v, seed, softmax_scale=None, use_causal_mask=True, mixed_precision=True, @@ -213,8 +225,8 @@ def flash_fwd(q, k, v, seed, o, lse=None, - v: shape (bs, nv_heads, d, seq_v) if config.should_transpose_v else (bs, nv_heads, seq_v, d) - seed: shape (1,) - o: shape (bs, n_heads, seq_q, d) - - lse: shape (bs, nheads, nl.tile_size.pmax, seq // nl.tile_size.pmax) if training else None - - We use seq_q and seq_k just for clarity, this kernel requires seq_q == seq_k + - lse: shape (bs, n_heads, nl.tile_size.pmax, seq // nl.tile_size.pmax) if training else None + - This kernel requires seq_k == seq_v IO tensor dtypes: - This kernel assumes all IO tensors have the same dtype @@ -246,36 +258,48 @@ def flash_fwd(q, k, v, seed, o, lse=None, config = config or FlashConfig() B_F_SIZE=512 B_P_SIZE=128 - b , h, d, n = q.shape + b, h, d, seqlen_q = q.shape B_D_SIZE = d - k_h = k.shape[1] - v_shape = v.shape + _, k_h, _, seqlen_k = k.shape if config.should_transpose_v: - assert tuple(v_shape) == (b, k_h, d, n), f"V shape does not match layout requirements, expect: {(b, k_h, d, n)} but got {v_shape}" - assert tuple(k.shape) == (b, k_h, d, n), f" k and v shape does not match the layout defined in the function, but got {k.shape}" + assert tuple(v.shape) == (b, k_h, d, seqlen_k), f"Expect shape of V to be {(b, k_h, d, seqlen_k)} (batch, heads, d_head, seqlen_k) but got {v.shape}" + assert tuple(k.shape) == (b, k_h, d, seqlen_k), f"Expect shape of K to be {(b, k_h, d, seqlen_k)} (batch, heads, d_head, seqlen_k) but got {k.shape}" else: - assert tuple(v_shape) == (b, k_h, n, d), f"V shape does not match layout requirements, expect: {(b, k_h, n, d)} but got {v_shape}" - assert tuple(k.shape) == (b,k_h, d, n), f" k and v shape does not match the layout defined in the function, but got {k.shape}" + assert tuple(v.shape) == (b, k_h, seqlen_k, d), f"Expect shape of V to be {(b, k_h, seqlen_k, d)} (batch, heads, seqlen_k, d_head) but got {v.shape}" + assert tuple(k.shape) == (b, k_h, d, seqlen_k), f"Expect shape of K to be {(b, k_h, d, seqlen_k)} (batch, heads, d_head, seqlen_k) but got {k.shape}" assert d <= 128, f" we do not support head_dim > 128, got head dim {d}" kernel_dtype = nl.bfloat16 if mixed_precision else q.dtype - acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype + acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype + + o = nl.ndarray((b, h, seqlen_q, d), dtype=q.dtype, buffer=nl.shared_hbm) + if config.training: + lse = nl.ndarray((b, h, nl.tile_size.pmax, seqlen_q // nl.tile_size.pmax), + dtype=acc_type, buffer=nl.shared_hbm) + else: + lse = None i_q_p = nl.arange(B_P_SIZE)[:,None] i_0_f = nl.arange(1)[None, :] - n_tile_q = n//B_P_SIZE # since q will be loaded on PE batch_id = nl.program_id(axis=0) - head_id = nl.program_id(axis=1) + + head_dims = list(range(1, nl.program_ndim())) + head_dims_shape = list(nl.num_programs(i) for i in head_dims) + head_dims_idx = list(nl.program_id(i) for i in head_dims) + head_id = linearize(head_dims_shape, head_dims_idx) + softmax_scale = softmax_scale or (1.0 / (d ** 0.5)) + n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine + LARGE_TILE_SZ = config.seq_tile_size # FIXME: Add masking for different seqlen values. assert config.seq_tile_size >= 512, f" seq tile_size {config.seq_tile_size} cannot be less than 512" - assert n % LARGE_TILE_SZ == 0, f"seqlen is not divisible by {LARGE_TILE_SZ}" - num_large_k_tile = n // LARGE_TILE_SZ + assert seqlen_k % LARGE_TILE_SZ == 0, f"Need seqlen_k to be divisible by {LARGE_TILE_SZ} but got {seqlen_k}" + num_large_k_tile = seqlen_k // LARGE_TILE_SZ # inference flag, check if lse is none - inference = not(config.training) + inference = not config.training if inference: assert lse is None, "lse should be none for inference" assert seed is None, f"seed should be None for inference, but got {seed}" @@ -331,11 +355,13 @@ def flash_fwd(q, k, v, seed, o, lse=None, i_f_d = nl.arange(B_D_SIZE)[None, :] i_p_d = nl.arange(B_D_SIZE)[:,None] q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE),dtype=kernel_dtype) - q_tile[i_p_d, i_f_128] = nl.load(q[batch_id, head_id * q_h_per_k_h + i_q_h, i_p_d, i*B_P_SIZE+i_f_128], dtype=kernel_dtype) \ - * softmax_scale # load (d, 128) tile in SBUF + q_tile[i_p_d, i_f_128] = nl.load(q[batch_id, + head_id * q_h_per_k_h + i_q_h, i_p_d, + i * B_P_SIZE + i_f_128], + dtype=kernel_dtype) * softmax_scale # load (d, 128) tile in SBUF # handle first tile and compute max and lse explicitly by passing initialize=True _flash_attention_core(q_local_tile=q_tile, k=cur_k_tile, v=cur_v_tile, - q_h_per_k_h=q_h_per_k_h, + q_h_per_k_h=q_h_per_k_h, seqlen_q=seqlen_q, nheads=h, o_buffer=o_buffer[i], l_buffer=l_buffer[i], m_buffer=m_buffer[i], batch_id=batch_id, head_id=head_id, gqa_head_idx=i_q_h, q_tile_idx=i, local_k_large_tile_idx=0, @@ -376,10 +402,12 @@ def flash_fwd(q, k, v, seed, o, lse=None, i_f_d = nl.arange(B_D_SIZE)[None, :] i_p_d = nl.arange(B_D_SIZE)[:,None] q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE),dtype=kernel_dtype) - q_tile[i_p_d, i_f_128] = nl.load(q[batch_id, head_id * q_h_per_k_h + i_q_h, i_p_d, i*B_P_SIZE+i_f_128], dtype=kernel_dtype) \ - * softmax_scale # load (d, 128) tile in SBUF + q_tile[i_p_d, i_f_128] = nl.load(q[batch_id, + head_id * q_h_per_k_h + i_q_h, i_p_d, + i * B_P_SIZE + i_f_128], + dtype=kernel_dtype) * softmax_scale # load (d, 128) tile in SBUF _flash_attention_core(q_local_tile=q_tile, k=cur_k_tile, v=cur_v_tile, - q_h_per_k_h=q_h_per_k_h, + q_h_per_k_h=q_h_per_k_h, seqlen_q=seqlen_q, nheads=h, o_buffer=o_buffer[i], l_buffer=l_buffer[i], m_buffer=m_buffer[i], batch_id=batch_id, head_id=head_id, gqa_head_idx=i_q_h, q_tile_idx=i, local_k_large_tile_idx=j, @@ -396,19 +424,23 @@ def flash_fwd(q, k, v, seed, o, lse=None, nl.exp(m_buffer[i, num_large_k_tile - 1, i_q_p, i_0_f] - l_buffer[i, num_large_k_tile - 1, i_q_p, i_0_f]), dtype=kernel_dtype) - nl.store(o[batch_id, head_id * q_h_per_k_h + i_q_h, i * B_P_SIZE + i_q_p, i_f_d], out[i_q_p, i_f_d]) + nl.store(o[batch_id, head_id * q_h_per_k_h + i_q_h, i*B_P_SIZE + i_q_p, i_f_d], out[i_q_p, i_f_d]) if not inference: lse_local = nl.zeros((par_dim(B_P_SIZE), 1), dtype=acc_type) lse_local[i_q_p, i_0_f] = nl.copy(l_buffer[i, num_large_k_tile - 1, i_q_p, i_0_f], dtype=acc_type) nl.store(lse[batch_id, head_id * q_h_per_k_h + i_q_h, i_q_p, i + i_0_f], lse_local[i_q_p, i_0_f]) + if config.training: + return o, lse + + return o +@nki.jit def flash_attn_bwd( q_ref, k_ref, v_ref, o_ref, dy_ref, lse_ref, seed_ref, - out_dq_ref, out_dk_ref, out_dv_ref, use_causal_mask=False, mixed_precision=False, dropout_p=0.0, @@ -454,56 +486,60 @@ def flash_attn_bwd( kernel_dtype = q_ref.dtype mixed_dtype = np.dtype(np.float32) if mixed_precision else kernel_dtype - assert q_ref.dtype == k_ref.dtype == v_ref.dtype == o_ref.dtype == dy_ref.dtype \ - == out_dq_ref.dtype == out_dk_ref.dtype == out_dv_ref.dtype + assert q_ref.dtype == k_ref.dtype == v_ref.dtype == o_ref.dtype == dy_ref.dtype assert lse_ref.dtype == mixed_dtype # Shape checking - bs, nheads, d_head, seqlen = q_ref.shape - assert tuple(k_ref.shape) == (bs, nheads, d_head, seqlen), \ + bs, nheads, d_head, seqlen_q = q_ref.shape + _, _, _, seqlen_k = k_ref.shape + assert tuple(k_ref.shape) == (bs, nheads, d_head, seqlen_k), \ f"Input K shape mismatch, got {k_ref.shape}" - assert tuple(v_ref.shape) == (bs, nheads, d_head, seqlen), \ + assert tuple(v_ref.shape) == (bs, nheads, d_head, seqlen_k), \ f"Input V shape mismatch, got {v_ref.shape}" - assert tuple(o_ref.shape) == (bs, nheads, d_head, seqlen), \ + assert tuple(o_ref.shape) == (bs, nheads, d_head, seqlen_q), \ f"Input o shape mismatch, got {o_ref.shape}" - assert tuple(dy_ref.shape) == (bs, nheads, d_head, seqlen), \ + assert tuple(dy_ref.shape) == (bs, nheads, d_head, seqlen_q), \ f"Input dy shape mismatch, got {dy_ref.shape}" - assert tuple(lse_ref.shape) == (bs, nheads, nl.tile_size.pmax, seqlen // nl.tile_size.pmax), \ + assert tuple(lse_ref.shape) == (bs, nheads, nl.tile_size.pmax, seqlen_q // nl.tile_size.pmax), \ f"Input lse shape mismatch, got {lse_ref.shape}" if seed_ref is not None: assert tuple(seed_ref.shape) == (1,), \ f"Input seed shape mismatch, got {seed_ref.shape}" - assert tuple(out_dq_ref.shape) == (bs, nheads, d_head, seqlen), \ - f"Output dQ shape mismatch, got {out_dq_ref.shape}" - assert tuple(out_dk_ref.shape) == (bs, nheads, d_head, seqlen), \ - f"Output dK shape mismatch, got {out_dk_ref.shape}" - assert tuple(out_dv_ref.shape) == (bs, nheads, d_head, seqlen), \ - f"Output dV shape mismatch, got {out_dv_ref.shape}" + out_dq_ref = nl.ndarray((bs, nheads, d_head, seqlen_q), dtype=q_ref.dtype, + buffer=nl.shared_hbm) + out_dk_ref = nl.ndarray((bs, nheads, d_head, seqlen_k), dtype=q_ref.dtype, + buffer=nl.shared_hbm) + out_dv_ref = nl.ndarray((bs, nheads, d_head, seqlen_k), dtype=q_ref.dtype, + buffer=nl.shared_hbm) # FIXME: Add masking for different seqlen values. - assert seqlen % 128 == 0, \ - f"Input sequence length must be divisible by 128, got {seqlen}" + assert seqlen_q % 128 == 0 and seqlen_k % 128 == 0, \ + f"Input sequence lengths must be divisible by 128, got seqlen_q == {seqlen_q} and seqlen_k == {seqlen_k}" # Softmax scaling factor, multiplied onto Q softmax_scale = softmax_scale or 1.0 / float(d_head ** 0.5) # Different batch samples/attention heads have independent attention batch_id = nl.program_id(axis=0) - head_id = nl.program_id(axis=1) - assert nl.num_programs(1) == nheads, \ - f"The grid shape mismatch, got {nl.num_programs(1)} but should be {nheads}" + head_dims = list(range(1, nl.program_ndim())) + head_dims_shape = list(nl.num_programs(i) for i in head_dims) + head_dims_idx = list(nl.program_id(i) for i in head_dims) + head_id = linearize(head_dims_shape, head_dims_idx) - q_seq_n_tiles, q_seq_tile_size = div_ceil(seqlen, 128), 128 + assert n_elts(head_dims_shape) == nheads, \ + f"The grid shape mismatch, got {n_elts(head_dims_shape)} but should be {nheads}" + + q_seq_n_tiles, q_seq_tile_size = div_ceil(seqlen_q, 128), 128 d_head_n_tiles, d_head_tile_size = div_ceil(d_head, 128), min(d_head, 128) - if seqlen >= 512: - k_seq_n_tiles, k_seq_tile_size = seqlen // 512, 512 + if seqlen_k >= 512: + k_seq_n_tiles, k_seq_tile_size = seqlen_k // 512, 512 else: - k_seq_n_tiles, k_seq_tile_size = seqlen // 128, 128 + k_seq_n_tiles, k_seq_tile_size = seqlen_k // 128, 128 - k_seq_n_tiles_backward, k_seq_tile_size_backward = seqlen // 128, 128 + k_seq_n_tiles_backward, k_seq_tile_size_backward = seqlen_k // 128, 128 k_seq_fwd_bwd_tile_multipler = k_seq_tile_size // k_seq_tile_size_backward ############################################################## @@ -615,7 +651,7 @@ def flash_attn_bwd( dk_psum=dk_psum, dv_psum=dv_psum, dq_local_reduced=dq_local_reduced, softmax_exp_bias=softmax_exp_bias, dy_o_sum=dy_o_sum, local_i_q_seq_tile=i_q_seq_tile, local_i_k_seq_tile=i_k_seq_tile, - seqlen=seqlen, d_head=d_head, + seqlen_q=seqlen_q, seqlen_k=seqlen_k, d_head=d_head, nheads=nheads, use_causal_mask=use_causal_mask, kernel_dtype=kernel_dtype, mixed_dtype=mixed_dtype, softmax_scale=softmax_scale, @@ -654,36 +690,43 @@ def flash_attn_bwd( value=dq_local_reduced[i_q_seq_tile, i_d_head_tile, ip_dq, if_dq], ) -@trace + return out_dq_ref, out_dk_ref, out_dv_ref + + def _flash_attn_bwd_core( q_local, k_local, transposed_k_local, v_local, dy_local, dk_psum, dv_psum, dq_local_reduced, softmax_exp_bias, dy_o_sum, local_i_q_seq_tile, local_i_k_seq_tile, - seqlen, d_head, + seqlen_q, seqlen_k, d_head, nheads, use_causal_mask, kernel_dtype, mixed_dtype, softmax_scale, seed_local, dropout_p, dropout_p_local, global_i_q_seq_tile = None, global_i_k_seq_tile = None, + # Used for nl.loop_reduce on dQ if local_i_k_seq_tile is not an index e.g. if it has an offset + local_i_k_seq_tile_for_dq_reduce = None, ): """ - The flash backward core funciton to calculate the gradients of Q, K and V + The flash backward core function to calculate the gradients of Q, K and V of the given tiles. The result will be accumulated into the dk, dv, dq psum """ - q_seq_n_tiles, q_seq_tile_size = div_ceil(seqlen, 128), 128 + q_seq_n_tiles, q_seq_tile_size = div_ceil(seqlen_q, 128), 128 d_head_n_tiles, d_head_tile_size = div_ceil(d_head, 128), min(d_head, 128) - if seqlen >= 512: - k_seq_n_tiles, k_seq_tile_size = seqlen // 512, 512 + if seqlen_k >= 512: + k_seq_n_tiles, k_seq_tile_size = seqlen_k // 512, 512 else: - k_seq_n_tiles, k_seq_tile_size = seqlen // 128, 128 - k_seq_n_tiles_backward, k_seq_tile_size_backward = seqlen // 128, 128 + k_seq_n_tiles, k_seq_tile_size = seqlen_k // 128, 128 + k_seq_n_tiles_backward, k_seq_tile_size_backward = seqlen_k // 128, 128 k_seq_fwd_bwd_tile_multipler = k_seq_tile_size // k_seq_tile_size_backward if global_i_q_seq_tile is None: global_i_q_seq_tile = local_i_q_seq_tile global_i_k_seq_tile = local_i_k_seq_tile + + if local_i_k_seq_tile_for_dq_reduce is None: + local_i_k_seq_tile_for_dq_reduce = local_i_k_seq_tile mask = global_i_q_seq_tile * q_seq_tile_size >= global_i_k_seq_tile * k_seq_tile_size if use_causal_mask else None # PSUM buffer shape: [q_seq_tile_size P, k_seq_tile_size F] @@ -735,7 +778,7 @@ def _flash_attn_bwd_core( if dropout_p > 0.0: offset = global_i_k_seq_tile + global_i_q_seq_tile * k_seq_n_tiles \ + head_id * k_seq_n_tiles * q_seq_n_tiles \ - + batch_id * nl.num_programs(1) * k_seq_n_tiles * q_seq_n_tiles + + batch_id * nheads * k_seq_n_tiles * q_seq_n_tiles offset_seed = nl.add(seed_local[0, 0], offset, mask=mask) nl.random_seed(seed=offset_seed, mask=mask) softmax_y[ip_q, if_k] = nl.dropout(softmax_y[ip_q, if_k], rate=dropout_p_local[ip_q, 0], mask=mask) @@ -778,12 +821,12 @@ def _flash_attn_bwd_core( ##################################################################### softmax_dx_local = nl.ndarray((par_dim(q_seq_tile_size), k_seq_tile_size), dtype=kernel_dtype, buffer=nl.sbuf) softmax_dx_local[ip_q, if_k] = \ - nisa.tensor_scalar(data=softmax_dy[ip_q, if_k], - op0=np.subtract, - operand0=dy_o_sum[local_i_q_seq_tile, ip_q, 0], - op1=np.multiply, - operand1=softmax_y[ip_q, if_k], - mask=mask) + nisa.scalar_tensor_tensor(data=softmax_dy[ip_q, if_k], + op0=np.subtract, + operand0=dy_o_sum[local_i_q_seq_tile, ip_q, 0], + op1=np.multiply, + operand1=softmax_y[ip_q, if_k], + mask=mask) ##################################################################### # Step 5.1 Calculate dK, with matmul(stationary=Q, moving=softmax_dx) @@ -820,10 +863,12 @@ def _flash_attn_bwd_core( mask=mask) dq_local = nl.multiply(dq_psum[ip_dq, if_dq], softmax_scale, dtype=kernel_dtype, mask=mask) dq_local_reduced[local_i_q_seq_tile, i_d_head_tile, ip_dq, if_dq] = nl.loop_reduce( - dq_local, op=np.add, loop_indices=(local_i_k_seq_tile,), + dq_local, op=np.add, loop_indices=(local_i_k_seq_tile_for_dq_reduce,), dtype=mixed_dtype, mask=mask) -def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, out_ref, use_causal_mask=False, + +@nki.jit +def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, use_causal_mask=False, mixed_percision=True): """ Fused self attention kernel for small head size Stable Diffusion workload. @@ -853,16 +898,17 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, out_ref, use_cau # Assume all IO tensors have the same dtype kernel_dtype = q_ref.dtype pe_in_dt = nl.bfloat16 if mixed_percision else np.float32 - assert q_ref.dtype == k_ref.dtype == v_ref.dtype == out_ref.dtype + assert q_ref.dtype == k_ref.dtype == v_ref.dtype # Shape checking bs, d_head, seqlen = q_ref.shape assert d_head <= 128, "Cannot use this kernel for d_head > 128" assert tuple(q_ref.shape) == (bs, d_head, seqlen), 'Input shape mismatch!' assert tuple(k_ref.shape) == (bs, seqlen, d_head), 'Input shape mismatch!' - assert tuple(v_ref.shape) == (bs, seqlen, - d_head), f'Input shape mismatch! Expected: {(bs, seqlen, d_head)} Actual: {tuple(v_ref.shape)}' - assert tuple(out_ref.shape) == (bs, seqlen, d_head), 'Output shape mismatch!' + assert tuple(v_ref.shape) == (bs, seqlen, d_head), \ + f'Input shape mismatch! Expected: {(bs, seqlen, d_head)} Actual: {tuple(v_ref.shape)}' + + out_ref = nl.ndarray((bs, seqlen, d_head), dtype=q_ref.dtype, buffer=nl.shared_hbm) # Softmax scaling factor, multiplied onto Q softmax_scale = 0.125 @@ -1028,4 +1074,5 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, out_ref, use_cau nl.store( out_ref[batch_id, i_q_seq_tile * q_seq_tile_size + if_out, ip_out], value=attn_res_div) - \ No newline at end of file + + return out_ref diff --git a/src/reference/tutorial.py b/src/reference/tutorial.py index 4f3ebef..b32492b 100644 --- a/src/reference/tutorial.py +++ b/src/reference/tutorial.py @@ -5,9 +5,14 @@ """ +from neuronxcc import nki import neuronxcc.nki.language as nl -def add_kernel_nx8x128x512(a_ptr, b_ptr, c_ptr, n_elements): + +@nki.jit +def add_kernel_nx8x128x512(a_ptr, b_ptr, n_elements): + c_ptr = nl.ndarray(a_ptr.shape, dtype=a_ptr.dtype, buffer=nl.shared_hbm) + ix = nl.arange(128)[:, None] iy = nl.arange(512)[None, :] @@ -18,12 +23,9 @@ def add_kernel_nx8x128x512(a_ptr, b_ptr, c_ptr, n_elements): for i in nl.affine_range(8): offset = j * block_size + i * tile_size + 512 * ix + iy - mask = offset < n_elements - a_ptr = a_ptr.ptr + offset - b_ptr = b_ptr.ptr + offset - c_ptr = c_ptr.ptr + offset - - a = nl.load(a_ptr, mask=mask) - b = nl.load(b_ptr, mask=mask) - c = a + b - nl.store(c_ptr, value=c, mask=mask) \ No newline at end of file + a = nl.load(a_ptr[j, i, ix, iy], mask=offset < n_elements) + b = nl.load(b_ptr[j, i, ix, iy], mask=offset < n_elements) + c = nl.add(a, b, mask=offset < n_elements) + nl.store(c_ptr[j, i, ix, iy], value=c, mask=offset < n_elements) + + return c_ptr diff --git a/src/reference/vision.py b/src/reference/vision.py index bc54941..4899d27 100644 --- a/src/reference/vision.py +++ b/src/reference/vision.py @@ -8,10 +8,13 @@ import neuronxcc.nki.language as nl import neuronxcc.nki.isa as nisa +from neuronxcc import nki from neuronxcc.nki.language import par_dim import neuronxcc.nki.typing as nt -def select_and_scatter_kernel(operand_tensor, source_tensor, out_tensor): + +@nki.jit +def select_and_scatter_kernel(operand_tensor, source_tensor): """ Implementation of a select-and-scatter kernel. @@ -51,7 +54,10 @@ def select_and_scatter_kernel(operand_tensor, source_tensor, out_tensor): assert C == 64 and N % 2 == 0 kernel_dtype = operand_tensor.dtype - assert operand_tensor.dtype == source_tensor.dtype == out_tensor.dtype + assert operand_tensor.dtype == source_tensor.dtype + + out_tensor = nl.ndarray((N, C, H, W), dtype=operand_tensor.dtype, + buffer=nl.shared_hbm) p = 128 # num of partitions to use for ib in nl.affine_range(N // 2): @@ -156,8 +162,11 @@ def select_and_scatter_kernel(operand_tensor, source_tensor, out_tensor): nl.store(out_tensor[2 * ib + ib_1, 0:64, 0:H, 0:W], value=out_local[(ib_1 * 64):((ib_1 + 1) * 64), 0:H, 0:W]) + return out_tensor -def resize_nearest_fixed_dma_kernel(data_tensor, out_tensor): + +@nki.jit +def resize_nearest_fixed_dma_kernel(data_tensor, out_shape): """ Resize the input image to the given size using the nearest interpolation mode. This kernel is designed to be used when the scaling factor is not an integer. @@ -174,7 +183,9 @@ def resize_nearest_fixed_dma_kernel(data_tensor, out_tensor): """ in_b, in_h, in_w, in_c = data_tensor.shape - out_b, out_h, out_w, out_c = out_tensor.shape + out_b, out_h, out_w, out_c = out_shape + out_tensor = nl.ndarray(out_shape, dtype=data_tensor.dtype, + buffer=nl.shared_hbm) assert in_b == out_b, "Input batch and output batch must be identical" assert in_c == out_c, "Input channel and output channel must be identical" @@ -198,3 +209,5 @@ def resize_nearest_fixed_dma_kernel(data_tensor, out_tensor): local_data = nl.load(target_addr) dst_addr_0 = out_tile[b_map, i, c_map] nl.store(dst_addr_0, value=local_data) + + return out_tensor diff --git a/src/tutorials/average_pool2d/average_pool2d_jax.py b/src/tutorials/average_pool2d/average_pool2d_jax.py index e3b428d..139c42d 100644 --- a/src/tutorials/average_pool2d/average_pool2d_jax.py +++ b/src/tutorials/average_pool2d/average_pool2d_jax.py @@ -4,29 +4,22 @@ JAX implementation for average pool 2D NKI tutorial. """ -from functools import partial -from jax_neuronx import nki_call -import jax +# NKI_EXAMPLE_40_BEGIN import jax.numpy as jnp - -from average_pool2d_nki_kernels import tensor_avgpool_kernel_ - - -def tensor_avgpool_kernel(in_array, pool_size): - return nki_call( - partial(tensor_avgpool_kernel_, pool_size=pool_size), - in_array, - out_shape=jax.ShapeDtypeStruct((C, HOUT, WOUT), dtype=in_array.dtype), - ) +# NKI_EXAMPLE_40_END +from average_pool2d_nki_kernels import tensor_avgpool_kernel +# NKI_EXAMPLE_40_BEGIN # Reference JAX implementation def jax_average_pool_2D(in_tensor, pool_size): c, h_in, w_in = in_tensor.shape reshaped = in_tensor.reshape(c, h_in // pool_size, pool_size, w_in // pool_size, pool_size) return jnp.nanmean(reshaped, axis=(2, 4)) + # NKI_EXAMPLE_40_END +# NKI_EXAMPLE_41_BEGIN if __name__ == "__main__": POOL_SIZE = 2 C, HIN, WIN = 2, 6, 6 @@ -34,7 +27,9 @@ def jax_average_pool_2D(in_tensor, pool_size): in_array = jnp.arange(C * HIN * WIN, dtype=jnp.float32).reshape(C, HIN, WIN) + # NKI_EXAMPLE_39_BEGIN out_nki = tensor_avgpool_kernel(in_array, pool_size=POOL_SIZE) + # NKI_EXAMPLE_39_END out_jax = jax_average_pool_2D(in_array, pool_size=POOL_SIZE) print(in_array, out_nki, out_jax) @@ -42,4 +37,5 @@ def jax_average_pool_2D(in_tensor, pool_size): if jnp.allclose(out_nki, out_jax): print("NKI and JAX match") else: - print("NKI and JAX differ") \ No newline at end of file + print("NKI and JAX differ") + # NKI_EXAMPLE_41_END diff --git a/src/tutorials/average_pool2d/average_pool2d_nki_kernels.py b/src/tutorials/average_pool2d/average_pool2d_nki_kernels.py index c81a4a5..68d3a31 100644 --- a/src/tutorials/average_pool2d/average_pool2d_nki_kernels.py +++ b/src/tutorials/average_pool2d/average_pool2d_nki_kernels.py @@ -5,48 +5,40 @@ """ import numpy as np +# NKI_EXAMPLE_37_BEGIN import neuronxcc.nki as nki import neuronxcc.nki.language as nl +from neuronxcc.nki.typing import tensor - -def tensor_avgpool_kernel_(in_tensor, out_tensor, pool_size): +@nki.jit +def tensor_avgpool_kernel(in_tensor, pool_size): """NKI kernel to compute a 2D avg-pool operation Args: in_tensor: an input tensor, of shape C x H x W pool_size: an integer representing a (square) pool-window size + + Return: out_tensor: the resulting output tensor, of shape C x (H/pool_size) x (W/pool_size) """ # Get input/output dimensions sz_cin, sz_hin, sz_win = in_tensor.shape - sz_cout, sz_hout, sz_wout = out_tensor.shape - assert sz_cin == sz_cout + sz_hout = sz_hin // pool_size + sz_wout = sz_win // pool_size + # Create output tensor shared between all SPMD instances as result tensor + out_tensor = nl.ndarray((sz_cin, sz_hout, sz_wout), dtype=in_tensor.dtype, + buffer=nl.shared_hbm) # Set relevant sizes sz_p = sz_cin sz_pool = pool_size - # Generate tensor h/w index patterns - # 3D indexing according to [C, H, W] - i_p = nl.arange(sz_p)[:, None, None] # 3D for - i_win = nl.arange(sz_win)[None, None, :] - i_hin = nl.arange(sz_hin)[None, :, None] - - i_wout = nl.arange(sz_wout)[None, None, :] - i_hout = nl.arange(sz_hout)[None, :, None] - # Generate pool index patterns (requires two extra dimensions, for the pool window) - i_0 = nl.arange(sz_p)[:, None, None, None, None] # - i_1 = nl.arange(sz_hin//sz_pool)[None, :, None, None, None] # y_outer - i_2 = nl.arange(sz_pool)[None, None, :, None, None] # y_inner - i_3 = nl.arange(sz_win//sz_pool)[None, None, None, :, None] # x_outer - i_4 = nl.arange(sz_pool)[None, None, None, None, :] # x_inner + i0, i1, i2, i3, i4 = nl.mgrid[0:sz_p, 0:sz_hin//sz_pool, 0:sz_pool, 0:sz_win//sz_pool, 0:sz_pool] # Load input data from external memory to on-chip memory - # Declare ndarray to force a 3D tensor (temporary requirement) - in_tile = nl.ndarray([sz_p, sz_hin, sz_win], dtype=in_tensor.dtype) - in_tile[:,:,:] = nl.load(in_tensor[i_p, i_hin, i_win]) + in_tile: tensor[sz_p, sz_hin, sz_win] = nl.load(in_tensor) # Perform the pooling operation: # We use numpy's advanced indexing, in order to extend in_tile to 5D, and then reduce-average two dimension. @@ -54,10 +46,15 @@ def tensor_avgpool_kernel_(in_tensor, out_tensor, pool_size): # axis[1] and axis[2] together index the rows, with axis[2] responsible for inner strides # (i.e. inside a pooling window), and axis[1] responsible for the outer strides. As such, we reduce over axis[2]. # Similarly, axis[3] and axis[4] together index the columns, and we thus reduce over axis[4]. - out_tile = nl.sum(in_tile[i_0, sz_pool*i_1+i_2, sz_pool*i_3+i_4], axis=[2,4]) / (pool_size*pool_size) + out_tile : tensor[sz_p, sz_hout, sz_wout] = nl.sum(in_tile[i0, sz_pool*i1+i2, sz_pool*i3+i4], + axis=[2,4]) / (pool_size*pool_size) + + # Store the results back to hbm + nl.store(out_tensor, value=out_tile) - # Store the results back to external memory - nl.store(out_tensor[i_p, i_hout, i_wout], value=out_tile) + # Transfer the ownership of `out_tensor` to the caller + return out_tensor + # NKI_EXAMPLE_37_END # Reference NumPy implementation @@ -74,10 +71,8 @@ def np_average_pool_2D(in_tensor, pool_size): HOUT, WOUT = HIN//POOL_SIZE, WIN//POOL_SIZE in_tensor = np.arange(C * HIN * WIN, dtype=np.float16).reshape(C, HIN, WIN) - out_nki = np.zeros((C, HOUT, WOUT), dtype=np.float16) - tensor_avgpool_kernel_baremetal = nki.baremetal(tensor_avgpool_kernel_) - tensor_avgpool_kernel_baremetal(in_tensor, out_nki, POOL_SIZE) + out_nki = tensor_avgpool_kernel(in_tensor, POOL_SIZE) out_np = np_average_pool_2D(in_tensor, POOL_SIZE) diff --git a/src/tutorials/average_pool2d/average_pool2d_torch.py b/src/tutorials/average_pool2d/average_pool2d_torch.py index 3409a31..c5fb4ea 100644 --- a/src/tutorials/average_pool2d/average_pool2d_torch.py +++ b/src/tutorials/average_pool2d/average_pool2d_torch.py @@ -4,13 +4,14 @@ PyTorch implementation for average pool 2D NKI tutorial. """ +# NKI_EXAMPLE_38_BEGIN import torch -from torch_neuronx import nki_jit from torch_xla.core import xla_model as xm - -from average_pool2d_nki_kernels import tensor_avgpool_kernel_ +# NKI_EXAMPLE_38_END +from average_pool2d_nki_kernels import tensor_avgpool_kernel +# NKI_EXAMPLE_38_BEGIN if __name__ == "__main__": device = xm.xla_device() @@ -22,8 +23,7 @@ in_tensor = torch.arange(C * HIN * WIN, dtype=torch.bfloat16).reshape(C, HIN, WIN).to(device=device) out_nki = torch.zeros((C, HOUT, WOUT), dtype=torch.bfloat16).to(device=device) - tensor_avgpool_kernel_torch = nki_jit(tensor_avgpool_kernel_) - tensor_avgpool_kernel_torch(in_tensor, out_nki, POOL_SIZE) + out_nki = tensor_avgpool_kernel(in_tensor, POOL_SIZE) out_torch = torch.nn.functional.avg_pool2d(in_tensor, POOL_SIZE, POOL_SIZE) @@ -33,3 +33,4 @@ print("NKI and Torch match") else: print("NKI and Torch differ") + # NKI_EXAMPLE_38_END diff --git a/src/tutorials/fused_mamba/mamba_nki_kernels.py b/src/tutorials/fused_mamba/mamba_nki_kernels.py index 9f8af60..4ff6642 100644 --- a/src/tutorials/fused_mamba/mamba_nki_kernels.py +++ b/src/tutorials/fused_mamba/mamba_nki_kernels.py @@ -4,16 +4,19 @@ Mamba-v1 NKI kernel implementation. """ +# NKI_EXAMPLE_25_BEGIN import neuronxcc.nki as nki import neuronxcc.nki.language as nl import neuronxcc.nki.isa as nisa import numpy as np +# NKI_EXAMPLE_25_END import os import argparse import itertools - -def mamba_v1(delta, u, A, B, C, output): +# NKI_EXAMPLE_25_BEGIN +@nki.jit +def mamba_v1(delta, u, A, B, C): """Computes the SSM operation in the Mamba model. :param delta: (batch_size, channels, seq_len) @@ -24,6 +27,9 @@ def mamba_v1(delta, u, A, B, C, output): :return: (batch_size, channels, seq_len) """ batch_size, channels, seq_len = delta.shape + output = nl.ndarray((batch_size, channels, seq_len), dtype=delta.dtype, + buffer=nl.shared_hbm) + _, state_size = A.shape # We can relax this using mask paramters in all the NKI API calls @@ -84,8 +90,12 @@ def mamba_v1(delta, u, A, B, C, output): nl.store(output[i_batch, channel_start:channel_start+channel_psize, 0:seq_len], scanC_accum[i_channel_tile, 0:channel_psize, 0:seq_len]) + return output +# NKI_EXAMPLE_25_END -def mamba_v2(delta, u, A, B, C, output): +# NKI_EXAMPLE_26_BEGIN +@nki.jit +def mamba_v2(delta, u, A, B, C): """Computes the SSM operation in the Mamba model. :param delta: (batch_size, channels, seq_len) @@ -96,6 +106,8 @@ def mamba_v2(delta, u, A, B, C, output): :return: (batch_size, channels, seq_len) """ batch_size, channels, seq_len = delta.shape + output = nl.ndarray((batch_size, channels, seq_len), dtype=delta.dtype, + buffer=nl.shared_hbm) _, state_size = A.shape assert channels % 128 == 0 @@ -153,8 +165,12 @@ def mamba_v2(delta, u, A, B, C, output): nl.store(output[i_batch, channel_start:channel_start+channel_psize, 0:seq_len], scanC_accum[0:channel_psize, 0:seq_len]) + return output +# NKI_EXAMPLE_26_END + -def mamba_v3(delta, u, A, B, C, output): +@nki.jit +def mamba_v3(delta, u, A, B, C): """Computes the SSM operation in the Mamba model. :param delta: (batch_size, channels, seq_len) @@ -165,6 +181,8 @@ def mamba_v3(delta, u, A, B, C, output): :return: (batch_size, channels, seq_len) """ batch_size, channels, seq_len = delta.shape + output = nl.ndarray((batch_size, channels, seq_len), dtype=delta.dtype, + buffer=nl.shared_hbm) _, state_size = A.shape # Map channels to the partition dimension @@ -239,6 +257,7 @@ def mamba_v3(delta, u, A, B, C, output): # Store scanC_accum for a single batch to output nl.store(output[i_batch, channel_start:channel_start+channel_psize, 0:seq_len], scanC_accum[0:channel_psize, 0:seq_len]) + return output def parse_args(): @@ -310,9 +329,7 @@ def parse_args(): if args.mode == "accuracy": # v1: reference kernel print(f">>>> Running v1 (reference).") - nki_out_v1 = np.empty((batch, channels, seq_len), dtype=dtype) - nki.baremetal(mamba_v1)\ - (delta, u, A, B, C, nki_out_v1) + nki_out_v1 = mamba_v1(delta, u, A, B, C) for version in args.version: if version == "v1": @@ -321,9 +338,7 @@ def parse_args(): print(f">>>> Running version {version}.") func = func_dict[version] - nki_out_test = np.empty((batch, channels, seq_len), dtype=dtype) - nki.baremetal(func)\ - (delta, u, A, B, C, nki_out_test) + nki_out_test = func(delta, u, A, B, C) print(f">>>> mamba {version} matches?", np.all(nki_out_test == nki_out_v1)) assert np.all(nki_out_test == nki_out_v1) @@ -333,11 +348,10 @@ def parse_args(): for version in args.version: print(f">>>> Running version {version}.") func = func_dict[version] - nki_out_test = np.empty((batch, channels, seq_len), dtype=dtype) nki.benchmark(func, save_neff_name='file.neff', save_trace_name='profile.ntff')\ - (delta, u, A, B, C, nki_out_test) + (delta, u, A, B, C) # TODO: rename neff/ntff (bug in nki.benchmark with neff name) os.rename("file.neff", f"{version}_b{batch}_sl{seq_len}_c{channels}_ss{state_size}.neff") os.rename("profile.ntff", f"{version}_b{batch}_sl{seq_len}_c{channels}_ss{state_size}.ntff") diff --git a/src/tutorials/fused_mamba/mamba_torch.py b/src/tutorials/fused_mamba/mamba_torch.py index a2e593f..cd94a0b 100644 --- a/src/tutorials/fused_mamba/mamba_torch.py +++ b/src/tutorials/fused_mamba/mamba_torch.py @@ -5,6 +5,7 @@ """ +# NKI_EXAMPLE_24_BEGIN import torch import torch_neuronx import torch_xla.core.xla_model as xm @@ -99,16 +100,14 @@ def parse_args(): torch_out = mamba_layer(delta, A, B, u, C) xm.mark_step() print(torch_out) + # NKI_EXAMPLE_24_END if args.mode == "accuracy": # Call NKI mamba_v1 kernel to check accuracy from mamba_nki_kernels import mamba_v1 - from torch_neuronx import nki_jit - - nki_out = torch.empty((batch, channels, seq_len), dtype=dtype, device=device) xm.mark_step() - nki_jit(mamba_v1)(delta, u, A, B, C, nki_out) + nki_out = mamba_v1(delta, u, A, B, C) xm.mark_step() allclose = torch.allclose(torch_out, nki_out, atol=1e-2, rtol=1e-2) diff --git a/src/tutorials/layernorm/layernorm_nki_kernel.py b/src/tutorials/layernorm/layernorm_nki_kernel.py index 503ce7d..c0c235c 100644 --- a/src/tutorials/layernorm/layernorm_nki_kernel.py +++ b/src/tutorials/layernorm/layernorm_nki_kernel.py @@ -4,21 +4,27 @@ LayerNorm NKI kernel implementation. """ +# NKI_EXAMPLE_45_BEGIN import neuronxcc.nki as nki import neuronxcc.nki.language as nl import neuronxcc.nki.isa as nisa import numpy as np import math +# NKI_EXAMPLE_45_END import os import argparse -def nki_layernorm_kernel_v1(input_tensor, epsilon, gamma_vector, beta_vector, output_tensor): +# NKI_EXAMPLE_45_BEGIN +@nki.jit +def nki_layernorm_kernel_v1(input_tensor, epsilon, gamma_vector, beta_vector): """Computes LayerNorm. Used nki.language APIs only. """ + output_tensor = nl.ndarray(input_tensor.shape, dtype=input_tensor.dtype, + buffer=nl.shared_hbm) + # Ensure that the shapes of tensors match - assert input_tensor.shape == output_tensor.shape assert input_tensor.shape[1] == gamma_vector.shape[0] == beta_vector.shape[0] # Generate tile indices for loading/storing data @@ -58,12 +64,20 @@ def nki_layernorm_kernel_v1(input_tensor, epsilon, gamma_vector, beta_vector, ou nl.store(output_tensor[i * nl.tile_size.pmax + i_p_io, i_f_io], value=output_sb, mask=(i * nl.tile_size.pmax + i_p_io < num_rows)) -def nki_layernorm_kernel_v2(input_tensor, epsilon, gamma_vector, beta_vector, output_tensor): + return output_tensor + # NKI_EXAMPLE_45_END + + +# NKI_EXAMPLE_46_BEGIN +@nki.jit +def nki_layernorm_kernel_v2(input_tensor, epsilon, gamma_vector, beta_vector): """Computes LayerNorm. Used nki.isa APIs to calculate mean/variance and perform shift/scale. """ + output_tensor = nl.ndarray(input_tensor.shape, dtype=input_tensor.dtype, + buffer=nl.shared_hbm) + # Ensure that the shapes of tensors match - assert input_tensor.shape == output_tensor.shape assert input_tensor.shape[1] == gamma_vector.shape[0] == beta_vector.shape[0] # Generate tile indices for loading/storing data @@ -122,69 +136,66 @@ def nki_layernorm_kernel_v2(input_tensor, epsilon, gamma_vector, beta_vector, ou nl.store(output_tensor[i * nl.tile_size.pmax + i_p_io, i_f_io], value=output_sb, mask=(i * nl.tile_size.pmax + i_p_io < num_rows)) + return output_tensor + # NKI_EXAMPLE_46_END + def parse_args(): - parser = argparse.ArgumentParser( - """Run LayerNorm pytorch implementation. - """) - parser.add_argument("--nrows", - default=4*1024, - type=int, - help="""The number of input rows""") - parser.add_argument("--ncols", - default=8*1024, - type=int, - help="""The number of input columns""") - parser.add_argument("--mode", - choices=["accuracy", "perf"], - default="accuracy", - help="""Do accuracy test or perf test. - Accuracy test compares LayerNorm kernel against PyTorch implementation. - Perf test will generate a NEFF for the PyTorch implementation in local directory - for a manual run of neuron-profile. - """) - args = parser.parse_args() - return args + parser = argparse.ArgumentParser( + """Run LayerNorm pytorch implementation. + """) + parser.add_argument("--nrows", + default=4*1024, + type=int, + help="""The number of input rows""") + parser.add_argument("--ncols", + default=8*1024, + type=int, + help="""The number of input columns""") + parser.add_argument("--mode", + choices=["accuracy", "perf"], + default="accuracy", + help="""Do accuracy test or perf test. + Accuracy test compares LayerNorm kernel against PyTorch implementation. + Perf test will generate a NEFF for the PyTorch implementation in local directory + for a manual run of neuron-profile. + """) + args = parser.parse_args() + return args if __name__ == "__main__": - args = parse_args() - func_dict = {"v1": nki_layernorm_kernel_v1, - "v2": nki_layernorm_kernel_v2, - } - - # Generate toy example - num_rows = args.nrows - num_cols = args.ncols - input_tensor = np.random.rand(num_rows, num_cols).astype(np.float32) - gamma_vector = np.random.rand(num_cols).astype(np.float32) - beta_vector = np.random.rand(num_cols).astype(np.float32) - epsilon = 1e-5 - - if args.mode == "accuracy": - # version 1 - print(f">>>> Running version 1") - nki_out_v1 = np.empty((num_rows, num_cols), dtype=np.float32) - nki.baremetal(nki_layernorm_kernel_v1)\ - (input_tensor, epsilon, gamma_vector, beta_vector, nki_out_v1) - # version 2 - print(f">>>> Running version 2") - nki_out_v2 = np.empty((num_rows, num_cols), dtype=np.float32) - nki.baremetal(nki_layernorm_kernel_v2)\ - (input_tensor, epsilon, gamma_vector, beta_vector, nki_out_v2) - # compare - np_all = np.all(nki_out_v1 == nki_out_v1) - print(f">>>> LayerNorm V1 and V2 matches?", np_all) - assert np_all - - else: - # perf mode - for version in ["v1", "v2"]: - print(f">>>> Running version {version}.") - func = func_dict[version] - nki_out_test = np.empty((num_rows, num_cols), dtype=np.float32) - nki.benchmark(func, - save_neff_name='file.neff', - save_trace_name='profile.ntff')\ - (input_tensor, epsilon, gamma_vector, beta_vector, nki_out_test) - os.rename("file.neff", f"{version}_{num_rows}_{num_cols}.neff") - os.rename("profile.ntff", f"{version}_{num_rows}_{num_cols}.ntff") + args = parse_args() + func_dict = {"v1": nki_layernorm_kernel_v1, + "v2": nki_layernorm_kernel_v2, + } + + # Generate toy example + num_rows = args.nrows + num_cols = args.ncols + input_tensor = np.random.rand(num_rows, num_cols).astype(np.float32) + gamma_vector = np.random.rand(num_cols).astype(np.float32) + beta_vector = np.random.rand(num_cols).astype(np.float32) + epsilon = 1e-5 + + if args.mode == "accuracy": + # version 1 + print(f">>>> Running version 1") + nki_out_v1 = nki_layernorm_kernel_v1(input_tensor, epsilon, gamma_vector, beta_vector) + # version 2 + print(f">>>> Running version 2") + nki_out_v2 = nki_layernorm_kernel_v2(input_tensor, epsilon, gamma_vector, beta_vector) + # compare + np_all = np.all(nki_out_v1 == nki_out_v1) + print(f">>>> LayerNorm V1 and V2 matches?", np_all) + assert np_all + + else: + # perf mode + for version in ["v1", "v2"]: + print(f">>>> Running version {version}.") + func = func_dict[version] + benchmark_kernel = nki.benchmark(func, save_neff_name='file.neff', + save_trace_name='profile.ntff') + nki_out_test = benchmark_kernel(input_tensor, epsilon, gamma_vector, beta_vector) + os.rename("file.neff", f"{version}_{num_rows}_{num_cols}.neff") + os.rename("profile.ntff", f"{version}_{num_rows}_{num_cols}.ntff") diff --git a/src/tutorials/layernorm/layernorm_torch.py b/src/tutorials/layernorm/layernorm_torch.py index 59853fd..c2be186 100644 --- a/src/tutorials/layernorm/layernorm_torch.py +++ b/src/tutorials/layernorm/layernorm_torch.py @@ -4,9 +4,9 @@ LayerNorm NKI kernel implementation. """ +# NKI_EXAMPLE_47_BEGIN import torch from torch_xla.core import xla_model as xm -from torch_neuronx import nki_jit import argparse import os @@ -42,13 +42,16 @@ def parse_args(): args = parser.parse_args() return args + +from neuronxcc.nki.docs.examples.layernorm.layernorm_nki_kernel import nki_layernorm_kernel_v1, \ + nki_layernorm_kernel_v2 + if __name__ == "__main__": args = parse_args() - from neuronxcc.nki.docs.examples.layernorm.layernorm_nki_kernel import nki_layernorm_kernel_v1, nki_layernorm_kernel_v2 func_dict = {"v1": nki_layernorm_kernel_v1, "v2": nki_layernorm_kernel_v2, } - + device = xm.xla_device() num_rows = args.nrows num_cols = args.ncols @@ -58,7 +61,7 @@ def parse_args(): gamma_vector = torch.rand((num_cols), dtype=torch.float32) beta_vector = torch.rand((num_cols), dtype=torch.float32) epsilon = 1e-5 - + # Compute torch layernorm layer in cpu output_torch = layernorm_layer(input_tensor, epsilon, gamma_vector, beta_vector) @@ -66,17 +69,15 @@ def parse_args(): input_tensor = input_tensor.to(device=device) gamma_vector = gamma_vector.to(device=device) beta_vector = beta_vector.to(device=device) - output_nki = torch.zeros((num_rows, num_cols), dtype=torch.float32).to(device=device) print(f">>>> Running version {args.version}.") func = func_dict[args.version] # add nki_jit decorator - nki_layernorm_kernel = nki_jit(func) # Compute NKI layernorm kernel in NeuronDevice xm.mark_step() - nki_layernorm_kernel(input_tensor, epsilon, gamma_vector, beta_vector, output_nki) + output_nki = func(input_tensor, epsilon, gamma_vector, beta_vector) xm.mark_step() output_nki = output_nki.to(device='cpu') @@ -86,5 +87,6 @@ def parse_args(): print("NKI and Torch match") else: print("NKI and Torch differ") - + # NKI_EXAMPLE_47_END + assert allclose \ No newline at end of file diff --git a/src/tutorials/matrix_multiplication/matrix_multiplication_nki_kernels.py b/src/tutorials/matrix_multiplication/matrix_multiplication_nki_kernels.py index 7aeb5d6..8f913f2 100644 --- a/src/tutorials/matrix_multiplication/matrix_multiplication_nki_kernels.py +++ b/src/tutorials/matrix_multiplication/matrix_multiplication_nki_kernels.py @@ -12,7 +12,9 @@ import numpy as np -def nki_matmul_basic_(lhsT, rhs, result): +# NKI_EXAMPLE_16_BEGIN +@nki.jit +def nki_matmul_basic_(lhsT, rhs): """NKI kernel to compute a 64x128x512 matrix multiplication operation Args: @@ -20,8 +22,11 @@ def nki_matmul_basic_(lhsT, rhs, result): matrix multiplication, delivered transposed for optimal performance rhs: an input tensor of shape [128,512], a right hand side argument of the matrix multiplication + Returns: result: the resulting output tensor of shape [64,512] """ + result = nl.ndarray((64, 512), dtype=lhsT.dtype, buffer=nl.shared_hbm) + # Defining indexes for input LHS.T # - Note: here we take LayoutConstraint #1 into account: # "For MatMult, contraction axis must be mapped to P-dim" @@ -53,8 +58,13 @@ def nki_matmul_basic_(lhsT, rhs, result): # This dictates which indices to use to address the result tile. nl.store(result[i_out_p, i_out_f], value=result_sbuf) + return result + # NKI_EXAMPLE_16_END + -def nki_matmul_tiled_(lhsT, rhs, result): +# NKI_EXAMPLE_18_BEGIN +@nki.jit +def nki_matmul_tiled_(lhsT, rhs): """NKI kernel to compute a matrix multiplication operation in a tiled manner Args: @@ -64,12 +74,14 @@ def nki_matmul_tiled_(lhsT, rhs, result): rhs: an input tensor of shape [K,N], where K is a multiple of 128, and N is a multiple of 512. It is the right-hand-side argument of the matrix multiplication. + Returns: result: the resulting output tensor of shape [M,N] """ K, M = lhsT.shape K_, N = rhs.shape assert K == K_, "lhsT and rhs must have the same contraction dimension" + result = nl.ndarray((M, N), dtype=lhsT.dtype, buffer=nl.shared_hbm) TILE_M = nl.tile_size.gemm_stationary_fmax # 128 TILE_K = nl.tile_size.pmax # 128 @@ -100,8 +112,13 @@ def nki_matmul_tiled_(lhsT, rhs, result): nl.store(result[m * TILE_M:(m + 1) * TILE_M, n * TILE_N:(n + 1) * TILE_N], value=res_sb) + return result + # NKI_EXAMPLE_18_END -def nki_matmul_hoist_load_(lhsT, rhs, result): + +# NKI_EXAMPLE_19_BEGIN +@nki.jit +def nki_matmul_hoist_load_(lhsT, rhs): """NKI kernel to compute a matrix multiplication operation in a tiled manner while hoisting the load of the lhsT and rhs to outer loops. @@ -112,12 +129,14 @@ def nki_matmul_hoist_load_(lhsT, rhs, result): rhs: an input tensor of shape [K,N], where K is a multiple of 128, and N is a multiple of 512. It is the right-hand-side argument of the matrix multiplication. + Returns: result: the resulting output tensor of shape [M,N] """ K, M = lhsT.shape K_, N = rhs.shape assert K == K_, "lhsT and rhs must have the same contraction dimension" + result = nl.ndarray((M, N), dtype=lhsT.dtype, buffer=nl.shared_hbm) TILE_M = nl.tile_size.gemm_stationary_fmax # 128 TILE_K = nl.tile_size.pmax # 128 @@ -163,8 +182,13 @@ def nki_matmul_hoist_load_(lhsT, rhs, result): res_sb = nl.copy(res_psum, dtype=result.dtype) nl.store(result[m * TILE_M + i_res.p, n * TILE_N + i_res.x], value=res_sb) + return result + # NKI_EXAMPLE_19_END + -def nki_matmul_block_free_dimension_(lhsT, rhs, result): +# NKI_EXAMPLE_20_BEGIN +@nki.jit +def nki_matmul_block_free_dimension_(lhsT, rhs): """NKI kernel to compute a matrix multiplication operation while blocking the free dimensions of the LHS and RHS to improve memory access pattern. @@ -175,12 +199,14 @@ def nki_matmul_block_free_dimension_(lhsT, rhs, result): rhs: an input tensor of shape [K,N], where K is a multiple of 128, and N is a multiple of 512. It is the right-hand-side argument of the matrix multiplication. + Returns: result: the resulting output tensor of shape [M,N] """ K, M = lhsT.shape K_, N = rhs.shape assert K == K_, "lhsT and rhs must have the same contraction dimension" + result = nl.ndarray((M, N), dtype=lhsT.dtype, buffer=nl.shared_hbm) TILE_M = nl.tile_size.gemm_stationary_fmax # 128 TILE_K = nl.tile_size.pmax # 128 @@ -243,11 +269,15 @@ def nki_matmul_block_free_dimension_(lhsT, rhs, result): (n * TILES_IN_BLOCK_N + bn) * TILE_N + i_res.x], value=res_sb) + return result + # NKI_EXAMPLE_20_END + +# NKI_EXAMPLE_21_BEGIN +@nki.jit def nki_matmul_fully_optimized_( lhsT, rhs, - result, # Meta-parameters TILES_IN_BLOCK_M=16, TILES_IN_BLOCK_N=2, @@ -264,13 +294,15 @@ def nki_matmul_fully_optimized_( rhs: an input tensor of shape [K,N], where K is a multiple of 128 * TILES_IN_BLOCK_K and N is a multiple of 512 * TILES_IN_BLOCK_N. It is the right-hand-side argument of the matrix multiplication. - result: the resulting output tensor of shape [M,N] TILES_IN_BLOCK_*: meta parameters to control blocking dimensions + Returns: + result: the resulting output tensor of shape [M,N] """ K, M = lhsT.shape K_, N = rhs.shape assert K == K_, "lhsT and rhs must have the same contraction dimension" + result = nl.ndarray((M, N), dtype=lhsT.dtype, buffer=nl.shared_hbm) TILE_M = nl.tile_size.gemm_stationary_fmax # 128 TILE_K = nl.tile_size.pmax # 128 @@ -360,16 +392,19 @@ def nki_matmul_fully_optimized_( BLOCK_N * n + i_res_packed.x], value=result_packed[i_res_packed.p, i_res_packed.x]) + return result +# NKI_EXAMPLE_21_END + +# NKI_EXAMPLE_23_BEGIN if __name__ == "__main__": # Benchmarking with large matrices to show the differences more clearly lhsT = nt.tensor[[8192, 4096], nl.bfloat16] rhs = nt.tensor[[8192, 8192], nl.bfloat16] - output = nt.tensor[[4096, 8192], nl.bfloat16] def benchmark_nki(nki_func): bench_func = nki.benchmark(warmup=5, iters=10)(nki_func) - bench_func(lhsT, rhs, output) + bench_func(lhsT, rhs) latency_res = bench_func.benchmark_result.nc_latency p99 = latency_res.get_latency_percentile(99) print("Latency: {:.2f} ms (P99)".format(p99 / 1000.0)) @@ -385,3 +420,4 @@ def benchmark_nki(nki_func): print("Benchmarking nki_matmul_fully_optimized") benchmark_nki(nki_matmul_fully_optimized_) + # NKI_EXAMPLE_23_END diff --git a/src/tutorials/matrix_multiplication/matrix_multiplication_torch.py b/src/tutorials/matrix_multiplication/matrix_multiplication_torch.py index ec0084c..de39ce8 100644 --- a/src/tutorials/matrix_multiplication/matrix_multiplication_torch.py +++ b/src/tutorials/matrix_multiplication/matrix_multiplication_torch.py @@ -7,23 +7,21 @@ import torch from torch_xla.core import xla_model as xm -from torch_neuronx import nki_jit from matrix_multiplication_nki_kernels import nki_matmul_basic_, nki_matmul_tiled_, nki_matmul_hoist_load_, nki_matmul_block_free_dimension_, nki_matmul_fully_optimized_ if __name__ == "__main__": + # NKI_EXAMPLE_17_BEGIN device = xm.xla_device() cpu = torch.device('cpu') # Test the small workload with basic kernel lhs_small = torch.rand((64, 128), dtype=torch.bfloat16, device=device) rhs_small = torch.rand((128, 512), dtype=torch.bfloat16, device=device) - output_small = torch.zeros((64, 512), dtype=torch.bfloat16, device=device) # Run NKI kernel - nki_matmul_basic_jit = nki_jit(nki_matmul_basic_) - nki_matmul_basic_jit(lhs_small.T, rhs_small, output_small) + output_small = nki_matmul_basic_(lhs_small.T, rhs_small) # Run torch reference output_small_torch = torch.matmul(lhs_small, rhs_small) @@ -34,18 +32,18 @@ print("NKI and Torch match") else: print("NKI and Torch differ") + # NKI_EXAMPLE_17_END + # NKI_EXAMPLE_22_BEGIN # Test the large workload with tiled kernels lhs = torch.rand((4096, 1024), dtype=torch.bfloat16, device=device) rhs = torch.rand((1024, 2048), dtype=torch.bfloat16, device=device) - output = torch.zeros((4096, 2048), dtype=torch.bfloat16, device=device) # Run torch reference output_torch = torch.matmul(lhs, rhs).to(device=cpu) def check_match(nki_func): - jit_func = nki_jit(nki_func) - jit_func(lhs.T, rhs, output) + output = nki_func(lhs.T, rhs) output_nki = output.to(device=cpu) if torch.allclose(output_torch, output_nki, atol=1e-4, rtol=1e-2): print("NKI and Torch match") @@ -63,3 +61,4 @@ def check_match(nki_func): print("Checking correctness of nki_matmul_fully_optimized") check_match(nki_matmul_fully_optimized_) + # NKI_EXAMPLE_22_END diff --git a/src/tutorials/rmsnorm/rmsnorm_jax.py b/src/tutorials/rmsnorm/rmsnorm_jax.py index 5b412d8..f0efc20 100644 --- a/src/tutorials/rmsnorm/rmsnorm_jax.py +++ b/src/tutorials/rmsnorm/rmsnorm_jax.py @@ -7,9 +7,9 @@ import jax import jax.numpy as jnp -from jax_neuronx import nki_call from rmsnorm_nki_kernels import nki_rmsnorm_kernel +# NKI_EXAMPLE_44_BEGIN # Reference JAX implementation def jax_rms_norm(a_tensor, g_tensor): # Square the tensor (element-wise) @@ -26,11 +26,7 @@ def jax_rms_norm(a_tensor, g_tensor): a_tensor = jax.random.uniform(a_key, (250, 512)) g_tensor = jax.random.uniform(g_key, (512,)) -output_nki = nki_call( - nki_rmsnorm_kernel, - a_tensor, g_tensor, - out_shape=jax.ShapeDtypeStruct(a_tensor.shape, dtype=a_tensor.dtype), -) +output_nki = nki_rmsnorm_kernel(a_tensor, g_tensor) print(a_tensor) @@ -43,3 +39,4 @@ def jax_rms_norm(a_tensor, g_tensor): print("NKI and JAX match") else: print("NKI and JAX differ") + # NKI_EXAMPLE_44_END diff --git a/src/tutorials/rmsnorm/rmsnorm_nki_kernels.py b/src/tutorials/rmsnorm/rmsnorm_nki_kernels.py index 140b682..402eecd 100644 --- a/src/tutorials/rmsnorm/rmsnorm_nki_kernels.py +++ b/src/tutorials/rmsnorm/rmsnorm_nki_kernels.py @@ -6,20 +6,23 @@ """ import numpy as np +# NKI_EXAMPLE_42_BEGIN import math import neuronxcc.nki as nki import neuronxcc.nki.language as nl -def nki_rmsnorm_kernel(a_tensor, g_tensor, out_tensor): +@nki.jit +def nki_rmsnorm_kernel(a_tensor, g_tensor): # Calculate out_tensor = a_tensor/RMS(a_tensor) * g_tensor # Where RMS(a_tensor) = sqrt((1/N) * sum(a_tensor * a_tensor)) # and N = a_tensor.shape[1] # Reduction (mean) is performed in the free (2nd) dimension + out_tensor = nl.ndarray(a_tensor.shape, dtype=a_tensor.dtype, + buffer=nl.shared_hbm) # Make sure shapes match assert a_tensor.shape[1] == g_tensor.shape[0] - assert a_tensor.shape == out_tensor.shape # Generate tensor indices to index input tensor ix = nl.arange(128)[:, None] @@ -68,14 +71,15 @@ def nki_rmsnorm_kernel(a_tensor, g_tensor, out_tensor): nl.store(out_tensor[i * 128 + ix, iy], value=out_tile, mask=(i * 128 + ix < num_rows)) + return out_tensor + # NKI_EXAMPLE_42_END + if __name__ == "__main__": a = np.random.rand(128, 512).astype(np.float32) g = np.random.rand(512).astype(np.float32) - output_nki = np.zeros(a.shape, dtype=a.dtype) - nki_rmsnorm_kernel_baremetal = nki.baremetal(nki_rmsnorm_kernel) - nki_rmsnorm_kernel_baremetal(a, g, output_nki) + output_nki = nki_rmsnorm_kernel(a, g) print(f"output_nki={output_nki}") # One-line numpy RMSNorm diff --git a/src/tutorials/rmsnorm/rmsnorm_torch.py b/src/tutorials/rmsnorm/rmsnorm_torch.py index 71ced3e..c9bfc69 100644 --- a/src/tutorials/rmsnorm/rmsnorm_torch.py +++ b/src/tutorials/rmsnorm/rmsnorm_torch.py @@ -5,11 +5,11 @@ """ -from torch_neuronx.xla_impl.ops import nki_jit import torch import os from rmsnorm_nki_kernels import nki_rmsnorm_kernel +# NKI_EXAMPLE_43_BEGIN # Reference torch implementation def torch_rmsnorm_kernel(a_tensor, g_tensor): # Square the tensor (element-wise) @@ -25,13 +25,10 @@ def torch_rmsnorm_kernel(a_tensor, g_tensor): from torch_xla.core import xla_model as xm device = xm.xla_device() -nki_rmsnorm_kernel = nki_jit(nki_rmsnorm_kernel) - a_tensor = torch.rand((250, 512), dtype=torch.float32).to(device=device) g_tensor = torch.rand((512), dtype=torch.float32).to(device=device) -output_nki = torch.zeros((250, 512), dtype=torch.float32).to(device=device) -nki_rmsnorm_kernel(a_tensor, g_tensor, output_nki) +output_nki = nki_rmsnorm_kernel(a_tensor, g_tensor) print(f"output_nki={output_nki}") output_torch = torch_rmsnorm_kernel(a_tensor, g_tensor) @@ -41,3 +38,4 @@ def torch_rmsnorm_kernel(a_tensor, g_tensor): print("NKI and Torch match") else: print("NKI and Torch differ") +# NKI_EXAMPLE_43_END diff --git a/src/tutorials/sd_attention/sd_attention_nki_kernels.py b/src/tutorials/sd_attention/sd_attention_nki_kernels.py index e5eec25..6d1f781 100644 --- a/src/tutorials/sd_attention/sd_attention_nki_kernels.py +++ b/src/tutorials/sd_attention/sd_attention_nki_kernels.py @@ -12,7 +12,9 @@ import argparse from scipy.special import softmax -def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, out_ref, use_causal_mask=False, +# NKI_EXAMPLE_31_BEGIN +@nki.jit +def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, use_causal_mask=False, mixed_percision=True): """ Fused self attention kernel for small head dimension Stable Diffusion workload, @@ -44,7 +46,7 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, out_ref, use_cau # Assume all IO tensors have the same dtype kernel_dtype = q_ref.dtype pe_in_dt = nl.bfloat16 if mixed_percision else np.float32 - assert q_ref.dtype == k_ref.dtype == v_ref.dtype == out_ref.dtype + assert q_ref.dtype == k_ref.dtype == v_ref.dtype # Shape checking seqlen, d_head = q_ref.shape @@ -53,7 +55,7 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, out_ref, use_cau assert tuple(k_ref.shape) == (seqlen, d_head), 'Input shape mismatch!' assert tuple(v_ref.shape) == (seqlen,d_head), \ f'Input shape mismatch! Expected: {(seqlen, d_head)} Actual: {tuple(v_ref.shape)}' - assert tuple(out_ref.shape) == (seqlen, d_head), 'Output shape mismatch!' + out_ref = nl.ndarray((seqlen, d_head), dtype=q_ref.dtype, buffer=nl.shared_hbm) # Softmax scaling factor, multiplied onto Q softmax_scale = 0.125 @@ -210,58 +212,61 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, out_ref, use_cau out_ref[i_q_seq_tile * q_seq_tile_size + if_out, ip_out], value=attn_res_div) + return out_ref +# NKI_EXAMPLE_31_END + def parse_args(): - parser = argparse.ArgumentParser("Run Stable Diffusion Attention NKI kernel.") - parser.add_argument("--mode", - choices=["accuracy", "perf"], - default="accuracy", - help="""Do accuracy test or perf test. - Accuracy test uses cpu golden output as golden reference. - """) + parser = argparse.ArgumentParser("Run Stable Diffusion Attention NKI kernel.") + parser.add_argument("--mode", + choices=["accuracy", "perf"], + default="accuracy", + help="""Do accuracy test or perf test. + Accuracy test uses cpu golden output as golden reference. + """) + + args = parser.parse_args() + return args - args = parser.parse_args() - return args def cpu_golden_attn(q, k, v): - softmax_scale = 0.125 + softmax_scale = 0.125 - q_scaled = q * softmax_scale - raw_score = np.matmul(q_scaled, k.transpose()) - norm_score = softmax(raw_score, axis=-1) + q_scaled = q * softmax_scale + raw_score = np.matmul(q_scaled, k.transpose()) + norm_score = softmax(raw_score, axis=-1) - return np.matmul(norm_score, v) + return np.matmul(norm_score, v) if __name__ == "__main__": - args = parse_args() - - print(f"Running {args.mode} mode.") - - seqlen, d_head = 4096, 64 - - # Set up input tensors - dtype = np.float32 - q_tensor = np.random.rand(seqlen, d_head).astype(dtype) - k_tensor = np.random.rand(seqlen, d_head).astype(dtype) - v_tensor = np.random.rand(seqlen, d_head).astype(dtype) - output_nki = np.empty((seqlen, d_head), dtype=dtype) - output_golden = cpu_golden_attn(q_tensor, k_tensor, v_tensor) - - if args.mode == "accuracy": - nki.baremetal(fused_self_attn_for_SD_small_head_size)\ - (q_tensor, k_tensor, v_tensor, output_nki) - allclose = np.allclose(output_nki, output_golden, atol=1e-5, rtol=1e-3) - print(f">>>> SD attention matches CPU reference? {allclose}") - assert allclose, "Accuracy check fails!" - - else: - benchmark_func = nki.benchmark(fused_self_attn_for_SD_small_head_size, - save_neff_name='file.neff', - save_trace_name='profile.ntff') - benchmark_func(q_tensor, k_tensor, v_tensor, output_nki) - - metrics = benchmark_func.benchmark_result.nc_latency - print(">>>> SD attention benchmark results") - print("latency.p50 = " + str(metrics.get_latency_percentile(50))) - print("latency.p99 = " + str(metrics.get_latency_percentile(99))) \ No newline at end of file + args = parse_args() + + print(f"Running {args.mode} mode.") + + seqlen, d_head = 4096, 64 + + # Set up input tensors + dtype = np.float32 + q_tensor = np.random.rand(seqlen, d_head).astype(dtype) + k_tensor = np.random.rand(seqlen, d_head).astype(dtype) + v_tensor = np.random.rand(seqlen, d_head).astype(dtype) + output_nki = np.empty((seqlen, d_head), dtype=dtype) + output_golden = cpu_golden_attn(q_tensor, k_tensor, v_tensor) + + if args.mode == "accuracy": + output_nki = fused_self_attn_for_SD_small_head_size(q_tensor, k_tensor, v_tensor) + allclose = np.allclose(output_nki, output_golden, atol=1e-5, rtol=1e-3) + print(f">>>> SD attention matches CPU reference? {allclose}") + assert allclose, "Accuracy check fails!" + + else: + benchmark_func = nki.benchmark(fused_self_attn_for_SD_small_head_size, + save_neff_name='file.neff', + save_trace_name='profile.ntff') + benchmark_func(q_tensor, k_tensor, v_tensor) + + metrics = benchmark_func.benchmark_result.nc_latency + print(">>>> SD attention benchmark results") + print("latency.p50 = " + str(metrics.get_latency_percentile(50))) + print("latency.p99 = " + str(metrics.get_latency_percentile(99))) \ No newline at end of file diff --git a/src/tutorials/sd_attention/sd_attention_torch.py b/src/tutorials/sd_attention/sd_attention_torch.py index f124607..639e5cf 100644 --- a/src/tutorials/sd_attention/sd_attention_torch.py +++ b/src/tutorials/sd_attention/sd_attention_torch.py @@ -5,8 +5,8 @@ """ +# NKI_EXAMPLE_32_BEGIN import torch -from torch_neuronx.xla_impl.ops import nki_jit from torch_xla.core import xla_model as xm from sd_attention_nki_kernels import fused_self_attn_for_SD_small_head_size @@ -28,10 +28,8 @@ def cpu_golden_attn(q, k, v): q_tensor = torch.rand((4096, 64), dtype=torch.float32).to(device=device) k_tensor = torch.rand((4096, 64), dtype=torch.float32).to(device=device) v_tensor = torch.rand((4096, 64), dtype=torch.float32).to(device=device) - output_nki = torch.zeros((4096, 64), dtype=torch.float32).to(device=device) - nki_func = nki_jit(func=fused_self_attn_for_SD_small_head_size) - nki_func(q_tensor, k_tensor, v_tensor, output_nki) + output_nki = fused_self_attn_for_SD_small_head_size(q_tensor, k_tensor, v_tensor) output_torch = cpu_golden_attn(q_tensor, k_tensor, v_tensor) @@ -42,4 +40,5 @@ def cpu_golden_attn(q, k, v): else: print("NKI and Torch differ") - assert allclose \ No newline at end of file + assert allclose + # NKI_EXAMPLE_32_END diff --git a/src/tutorials/tensor_addition/tensor_addition_jax.py b/src/tutorials/tensor_addition/tensor_addition_jax.py index 9655b84..e40f962 100644 --- a/src/tutorials/tensor_addition/tensor_addition_jax.py +++ b/src/tutorials/tensor_addition/tensor_addition_jax.py @@ -4,42 +4,15 @@ JAX implementation for tensor addition NKI tutorial. """ +# NKI_EXAMPLE_30_BEGIN import jax import jax.numpy as jnp -from jax_neuronx import nki_call +# NKI_EXAMPLE_30_END -from tensor_addition_nki_kernels import nki_tensor_add_kernel_ - - -def nki_tensor_add(a_input, b_input): - """NKI kernel caller to compute element-wise addition of two input tensors - - This kernel caller lifts tile-size restriction, by applying the kernel on tiles of the inputs/outputs - - Args: - a_input: a first input tensor, of shape [N*128, M*512] - b_input: a second input tensor, of shape [N*128, M*512] - - Returns: - a tensor of shape [N*128, M*512], the result of a_input + b_input - """ - - # The SPMD launch grid denotes the number of kernel instances. - # In this case, we use a 2D grid where the size of each invocation is 128x512 - grid_x = a_input.shape[0] // 128 - grid_y = a_input.shape[1] // 512 - - out_shape = jax.ShapeDtypeStruct((a_input.shape[0], a_input.shape[1]), dtype=a_input.dtype) - - return nki_call( - nki_tensor_add_kernel_, - a_input, - b_input, - grid=(grid_x, grid_y), - out_shape=out_shape, - ) +from tensor_addition_nki_kernels import nki_tensor_add +# NKI_EXAMPLE_30_BEGIN if __name__ == "__main__": seed_a, seed_b = jax.random.split(jax.random.PRNGKey(42)) @@ -59,3 +32,4 @@ def nki_tensor_add(a_input, b_input): print("NKI and JAX differ") assert allclose + # NKI_EXAMPLE_30_END diff --git a/src/tutorials/tensor_addition/tensor_addition_nki_kernels.py b/src/tutorials/tensor_addition/tensor_addition_nki_kernels.py index 2b49237..ea72488 100644 --- a/src/tutorials/tensor_addition/tensor_addition_nki_kernels.py +++ b/src/tutorials/tensor_addition/tensor_addition_nki_kernels.py @@ -5,20 +5,26 @@ """ import numpy as np +# NKI_EXAMPLE_27_BEGIN import neuronxcc.nki as nki import neuronxcc.nki.language as nl -def nki_tensor_add_kernel_(a_input, b_input, c_output): +@nki.jit +def nki_tensor_add_kernel_(a_input, b_input): """NKI kernel to compute element-wise addition of two input tensors - This kernel assumes strict input/output tile-sizes, of up-to [128,512] + This kernel assumes strict input/output sizes can be uniformly tiled to [128,512] Args: - a_input: a first input tensor, of shape [128,512] - b_input: a second input tensor, of shape [128,512] - c_output: an output tensor, of shape [128,512] + a_input: a first input tensor + b_input: a second input tensor + + Returns: + c_output: an output tensor """ + # Create output tensor shared between all SPMD instances as result tensor + c_output = nl.ndarray(a_input.shape, dtype=a_input.dtype, buffer=nl.shared_hbm) # Calculate tile offsets based on current 'program' offset_i_x = nl.program_id(0) * 128 @@ -39,7 +45,12 @@ def nki_tensor_add_kernel_(a_input, b_input, c_output): # store the addition results back to device memory (c_output) nl.store(c_output[ix, iy], value=c_tile) + # Transfer the ownership of `c_output` to the caller + return c_output + # NKI_EXAMPLE_27_END + +# NKI_EXAMPLE_28_BEGIN def nki_tensor_add(a_input, b_input): """NKI kernel caller to compute element-wise addition of two input tensors @@ -57,12 +68,9 @@ def nki_tensor_add(a_input, b_input): # In this case, we use a 2D grid where the size of each invocation is 128x512 grid_x = a_input.shape[0] // 128 grid_y = a_input.shape[1] // 512 - c_output = np.zeros(a_input.shape, dtype=a_input.dtype) - - nki_tensor_add_kernel_baremetal = nki.baremetal(nki_tensor_add_kernel_) - nki_tensor_add_kernel_baremetal[grid_x, grid_y](a_input, b_input, c_output) - return c_output + return nki_tensor_add_kernel_[grid_x, grid_y](a_input, b_input) + # NKI_EXAMPLE_28_END if __name__ == "__main__": diff --git a/src/tutorials/tensor_addition/tensor_addition_torch.py b/src/tutorials/tensor_addition/tensor_addition_torch.py index 942e728..83673e5 100644 --- a/src/tutorials/tensor_addition/tensor_addition_torch.py +++ b/src/tutorials/tensor_addition/tensor_addition_torch.py @@ -4,38 +4,15 @@ PyTorch implementation for tensor addition NKI tutorial. """ +# NKI_EXAMPLE_29_BEGIN import torch from torch_xla.core import xla_model as xm -from torch_neuronx import nki_jit +# NKI_EXAMPLE_29_END -from tensor_addition_nki_kernels import nki_tensor_add_kernel_ +from tensor_addition_nki_kernels import nki_tensor_add -def nki_tensor_add(a_input, b_input): - """NKI kernel caller to compute element-wise addition of two input tensors - - This kernel caller lifts tile-size restriction, by applying the kernel on tiles of the inputs/outputs - - Args: - a_input: a first input tensor, of shape [N*128, M*512] - b_input: a second input tensor, of shape [N*128, M*512] - - Returns: - a tensor of shape [N*128, M*512], the result of a_input + b_input - """ - - # The SPMD launch grid denotes the number of kernel instances. - # In this case, we use a 2D grid where the size of each invocation is 128x512 - grid_x = a_input.shape[0] // 128 - grid_y = a_input.shape[1] // 512 - c_output = torch.zeros(a_input.shape, dtype=a_input.dtype).to(device=device) - - # Decorate the NKI kernel for PyTorch tracing - nki_tensor_add_kernel_torch = nki_jit(nki_tensor_add_kernel_) - nki_tensor_add_kernel_torch[grid_x, grid_y](a_input, b_input, c_output) - - return c_output - +# NKI_EXAMPLE_29_BEGIN if __name__ == "__main__": device = xm.xla_device() @@ -55,3 +32,4 @@ def nki_tensor_add(a_input, b_input): print("NKI and Torch differ") assert allclose + # NKI_EXAMPLE_29_END diff --git a/src/tutorials/transpose2d/transpose2d_jax.py b/src/tutorials/transpose2d/transpose2d_jax.py index 024782c..f23ceef 100644 --- a/src/tutorials/transpose2d/transpose2d_jax.py +++ b/src/tutorials/transpose2d/transpose2d_jax.py @@ -5,25 +5,18 @@ """ +# NKI_EXAMPLE_36_BEGIN import jax import jax.numpy as jnp -from functools import partial -from jax_neuronx import nki_call +# NKI_EXAMPLE_36_END from transpose2d_nki_kernels import tensor_transpose2D_kernel_ - -def transpose2D(in_tensor, shape2D): - return nki_call( - partial(tensor_transpose2D_kernel_, shape2D=shape2D), - in_tensor, - out_shape=jax.ShapeDtypeStruct(in_tensor.shape, dtype=in_tensor.dtype) - ) - +# NKI_EXAMPLE_36_BEGIN if __name__ == "__main__": P, X, Y = 5, 37, 44 a = jax.random.uniform(jax.random.PRNGKey(42), (P, X * Y)) - a_t_nki = transpose2D(a, (X, Y)) + a_t_nki = tensor_transpose2D_kernel_(a, shape2D=(X, Y)) a_t_jax = jnp.transpose(a.reshape(P, X, Y), axes=(0, 2, 1)).reshape(P, X * Y) print(a, a_t_nki, a_t_jax) @@ -35,3 +28,4 @@ def transpose2D(in_tensor, shape2D): print("NKI and JAX differ") assert allclose +# NKI_EXAMPLE_36_END diff --git a/src/tutorials/transpose2d/transpose2d_nki_kernels.py b/src/tutorials/transpose2d/transpose2d_nki_kernels.py index d993c7e..171e6ed 100644 --- a/src/tutorials/transpose2d/transpose2d_nki_kernels.py +++ b/src/tutorials/transpose2d/transpose2d_nki_kernels.py @@ -5,11 +5,13 @@ """ import numpy as np +# NKI_EXAMPLE_33_BEGIN import neuronxcc.nki as nki import neuronxcc.nki.language as nl -def tensor_transpose2D_kernel_(in_tensor, out_tensor, shape2D): +@nki.jit +def tensor_transpose2D_kernel_(in_tensor, shape2D): """ NKI kernel to reorder the elements on axis[1] of the input tensor. @@ -36,6 +38,8 @@ def tensor_transpose2D_kernel_(in_tensor, out_tensor, shape2D): shape2D: tuple representing the dimensions to be transposed: (#rows, #cols) out_tensor: an output (transposed) tensor """ + out_tensor = nl.ndarray(in_tensor.shape, dtype=in_tensor.dtype, + buffer=nl.shared_hbm) # Gather input shapes sz_p, _ = in_tensor.shape @@ -64,14 +68,15 @@ def tensor_transpose2D_kernel_(in_tensor, out_tensor, shape2D): # Finally, we store out_tile to external memory nl.store(out_tensor, value=out_tile) + return out_tensor + # NKI_EXAMPLE_33_END + if __name__ == "__main__": P, X, Y = 5, 3, 4 a = np.arange(P*X*Y, dtype=np.int8).reshape((P, X*Y)) - a_t_nki = np.zeros((P, Y*X), dtype=np.int8) - tensor_transpose2D_kernel_torch = nki.baremetal(tensor_transpose2D_kernel_) - tensor_transpose2D_kernel_torch(a, a_t_nki, (X, Y)) + a_t_nki = tensor_transpose2D_kernel_(a, (X, Y)) a_t_np = np.transpose(a.reshape(P, X, Y), (0, 2, 1)).reshape(P, X * Y) diff --git a/src/tutorials/transpose2d/transpose2d_torch.py b/src/tutorials/transpose2d/transpose2d_torch.py index 71083d7..61fe367 100644 --- a/src/tutorials/transpose2d/transpose2d_torch.py +++ b/src/tutorials/transpose2d/transpose2d_torch.py @@ -4,13 +4,15 @@ PyTorch implementation for transpose2d NKI tutorial. """ +# NKI_EXAMPLE_34_BEGIN import torch from torch_xla.core import xla_model as xm -from torch_neuronx import nki_jit +# NKI_EXAMPLE_34_END from transpose2d_nki_kernels import tensor_transpose2D_kernel_ +# NKI_EXAMPLE_34_BEGIN if __name__ == "__main__": device = xm.xla_device() @@ -18,8 +20,7 @@ a = torch.arange(P*X*Y, dtype=torch.int8).reshape((P, X*Y)).to(device=device) a_t_nki = torch.zeros((P, Y*X), dtype=torch.int8).to(device=device) - tensor_transpose2D_kernel_torch = nki_jit(tensor_transpose2D_kernel_) - tensor_transpose2D_kernel_torch(a, a_t_nki, (X, Y)) + a_t_nki = tensor_transpose2D_kernel_(a, (X, Y)) a_t_torch = torch.transpose(a.reshape(P, X, Y), 1, 2).reshape(P, X * Y) @@ -32,3 +33,4 @@ print("NKI and PyTorch differ") assert allclose + # NKI_EXAMPLE_34_END diff --git a/test/integration/flash_attention/flash_attention_benchmark.py b/test/integration/flash_attention/flash_attention_benchmark.py index 5aa2e40..918a14f 100644 --- a/test/integration/flash_attention/flash_attention_benchmark.py +++ b/test/integration/flash_attention/flash_attention_benchmark.py @@ -14,6 +14,8 @@ from flash_attention import nki_flash_attn_func +parent_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +sys.path.append(parent_dir) from perf_utils.LatencyCollector import benchmark if len(sys.argv) != 2: diff --git a/test/integration/fused_sd_attention_small_head/sd2_512_benchmark.py b/test/integration/fused_sd_attention_small_head/sd2_512_benchmark.py index e7fd205..5d63424 100644 --- a/test/integration/fused_sd_attention_small_head/sd2_512_benchmark.py +++ b/test/integration/fused_sd_attention_small_head/sd2_512_benchmark.py @@ -8,6 +8,8 @@ from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler from diffusers.models.unet_2d_condition import UNet2DConditionOutput +parent_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +sys.path.append(parent_dir) from perf_utils.LatencyCollector import benchmark import sys diff --git a/test/integration/resize_nearest_fixed_dma_kernel/sd2_inpainting_936_624_benchmark.py b/test/integration/resize_nearest_fixed_dma_kernel/sd2_inpainting_936_624_benchmark.py index 3ba0eab..4970f72 100644 --- a/test/integration/resize_nearest_fixed_dma_kernel/sd2_inpainting_936_624_benchmark.py +++ b/test/integration/resize_nearest_fixed_dma_kernel/sd2_inpainting_936_624_benchmark.py @@ -23,6 +23,8 @@ else: from diffusers.models.cross_attention import CrossAttention +parent_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +sys.path.append(parent_dir) from perf_utils.LatencyCollector import benchmark import sys diff --git a/test/unit/test_SD_attention_small_head.py b/test/unit/test_SD_attention_small_head.py index 5480fa4..32e6945 100644 --- a/test/unit/test_SD_attention_small_head.py +++ b/test/unit/test_SD_attention_small_head.py @@ -11,7 +11,7 @@ test_trace_file_path='local_trace.ntff' numeric_func = baremetal(fused_self_attn_for_SD_small_head_size) -bench_func = benchmark(warmup=5, iters=10, save_trace_name=test_trace_file_path)(fused_self_attn_for_SD_small_head_size) +bench_func = benchmark(warmup=5, iters=20, save_trace_name=test_trace_file_path)(fused_self_attn_for_SD_small_head_size) def cpu_golden_attn(q, k, v): softmax_scale = 0.125 @@ -34,16 +34,16 @@ def test_attention_for_SD_perf(self, bs, seqlen, d, dtype, latency): q = np.random.random_sample((bs, d, seqlen)).astype(np.float32) k = np.random.random_sample((bs, seqlen, d)).astype(np.float32) v = np.random.random_sample((bs, seqlen, d)).astype(np.float32) - out = nl.static_cast(np.ndarray(shape=(bs, seqlen, d)), dtype) - + q_dev = nl.static_cast(q, dtype) k_dev = nl.static_cast(k, dtype) v_dev = nl.static_cast(v, dtype) - bench_func[bs](q_dev, k_dev, v_dev, out) - latency_res = bench_func.benchmark_result.nc_latency - p99 = latency_res.get_latency_percentile(99) - assert p99 <= latency + bench_func_ = bench_func[bs] + bench_func_(q_dev, k_dev, v_dev) + latency_res = bench_func_.benchmark_result.nc_latency + p50 = latency_res.get_latency_percentile(50) + assert p50 <= latency*1.05 # short running kernels are subjected to hardware fluctuation assert os.path.getsize(test_trace_file_path) > 0 @pytest.mark.parametrize("bs,seqlen,d,dtype", [ @@ -54,13 +54,12 @@ def test_attention_for_SD_numberic(self, bs, seqlen, d, dtype): q = np.random.random_sample((bs, d, seqlen)).astype(np.float32) k = np.random.random_sample((bs, seqlen, d)).astype(np.float32) v = np.random.random_sample((bs, seqlen, d)).astype(np.float32) - out = nl.static_cast(np.ndarray(shape=(bs, seqlen, d)), dtype) - + q_dev = nl.static_cast(q, dtype) k_dev = nl.static_cast(k, dtype) v_dev = nl.static_cast(v, dtype) - numeric_func[bs](q_dev, k_dev, v_dev, out) + out = numeric_func[bs](q_dev, k_dev, v_dev) out = nl.static_cast(out, np.float32) golden_result = cpu_golden_attn(q, k, v) assert np.allclose(out, golden_result, atol=1e-2) diff --git a/test/unit/test_allocated_SD_attention_small_head.py b/test/unit/test_allocated_SD_attention_small_head.py new file mode 100644 index 0000000..ee0de86 --- /dev/null +++ b/test/unit/test_allocated_SD_attention_small_head.py @@ -0,0 +1,67 @@ +""" +Copyright (c) 2023, Amazon.com. All Rights Reserved +""" +import os +import pytest +from neuronxcc.nki.kernels.allocated_attention import allocated_fused_self_attn_for_SD_small_head_size +from neuronxcc.nki import benchmark, baremetal +import neuronxcc.nki as nki +import neuronxcc.nki.language as nl +import numpy as np +from scipy.special import softmax + +test_trace_file_path='local_trace.ntff' +numeric_func = baremetal(allocated_fused_self_attn_for_SD_small_head_size) +bench_func = benchmark(warmup=5, iters=20, save_trace_name=test_trace_file_path)(allocated_fused_self_attn_for_SD_small_head_size) + +def cpu_golden_attn(q, k, v): + softmax_scale = 0.125 + q_scaled = q * softmax_scale + raw_score = np.matmul(q_scaled.transpose(0, 2, 1), k) + + norm_score = softmax(raw_score, axis=-1) + + # Transpose the result so it has the same layout as ours + return np.matmul(norm_score, v).transpose(0, 2, 1) + +class TestAttention: + + @pytest.mark.parametrize("bs,seqlen,d,dtype,latency", [ + [1, 4096, 128, np.float32, 410], + [1, 4096, 128, nl.bfloat16, 350], + [1, 5120, 128, nl.bfloat16, 586] + ]) + def test_allocated_attention_for_SD_perf(self, bs, seqlen, d, dtype, latency): + q = np.random.random_sample((bs, d, seqlen)).astype(np.float32) + k = np.random.random_sample((bs, d, seqlen)).astype(np.float32) + v = np.random.random_sample((bs, seqlen, d)).astype(np.float32) + + q_dev = nl.static_cast(q, dtype) + k_dev = nl.static_cast(k, dtype) + v_dev = nl.static_cast(v, dtype) + + bench_func_ = bench_func[bs] + bench_func_(q_dev, k_dev, v_dev) + latency_res = bench_func_.benchmark_result.nc_latency + p50 = latency_res.get_latency_percentile(50) + assert p50 <= latency * 1.05 # short running kernels are subjected to hardware fluctuation + assert os.path.getsize(test_trace_file_path) > 0 + + @pytest.mark.parametrize("bs,seqlen,d,dtype", [ + [1, 4096, 128, np.float32], + [1, 4096, 128, nl.bfloat16], + [1, 5120, 128, nl.bfloat16] + ]) + def test_allocated_attention_for_SD_numberic(self, bs, seqlen, d, dtype): + q = np.random.random_sample((bs, d, seqlen)).astype(np.float32) + k = np.random.random_sample((bs, d, seqlen)).astype(np.float32) + v = np.random.random_sample((bs, seqlen, d)).astype(np.float32) + + q_dev = nl.static_cast(q, dtype) + k_dev = nl.static_cast(k, dtype) + v_dev = nl.static_cast(v, dtype) + + out = numeric_func[bs](q_dev, k_dev, v_dev) + out = nl.static_cast(out, np.float32) + golden_result = cpu_golden_attn(q, k, v) + assert np.allclose(out, golden_result, atol=1e-2) diff --git a/test/unit/test_flash_attn_bwd.py b/test/unit/test_flash_attn_bwd.py index a55abbe..3aedab0 100644 --- a/test/unit/test_flash_attn_bwd.py +++ b/test/unit/test_flash_attn_bwd.py @@ -7,6 +7,8 @@ import neuronxcc.nki.language as nl import numpy as np +from TestDecorators import xfail + numeric_func = baremetal(flash_attn_bwd) bench_func = benchmark(warmup=5, iters=10)(flash_attn_bwd) @@ -85,6 +87,7 @@ def mixed_precision_matmul(a, b): class TestAttention: + @xfail # P167481231 @pytest.mark.parametrize("bs, nheads, seqlen, d, dtype, latency", [ [1, 4, 32*1024, 128, nl.bfloat16, 117000], ]) @@ -97,23 +100,16 @@ def test_flash_attn_bwd_perf(self, bs, nheads, seqlen, d, dtype, latency): lse = np.random.random_sample([bs, nheads, nl.tile_size.pmax, seqlen // nl.tile_size.pmax]).astype(np.float32) seed = None - out_dq = np.zeros(shape=[bs, nheads, d, seqlen], dtype=dtype) - out_dk = np.zeros(shape=[bs, nheads, d, seqlen], dtype=dtype) - out_dv = np.zeros(shape=[bs, nheads, d, seqlen], dtype=dtype) - q = nl.static_cast(q, dtype) k = nl.static_cast(k, dtype) v = nl.static_cast(v, dtype) o_proj = nl.static_cast(o_proj, dtype) dy = nl.static_cast(dy, dtype) - out_dq = nl.static_cast(out_dq, dtype) - out_dk = nl.static_cast(out_dk, dtype) - out_dv = nl.static_cast(out_dv, dtype) - - bench_func[bs, nheads](q, k, v, o_proj, dy, lse, seed, - out_dq, out_dk, out_dv, - use_causal_mask=True, mixed_precision=True) - latency_res = bench_func.benchmark_result.nc_latency + + bench_func_ = bench_func[bs, nheads] + bench_func_(q, k, v, o_proj, dy, lse, seed, + use_causal_mask=True, mixed_precision=True) + latency_res = bench_func_.benchmark_result.nc_latency p99 = latency_res.get_latency_percentile(99) assert p99 <= latency @@ -130,10 +126,7 @@ def test_flash_attn_bwd_numerical(self, bs, nheads, seqlen, d, dtype): v = nl.static_cast(v, dtype) dy = nl.static_cast(dy, dtype) seed = None - out_dq = np.zeros(shape=[bs, nheads, d, seqlen], dtype=dtype) - out_dk = np.zeros(shape=[bs, nheads, d, seqlen], dtype=dtype) - out_dv = np.zeros(shape=[bs, nheads, d, seqlen], dtype=dtype) - + dq_golden, dk_golden, dv_golden, cached_negative_max, cached_sum_reciprocal, o_proj = \ cpu_attention_backward(q, k, v, dy, use_causal_mask=True) cached_negative_max = cached_negative_max.reshape(bs, nheads, seqlen // nl.tile_size.pmax, @@ -142,9 +135,9 @@ def test_flash_attn_bwd_numerical(self, bs, nheads, seqlen, d, dtype): nl.tile_size.pmax).transpose(0, 1, 3, 2) lse = -1.0 * (cached_negative_max + np.log(cached_sum_reciprocal)) - numeric_func[bs, nheads](q, k, v, o_proj, dy, lse, seed, - out_dq, out_dk, out_dv, - use_causal_mask=True, mixed_precision=True) + out_dq, out_dk, out_dv = numeric_func[bs, nheads](q, k, v, o_proj, dy, lse, seed, + use_causal_mask=True, + mixed_precision=True) assert np.allclose(out_dq, dq_golden, atol=1e-2) assert np.allclose(out_dk, dk_golden, atol=1e-2) diff --git a/test/unit/test_flash_attn_fwd.py b/test/unit/test_flash_attn_fwd.py index 4d91164..fff4ac2 100644 --- a/test/unit/test_flash_attn_fwd.py +++ b/test/unit/test_flash_attn_fwd.py @@ -63,75 +63,84 @@ def mixed_precision_matmul(a, b): class TestAttention: - @pytest.mark.parametrize("bs, nheads, seqlen, d, dtype, use_causal_mask,\ + @pytest.mark.parametrize("bs, nheads, seqlen_q, seqlen_k, d, dtype, use_causal_mask,\ mixed_precision, training, tile_size, kv_heads, should_transpose_v, latency", [ - [1, 6, 32*1024, 96, nl.bfloat16, True, True, True, 2048, 3, False, 87000000000], - [1, 1, 32*1024, 96, nl.bfloat16, True, True, False, 2048, None, False, 15100000000], + [1, 6, 32*1024, 32*1024, 96, nl.bfloat16, True, True, True, 2048, 3, False, 87000000000], + [1, 1, 32*1024, 32*1024, 96, nl.bfloat16, True, True, False, 2048, None, False, 15100000000], + # Non-square + [1, 3, 32*1024, 16*1024, 96, nl.bfloat16, True, True, False, 2048, None, False, 7550000000], + [1, 3, 16*1024, 32*1024, 96, nl.bfloat16, True, True, False, 2048, None, False, 7550000000], ]) - def test_flash_attn_fwd_perf(self, bs, nheads, seqlen, d, dtype, use_causal_mask, + def test_flash_attn_fwd_perf(self, bs, nheads, seqlen_q, seqlen_k, d, dtype, use_causal_mask, mixed_precision, training, tile_size, kv_heads, should_transpose_v,latency): - q = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2 - k = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2 + q = (np.random.random_sample([bs, nheads, d, seqlen_q]) - 0.5) * 2 + k = (np.random.random_sample([bs, nheads, d, seqlen_k]) - 0.5) * 2 if should_transpose_v: - v = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2 + v = (np.random.random_sample([bs, nheads, d, seqlen_k]) - 0.5) * 2 else: - v = (np.random.random_sample([bs, nheads, seqlen, d]) - 0.5) * 2 - o_proj = np.zeros(shape=[bs, nheads, seqlen, d], dtype=dtype) - out_lse = np.zeros(shape=[bs, nheads, int(nl.tile_size.pmax), seqlen // nl.tile_size.pmax], + v = (np.random.random_sample([bs, nheads, seqlen_k, d]) - 0.5) * 2 + o_proj = np.zeros(shape=[bs, nheads, seqlen_q, d], dtype=dtype) + out_lse = np.zeros(shape=[bs, nheads, int(nl.tile_size.pmax), seqlen_q // nl.tile_size.pmax], dtype=nl.float32 if mixed_precision else dtype) if training else None seed = None q = nl.static_cast(q, dtype) k = nl.static_cast(k, dtype) v = nl.static_cast(v, dtype) - o_proj = nl.static_cast(o_proj, dtype) config = FlashConfig(**{'seq_tile_size':tile_size, 'training':training, 'should_transpose_v':should_transpose_v}) heads = nheads if kv_heads is None else kv_heads - bench_func[bs, heads](q, k, v, seed, o_proj, out_lse, - use_causal_mask=use_causal_mask, mixed_precision=mixed_precision, config=config) - latency_res = bench_func.benchmark_result.nc_latency + bench_func_ = bench_func[bs, heads] + bench_func_(q, k, v, seed, use_causal_mask=use_causal_mask, + mixed_precision=mixed_precision, config=config) + latency_res = bench_func_.benchmark_result.nc_latency p99 = latency_res.get_latency_percentile(99) assert p99 <= latency - @pytest.mark.parametrize("bs, nheads, seqlen, d, dtype, use_causal_mask,\ + @pytest.mark.parametrize("bs, nheads, seqlen_q, seqlen_k, d, dtype, use_causal_mask,\ training, tile_size, kv_heads, should_transpose_v", [ - [1, 6, 4096, 128, np.float32, True, True, 2048, 3, False], - [1, 1, 4096, 128, np.float32, True, False, 2048, None, False], + [1, 6, 4096, 4096, 128, np.float32, True, True, 2048, 3, False], + [1, 1, 4096, 4096, 128, np.float32, True, False, 2048, None, False], + [1, 1, 8192, 4096, 128, np.float32, True, False, 2048, None, False], + [1, 1, 4096, 8192, 128, np.float32, True, False, 2048, None, False], ]) - def test_flash_attn_fwd_numerical(self, bs, nheads, seqlen, d, dtype, use_causal_mask, + def test_flash_attn_fwd_numerical(self, bs, nheads, seqlen_q, seqlen_k, d, dtype, use_causal_mask, training, tile_size, kv_heads, should_transpose_v): - q = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2 - k = (np.random.random_sample([bs, kv_heads or nheads, d, seqlen]) - 0.5) * 2 + q = (np.random.random_sample([bs, nheads, d, seqlen_q]) - 0.5) * 2 + k = (np.random.random_sample([bs, kv_heads or nheads, d, seqlen_k]) - 0.5) * 2 if should_transpose_v: - v = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2 + v = (np.random.random_sample([bs, nheads, d, seqlen_k]) - 0.5) * 2 cpu_permute = (0, 1, 2, 3) else: - v = (np.random.random_sample([bs, kv_heads or nheads, seqlen, d]) - 0.5) * 2 + v = (np.random.random_sample([bs, kv_heads or nheads, seqlen_k, d]) - 0.5) * 2 cpu_permute = (0, 1, 3, 2) - o_proj = np.zeros(shape=[bs, nheads, seqlen, d], dtype=dtype) + q = nl.static_cast(q, dtype) k = nl.static_cast(k, dtype) v = nl.static_cast(v, dtype) seed = None - out_lse = np.zeros(shape=[bs, nheads, int(nl.tile_size.pmax), seqlen // nl.tile_size.pmax], - dtype=np.float32) if training else None o_proj_golden, cached_negative_max, cached_sum_reciprocal = \ cpu_attention_forward(q, k, v.transpose(cpu_permute), use_causal_mask=use_causal_mask,mixed_precision=True) o_proj_golden = o_proj_golden.transpose(0,1,3,2) # (b,h, d, seq) - cached_negative_max = cached_negative_max.reshape(bs, nheads, seqlen // nl.tile_size.pmax, + cached_negative_max = cached_negative_max.reshape(bs, nheads, seqlen_q // nl.tile_size.pmax, nl.tile_size.pmax).transpose(0, 1, 3, 2) - cached_sum_reciprocal = cached_sum_reciprocal.reshape(bs, nheads, seqlen // nl.tile_size.pmax, + cached_sum_reciprocal = cached_sum_reciprocal.reshape(bs, nheads, seqlen_q // nl.tile_size.pmax, nl.tile_size.pmax).transpose(0, 1, 3, 2) lse_golden = -1.0 * (cached_negative_max + np.log(cached_sum_reciprocal)) if training else None config = FlashConfig(**{'seq_tile_size':tile_size, 'training':training, 'should_transpose_v':should_transpose_v}) heads = nheads if kv_heads is None else kv_heads - numeric_func[bs, heads](q, k, v, seed, o_proj, out_lse, seed, - use_causal_mask=use_causal_mask, mixed_precision=True, config=config) + results = numeric_func[bs, heads](q, k, v, seed, + use_causal_mask=use_causal_mask, + mixed_precision=True, + config=config) - assert np.allclose(o_proj, o_proj_golden, atol=1e-2) if training: + o_proj, out_lse = results + assert np.allclose(o_proj, o_proj_golden, atol=1e-2) assert np.allclose(out_lse, lse_golden, atol=1e-2) + else: + o_proj = results + assert np.allclose(o_proj, o_proj_golden, atol=1e-2) diff --git a/test/unit/test_resize_nearest.py b/test/unit/test_resize_nearest.py index a77968b..2bbc601 100644 --- a/test/unit/test_resize_nearest.py +++ b/test/unit/test_resize_nearest.py @@ -11,6 +11,7 @@ numeric_func = baremetal(resize_nearest_fixed_dma_kernel) bench_func = benchmark(warmup=5, iters=10)(resize_nearest_fixed_dma_kernel) + def cpu_golden_result(data_tensor, output_shape): in_b, in_h, in_w, in_c = data_tensor.shape out_b, out_h, out_w, out_c = output_shape @@ -36,18 +37,18 @@ def cpu_golden_result(data_tensor, output_shape): class TestResizeNearest: @pytest.mark.parametrize("in_b, in_h, in_w, in_c, out_b, out_h, out_w, out_c, dtype, latency", [ - [10, 30, 20, 1280, 10, 59, 38, 1280, np.float32, 1722], + [10, 30, 20, 1280, 10, 59, 38, 1280, np.float32, 1740], [1, 30, 20, 1280, 1, 59, 38, 1280, nl.float16, 659], [1, 30, 20, 1280, 1, 59, 38, 1280, nl.bfloat16, 659], ]) def test_resize_nearest_for_perf(self, in_b, in_h, in_w, in_c, out_b, out_h, out_w, out_c, dtype, latency): input_tensor = np.random.random_sample((in_b, in_h, in_w, in_c)).astype(np.float32) - output_tensor = nl.static_cast(np.ndarray(shape=(out_b, out_h, out_w, out_c)), dtype) - + input_dev = nl.static_cast(input_tensor, dtype) - bench_func[in_b](input_dev, output_tensor) - latency_res = bench_func.benchmark_result.nc_latency + bench_func_ = bench_func[in_b] + bench_func_(input_dev, (out_b, out_h, out_w, out_c)) + latency_res = bench_func_.benchmark_result.nc_latency p99 = latency_res.get_latency_percentile(99) assert p99 <= latency @@ -58,11 +59,10 @@ def test_resize_nearest_for_perf(self, in_b, in_h, in_w, in_c, out_b, out_h, out ]) def test_resize_nearest_for_numberic(self, in_b, in_h, in_w, in_c, out_b, out_h, out_w, out_c, dtype): input_tensor = np.random.random_sample((in_b, in_h, in_w, in_c)).astype(np.float32) - output_tensor = nl.static_cast(np.ndarray(shape=(out_b, out_h, out_w, out_c)), dtype) - + input_dev = nl.static_cast(input_tensor, dtype) - numeric_func[in_b](input_dev, output_tensor) + output_tensor = numeric_func[in_b](input_dev, (out_b, out_h, out_w, out_c)) output_tensor = nl.static_cast(output_tensor, np.float32) golden_result = cpu_golden_result(input_tensor, output_tensor.shape) assert np.allclose(output_tensor, golden_result, atol=1e-2) diff --git a/test/unit/test_rmsnorm_qkv.py b/test/unit/test_rmsnorm_qkv.py new file mode 100644 index 0000000..24ad31c --- /dev/null +++ b/test/unit/test_rmsnorm_qkv.py @@ -0,0 +1,65 @@ +""" +Copyright (c) 2024, Amazon.com. All Rights Reserved +""" +import pytest +from neuronxcc.nki.kernels.allocated_fused_linear import allocated_fused_rms_norm_qkv +from neuronxcc.nki import benchmark, baremetal +import neuronxcc.nki.language as nl +import numpy as np + +numeric_func = baremetal(allocated_fused_rms_norm_qkv) +bench_func = benchmark(warmup=5, iters=10)(allocated_fused_rms_norm_qkv) + +np.random.seed(0) + + +def rms_norm(hidden, gamma, eps=1e-6): + rms = np.sqrt(np.mean(np.square(hidden), axis=-1, keepdims=True)) + norm = hidden * np.reciprocal(rms + eps) + if gamma is not None: + norm *= gamma + return norm + +def cpu_golden_result(hidden, gamma, qkv_weights, dtype, do_norm=True): + if do_norm: + hidden = rms_norm(hidden, gamma) + qkv_out = (hidden @ qkv_weights).astype(dtype) + return qkv_out + +class TestRMSNormQKV: + @pytest.mark.parametrize("batch, seqlen, dim, d_head, dtype, latency", [ + [1, 128, 512, 512, np.float16, 25], + [1, 512, 1024, 384, nl.bfloat16, 40], + [1, 128, 1024, 512, nl.bfloat16, 28], + [1, 1024, 8192, 512, nl.bfloat16, 301 * 1.02] + ]) + def test_allocated_rmsnorm_qkv_perf(self, batch, seqlen, dim, d_head, dtype, latency): + hidden = np.random.random_sample((batch, seqlen, dim)).astype(np.float32) + weights = np.random.random_sample((dim, d_head)).astype(np.float32) + + hidden = nl.static_cast(hidden, dtype) + weights = nl.static_cast(weights, dtype) + + bench_func(hidden, weights) + latency_res = bench_func.benchmark_result.nc_latency + p99 = latency_res.get_latency_percentile(99) + assert p99 <= latency + + @pytest.mark.parametrize("batch, seqlen, dim, d_head, dtype", [ + [1, 128, 512, 512, np.float16], + [1, 512, 1024, 384, nl.bfloat16], + [1, 128, 1024, 512, nl.bfloat16], + [1, 1024, 8192, 512, nl.bfloat16] + ]) + def test_allocated_rmsnorm_qkv_numeric(self, batch, seqlen, dim, d_head, dtype): + hidden = np.random.random_sample((batch, seqlen, dim)) + weights = np.random.random_sample((dim, d_head)) + + hidden_dev = nl.static_cast(hidden, dtype) + weights_dev = nl.static_cast(weights, dtype) + + out = numeric_func(hidden_dev, weights_dev) + out = nl.static_cast(out, np.float32) + golden_res = nl.static_cast(cpu_golden_result(hidden, None, weights, dtype, do_norm=True), np.float32) + assert np.allclose(out, golden_res, atol=1e-2, rtol=1e-2) + diff --git a/test/unit/test_select_and_scatter.py b/test/unit/test_select_and_scatter.py index 70f7a7c..fc99b37 100644 --- a/test/unit/test_select_and_scatter.py +++ b/test/unit/test_select_and_scatter.py @@ -39,7 +39,6 @@ def cpu_golden_result(operand_tensor, source_tensor, window_dimensions=(3, 3), w out_h = h * stride_h + local_h - padding[0] out_w = w * stride_w + local_w - padding[1] output_tensor[n, c, out_h, out_w] += source_tensor[n, c, h, w] - return output_tensor class TestSelectAndScatter: @@ -47,31 +46,28 @@ class TestSelectAndScatter: [8, 64, 112, 112, 56, 56, np.float32, 4500], ]) def test_select_and_scatter_for_perf(self, n, c, operand_h, operand_w, source_h, source_w, dtype, latency): - operand_tensor = np.random.random_sample((n, c, operand_h, operand_w)).astype(np.float32) - source_tensor = np.random.random_sample((n, c, source_h, source_w)).astype(np.float32) - output_tensor = nl.static_cast(np.ndarray(shape=(n, c, operand_h, operand_w)), dtype) - - operand_dev = nl.static_cast(operand_tensor, dtype) - source_dev = nl.static_cast(source_tensor, dtype) + operand_dev = nl.static_cast(np.random.random_sample((n, c, operand_h, operand_w)), dtype) + source_dev = nl.static_cast(np.random.random_sample((n, c, source_h, source_w)), dtype) - bench_func(operand_dev, source_dev, output_tensor) + bench_func(operand_dev, source_dev) latency_res = bench_func.benchmark_result.nc_latency p99 = latency_res.get_latency_percentile(99) assert p99 <= latency @pytest.mark.parametrize("n, c, operand_h, operand_w, source_h, source_w, dtype", [ [8, 64, 112, 112, 56, 56, np.float32], - pytest.param(8, 64, 112, 112, 56, 56, nl.bfloat16, marks=pytest.mark.xfail), + [8, 64, 112, 112, 56, 56, nl.bfloat16], ]) def test_select_and_scatter_for_numeric(self, n, c, operand_h, operand_w, source_h, source_w, dtype): - operand_tensor = np.random.random_sample((n, c, operand_h, operand_w)).astype(np.float32) - source_tensor = np.random.random_sample((n, c, source_h, source_w)).astype(np.float32) - output_tensor = nl.static_cast(np.ndarray(shape=(n, c, operand_h, operand_w)), dtype) - - operand_dev = nl.static_cast(operand_tensor, dtype) - source_dev = nl.static_cast(source_tensor, dtype) + operand_dev = nl.static_cast(np.random.random_sample((n, c, operand_h, operand_w)), dtype) + source_dev = nl.static_cast(np.random.random_sample((n, c, source_h, source_w)), dtype) + + sw = nl.static_cast(np.ndarray(shape=(n, c, source_h, source_w, 3, 3)), dtype) + operand_tensor = nl.static_cast(operand_dev, np.float32) + source_tensor = nl.static_cast(source_dev, np.float32) - numeric_func(operand_dev, source_dev, output_tensor) + output_dev = numeric_func(operand_dev, source_dev) golden_result = cpu_golden_result(operand_tensor, source_tensor) - output_tensor = nl.static_cast(output_tensor, np.float32) - assert np.allclose(output_tensor, golden_result) \ No newline at end of file + nki_result = nl.static_cast(output_dev, np.float32) + + assert np.allclose(nki_result, golden_result, rtol=1e-2, atol=1e-2) From bdeca88513a228c182a2b46a9b124c3fbdf0eab0 Mon Sep 17 00:00:00 2001 From: aws-qieqingy <122939906+aws-qieqingy@users.noreply.github.com> Date: Fri, 13 Dec 2024 20:52:29 -0500 Subject: [PATCH 2/3] Recover documentation change overwritten by code push (#7) --- CONTRIBUTING.md | 69 ++++++++++++++++++++++++++++++++---- LICENSE | 16 +++++++++ LICENSE.txt | 1 - README.md | 93 ++++++++++--------------------------------------- 4 files changed, 97 insertions(+), 82 deletions(-) create mode 100644 LICENSE delete mode 100644 LICENSE.txt diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 32ce44e..4c16260 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,6 +1,6 @@ # Contributing Guidelines -Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional +Thank you for your interest in contributing to our project. Whether it's a new NKI kernel, improving existing kernel code, bug fix, new feature, correction, or additional documentation, we greatly value feedback and contributions from our community. Please read through this document before submitting any issues or pull requests to ensure we have all the necessary @@ -24,14 +24,13 @@ reported the issue. Please try to include as much information as you can. Detail Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 1. You are working against the latest source on the *main* branch. -2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. -3. You open an issue to discuss any significant work - we would hate for your time to be wasted. +2. You check existing open, and recently merged pull requests to make sure someone else hasn't addressed the problem already. To send us a pull request, please: 1. Fork the repository. -2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. -3. Ensure local tests pass. +2. Modify the source; please focus on the specific changes you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. +3. Please ensure your change satisfies the requirements listed in [Testing Requirements](#testing-requirements) and [Coding Guidelines](#coding-guidelines) 4. Commit to your fork using clear commit messages. 5. Send us a pull request, answering any default questions in the pull request interface. 6. Wait for a repository collaborator to look at your pull request, run the automated tests, and review. If additional changes or discussion is needed, a collaborator will get back to you, so please stay involved in the conversation. @@ -40,8 +39,64 @@ To send us a pull request, please: GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). +### Testing Requirements +Running the binaries for a NKI kernel require Neuron devices on an AWS EC2 instance from trn1, trn1n, or inf2 instance families. +Details on setting up an instance can be found in [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-setup.html). -## Finding contributions to work on +If you would like to test your kernel without requiring a Neuron device, you can use `nki.simulate()` to run your kernel using `NumPy` input/output tensors and types. +An example can be found in the [layernorm tutorial test](test/unit/test_tutorials_layernorm.py). However, kernels with _only_ simulation tests will not be accepted. + +#### Requirements for Kernels Targeting `src/reference/` + +All kernels located in this folder need to have the following tests. + +1. Numeric accuracy tests with `nki.baremetal`. The output from the kernel +must be validated against a CPU reference implementation. See `test_flash_attn_fwd_numerical` in [test_flash_attn_fwd.py](test/unit/test_flash_attn_fwd.py) as an example. Documentation for `nki.baremetal` is available at [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/api/generated/nki.baremetal.html). + +2. Performance benchmark tests with `nki.benchmark`. The unit test must have performance checks. At a minimum, put an assertion to verify p99 latency meets a certain threshold. See `test_flash_attn_fwd_perf` in [test_flash_attn_fwd.py](test/unit/test_flash_attn_fwd.py) as an example. Documentation for `nki.benchmark` is available at [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/api/generated/nki.benchmark.html) + +3. End-to-End integration tests that use your kernel in a model. + + a. Each test should be in its own separate folder. + + b. Each Test must have a `run.sh` script, that accepts an argument \. See [run.sh of FlashAttention](test/integration/flash_attention/run.sh) as an example. + + c. The test scripts must produce benchmark results with the `benchmark` function, located in [LatencyCollector.py](test/integration/perf_utils/LatencyCollector.py). The `benchmark` function will write the latency of your E2E model to the `test_result.json`. + + d. Register your test target in [run_integration.sh](test/integration/run_integration.sh). + + +### Coding Guidelines +Most guidelines are covered by a **PEP-8** check on all newly submitted code, which covers aspects such as code layout and basic Python naming conventions. +In addition to PEP-8, we use the following NKI specific style guidelines: + +1. **Abbreviations** + * Importing NKI modules should use consistent names. For example, + ``` + import neuronxcc.nki as nki + import neuronxcc.nki.isa as nisa + import neuronxcc.nki.language as nl + import neuronxcc.nki.typing as nt + import numpy as np + ``` +2. Variable Names + * Indexing should specify partition and free dimensions along with the variable they are used for. For example: + The index for the partition dimension for tile `a` would be + ``` + i_p_a = nl.arange(128)[:, None] + ``` + while the index for the free dimension for tile `b` would be + ``` + i_f_b = nl.arange(512)[None, :] + ``` + * Name loop variables, indices, and buffers consistently, and specify their intended use in the name. + +3. Documentation + * New kernels should containing inline docstrings that describe the semantics of the kernel, and provide information on the IO layout. + Upon release, we generate the documentation for our kernels and merge them into the NKI API documentation which will appear in the official AWS NKI documentation. + + +## Finding Contributions to Work on Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any ['help wanted'](https://github.com/aws-neuron/nki-samples/labels/help%20wanted) issues is a great place to start. @@ -51,7 +106,7 @@ For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of opensource-codeofconduct@amazon.com with any additional questions or comments. -## Security issue notifications +## Security Issue Notifications If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..3b1fad4 --- /dev/null +++ b/LICENSE @@ -0,0 +1,16 @@ +MIT No Attribution + +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/LICENSE.txt b/LICENSE.txt deleted file mode 100644 index e7f39e2..0000000 --- a/LICENSE.txt +++ /dev/null @@ -1 +0,0 @@ -TODO: Fill LICENSE after it is finalized \ No newline at end of file diff --git a/README.md b/README.md index 2d97f6b..60602a9 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ At the core of the Neuron SDK is the Neuron Compiler, which takes computation gr them into highly optimized machine code. [NKI](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki) is a Python-based programming environment designed for the compiler which -adopts commonly used NumPy andTriton-like syntax along with tile-level semantics. +adopts commonly used NumPy and Triton-like syntax along with tile-level semantics. NKI also interoperates with the Neuron Profiler, providing insights into performance bottlenecks and instruction latencies. It offers tensor printing support, standard error messaging, and built-in kernel simulation capabilities for efficient debugging purposes. NKI offers two types of programming interfaces: @@ -16,31 +16,25 @@ enabling bare-metal access to the chip for full control. ![alt "High-level flow of NKI in the Neuron Compiler. NKI emits IR immediately before the backend-IR compilation stage"](doc_assets/high-level-nki-flow.png#center "High-Level NKI Flow") -### nki.language -**nki.language** enables precise control over computation and data movement on NeuronCores-- the processing units within AWS Inferentia and Trainium chips. -Developers can control data movement between device memory and on-chip memory explicitly using `nl.load()` and `nl.store()` operations. -Developers can then perform the desired computations on loaded tensors, such as element-wise operations or tensor contractions, -providing crucial performance improvements. Additionally, developers can control how computation is performed on different compute engines inside NeuronCores. -nki.language APIs are considered high-level APIs and are designed for "ease of use" for ML practitioners. -To achieve the best performance, developers can enlist the nki.isa APIs. - -![alt "Diagram of the NeuronCore Architecture. It shows 4 engines: tensor, vector, scalar, and GPSIMD, connected to SBUF memory. The tensor, vector, and scalar engines are also connected to a high-speed PSUM memory bank that supports accumulate on write. Lastly the HBM (DRAM) is connected to both SBUF and PSUM memory banks."](doc_assets/pm-nc.png#scale_50#center "NeuronCore Architecture") - -### nki.isa - -**nki.isa** provides direct access to chip instructions to offer flexibility and fine-grained control over instruction usage and performance optimizations. -Developers can utilize various `nki.isa` instructions using the Tensor, Vector, Scalar, GP-SIMD, and DMA engines. -For example, developers can use `nki.isa.nc_matmul()` to compute a matrix multiplication using Tensor Engine. -Alternatively, developers can use `nki.isa.activation()` to apply an activation function on every element of the input tile using Scalar Engine. +## Documentation +The latest NKI documentation can be found on the AWS Documentation site, [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/). +Documentation for NKI kernels are both inline (docstring) and available on the documentation site's +[kernel API reference page](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/api/nki.kernels.html). ## Repository Structure ### src #### reference -The [reference kernels](src/reference/) are optimized reference kernels. All kernels located in this folder must have all of numeric accuracy tests +This folder contains the source code of the `neuronxcc.nki.kernels`, and they are optimized kernels from the Neuron Team serving as samples. + +All kernels located in this folder have numeric accuracy tests and performance benchmarks defined in the [test](test/) directory. We also demonstrate using these kernels end-to-end in our [integration tests](test/integration/). +Note that these kernels are already being deployed as part of the Neuron stack. With flash attention as an example, +[compiling Llama models with transformers-neuronx](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/transformers-neuronx/transformers-neuronx-developer-guide.html) +will automatically invoke the `flash_fwd` kernel in [attention.py](src/reference/attention.py). Therefore, replacing the framework operators with these NKI kernels likely won't result in extra performance benefit. + #### tutorials The [tutorial kernels](src/tutorials/) are for educational purpose and include the kernels that are used in NKI guides. @@ -58,65 +52,16 @@ verify the numeric accuracy of the operation, and publish performance results to The [integration tests](tests/integration) folder contains integration tests of (selected) kernels. They verify the numeric accuracy of the model’s output, and publish end-to-end performance results into the [integration benchmarks](docs/benchmarks/integration) folder. -## Documentation -The latest NKI documentation can be found on the AWS Documentation site, [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/). -Documentation for NKI kernels are both inline (docstring) and available on the documentation site's -[kernel API reference page](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/api/nki.kernels.html). +## Maintenance Policy +NKI is currently released as **beta** while we gather feedback from our users and integrate it into the API. NKI API follow the [Neuron SDK Maintenance Policy](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/sdk-policy.html). -## Versioning -NKI is currently released as **beta** while we gather feedback from our users and integrate it into the API. We will also be updating the NKI API as needed -to support new Neuron and Neuron Compiler features. While NKI is in beta we may need to make backwards-incompatible changes to incorporate feedback from -our users or to support new use-cases of NKI on Neuron devices. Upon releasing NKI as generally available (GA), we will commit to not making backwards -incompatible changes to the NKI API for any supported version of the Neuron compiler. +## Getting Help +Have a look at the GitHub issues for this repository where you will find past issues customers have encountered with workarounds and clarifications. +If you cannot find a suitable issue for your use-case feel free to [file an issue](https://github.com/aws-neuron/nki-samples/issues/new) to ask for assistance or to suggest improvements. Please read [CONTRIBUTING.md](CONTRIBUTING.md) for detailed information on submitting issues. ## Contributing -We invite you to join the NKI community! If you'd like to share kernels you create with the community, we welcome your contributions to this repository via. -GitHub pull-requests as well as through filed issues discussing features, bug fixes, new use-cases, and API improvements. - -### Getting Help -Have a look at the GitHub issues for this repository where you will find past issues customers have encountered with workarounds and clarifications. -If you cannot find a suitable issue for your use-case feel free to file an issue asking for assistance or to suggest improvements. - -In addition, extensive NKI documentation can be found [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki). - -### Testing and Merging -Running the binaries for a NKI kernel require Neuron devices on an AWS EC2 instance from trn1, trn1n, or inf2 instance families. -Details on setting up an instance can be found in [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-setup.html). - -Before merging, the Neuron team will need to internally test and verify kernels work as expected. If the change is accepted, -we will manually merge your changes, and it will be merged here upon the next release. - -If you would like to test your kernel without a requiring a Neuron device, you can use `nki.simulate()` to run your kernel using `NumPy` tensors and types. -An example can be found in the [layernorm tutorial test](test/unit/test_tutorials_layernorm.py). - -### Coding Guidelines -Most guidelines are covered by a **PEP-8** check on all newly submitted code, which covers aspects such as code layout and basic Python naming conventions. -In addition to PEP-8, we use the following NKI specific style guidelines: - -1. **Abbreviations** - * Importing NKI modules should use consistent names. For example, - ``` - import neuronxcc.nki as nki - import neuronxcc.nki.isa as nisa - import neuronxcc.nki.language as nl - import neuronxcc.nki.typing as nt - import numpy as np - ``` -2. Variable Names - * Indexing should specify partition and free dimensions along with the variable they are used for. For example: - The index for the partition dimension for tile `a` would be - ``` - i_p_a = nl.arange(128)[:, None] - ``` - while the index for the free dimension for tile `b` would be - ``` - i_f_b = nl.arange(512)[None, :] - ``` - * Name loop variables, indices, and buffers consistently, and specify their intended use in the name. - -3. Documentation - * New kernels should containing inline docstrings that describe the semantics of the kernel, and provide information on the IO layout. - Upon release, we generate the documentation for our kernels and merge them into the NKI API documentation which will appear in the official AWS NKI documentation. +We invite you to join the NKI community! If you'd like to share kernels you create with the community, we welcome your contributions to this repository via +GitHub pull-requests as well as through filed issues discussing features, bug fixes, new use-cases, and API improvements. Please see [CONTRIBUTING.md](CONTRIBUTING.md) for more information ## Licensing This repository is licensed under the terms of the [MIT-0 License](LICENSE.txt) \ No newline at end of file From 23835a53cfcb83fcf6bb0eeb0cb98f39d15a8019 Mon Sep 17 00:00:00 2001 From: aws-qieqingy <122939906+aws-qieqingy@users.noreply.github.com> Date: Sat, 21 Dec 2024 16:27:43 -0500 Subject: [PATCH 3/3] Sync latest kernels and unit tests (#8) * Add latest kernels and unit tests for NeuronSDK release 2.21 * Move code into src/nki_samples to make unit test executable * Add Github workflow to run simulation tests on main branch and incoming PRs --- .github/workflows/run_simulation_tests.yml | 32 + src/nki_samples/reference/__init__.py | 10 + .../reference/allocated_attention.py | 12 +- .../reference/allocated_fused_linear.py | 0 src/{ => nki_samples}/reference/attention.py | 926 ++++++++++-------- src/{ => nki_samples}/reference/tutorial.py | 0 src/{ => nki_samples}/reference/vision.py | 0 .../average_pool2d/average_pool2d_jax.py | 0 .../average_pool2d_nki_kernels.py | 0 .../average_pool2d/average_pool2d_torch.py | 0 .../fused_mamba/mamba_nki_kernels.py | 0 .../tutorials/fused_mamba/mamba_torch.py | 0 .../layernorm/layernorm_nki_kernel.py | 0 .../tutorials/layernorm/layernorm_torch.py | 0 .../matrix_multiplication_nki_kernels.py | 0 .../matrix_multiplication_torch.py | 0 .../tutorials/rmsnorm/rmsnorm_jax.py | 0 .../tutorials/rmsnorm/rmsnorm_nki_kernels.py | 0 .../tutorials/rmsnorm/rmsnorm_torch.py | 0 .../sd_attention/sd_attention_nki_kernels.py | 0 .../sd_attention/sd_attention_torch.py | 0 .../tensor_addition/tensor_addition_jax.py | 0 .../tensor_addition_nki_kernels.py | 0 .../tensor_addition/tensor_addition_torch.py | 0 .../tutorials/transpose2d/transpose2d_jax.py | 0 .../transpose2d/transpose2d_nki_kernels.py | 0 .../transpose2d/transpose2d_torch.py | 0 src/reference/__init__.py | 32 - test/unit/README.md | 8 +- test/unit/__main__.py | 14 - test/unit/conftest.py | 28 + test/unit/test_SD_attention_small_head.py | 14 +- .../test_allocated_SD_attention_small_head.py | 15 +- test/unit/test_flash_attn_bwd.py | 21 +- test/unit/test_flash_attn_fwd.py | 24 +- test/unit/test_neuron_profile.py | 86 ++ test/unit/test_resize_nearest.py | 16 +- test/unit/test_rmsnorm_qkv.py | 18 +- test/unit/test_select_and_scatter.py | 16 +- 39 files changed, 758 insertions(+), 514 deletions(-) create mode 100644 .github/workflows/run_simulation_tests.yml create mode 100644 src/nki_samples/reference/__init__.py rename src/{ => nki_samples}/reference/allocated_attention.py (97%) rename src/{ => nki_samples}/reference/allocated_fused_linear.py (100%) rename src/{ => nki_samples}/reference/attention.py (54%) rename src/{ => nki_samples}/reference/tutorial.py (100%) rename src/{ => nki_samples}/reference/vision.py (100%) rename src/{ => nki_samples}/tutorials/average_pool2d/average_pool2d_jax.py (100%) rename src/{ => nki_samples}/tutorials/average_pool2d/average_pool2d_nki_kernels.py (100%) rename src/{ => nki_samples}/tutorials/average_pool2d/average_pool2d_torch.py (100%) rename src/{ => nki_samples}/tutorials/fused_mamba/mamba_nki_kernels.py (100%) rename src/{ => nki_samples}/tutorials/fused_mamba/mamba_torch.py (100%) rename src/{ => nki_samples}/tutorials/layernorm/layernorm_nki_kernel.py (100%) rename src/{ => nki_samples}/tutorials/layernorm/layernorm_torch.py (100%) rename src/{ => nki_samples}/tutorials/matrix_multiplication/matrix_multiplication_nki_kernels.py (100%) rename src/{ => nki_samples}/tutorials/matrix_multiplication/matrix_multiplication_torch.py (100%) rename src/{ => nki_samples}/tutorials/rmsnorm/rmsnorm_jax.py (100%) rename src/{ => nki_samples}/tutorials/rmsnorm/rmsnorm_nki_kernels.py (100%) rename src/{ => nki_samples}/tutorials/rmsnorm/rmsnorm_torch.py (100%) rename src/{ => nki_samples}/tutorials/sd_attention/sd_attention_nki_kernels.py (100%) rename src/{ => nki_samples}/tutorials/sd_attention/sd_attention_torch.py (100%) rename src/{ => nki_samples}/tutorials/tensor_addition/tensor_addition_jax.py (100%) rename src/{ => nki_samples}/tutorials/tensor_addition/tensor_addition_nki_kernels.py (100%) rename src/{ => nki_samples}/tutorials/tensor_addition/tensor_addition_torch.py (100%) rename src/{ => nki_samples}/tutorials/transpose2d/transpose2d_jax.py (100%) rename src/{ => nki_samples}/tutorials/transpose2d/transpose2d_nki_kernels.py (100%) rename src/{ => nki_samples}/tutorials/transpose2d/transpose2d_torch.py (100%) delete mode 100644 src/reference/__init__.py delete mode 100644 test/unit/__main__.py create mode 100644 test/unit/conftest.py create mode 100644 test/unit/test_neuron_profile.py diff --git a/.github/workflows/run_simulation_tests.yml b/.github/workflows/run_simulation_tests.yml new file mode 100644 index 0000000..7ea89ef --- /dev/null +++ b/.github/workflows/run_simulation_tests.yml @@ -0,0 +1,32 @@ +name: Run Python Simulation Tests + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +permissions: + contents: read + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Set up Python 3.10 + uses: actions/setup-python@v3 + with: + python-version: "3.10" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip config set global.extra-index-url https://pip.repos.neuron.amazonaws.com + python -m pip install wget awscli + python -m pip install pytest + python -m pip install neuronx-cc==2.* + - name: Test with pytest + run: | + PYTHONPATH=$PYTHONPATH:src/ pytest test/unit/ --simulation-only \ No newline at end of file diff --git a/src/nki_samples/reference/__init__.py b/src/nki_samples/reference/__init__.py new file mode 100644 index 0000000..c9e5d37 --- /dev/null +++ b/src/nki_samples/reference/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2023, Amazon.com. All Rights Reserved + +""" +Package containing public kernels for Neuron Kernel Interface (NKI). + +Kernels here are also available in the `neuronxcc.nki.kernels` namespace, and they +are synced with this repository on every Neuron SDK release. + +https://github.com/aws-neuron/nki-samples +""" diff --git a/src/reference/allocated_attention.py b/src/nki_samples/reference/allocated_attention.py similarity index 97% rename from src/reference/allocated_attention.py rename to src/nki_samples/reference/allocated_attention.py index 564412c..94b513f 100644 --- a/src/reference/allocated_attention.py +++ b/src/nki_samples/reference/allocated_attention.py @@ -9,13 +9,13 @@ @nki.jit def allocated_fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, use_causal_mask=False, - mixed_percision=True): + mixed_precision=True): """ Allocated fused self attention kernel for small head size Stable Diffusion workload. - Computes (softmax(Q.T@K)V).T. The wired layout is choosen to avoid transpose as + Computes (softmax(Q.T@K)V).T. The wired layout is chosen to avoid transpose as much as possible to simplify the debug. The kernel uses the direct allocation API, - and implements double buffering to achive better performance than automatic allocation. + and implements double buffering to achieve better performance than automatic allocation. As of NeuronSDK 2.21, it achieves 18% better performance than auto allocated equivalent. To see the performance gap, you can use ``force_auto_alloc`` decorator to override manual allocation and benchmark the performance difference. @@ -34,14 +34,14 @@ def allocated_fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, IO tensor dtypes: - This kernel assumes all IO tensors have the same dtype - - If mixed_percision is True, then all Tensor Engine operation will be performed in + - If mixed_precision is True, then all Tensor Engine operation will be performed in bfloat16 and accumulation will be performed in float32. Otherwise the intermediates will be in the same type as the inputs. """ # Use q_ref dtype as the intermediate tensor dtype # Assume all IO tensors have the same dtype kernel_dtype = np.float32 - pe_in_dt = nl.bfloat16 if mixed_percision else kernel_dtype + pe_in_dt = nl.bfloat16 if mixed_precision else kernel_dtype kernel_dtype_itemsize = np.dtype(kernel_dtype).itemsize pe_in_dt_itemsize = np.dtype(pe_in_dt).itemsize @@ -211,7 +211,7 @@ def psum_addr(bank_map, idx, pdim_size, fdim_size): on_true_tile=qk_psum[i_interleave_grp, i_k_seq_tile,ip_qk, if_qk], on_false_value=-9984.0, dtype=kernel_dtype) else: # Copy result to SBUF and find partial maximum for softmax - qk_res_buf[i_interleave_grp, ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk] = nisa.tensor_scalar_reduce(qk_psum[i_interleave_grp, i_k_seq_tile,ip_qk, if_qk], np.add, 1.0, + qk_res_buf[i_interleave_grp, ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk] = nisa.tensor_scalar_reduce(data=qk_psum[i_interleave_grp, i_k_seq_tile,ip_qk, if_qk], op0=np.add, operand0=1.0, reduce_op=np.max, reduce_res=neg_max_res[i_interleave_grp, ip_max, i_k_seq_tile], dtype=kernel_dtype) # Find global max from tiles diff --git a/src/reference/allocated_fused_linear.py b/src/nki_samples/reference/allocated_fused_linear.py similarity index 100% rename from src/reference/allocated_fused_linear.py rename to src/nki_samples/reference/allocated_fused_linear.py diff --git a/src/reference/attention.py b/src/nki_samples/reference/attention.py similarity index 54% rename from src/reference/attention.py rename to src/nki_samples/reference/attention.py index 9bf0444..3c456a6 100644 --- a/src/reference/attention.py +++ b/src/nki_samples/reference/attention.py @@ -15,6 +15,7 @@ from functools import reduce as functools_reduce from operator import mul as operator_mul + def n_elts(shape): return functools_reduce(operator_mul, shape, 1) @@ -27,21 +28,59 @@ def linearize(shape, indices): def div_ceil(n, d): return (n + d - 1) // d + @dataclass(frozen=True) class FlashConfig: """ Config class for flash attention with default values """ seq_tile_size:int = 2048 + attn_core_tile_size:int = 256 training:bool = True should_transpose_v:bool = False + lse_dtype: str = "" + + +@nki.jit(mode='trace') +def transpose_p_local(p_local_transposed, p_local, LARGE_TILE_SZ): + for i in nl.affine_range(LARGE_TILE_SZ // 512): + if nisa.get_nc_version() == nisa.nc_version.gen3: + p_local_t_tmp = nl.ndarray((par_dim(128), 512), buffer=nl.sbuf, dtype=p_local.dtype) + else: + p_local_t_tmp = nl.ndarray((par_dim(128), 512), buffer=nl.psum, dtype=np.float32) + + for j in nl.affine_range(512 // 128): + j_128_slice = nl.ds(j * 128, 128) + i_j_128_slice = nl.ds(i * 512 + j * 128, 128) + + if nisa.get_nc_version() == nisa.nc_version.gen3: + p_local_t_tmp[:, j_128_slice] = nisa.dma_transpose( + p_local[:, i_j_128_slice]) + else: + p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose( + p_local[:, i_j_128_slice]) + + p_local_transposed[:, nl.ds(i * 512, 512)] = nl.copy( + p_local_t_tmp, dtype=p_local_transposed.dtype) - __annotations__ = { - 'seq_tile_size': int, - 'training': bool, - 'should_transpose_v': bool - } +@nki.jit(mode='trace') +def dropout_p_local(p_local, dropout_p, dropout_p_tensor, seed_tensor, + seed_offset_base, k_r_i, REDUCTION_TILE): + B_F_SIZE = 512 + for k_d_i in nl.sequential_range(REDUCTION_TILE // B_F_SIZE): + p_local_f_slice = nl.ds(k_r_i * REDUCTION_TILE + k_d_i * B_F_SIZE, B_F_SIZE) + + offset = k_d_i + seed_offset_base + offset_seed = nl.add(seed_tensor, offset, dtype=nl.int32) + nl.random_seed(seed=offset_seed) + softmax_dropout = nl.dropout(p_local[:, p_local_f_slice], + rate=dropout_p_tensor[:, 0]) + p_local[:, p_local_f_slice] = nl.multiply( + softmax_dropout, 1 / (1 - dropout_p)) + + +@nki.jit(mode='trace') def _flash_attention_core(q_local_tile, k, v, q_h_per_k_h, seqlen_q, nheads, o_buffer, l_buffer, m_buffer, @@ -49,169 +88,212 @@ def _flash_attention_core(q_local_tile, k, v, local_k_large_tile_idx, kernel_dtype, acc_type, flash_config: FlashConfig, - olm_buffer_idx=None, - global_k_large_tile_idx=None, - use_causal_mask=False, initialize=False, + use_causal_mask, initialize, B_P_SIZE=128, B_F_SIZE=512, B_D_SIZE=128, - dropout_p=0.0, dropout_p_tensor=None, seed_tensor=None - ): + dropout_p=0.0, dropout_p_tensor=None, seed_tensor=None, + logit_bias_tile=None): """ The flash attention core function to calcualte self attention between a tile of q and a block of K and V. The q_local_tile has (B_P_SIZE, B_F_SIZE), which is loaded into the SBUF already. The block size of K and V is defined in the seq_tile_size of the flash_config. The results are stored in the following three buffers - o_buffer: (num_large_k_tile, B_P_SIZE, d) - l_buffer: (num_large_k_tile, B_P_SIZE, 1) - m_buffer: (num_large_k_tile, B_P_SIZE, 1) + o_buffer: (B_P_SIZE, d) + l_buffer: (B_P_SIZE, 1) + m_buffer: (B_P_SIZE, 1) """ LARGE_TILE_SZ = flash_config.seq_tile_size - REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2) num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE seqlen_k = k.shape[-1] seq_q_num_tiles = seqlen_q // B_P_SIZE seq_k_num_tiles = seqlen_k // B_F_SIZE - # Indices used by the distributed attention - if global_k_large_tile_idx is None: - global_k_large_tile_idx = local_k_large_tile_idx - if olm_buffer_idx is None: - olm_buffer_idx = local_k_large_tile_idx - - i_q_p = nl.arange(B_P_SIZE)[:, None] - i_q_f = nl.arange(B_F_SIZE)[None, :] - i_d_p = nl.arange(B_D_SIZE)[:, None] - i_d_f = nl.arange(B_D_SIZE)[None, :] - i_f_128 = nl.arange(B_P_SIZE)[None, :] - i_f_k_tiles = nl.arange(num_k_tile_per_large_tile)[None, :] - - # mask are used to only apply computation to the lower half of the matrix, - # which reduce the arthimetic intensity by half - forward_mask = q_tile_idx * B_P_SIZE >= global_k_large_tile_idx * LARGE_TILE_SZ if use_causal_mask else None - # Negation mask is the negation of `forward_mask`, which is used for the - # instructions executed on the blocks in the upper triangular section - # of the matrix. - # These instructions should not be executed when causual mask is disabled. - # - # For example, the o_buffer still needs to be propagated from o[j-1] to o[j] in - # the upper triangular of the matrix. - negation_mask = q_tile_idx * B_P_SIZE < global_k_large_tile_idx * LARGE_TILE_SZ if use_causal_mask else None - qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), buffer=nl.sbuf, dtype=acc_type) max_local = nl.ndarray((par_dim(B_P_SIZE), num_k_tile_per_large_tile), dtype=acc_type) + for k_i in nl.affine_range(num_k_tile_per_large_tile): - qk_psum = nl.zeros((par_dim(B_P_SIZE), B_F_SIZE), + k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE) + + qk_psum = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE), dtype=np.float32, buffer=nl.psum) # (128, 512) - multiplication_required_selection = global_k_large_tile_idx * LARGE_TILE_SZ + k_i * B_F_SIZE <= q_tile_idx * B_P_SIZE if use_causal_mask else None - qk_psum[i_q_p, i_q_f] += nl.matmul(q_local_tile, k[i_d_p, k_i * B_F_SIZE + i_q_f], transpose_x=True, - mask=multiplication_required_selection) # (p(128), 512) + if use_causal_mask: + multiplication_required_selection = q_tile_idx * B_P_SIZE >= local_k_large_tile_idx * LARGE_TILE_SZ + k_i * B_F_SIZE + else: + multiplication_required_selection = True + + if multiplication_required_selection: + qk_psum[:, :] = nl.matmul(q_local_tile, k[:, k_i_b_f_slice], transpose_x=True) # (p(128), 512) + else: + qk_psum[:, :] = 0 if use_causal_mask: - left_diagonal_selection = q_tile_idx * B_P_SIZE >= global_k_large_tile_idx * LARGE_TILE_SZ + (k_i + 1) * B_F_SIZE - diagonal_and_right_selection = (q_tile_idx * B_P_SIZE < global_k_large_tile_idx * LARGE_TILE_SZ + (k_i + 1) * B_F_SIZE) & forward_mask + left_diagonal_selection = q_tile_idx * B_P_SIZE >= local_k_large_tile_idx * LARGE_TILE_SZ + (k_i + 1) * B_F_SIZE + diagonal_and_right_selection = (q_tile_idx * B_P_SIZE < local_k_large_tile_idx * LARGE_TILE_SZ + (k_i + 1) * B_F_SIZE) + right_diagonal_selection = ((q_tile_idx + 1) * B_P_SIZE <= local_k_large_tile_idx * LARGE_TILE_SZ + k_i * B_F_SIZE) + diagonal = ((q_tile_idx * B_P_SIZE < local_k_large_tile_idx * LARGE_TILE_SZ + (k_i + 1) * B_F_SIZE) & + ((q_tile_idx + 1) * B_P_SIZE > local_k_large_tile_idx * LARGE_TILE_SZ + k_i * B_F_SIZE)) + i_q_p, i_q_f = nl.mgrid[0:B_P_SIZE, 0:B_F_SIZE] q_pos = q_tile_idx * B_P_SIZE + i_q_p - k_pos = global_k_large_tile_idx * LARGE_TILE_SZ + k_i * B_F_SIZE + i_q_f + k_pos = local_k_large_tile_idx * LARGE_TILE_SZ + k_i * B_F_SIZE + i_q_f pred = q_pos >= k_pos - # For tiles on and to the right of the diagonal, need to do affine_select. - # Magic number -9984.0 to replace -inf similar to what Tensorizer uses - qk_res_buf[i_q_p, k_i * B_F_SIZE + i_q_f] = nisa.affine_select( - pred=pred, - on_true_tile=qk_psum[i_q_p, i_q_f], on_false_value=-9984.0, dtype=kernel_dtype, - mask=diagonal_and_right_selection) - - # For tiles on the left of the diagonal, direct copy, no select required. - qk_res_buf[i_q_p, k_i * B_F_SIZE + i_q_f] = \ - nl.copy(qk_psum[i_q_p, i_q_f], dtype=kernel_dtype, mask=left_diagonal_selection) + + qk_select_tmp = nl.ndarray(qk_psum.shape, dtype=qk_psum.dtype, buffer=nl.sbuf) + + if logit_bias_tile is not None: + if right_diagonal_selection: + qk_select_tmp[...] = qk_psum + + # For tiles to the right of the diagonal, do affine_select. + # Magic number -9984.0 to replace -inf similar to what Tensorizer uses + qk_res_buf[:, k_i_b_f_slice] = nisa.affine_select( + pred=pred, + on_true_tile=qk_select_tmp, on_false_value=-9984.0, dtype=acc_type) + + # For tiles on the diagonal, add logit bias and need to do affine_select. + intermediate = \ + nl.add(qk_psum, logit_bias_tile[:, k_i_b_f_slice], + dtype=acc_type, mask=diagonal) + qk_res_buf[:, k_i_b_f_slice] = nisa.affine_select( + pred=pred, + on_true_tile=intermediate, on_false_value=-9984.0, dtype=acc_type, + mask=diagonal) + + # For tiles on the left of the diagonal, just add logit bias, no select required. + qk_res_buf[:, k_i_b_f_slice] = \ + nl.add(qk_psum, logit_bias_tile[:, k_i_b_f_slice], + dtype=acc_type, mask=left_diagonal_selection) + else: + # For tiles on and to the right of the diagonal, need to do affine_select. + # Magic number -9984.0 to replace -inf similar to what Tensorizer uses + if diagonal_and_right_selection: + qk_select_tmp[...] = qk_psum + + qk_res_buf[:, k_i_b_f_slice] = nisa.affine_select( + pred=pred, + on_true_tile=qk_select_tmp, on_false_value=-9984.0, dtype=acc_type) + + # For tiles on the left of the diagonal, direct copy, no select required. + qk_res_buf[:, k_i_b_f_slice] = \ + nl.copy(qk_psum, dtype=acc_type, mask=left_diagonal_selection) else: - # Simply send psum result back to sbuf - qk_res_buf[i_q_p, k_i * B_F_SIZE + i_q_f] = \ - nl.copy(qk_psum[i_q_p, i_q_f], dtype=kernel_dtype) + if logit_bias_tile is not None: + # Simply add logit bias which copies back to sbuf at the same time + qk_res_buf[:, k_i_b_f_slice] = \ + nl.add(qk_psum, logit_bias_tile[:, k_i_b_f_slice], dtype=acc_type) + else: + # Simply send psum result back to sbuf + qk_res_buf[:, k_i_b_f_slice] = nl.copy(qk_psum, dtype=acc_type) # Calculate max of the current tile - max_local[i_q_p, k_i] = nisa.tensor_reduce(np.max, qk_res_buf[i_q_p, k_i * B_F_SIZE + i_q_f], axis=(1,), - dtype=acc_type, negate=False, mask=forward_mask) - - max_ = nisa.tensor_reduce(np.max, max_local[i_q_p, i_f_k_tiles], axis=(1, ), - dtype=acc_type, negate=False, mask=forward_mask) - if not initialize: - m_previous = nl.copy(m_buffer[olm_buffer_idx - 1, i_q_p, 0]) - m_buffer[olm_buffer_idx, i_q_p, 0] = nl.maximum(m_previous, max_, mask=forward_mask) # (128,1) - if use_causal_mask: - m_buffer[olm_buffer_idx, i_q_p, 0] = nl.copy(m_previous, mask=negation_mask) + max_local[:, k_i] = nisa.tensor_reduce( + np.max, qk_res_buf[:, k_i_b_f_slice], axis=(1,), dtype=acc_type, + negate=False) - m_current = m_buffer[olm_buffer_idx, i_q_p, 0] - # Compute scaling factor - alpha = nisa.activation(np.exp, m_previous, bias=-1*m_current, scale=1.0, mask=forward_mask) - o_previous = nl.copy(o_buffer[olm_buffer_idx-1, i_q_p, i_d_f], mask=forward_mask) - o_previous_scaled = nl.multiply(o_previous, alpha, mask=forward_mask) - else: - m_buffer[0, i_q_p, 0] = nl.copy(max_) + max_ = nisa.tensor_reduce(np.max, max_local[:, :], axis=(1, ), + dtype=acc_type, negate=False) + + o_previous_scaled = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE), dtype=o_buffer.dtype) + + if initialize: + m_buffer[:, 0] = nl.copy(max_) m_current = max_ + else: + m_previous = nl.copy(m_buffer[:, 0]) + m_buffer[:, 0] = nl.maximum(m_previous, max_) # (128,1) + + m_current = m_buffer[:, 0] + # Compute scaling factor + alpha = nisa.activation(np.exp, m_current, bias=m_previous, scale=-1.0) + o_previous_scaled[...] = nl.multiply(o_buffer[:, :], alpha) p_local = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype) - i_r_f = nl.arange(REDUCTION_TILE)[None,: ] + REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2) + p_partial_sum = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE), dtype=acc_type) + for k_r_i in nl.affine_range(LARGE_TILE_SZ // REDUCTION_TILE): - # compute exp(qk-max) - p_local[i_q_p, k_r_i * REDUCTION_TILE + i_r_f] = \ - nisa.activation(np.exp, - qk_res_buf[i_q_p, k_r_i * REDUCTION_TILE + i_r_f], - bias=-1 * m_current, - scale=1.0, - dtype=kernel_dtype, - mask=forward_mask) + k_r_i_reduce_slice = nl.ds(k_r_i * REDUCTION_TILE, REDUCTION_TILE) # dropout if dropout_p > 0.0: - for k_d_i in nl.sequential_range(REDUCTION_TILE // B_F_SIZE): - offset = k_d_i + k_r_i * (REDUCTION_TILE // B_F_SIZE) \ - + global_k_large_tile_idx * (LARGE_TILE_SZ // B_F_SIZE) \ - + q_tile_idx * seq_k_num_tiles \ - + (head_id * q_h_per_k_h + gqa_head_idx) * seq_k_num_tiles * seq_q_num_tiles \ - + batch_id * nheads * seq_k_num_tiles * seq_q_num_tiles - offset_seed = nl.add(seed_tensor[0, 0], offset, mask=forward_mask) - nl.random_seed(seed=offset_seed, mask=forward_mask) - softmax_dropout = nl.dropout(p_local[i_q_p, k_r_i * REDUCTION_TILE + k_d_i * B_F_SIZE + i_q_f], - rate=dropout_p_tensor[i_q_p, 0], - mask=forward_mask) - p_local[i_q_p, k_r_i * REDUCTION_TILE + k_d_i * B_F_SIZE + i_q_f] = \ - nl.multiply(softmax_dropout, 1 / (1 - dropout_p), mask=forward_mask) - - # Compute partial row-tile sum of exp(qk-max)) - p_partial_sum[i_q_p, k_r_i] = nl.sum(p_local[i_q_p, k_r_i * REDUCTION_TILE + i_r_f], axis=1, dtype=acc_type, mask=forward_mask) + # compute exp(qk-max) + p_local[:, k_r_i_reduce_slice] = \ + nisa.activation(np.exp, qk_res_buf[:, k_r_i_reduce_slice], + bias=-1 * m_current, scale=1.0, + dtype=kernel_dtype) + + seed_offset_base = k_r_i * (REDUCTION_TILE // B_F_SIZE) \ + + local_k_large_tile_idx * (LARGE_TILE_SZ // B_F_SIZE) \ + + q_tile_idx * seq_k_num_tiles \ + + (head_id * q_h_per_k_h + gqa_head_idx) * seq_k_num_tiles * seq_q_num_tiles \ + + batch_id * nheads * seq_k_num_tiles * seq_q_num_tiles + + dropout_p_local(p_local=p_local, dropout_p=dropout_p, + dropout_p_tensor=dropout_p_tensor, seed_tensor=seed_tensor, + seed_offset_base=seed_offset_base, k_r_i=k_r_i, + REDUCTION_TILE=REDUCTION_TILE) + + # Compute partial row-tile sum of exp(qk-max)) + # FIXME: Use activation accumulate and accumulate over k_r_i loop? + p_partial_sum[:, k_r_i] = nl.sum(p_local[:, k_r_i_reduce_slice], + axis=1, dtype=acc_type) + else: + # compute exp(qk-max) + # Compute partial row-tile sum of exp(qk-max)) + # FIXME: Use activation accumulate to accumulate over k_r_i loop? + p_local[:, k_r_i_reduce_slice] = \ + nisa.activation_reduce(np.exp, qk_res_buf[:, k_r_i_reduce_slice], + bias=-1 * m_current, scale=1.0, + reduce_op=nl.add, reduce_res=p_partial_sum[:, k_r_i], + dtype=kernel_dtype) + + ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type) p_local_transposed = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype) - for i_p_t in nl.affine_range(LARGE_TILE_SZ // 512): - p_local_t_tmp = nl.ndarray((par_dim(B_P_SIZE), 512), buffer=nl.psum, dtype=np.float32) - for i_p_t_local in nl.affine_range(512//128): - p_local_t_tmp[i_q_p, i_p_t_local*128 + i_f_128] = nisa.nc_transpose(p_local[i_q_p, i_p_t*512+i_p_t_local * B_P_SIZE + i_f_128], mask=forward_mask) - i_f_512 = nl.arange(512)[None, :] - p_local_transposed[i_q_p, i_p_t * 512 + i_f_512 ] = nl.copy(p_local_t_tmp[i_q_p, i_f_512], dtype=kernel_dtype, mask=forward_mask) - - ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type, mask=forward_mask) - pv_psum = nl.zeros((par_dim(B_P_SIZE), B_D_SIZE), dtype=np.float32, buffer=nl.psum) + transpose_p_local(p_local_transposed=p_local_transposed, p_local=p_local, + LARGE_TILE_SZ=LARGE_TILE_SZ) + + pv_psum = nl.zeros((par_dim(B_P_SIZE), B_D_SIZE), dtype=np.float32, + buffer=nl.psum, lazy_initialization=True) for k_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE): - pv_psum[i_q_p, i_d_f] += nl.matmul(p_local_transposed[i_q_p, k_i * B_P_SIZE + i_f_128], - v[k_i, i_q_p, i_d_f], - transpose_x=True, - mask=forward_mask) # (128, 128) (p(Br), d) + pv_psum[:, :] += nl.matmul(p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)], + v[k_i, :, :], transpose_x=True) # (128, 128) (p(Br), d) if initialize: - o_buffer[olm_buffer_idx, i_q_p, i_d_f] = nl.copy(pv_psum[i_q_p, i_d_f]) - l_buffer[olm_buffer_idx, i_q_p, 0] = nl.add(nl.log(ps), max_) + o_buffer[:, :] = nl.copy(pv_psum[:, :]) + l_buffer[:, 0] = nl.add(nl.log(ps), max_) else: - if use_causal_mask: - o_buffer[olm_buffer_idx, i_q_p, i_d_f] = nl.copy(o_buffer[olm_buffer_idx-1, i_q_p, i_d_f], mask=negation_mask) - o_buffer[olm_buffer_idx, i_q_p, i_d_f] = nl.add(o_previous_scaled, pv_psum, mask=forward_mask) + o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum) + + exp = nisa.activation(nl.exp, m_current, bias=l_buffer[:, 0], scale=-1.0) + l_buffer[:, 0] = nl.add(m_current, nisa.activation(nl.log, exp, bias=ps)) + + +@nki.jit(mode='trace') +def load_v_tile(v_hbm_tile, cur_v_tile, j, v_i, config): + LARGE_TILE_SZ = config.seq_tile_size + B_P_SIZE = 128 + + if not config.should_transpose_v: + cur_v_tile[v_i, :, :] = nl.load( + v_hbm_tile[nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE), :], + dtype=cur_v_tile.dtype) + return + + if nisa.get_nc_version() == nisa.nc_version.gen3: + cur_v_tile_transposed = nisa.dma_transpose( + v_hbm_tile[:, nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE)]) + cur_v_tile[v_i, :, :] = nisa.tensor_copy(cur_v_tile_transposed, + dtype=cur_v_tile.dtype) + return + + cur_v_tile[v_i, :, :] = nl.load_transpose2d( + v_hbm_tile[:, nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE)], + dtype=cur_v_tile.dtype) - l_prev = l_buffer[olm_buffer_idx-1, i_q_p, 0] - l_exp = nl.add(nl.exp(nl.subtract(l_prev, m_current, mask=forward_mask), mask=forward_mask), ps, mask=forward_mask) - l_buffer[olm_buffer_idx, i_q_p, 0] = nl.add(m_current, nl.log(l_exp, mask=forward_mask), mask=forward_mask) - if use_causal_mask: - l_buffer[olm_buffer_idx, i_q_p, 0] = nl.copy(l_buffer[olm_buffer_idx-1, i_q_p, 0], mask=negation_mask) @nki.jit -def flash_fwd(q, k, v, seed, +def flash_fwd(q, k, v, seed, logit_bias=None, softmax_scale=None, use_causal_mask=True, mixed_precision=True, @@ -224,27 +306,35 @@ def flash_fwd(q, k, v, seed, - k: shape (bs, nk_heads, d, seq_k) - v: shape (bs, nv_heads, d, seq_v) if config.should_transpose_v else (bs, nv_heads, seq_v, d) - seed: shape (1,) + - logit_bias: shape (bs, n_heads, seq_q, seq_k) - o: shape (bs, n_heads, seq_q, d) - lse: shape (bs, n_heads, nl.tile_size.pmax, seq // nl.tile_size.pmax) if training else None - This kernel requires seq_k == seq_v IO tensor dtypes: - This kernel assumes all IO tensors have the same dtype - - If mixed_percision is True, then all Tensor Engine operation will be performed in + - If mixed_precision is True, then all Tensor Engine operation will be performed in bfloat16 and accumulation will be performed in float32. Otherwise the intermediates will be in the same type as the inputs. Compile-time Constants: - softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)` - - mixed_precision: flag to set non-matmul ops in fp32 precision, defualt is set to `true`, if false, we use same precision as input types + - mixed_precision: flag to set non-matmul ops in fp32 precision, default is set to `true`, if false, we use same precision as input types - causal_mask: flag to set causal masking - - config: Instance of dataclass :class:`nki.kernels.attention.FlashConfig` with Performance config parameters for flash attention with default values + - config: Instance of :class:`nki.kernels.attention.FlashConfig` with Performance config parameters for flash attention with default values seq_tile_size: `default=2048`, size of the kv tile size for attention computation reduction training: bool to indicate training vs inference `default=True` Performance Notes: - For better performance, the kernel is tiled to be of size `LARGE_TILE_SZ`, and Flash attention math techniques are applied in unit - of `LARGE_TILE_SZ`. Seqlen that is not divisible by `LARGE_TILE_SZ` is not supported at the moment. + For better performance, the kernel is tiled to be of size `config.seq_tile_size`, and Flash attention math techniques are applied in unit + of `config.seq_tile_size`. Seqlen that is not divisible by `config.seq_tile_size` is not supported at the moment. + + For large seqlen, `o_buffer` will overflow the statebuf. the kernel is tile `o_buffer` based on the value of `config.attn_core_tile_size`. + This is a tradeoff between memory usage and performance. The default value of `config.attn_core_tile_size` is 256, which means the `o_buffer` + will roughly take half of the statebuf. The computes are also tiled accordingly. DMA will be rematerialized + `seqlen_q // B_P_SIZE // attn_core_tile_size times`. + + GQA support Notes: the spmd kernel for launching kernel should be on kv_heads instead of nheads @@ -273,26 +363,27 @@ def flash_fwd(q, k, v, seed, o = nl.ndarray((b, h, seqlen_q, d), dtype=q.dtype, buffer=nl.shared_hbm) if config.training: + if config.lse_dtype: + lse_dtype = getattr(nl, config.lse_dtype) + else: + lse_dtype = acc_type lse = nl.ndarray((b, h, nl.tile_size.pmax, seqlen_q // nl.tile_size.pmax), - dtype=acc_type, buffer=nl.shared_hbm) + dtype=lse_dtype, buffer=nl.shared_hbm) else: lse = None - i_q_p = nl.arange(B_P_SIZE)[:,None] - i_0_f = nl.arange(1)[None, :] - + assert nl.program_ndim() == 2,\ + f'Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!' batch_id = nl.program_id(axis=0) - - head_dims = list(range(1, nl.program_ndim())) - head_dims_shape = list(nl.num_programs(i) for i in head_dims) - head_dims_idx = list(nl.program_id(i) for i in head_dims) - head_id = linearize(head_dims_shape, head_dims_idx) + head_id = nl.program_id(axis=1) softmax_scale = softmax_scale or (1.0 / (d ** 0.5)) n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine LARGE_TILE_SZ = config.seq_tile_size + attn_core_tile_size = config.attn_core_tile_size + # FIXME: Add masking for different seqlen values. assert config.seq_tile_size >= 512, f" seq tile_size {config.seq_tile_size} cannot be less than 512" assert seqlen_k % LARGE_TILE_SZ == 0, f"Need seqlen_k to be divisible by {LARGE_TILE_SZ} but got {seqlen_k}" @@ -316,131 +407,107 @@ def flash_fwd(q, k, v, seed, dropout_p_tensor = None seed_local = None - for i_q_h in nl.affine_range(q_h_per_k_h): + if logit_bias is not None: + b_logit_bias, h_logit_bias, _, _ = logit_bias.shape + assert b_logit_bias == 1 and h_logit_bias == 1, "only support broadcasting logit_bias with batch 1, n_heads 1" + n_remat = div_ceil(n_tile_q, attn_core_tile_size) + attn_core_tile_size = min(n_tile_q, attn_core_tile_size) + + for i_q_h in nl.affine_range(q_h_per_k_h): # =============== Global Flash Attention accumulators ====================== # - o_buffer = nl.full((n_tile_q, num_large_k_tile, par_dim(B_P_SIZE), d), 0.0, dtype=acc_type, buffer=nl.sbuf) - l_buffer = nl.full((n_tile_q, num_large_k_tile, par_dim(B_P_SIZE), 1), 0.0, dtype=acc_type, buffer=nl.sbuf) - m_buffer = nl.full((n_tile_q, num_large_k_tile, par_dim(B_P_SIZE), 1), 0.0, dtype=acc_type) + l_buffer = nl.zeros((par_dim(B_P_SIZE), n_tile_q), dtype=acc_type, + buffer=nl.sbuf, lazy_initialization=True) # =============== Global Flash Attention accumulators END ================== # - j = 0 - cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype) - cur_v_tile = nl.ndarray((LARGE_TILE_SZ//B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE), dtype=kernel_dtype) - load_tile_size = B_P_SIZE - for k_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size): - load_p = nl.arange(B_D_SIZE)[:, None] - load_f = nl.arange(load_tile_size)[None, :] - cur_k_tile[load_p, load_tile_size*k_i+load_f] = nl.load( - k[batch_id, head_id, load_p, load_tile_size*k_i+load_f] - ) - if config.should_transpose_v: - for v_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size): - load_p = nl.arange(B_D_SIZE)[:, None] - load_f = nl.arange(B_P_SIZE)[None, :] - - loaded = nl.load(v[batch_id, head_id, load_p, B_P_SIZE*v_i+load_f], dtype=kernel_dtype) - store_p = nl.arange(B_P_SIZE)[:, None] - store_f = nl.arange(B_D_SIZE)[None, :] - cur_v_tile[v_i, store_p, store_f] = nisa.nc_transpose(loaded) - else: - for v_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size): - load_p = nl.arange(B_P_SIZE)[:, None] - load_f = nl.arange(B_D_SIZE)[None, :] - - cur_v_tile[v_i, load_p, load_f] = nl.load(v[batch_id, head_id, B_P_SIZE*v_i+load_p, load_f], dtype=kernel_dtype) - - for i in nl.affine_range(n_tile_q): - i_f_128 = nl.arange(B_P_SIZE)[None, :] - i_f_d = nl.arange(B_D_SIZE)[None, :] - i_p_d = nl.arange(B_D_SIZE)[:,None] - q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE),dtype=kernel_dtype) - q_tile[i_p_d, i_f_128] = nl.load(q[batch_id, - head_id * q_h_per_k_h + i_q_h, i_p_d, - i * B_P_SIZE + i_f_128], - dtype=kernel_dtype) * softmax_scale # load (d, 128) tile in SBUF - # handle first tile and compute max and lse explicitly by passing initialize=True - _flash_attention_core(q_local_tile=q_tile, k=cur_k_tile, v=cur_v_tile, - q_h_per_k_h=q_h_per_k_h, seqlen_q=seqlen_q, nheads=h, - o_buffer=o_buffer[i], l_buffer=l_buffer[i], m_buffer=m_buffer[i], - batch_id=batch_id, head_id=head_id, - gqa_head_idx=i_q_h, q_tile_idx=i, local_k_large_tile_idx=0, - kernel_dtype=kernel_dtype, acc_type=acc_type, - flash_config=config, use_causal_mask=use_causal_mask, - initialize=True, - B_P_SIZE=B_P_SIZE, B_F_SIZE=B_F_SIZE, B_D_SIZE=B_D_SIZE, - dropout_p=dropout_p, dropout_p_tensor=dropout_p_tensor, seed_tensor=seed_local) - - for j in nl.sequential_range(1, num_large_k_tile): - cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype) - cur_v_tile = nl.ndarray((LARGE_TILE_SZ//B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE), dtype=kernel_dtype) - load_tile_size = B_P_SIZE - for k_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size): - load_p = nl.arange(B_D_SIZE)[:, None] - load_f = nl.arange(load_tile_size)[None, :] - cur_k_tile[load_p, load_tile_size*k_i+load_f] = nl.load( - k[batch_id, head_id, load_p, j*LARGE_TILE_SZ+load_tile_size*k_i+load_f] - ) - if config.should_transpose_v: - for v_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size): - load_p = nl.arange(B_D_SIZE)[:, None] - load_f = nl.arange(B_P_SIZE)[None, :] + for i0 in nl.sequential_range(n_remat): + # =============== Global Flash Attention accumulators ====================== # + o_buffer = nl.zeros((attn_core_tile_size, par_dim(B_P_SIZE), d), dtype=acc_type, + buffer=nl.sbuf, lazy_initialization=True) + m_buffer = nl.zeros((attn_core_tile_size, par_dim(B_P_SIZE), 1), dtype=acc_type, + buffer=nl.sbuf, lazy_initialization=True) + # =============== Global Flash Attention accumulators END ================== # - loaded = nl.load(v[batch_id, head_id, load_p, j*LARGE_TILE_SZ+B_P_SIZE*v_i+load_f], dtype=kernel_dtype) - store_p = nl.arange(B_P_SIZE)[:, None] - store_f = nl.arange(B_D_SIZE)[None, :] - cur_v_tile[v_i, store_p, store_f] = nisa.nc_transpose(loaded) - else: + for j in nl.sequential_range(0, num_large_k_tile): + cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype) + cur_v_tile = nl.ndarray((LARGE_TILE_SZ // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE), dtype=kernel_dtype) + + cur_k_tile[:, :] = nl.load(k[batch_id, head_id, :, nl.ds(j*LARGE_TILE_SZ, LARGE_TILE_SZ)]) + + load_tile_size = B_P_SIZE + + v_hbm_tile = v[batch_id, head_id] for v_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size): - load_p = nl.arange(B_P_SIZE)[:, None] - load_f = nl.arange(B_D_SIZE)[None, :] - - cur_v_tile[v_i, load_p, load_f] = nl.load(v[batch_id, head_id, j*LARGE_TILE_SZ+B_P_SIZE*v_i+load_p, load_f], dtype=kernel_dtype) - - for i in nl.affine_range(n_tile_q): - i_f_128 = nl.arange(B_P_SIZE)[None, :] - i_f_d = nl.arange(B_D_SIZE)[None, :] - i_p_d = nl.arange(B_D_SIZE)[:,None] - q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE),dtype=kernel_dtype) - q_tile[i_p_d, i_f_128] = nl.load(q[batch_id, - head_id * q_h_per_k_h + i_q_h, i_p_d, - i * B_P_SIZE + i_f_128], - dtype=kernel_dtype) * softmax_scale # load (d, 128) tile in SBUF - _flash_attention_core(q_local_tile=q_tile, k=cur_k_tile, v=cur_v_tile, - q_h_per_k_h=q_h_per_k_h, seqlen_q=seqlen_q, nheads=h, - o_buffer=o_buffer[i], l_buffer=l_buffer[i], m_buffer=m_buffer[i], - batch_id=batch_id, head_id=head_id, - gqa_head_idx=i_q_h, q_tile_idx=i, local_k_large_tile_idx=j, - kernel_dtype=kernel_dtype, acc_type=acc_type, - flash_config=config, use_causal_mask=use_causal_mask, - initialize=False, - B_P_SIZE=B_P_SIZE, B_F_SIZE=B_F_SIZE, B_D_SIZE=B_D_SIZE, - dropout_p=dropout_p, dropout_p_tensor=dropout_p_tensor, seed_tensor=seed_local) - - # -------- write output to buffer on HBM ------------ # - for i in nl.affine_range(n_tile_q): - out = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE), dtype=kernel_dtype) - out[i_q_p, i_f_d] = nl.multiply(o_buffer[i, num_large_k_tile - 1, i_q_p, i_f_d], - nl.exp(m_buffer[i, num_large_k_tile - 1, i_q_p, i_0_f] - l_buffer[i, num_large_k_tile - 1, i_q_p, i_0_f]), - dtype=kernel_dtype) - - nl.store(o[batch_id, head_id * q_h_per_k_h + i_q_h, i*B_P_SIZE + i_q_p, i_f_d], out[i_q_p, i_f_d]) - if not inference: - lse_local = nl.zeros((par_dim(B_P_SIZE), 1), dtype=acc_type) - lse_local[i_q_p, i_0_f] = nl.copy(l_buffer[i, num_large_k_tile - 1, i_q_p, i_0_f], dtype=acc_type) - nl.store(lse[batch_id, head_id * q_h_per_k_h + i_q_h, i_q_p, i + i_0_f], lse_local[i_q_p, i_0_f]) + load_v_tile(v_hbm_tile=v_hbm_tile, cur_v_tile=cur_v_tile, j=j, v_i=v_i, + config=config) + + for i1 in nl.affine_range(attn_core_tile_size): + i = i0 * attn_core_tile_size + i1 + # mask are used to only apply computation to the lower half of the matrix, + # which reduce the arthimetic intensity by half. + # forward_mask imply initialize, i.e. if forward_mask is false, initialize will + # be false as well + if use_causal_mask: + forward_mask = i * B_P_SIZE >= j * LARGE_TILE_SZ + else: + forward_mask = True + + if (i < n_tile_q) & forward_mask: + q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE),dtype=kernel_dtype) + q_hbm_tile = q[batch_id, head_id * q_h_per_k_h + i_q_h] + q_sbuf_tile = nl.load(q_hbm_tile[:, nl.ds(i * B_P_SIZE, B_P_SIZE)], + dtype=kernel_dtype) # load (d, 128) tile in SBUF + q_tile[:, :] = q_sbuf_tile * softmax_scale + + logit_bias_tile = None + if logit_bias is not None: + logit_bias_tile = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype) + logit_bias_tile[:, :] = nl.load( + logit_bias[0, 0, nl.ds(i * B_P_SIZE, B_P_SIZE), + nl.ds(j * LARGE_TILE_SZ, LARGE_TILE_SZ)]) + + _flash_attention_core(q_local_tile=q_tile, k=cur_k_tile, v=cur_v_tile, + q_h_per_k_h=q_h_per_k_h, seqlen_q=seqlen_q, nheads=h, + o_buffer=o_buffer[i1], l_buffer=l_buffer[:, i], m_buffer=m_buffer[i1], + batch_id=batch_id, head_id=head_id, + gqa_head_idx=i_q_h, q_tile_idx=i, local_k_large_tile_idx=j, + kernel_dtype=kernel_dtype, acc_type=acc_type, + flash_config=config, use_causal_mask=use_causal_mask, + initialize=j == 0, + B_P_SIZE=B_P_SIZE, B_F_SIZE=B_F_SIZE, B_D_SIZE=B_D_SIZE, + dropout_p=dropout_p, dropout_p_tensor=dropout_p_tensor, + seed_tensor=seed_local, logit_bias_tile=logit_bias_tile) + + # -------- write output to buffer on HBM ------------ # + for i1 in nl.affine_range(attn_core_tile_size): + i = i0 * attn_core_tile_size + i1 + + if i < n_tile_q: + exp = nisa.activation(np.exp, l_buffer[:, i], bias=m_buffer[i1, :, :], + scale=-1.0) + out = nl.multiply(o_buffer[i1, :, :], exp, + dtype=kernel_dtype) + + nl.store(o[batch_id, head_id * q_h_per_k_h + i_q_h, + nl.ds(i*B_P_SIZE, B_P_SIZE), :], out) + + if not inference: + nl.store(lse[batch_id, head_id * q_h_per_k_h + i_q_h, :, :], l_buffer[:, :]) if config.training: return o, lse return o + + @nki.jit def flash_attn_bwd( q_ref, k_ref, v_ref, o_ref, dy_ref, lse_ref, seed_ref, + logit_bias_ref=None, use_causal_mask=False, mixed_precision=False, dropout_p=0.0, @@ -457,6 +524,7 @@ def flash_attn_bwd( - dy_ref: shape (bs, nheads, head_size, seq) - lse_ref: shape (bs, nheads, nl.tile_size.pmax, seq // nl.tile_size.pmax) - seed_ref: shape (1,) + - logit_bias_ref: shape (bs, n_heads, seq_q, seq_k) - out_dq_ref: shape (bs, nheads, head_size, seq) - out_dk_ref: shape (bs, nheads, head_size, seq) - out_dv_ref: shape (bs, nheads, head_size, seq) @@ -464,11 +532,11 @@ def flash_attn_bwd( Detailed steps: 1. D = rowsum(dO ◦ O) (pointwise multiply) - 2. Recompute (softmax(Q^T@K)) + 2. Recompute (softmax(Q^T@K + logic_bias)) 2.1 Q^T@K 2.2 Scale the QK score - 2.3 Apply causal mask + 2.3 Apply causal mask and add logit_bias 2.4 softmax 3. Compute the gradients of y = score @ V with respect to the loss @@ -487,7 +555,6 @@ def flash_attn_bwd( mixed_dtype = np.dtype(np.float32) if mixed_precision else kernel_dtype assert q_ref.dtype == k_ref.dtype == v_ref.dtype == o_ref.dtype == dy_ref.dtype - assert lse_ref.dtype == mixed_dtype # Shape checking bs, nheads, d_head, seqlen_q = q_ref.shape @@ -520,16 +587,18 @@ def flash_attn_bwd( # Softmax scaling factor, multiplied onto Q softmax_scale = softmax_scale or 1.0 / float(d_head ** 0.5) + assert nl.program_ndim() == 2,\ + f'Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!' # Different batch samples/attention heads have independent attention batch_id = nl.program_id(axis=0) + head_id = nl.program_id(axis=1) - head_dims = list(range(1, nl.program_ndim())) - head_dims_shape = list(nl.num_programs(i) for i in head_dims) - head_dims_idx = list(nl.program_id(i) for i in head_dims) - head_id = linearize(head_dims_shape, head_dims_idx) + assert nl.num_programs(1) == nheads, \ + f"The grid shape mismatch, got {nl.num_programs(1)} but should be {nheads}" - assert n_elts(head_dims_shape) == nheads, \ - f"The grid shape mismatch, got {n_elts(head_dims_shape)} but should be {nheads}" + if logit_bias_ref is not None: + b_logit_bias, h_logit_bias, _, _ = logit_bias_ref.shape + assert b_logit_bias == 1 and h_logit_bias == 1, "Only support broadcasting logit_bias with batch 1, n_heads 1" q_seq_n_tiles, q_seq_tile_size = div_ceil(seqlen_q, 128), 128 d_head_n_tiles, d_head_tile_size = div_ceil(d_head, 128), min(d_head, 128) @@ -545,45 +614,19 @@ def flash_attn_bwd( ############################################################## # Step 2.4 Prefetch exp bias for softmax ############################################################## - softmax_exp_bias = nl.zeros((q_seq_n_tiles, par_dim(q_seq_tile_size), 1), dtype=mixed_dtype) - for i_q_seq_tile in nl.affine_range(q_seq_n_tiles): - ip_qk = nl.arange(q_seq_tile_size)[:, None] - lse_local = nl.load( - lse_ref[batch_id, head_id, ip_qk, i_q_seq_tile], - dtype=mixed_dtype) - softmax_exp_bias[i_q_seq_tile, ip_qk, 0] = lse_local * -1.0 + softmax_exp_bias = nl.zeros((par_dim(q_seq_tile_size), q_seq_n_tiles), dtype=mixed_dtype) + lse_local = nl.load(lse_ref[batch_id, head_id, :, :], dtype=mixed_dtype) + softmax_exp_bias[:, :] = lse_local * -1.0 ############################################################## # Step 1 Compute rowsum(dO ◦ O) ############################################################## dy_o_sum = nl.ndarray((q_seq_n_tiles, par_dim(q_seq_tile_size), 1), dtype=mixed_dtype) - for i_q_seq_tile in nl.affine_range(q_seq_n_tiles): - ip_reduce = nl.arange(q_seq_tile_size)[:, None] - dy_o_partial = nl.zeros((par_dim(q_seq_tile_size), d_head_n_tiles), dtype=mixed_dtype) - for i_d_head_tile in nl.affine_range(d_head_n_tiles): - ip_load = nl.arange(d_head_tile_size)[:, None] - if_q = nl.arange(q_seq_tile_size)[None, :] - dy_local = nl.load_transpose2d( - dy_ref[batch_id, head_id, i_d_head_tile * d_head_tile_size + ip_load, i_q_seq_tile * q_seq_tile_size + if_q], - dtype=mixed_dtype) - o_local = nl.load_transpose2d( - o_ref[batch_id, head_id, i_d_head_tile * d_head_tile_size + ip_load, i_q_seq_tile * q_seq_tile_size + if_q], - dtype=mixed_dtype - ) - - dy_o_partial[ip_reduce, i_d_head_tile] = nisa.tensor_reduce( - np.add, data=dy_local*o_local, axis=(1,), dtype=mixed_dtype - ) - - dy_o_sum[i_q_seq_tile, ip_reduce, 0] = nisa.tensor_reduce( - np.add, data=dy_o_partial[ip_reduce, nl.arange(d_head_n_tiles)[None, :]], - axis=(1,), dtype=mixed_dtype - ) - - # Indices for prefetch - ip_qk = nl.arange(d_head_tile_size)[:, None] - if_q = nl.arange(q_seq_tile_size)[None, :] - if_k = nl.arange(k_seq_tile_size)[None, :] + compute_rowsum(dy_o_sum=dy_o_sum, + dy_ref_hbm_tile=dy_ref[batch_id, head_id], + o_ref_hbm_tile=o_ref[batch_id, head_id], + d_head_n_tiles=d_head_n_tiles, d_head_tile_size=d_head_tile_size, + q_seq_n_tiles=q_seq_n_tiles, q_seq_tile_size=q_seq_tile_size) if dropout_p > 0.0: seed_local = nl.load(seed_ref[0]) @@ -603,28 +646,25 @@ def flash_attn_bwd( _range = nl.sequential_range if dropout_p > 0.0 else nl.affine_range for i_k_seq_tile in nl.affine_range(k_seq_n_tiles): - # Prefetch V, K - v_local = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size), dtype=kernel_dtype) - k_local = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size), dtype=kernel_dtype) - transposed_k_local = nl.zeros((k_seq_fwd_bwd_tile_multipler, d_head_n_tiles, par_dim(k_seq_tile_size_backward), d_head_tile_size), dtype=kernel_dtype) - for i_d_head_tile in nl.affine_range(d_head_n_tiles): - k_local[i_d_head_tile, ip_qk, if_k] = nl.load( - k_ref[batch_id, head_id, i_d_head_tile * d_head_tile_size + ip_qk, i_k_seq_tile * k_seq_tile_size + if_k], - dtype=kernel_dtype) - v_local[i_d_head_tile, ip_qk, if_k] = nl.load( - v_ref[batch_id, head_id, i_d_head_tile * d_head_tile_size + ip_qk, i_k_seq_tile * k_seq_tile_size + if_k], - dtype=kernel_dtype) - ############################################################## - # Prefetch k transpose for the backward too - ############################################################## - if_k_backward = nl.arange(k_seq_tile_size_backward)[None, :] - ip_k_backward = nl.arange(k_seq_tile_size_backward)[:, None] - if_d_head = nl.arange(d_head_tile_size)[None, :] - for i_k_seq_tile_backward in nl.affine_range(k_seq_fwd_bwd_tile_multipler): - transposed_k_local[i_k_seq_tile_backward, i_d_head_tile, ip_k_backward, if_d_head] = \ - nisa.nc_transpose(k_local[i_d_head_tile, ip_qk, - i_k_seq_tile_backward * k_seq_tile_size_backward + if_k_backward]) + i_k_seq_dslice = nl.ds(i_k_seq_tile * k_seq_tile_size, k_seq_tile_size) + # Prefetch V, K + v_local = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size), + dtype=kernel_dtype) + k_local = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size), + dtype=kernel_dtype) + transposed_k_local = nl.zeros((k_seq_fwd_bwd_tile_multipler, d_head_n_tiles, + par_dim(k_seq_tile_size_backward), d_head_tile_size), + dtype=kernel_dtype) + + load_kv(k_ref_hbm_tile=k_ref[batch_id, head_id], + v_ref_hbm_tile=v_ref[batch_id, head_id], + k_local=k_local, transposed_k_local=transposed_k_local, v_local=v_local, + d_head_n_tiles=d_head_n_tiles, d_head_tile_size=d_head_tile_size, + i_k_seq_tile=i_k_seq_tile, k_seq_tile_size=k_seq_tile_size, + k_seq_tile_size_backward=k_seq_tile_size_backward) + + # FIXME: Pass sbuf instead, we will have psum spilling in the current implementation dv_psum = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size), dtype=np.float32, buffer=nl.psum) dk_psum = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size), @@ -633,17 +673,20 @@ def flash_attn_bwd( # Prefetch dy, Q dy_local = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), q_seq_tile_size), dtype=kernel_dtype) q_local = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), q_seq_tile_size), dtype=kernel_dtype) - for i_d_head_tile in nl.affine_range(d_head_n_tiles): - ip_qk = nl.arange(d_head_tile_size)[:, None] - if_q = nl.arange(q_seq_tile_size)[None, :] - dy_local[i_d_head_tile, ip_qk, if_q] = nl.load( - dy_ref[batch_id, head_id, i_d_head_tile * d_head_tile_size + ip_qk, i_q_seq_tile * q_seq_tile_size + if_q], - dtype=kernel_dtype) + load_dy_q(dy_ref_hbm_tile = dy_ref[batch_id, head_id], + q_ref_hbm_tile = q_ref[batch_id, head_id], + dy_local=dy_local, q_local=q_local, d_head_n_tiles=d_head_n_tiles, + d_head_tile_size=d_head_tile_size, i_q_seq_tile=i_q_seq_tile, + q_seq_tile_size=q_seq_tile_size, softmax_scale=softmax_scale) - q_local[i_d_head_tile, ip_qk, if_q] = nl.load( - q_ref[batch_id, head_id, i_d_head_tile * d_head_tile_size + ip_qk, i_q_seq_tile * q_seq_tile_size + if_q], - dtype=kernel_dtype) * softmax_scale + logit_bias_tile = None + if logit_bias_ref is not None: + i_q_seq_dslice = nl.ds(i_q_seq_tile * q_seq_tile_size, q_seq_tile_size) + logit_bias_tile = nl.ndarray((par_dim(q_seq_tile_size), k_seq_tile_size), + buffer=nl.sbuf, dtype=kernel_dtype) + logit_bias_tile[:, :] = nl.load( + logit_bias_ref[0, 0, i_q_seq_dslice, i_k_seq_dslice]) _flash_attn_bwd_core( q_local=q_local, k_local=k_local, transposed_k_local=transposed_k_local, @@ -656,43 +699,102 @@ def flash_attn_bwd( kernel_dtype=kernel_dtype, mixed_dtype=mixed_dtype, softmax_scale=softmax_scale, seed_local=seed_local, dropout_p=dropout_p, dropout_p_local=dropout_p_local, + logit_bias_tile=logit_bias_tile ) # Write dK, dV - for i_d_head_tile in nl.affine_range(d_head_n_tiles): - ip_dkv = nl.arange(d_head_tile_size)[:, None] - if_dkv = nl.arange(k_seq_tile_size)[None, :] - - nl.store( - out_dv_ref[batch_id, head_id, - i_d_head_tile * d_head_tile_size + ip_dkv, - i_k_seq_tile * k_seq_tile_size + if_dkv], - value=dv_psum[i_d_head_tile, ip_dkv, if_dkv], - ) - - nl.store( - out_dk_ref[batch_id, head_id, - i_d_head_tile * d_head_tile_size + ip_dkv, - i_k_seq_tile * k_seq_tile_size + if_dkv], - value=dk_psum[i_d_head_tile, ip_dkv, if_dkv], - ) + store_dk_dv(out_dk_ref_hbm_tile=out_dk_ref[batch_id, head_id], + out_dv_ref_hbm_tile=out_dv_ref[batch_id, head_id], + local_dk=dk_psum, local_dv=dv_psum, i_k_seq_dslice=i_k_seq_dslice, + d_head_n_tiles=d_head_n_tiles, d_head_tile_size=d_head_tile_size) # Write dQ for i_q_seq_tile in nl.affine_range(q_seq_n_tiles): for i_d_head_tile in nl.affine_range(d_head_n_tiles): - ip_dq = nl.arange(d_head_tile_size)[:, None] - if_dq = nl.arange(q_seq_tile_size)[None, :] - + i_q_seq_dslice = nl.ds(i_q_seq_tile * q_seq_tile_size, q_seq_tile_size) + i_d_head_dslice = nl.ds(i_d_head_tile * d_head_tile_size, d_head_tile_size) nl.store( - out_dq_ref[batch_id, head_id, - i_d_head_tile * d_head_tile_size + ip_dq, - i_q_seq_tile * q_seq_tile_size + if_dq], - value=dq_local_reduced[i_q_seq_tile, i_d_head_tile, ip_dq, if_dq], + out_dq_ref[batch_id, head_id, i_d_head_dslice, i_q_seq_dslice], + value=dq_local_reduced[i_q_seq_tile, i_d_head_tile, :, :], ) return out_dq_ref, out_dk_ref, out_dv_ref +@nki.jit(mode='trace') +def load_dy_q(dy_ref_hbm_tile, q_ref_hbm_tile, dy_local, q_local, d_head_n_tiles, d_head_tile_size, i_q_seq_tile, + q_seq_tile_size, softmax_scale): + for i_d_head_tile in nl.affine_range(d_head_n_tiles): + i_d_head_dslice = nl.ds(i_d_head_tile * d_head_tile_size, d_head_tile_size) + i_q_seq_dslice = nl.ds(i_q_seq_tile * q_seq_tile_size, q_seq_tile_size) + + dy_local[i_d_head_tile, :, :] = nl.load( + dy_ref_hbm_tile[i_d_head_dslice, i_q_seq_dslice], + dtype=dy_local.dtype) + + q_local[i_d_head_tile, :, :] = nl.load( + q_ref_hbm_tile[i_d_head_dslice, i_q_seq_dslice], + dtype=q_local.dtype) * softmax_scale + + +@nki.jit(mode='trace') +def store_dk_dv(out_dk_ref_hbm_tile, out_dv_ref_hbm_tile, local_dk, local_dv, + d_head_n_tiles, d_head_tile_size, i_k_seq_dslice): + for i in nl.affine_range(d_head_n_tiles): + i_d_head_dslice = nl.ds(i * d_head_tile_size, d_head_tile_size) + + nl.store(out_dv_ref_hbm_tile[i_d_head_dslice, i_k_seq_dslice], + value=local_dv[i, :, :]) + + nl.store(out_dk_ref_hbm_tile[i_d_head_dslice, i_k_seq_dslice], + value=local_dk[i, :, :]) + + +@nki.jit(mode='trace') +def load_kv(k_ref_hbm_tile, v_ref_hbm_tile, k_local, transposed_k_local, v_local, + d_head_n_tiles, d_head_tile_size, i_k_seq_tile, k_seq_tile_size, + k_seq_tile_size_backward): + k_seq_fwd_bwd_tile_multipler = k_seq_tile_size // k_seq_tile_size_backward + + for i in nl.affine_range(d_head_n_tiles): + i_d_head_dslice = nl.ds(i * d_head_tile_size, d_head_tile_size) + i_k_seq_dslice = nl.ds(i_k_seq_tile * k_seq_tile_size, k_seq_tile_size) + k_local[i, :, :] = nl.load(k_ref_hbm_tile[i_d_head_dslice, i_k_seq_dslice], + dtype=k_local.dtype) + v_local[i, :, :] = nl.load(v_ref_hbm_tile[i_d_head_dslice, i_k_seq_dslice], + dtype=v_local.dtype) + ############################################################## + # Prefetch k transpose for the backward too + ############################################################## + for j in nl.affine_range(k_seq_fwd_bwd_tile_multipler): + i_k_dslice = nl.ds(j * k_seq_tile_size_backward, k_seq_tile_size_backward) + transposed_k_local[j, i, :, :] = nisa.nc_transpose(k_local[i, :, i_k_dslice]) + + +@nki.jit(mode='trace') +def compute_rowsum(dy_o_sum, dy_ref_hbm_tile, o_ref_hbm_tile, d_head_n_tiles, d_head_tile_size, q_seq_n_tiles, + q_seq_tile_size): + mixed_dtype = dy_o_sum.dtype + for i in nl.affine_range(q_seq_n_tiles): + dy_o_partial = nl.zeros((par_dim(q_seq_tile_size), d_head_n_tiles), dtype=mixed_dtype) + for j in nl.affine_range(d_head_n_tiles): + d_head_dslice = nl.ds(j * d_head_tile_size, d_head_tile_size) + q_seq_dslice = nl.ds(i * q_seq_tile_size, q_seq_tile_size) + + dy_local = nl.load_transpose2d(dy_ref_hbm_tile[d_head_dslice, q_seq_dslice], + dtype=mixed_dtype) + o_local = nl.load_transpose2d(o_ref_hbm_tile[d_head_dslice, q_seq_dslice], + dtype=mixed_dtype) + + dy_o = nl.multiply(dy_local, o_local, dtype=mixed_dtype) + dy_o_partial[:, j] = nisa.tensor_reduce(np.add, data=dy_o, axis=(1,), + dtype=mixed_dtype) + + dy_o_sum[i, :, 0] = nisa.tensor_reduce( + np.add, data=dy_o_partial[:, :], axis=(1,), dtype=mixed_dtype) + + +@nki.jit(mode='trace') def _flash_attn_bwd_core( q_local, k_local, transposed_k_local, v_local, dy_local, dk_psum, dv_psum, dq_local_reduced, @@ -703,11 +805,7 @@ def _flash_attn_bwd_core( kernel_dtype, mixed_dtype, softmax_scale, seed_local, dropout_p, dropout_p_local, - global_i_q_seq_tile = None, - global_i_k_seq_tile = None, - # Used for nl.loop_reduce on dQ if local_i_k_seq_tile is not an index e.g. if it has an offset - local_i_k_seq_tile_for_dq_reduce = None, -): + logit_bias_tile=None): """ The flash backward core function to calculate the gradients of Q, K and V of the given tiles. The result will be accumulated into the dk, dv, dq psum @@ -721,14 +819,7 @@ def _flash_attn_bwd_core( k_seq_n_tiles_backward, k_seq_tile_size_backward = seqlen_k // 128, 128 k_seq_fwd_bwd_tile_multipler = k_seq_tile_size // k_seq_tile_size_backward - if global_i_q_seq_tile is None: - global_i_q_seq_tile = local_i_q_seq_tile - global_i_k_seq_tile = local_i_k_seq_tile - - if local_i_k_seq_tile_for_dq_reduce is None: - local_i_k_seq_tile_for_dq_reduce = local_i_k_seq_tile - - mask = global_i_q_seq_tile * q_seq_tile_size >= global_i_k_seq_tile * k_seq_tile_size if use_causal_mask else None + mask = local_i_q_seq_tile * q_seq_tile_size >= local_i_k_seq_tile * k_seq_tile_size if use_causal_mask else None # PSUM buffer shape: [q_seq_tile_size P, k_seq_tile_size F] qk_psum = nl.zeros((par_dim(q_seq_tile_size), k_seq_tile_size), dtype=np.float32, buffer=nl.psum) @@ -736,66 +827,75 @@ def _flash_attn_bwd_core( batch_id = nl.program_id(axis=0) head_id = nl.program_id(axis=1) - # Tensor indices for accessing qk result in k_seq_tile_size - if_q = nl.arange(q_seq_tile_size)[None, :] - ip_qk = nl.arange(d_head_tile_size)[:, None] - - ip_q = nl.arange(q_seq_tile_size)[:, None] - if_k = nl.arange(k_seq_tile_size)[None, :] # Loop over contraction dim of QK matmul for i_d_head_tile in nl.affine_range(d_head_n_tiles): ############################################################## # Step 2.1 Compute Q^T@K, with matmul(stationary=tensor_q, moving=tensor_k, contract=d_head) ############################################################## - qk_psum[ip_q, if_k] += nisa.nc_matmul(q_local[i_d_head_tile, ip_qk, if_q], - k_local[i_d_head_tile, ip_qk, if_k], + qk_psum[:, :] += nisa.nc_matmul(q_local[i_d_head_tile, :, :], + k_local[i_d_head_tile, :, :], mask=mask) ###################################### # Step 2.2. Apply optional causal mask ###################################### if use_causal_mask: - # Magic number -9984.0 to replace -inf similar to what Tensorizer uses - qk_res_buf[ip_q, if_k] = nisa.affine_select( - pred=(global_i_q_seq_tile * q_seq_tile_size + ip_q >= global_i_k_seq_tile * k_seq_tile_size + if_k), - on_true_tile=qk_psum[ip_q, if_k], on_false_value=-9984.0, dtype=mixed_dtype, - mask=mask) + iq, ik = nl.mgrid[0:q_seq_tile_size, 0:k_seq_tile_size] + causal_pred = (local_i_q_seq_tile * q_seq_tile_size + iq >= local_i_k_seq_tile * k_seq_tile_size + ik) + if logit_bias_tile is not None: + # Magic number -9984.0 to replace -inf similar to what Tensorizer uses + intermediate = \ + nl.add(qk_psum[:, :], logit_bias_tile[:, :], dtype=mixed_dtype, mask=mask) + qk_res_buf[:, :] = nisa.affine_select( + pred=causal_pred, + on_true_tile=intermediate, on_false_value=-9984.0, dtype=mixed_dtype, + mask=mask + ) + + else: + # Magic number -9984.0 to replace -inf similar to what Tensorizer uses + qk_res_buf[:, :] = nisa.affine_select( + pred=causal_pred, + on_true_tile=qk_psum[:, :], on_false_value=-9984.0, dtype=mixed_dtype, + mask=mask) else: - # Simply send psum result back to sbuf - qk_res_buf[ip_q, if_k] = \ - nl.copy(qk_psum[ip_q, if_k], dtype=mixed_dtype) + if logit_bias_tile is not None: + # Simply add logit bias which copies back to sbuf at the same time + qk_res_buf[:, :] = \ + nl.add(qk_psum[:, :], logit_bias_tile[:, :], dtype=mixed_dtype) + else: + # Simply send psum result back to sbuf + qk_res_buf[:, :] = \ + nl.copy(qk_psum[:, :], dtype=mixed_dtype) softmax_y = nl.ndarray((par_dim(q_seq_tile_size), k_seq_tile_size), dtype=kernel_dtype, buffer=nl.sbuf) - softmax_y[ip_q, if_k] = nisa.activation(np.exp, - data=qk_res_buf[ip_q, if_k], - bias=softmax_exp_bias[local_i_q_seq_tile, ip_q, 0], - scale=1.0, - mask=mask) + softmax_y[:, :] = nisa.activation(np.exp, + data=qk_res_buf[:, :], + bias=softmax_exp_bias[:, local_i_q_seq_tile], + scale=1.0, + mask=mask) ##################################################################### # Dropout ##################################################################### if dropout_p > 0.0: - offset = global_i_k_seq_tile + global_i_q_seq_tile * k_seq_n_tiles \ + offset = local_i_k_seq_tile + local_i_q_seq_tile * k_seq_n_tiles \ + head_id * k_seq_n_tiles * q_seq_n_tiles \ + batch_id * nheads * k_seq_n_tiles * q_seq_n_tiles offset_seed = nl.add(seed_local[0, 0], offset, mask=mask) nl.random_seed(seed=offset_seed, mask=mask) - softmax_y[ip_q, if_k] = nl.dropout(softmax_y[ip_q, if_k], rate=dropout_p_local[ip_q, 0], mask=mask) - softmax_y[ip_q, if_k] = nl.multiply(softmax_y[ip_q, if_k], 1 / (1 - dropout_p), mask=mask) + softmax_y[:, :] = nl.dropout(softmax_y[:, :], rate=dropout_p_local[:, 0], mask=mask) + softmax_y[:, :] = nl.multiply(softmax_y[:, :], 1 / (1 - dropout_p), mask=mask) ##################################################################### # Step 3.1 Calculate the backward gradients dL/dV, where y=softmax@V # in value projection with matmul(stationary=dy, moving=softmax) ##################################################################### for i_d_head_tile in nl.affine_range(d_head_n_tiles): - ip_dv = nl.arange(d_head_tile_size)[:, None] - if_dv = nl.arange(k_seq_tile_size)[None, :] - if_trans_dy = nl.arange(q_seq_tile_size)[None, :] - trans_dy = nisa.nc_transpose(dy_local[i_d_head_tile, ip_dv, if_trans_dy], + trans_dy = nisa.nc_transpose(dy_local[i_d_head_tile, :, :], mask=mask) - dv_psum[i_d_head_tile, ip_dv, if_dv] += \ - nisa.nc_matmul(trans_dy, softmax_y[ip_q, if_k], mask=mask) + dv_psum[i_d_head_tile, :, :] += \ + nisa.nc_matmul(trans_dy, softmax_y[:, :], mask=mask) ##################################################################### # Step 3.2 Calculate the backward gradients dL/dsoftmax, where y=softmax@V @@ -804,15 +904,13 @@ def _flash_attn_bwd_core( softmax_dy_psum = nl.zeros((par_dim(q_seq_tile_size), k_seq_tile_size), dtype=np.float32, buffer=nl.psum) for i_d_head_tile in nl.affine_range(d_head_n_tiles): - ip_softmax_dy = nl.arange(d_head_tile_size)[:, None] - if_dy = nl.arange(q_seq_tile_size)[None, :] - softmax_dy_psum[ip_q, if_k] += \ - nisa.nc_matmul(dy_local[i_d_head_tile, ip_softmax_dy, if_dy], - v_local[i_d_head_tile, ip_softmax_dy, if_k], + softmax_dy_psum[:, :] += \ + nisa.nc_matmul(dy_local[i_d_head_tile, :, :], + v_local[i_d_head_tile, :, :], mask=mask) softmax_dy = nl.ndarray((par_dim(q_seq_tile_size), k_seq_tile_size), dtype=kernel_dtype, buffer=nl.sbuf) - softmax_dy[ip_q, if_k] = nl.copy(softmax_dy_psum[ip_q, if_k], dtype=kernel_dtype, + softmax_dy[:, :] = nl.copy(softmax_dy_psum[:, :], dtype=kernel_dtype, mask=mask) ##################################################################### @@ -820,61 +918,55 @@ def _flash_attn_bwd_core( # dL/dx = y * (dL/dy - rowsum(dO_O)), where y = softmax(x) ##################################################################### softmax_dx_local = nl.ndarray((par_dim(q_seq_tile_size), k_seq_tile_size), dtype=kernel_dtype, buffer=nl.sbuf) - softmax_dx_local[ip_q, if_k] = \ - nisa.scalar_tensor_tensor(data=softmax_dy[ip_q, if_k], + softmax_dx_local[:, :] = \ + nisa.scalar_tensor_tensor(data=softmax_dy[:, :], op0=np.subtract, - operand0=dy_o_sum[local_i_q_seq_tile, ip_q, 0], + operand0=dy_o_sum[local_i_q_seq_tile, :, 0], op1=np.multiply, - operand1=softmax_y[ip_q, if_k], + operand1=softmax_y[:, :], mask=mask) ##################################################################### # Step 5.1 Calculate dK, with matmul(stationary=Q, moving=softmax_dx) ##################################################################### for i_d_head_tile in nl.affine_range(d_head_n_tiles): - ip_trans_q = nl.arange(d_head_tile_size)[:, None] - if_trans_q = nl.arange(q_seq_tile_size)[None, :] - ip_dk = nl.arange(d_head_tile_size)[:, None] - trans_q_local = nisa.nc_transpose(q_local[i_d_head_tile, ip_trans_q, if_trans_q], + trans_q_local = nisa.nc_transpose(q_local[i_d_head_tile, :, :], mask=mask) - dk_psum[i_d_head_tile, ip_dk, if_k] += \ + dk_psum[i_d_head_tile, :, :] += \ nisa.nc_matmul(trans_q_local, - softmax_dx_local[ip_q, if_k], + softmax_dx_local[:, :], mask=mask) ##################################################################### # Step 5.2 Calculate dQ ##################################################################### - if_k = nl.arange(k_seq_tile_size_backward)[None, :] - ip_dq = nl.arange(d_head_tile_size)[:, None] - if_dq = nl.arange(q_seq_tile_size)[None, :] - if_d = nl.arange(d_head_tile_size)[None, :] - ip_transposed_k = nl.arange(k_seq_tile_size_backward)[:, None] for i_d_head_tile in nl.affine_range(d_head_n_tiles): dq_psum = nl.zeros((par_dim(d_head_tile_size), q_seq_tile_size), dtype=np.float32, buffer=nl.psum) for i_k_seq_tile_backward in nl.affine_range(k_seq_fwd_bwd_tile_multipler): + i_k_seq_dslice = nl.ds(i_k_seq_tile_backward * k_seq_tile_size_backward, + k_seq_tile_size_backward) transposed_softmax_dx_local = \ - nisa.nc_transpose(softmax_dx_local[ip_q, i_k_seq_tile_backward * k_seq_tile_size_backward + if_k], + nisa.nc_transpose(softmax_dx_local[:, i_k_seq_dslice], mask=mask) - dq_psum[ip_dq, if_dq] += nisa.nc_matmul( - transposed_k_local[i_k_seq_tile_backward, i_d_head_tile, ip_transposed_k, if_d], + dq_psum[:, :] += nisa.nc_matmul( + transposed_k_local[i_k_seq_tile_backward, i_d_head_tile, :, :], transposed_softmax_dx_local, mask=mask) - dq_local = nl.multiply(dq_psum[ip_dq, if_dq], softmax_scale, dtype=kernel_dtype, mask=mask) - dq_local_reduced[local_i_q_seq_tile, i_d_head_tile, ip_dq, if_dq] = nl.loop_reduce( - dq_local, op=np.add, loop_indices=(local_i_k_seq_tile_for_dq_reduce,), + dq_local = nl.multiply(dq_psum[:, :], softmax_scale, dtype=kernel_dtype, mask=mask) + dq_local_reduced[local_i_q_seq_tile, i_d_head_tile, :, :] = nl.loop_reduce( + dq_local, op=np.add, loop_indices=(local_i_k_seq_tile,), dtype=mixed_dtype, mask=mask) @nki.jit def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, use_causal_mask=False, - mixed_percision=True): + mixed_precision=True): """ Fused self attention kernel for small head size Stable Diffusion workload. Computes softmax(QK^T)V. Decoder model can optionally include a causal mask - application. Does not include QKV rojection, output projection, dropout, + application. Does not include QKV projection, output projection, dropout, residual connection, etc. This kernel is designed to be used for Stable Diffusion models where the @@ -890,14 +982,14 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, use_causal_mask= IO tensor dtypes: - This kernel assumes all IO tensors have the same dtype - - If mixed_percision is True, then all Tensor Engine operation will be performed in + - If mixed_precision is True, then all Tensor Engine operation will be performed in bfloat16 and accumulation will be performed in float32. Otherwise the intermediates will be in the same type as the inputs. """ # Use q_ref dtype as the intermediate tensor dtype # Assume all IO tensors have the same dtype kernel_dtype = q_ref.dtype - pe_in_dt = nl.bfloat16 if mixed_percision else np.float32 + pe_in_dt = nl.bfloat16 if mixed_precision else np.float32 assert q_ref.dtype == k_ref.dtype == v_ref.dtype # Shape checking diff --git a/src/reference/tutorial.py b/src/nki_samples/reference/tutorial.py similarity index 100% rename from src/reference/tutorial.py rename to src/nki_samples/reference/tutorial.py diff --git a/src/reference/vision.py b/src/nki_samples/reference/vision.py similarity index 100% rename from src/reference/vision.py rename to src/nki_samples/reference/vision.py diff --git a/src/tutorials/average_pool2d/average_pool2d_jax.py b/src/nki_samples/tutorials/average_pool2d/average_pool2d_jax.py similarity index 100% rename from src/tutorials/average_pool2d/average_pool2d_jax.py rename to src/nki_samples/tutorials/average_pool2d/average_pool2d_jax.py diff --git a/src/tutorials/average_pool2d/average_pool2d_nki_kernels.py b/src/nki_samples/tutorials/average_pool2d/average_pool2d_nki_kernels.py similarity index 100% rename from src/tutorials/average_pool2d/average_pool2d_nki_kernels.py rename to src/nki_samples/tutorials/average_pool2d/average_pool2d_nki_kernels.py diff --git a/src/tutorials/average_pool2d/average_pool2d_torch.py b/src/nki_samples/tutorials/average_pool2d/average_pool2d_torch.py similarity index 100% rename from src/tutorials/average_pool2d/average_pool2d_torch.py rename to src/nki_samples/tutorials/average_pool2d/average_pool2d_torch.py diff --git a/src/tutorials/fused_mamba/mamba_nki_kernels.py b/src/nki_samples/tutorials/fused_mamba/mamba_nki_kernels.py similarity index 100% rename from src/tutorials/fused_mamba/mamba_nki_kernels.py rename to src/nki_samples/tutorials/fused_mamba/mamba_nki_kernels.py diff --git a/src/tutorials/fused_mamba/mamba_torch.py b/src/nki_samples/tutorials/fused_mamba/mamba_torch.py similarity index 100% rename from src/tutorials/fused_mamba/mamba_torch.py rename to src/nki_samples/tutorials/fused_mamba/mamba_torch.py diff --git a/src/tutorials/layernorm/layernorm_nki_kernel.py b/src/nki_samples/tutorials/layernorm/layernorm_nki_kernel.py similarity index 100% rename from src/tutorials/layernorm/layernorm_nki_kernel.py rename to src/nki_samples/tutorials/layernorm/layernorm_nki_kernel.py diff --git a/src/tutorials/layernorm/layernorm_torch.py b/src/nki_samples/tutorials/layernorm/layernorm_torch.py similarity index 100% rename from src/tutorials/layernorm/layernorm_torch.py rename to src/nki_samples/tutorials/layernorm/layernorm_torch.py diff --git a/src/tutorials/matrix_multiplication/matrix_multiplication_nki_kernels.py b/src/nki_samples/tutorials/matrix_multiplication/matrix_multiplication_nki_kernels.py similarity index 100% rename from src/tutorials/matrix_multiplication/matrix_multiplication_nki_kernels.py rename to src/nki_samples/tutorials/matrix_multiplication/matrix_multiplication_nki_kernels.py diff --git a/src/tutorials/matrix_multiplication/matrix_multiplication_torch.py b/src/nki_samples/tutorials/matrix_multiplication/matrix_multiplication_torch.py similarity index 100% rename from src/tutorials/matrix_multiplication/matrix_multiplication_torch.py rename to src/nki_samples/tutorials/matrix_multiplication/matrix_multiplication_torch.py diff --git a/src/tutorials/rmsnorm/rmsnorm_jax.py b/src/nki_samples/tutorials/rmsnorm/rmsnorm_jax.py similarity index 100% rename from src/tutorials/rmsnorm/rmsnorm_jax.py rename to src/nki_samples/tutorials/rmsnorm/rmsnorm_jax.py diff --git a/src/tutorials/rmsnorm/rmsnorm_nki_kernels.py b/src/nki_samples/tutorials/rmsnorm/rmsnorm_nki_kernels.py similarity index 100% rename from src/tutorials/rmsnorm/rmsnorm_nki_kernels.py rename to src/nki_samples/tutorials/rmsnorm/rmsnorm_nki_kernels.py diff --git a/src/tutorials/rmsnorm/rmsnorm_torch.py b/src/nki_samples/tutorials/rmsnorm/rmsnorm_torch.py similarity index 100% rename from src/tutorials/rmsnorm/rmsnorm_torch.py rename to src/nki_samples/tutorials/rmsnorm/rmsnorm_torch.py diff --git a/src/tutorials/sd_attention/sd_attention_nki_kernels.py b/src/nki_samples/tutorials/sd_attention/sd_attention_nki_kernels.py similarity index 100% rename from src/tutorials/sd_attention/sd_attention_nki_kernels.py rename to src/nki_samples/tutorials/sd_attention/sd_attention_nki_kernels.py diff --git a/src/tutorials/sd_attention/sd_attention_torch.py b/src/nki_samples/tutorials/sd_attention/sd_attention_torch.py similarity index 100% rename from src/tutorials/sd_attention/sd_attention_torch.py rename to src/nki_samples/tutorials/sd_attention/sd_attention_torch.py diff --git a/src/tutorials/tensor_addition/tensor_addition_jax.py b/src/nki_samples/tutorials/tensor_addition/tensor_addition_jax.py similarity index 100% rename from src/tutorials/tensor_addition/tensor_addition_jax.py rename to src/nki_samples/tutorials/tensor_addition/tensor_addition_jax.py diff --git a/src/tutorials/tensor_addition/tensor_addition_nki_kernels.py b/src/nki_samples/tutorials/tensor_addition/tensor_addition_nki_kernels.py similarity index 100% rename from src/tutorials/tensor_addition/tensor_addition_nki_kernels.py rename to src/nki_samples/tutorials/tensor_addition/tensor_addition_nki_kernels.py diff --git a/src/tutorials/tensor_addition/tensor_addition_torch.py b/src/nki_samples/tutorials/tensor_addition/tensor_addition_torch.py similarity index 100% rename from src/tutorials/tensor_addition/tensor_addition_torch.py rename to src/nki_samples/tutorials/tensor_addition/tensor_addition_torch.py diff --git a/src/tutorials/transpose2d/transpose2d_jax.py b/src/nki_samples/tutorials/transpose2d/transpose2d_jax.py similarity index 100% rename from src/tutorials/transpose2d/transpose2d_jax.py rename to src/nki_samples/tutorials/transpose2d/transpose2d_jax.py diff --git a/src/tutorials/transpose2d/transpose2d_nki_kernels.py b/src/nki_samples/tutorials/transpose2d/transpose2d_nki_kernels.py similarity index 100% rename from src/tutorials/transpose2d/transpose2d_nki_kernels.py rename to src/nki_samples/tutorials/transpose2d/transpose2d_nki_kernels.py diff --git a/src/tutorials/transpose2d/transpose2d_torch.py b/src/nki_samples/tutorials/transpose2d/transpose2d_torch.py similarity index 100% rename from src/tutorials/transpose2d/transpose2d_torch.py rename to src/nki_samples/tutorials/transpose2d/transpose2d_torch.py diff --git a/src/reference/__init__.py b/src/reference/__init__.py deleted file mode 100644 index 922dd83..0000000 --- a/src/reference/__init__.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright (c) 2023, Amazon.com. All Rights Reserved - -""" -Package containing public kernels for Neuron Kernel Interface (NKI). - -Kernels here are the same to the ones available in the -NKI Github Sample Repo. - -https://github.com/aws-neuron/nki-samples -""" -from neuronxcc.nki.kernels.attention import fused_self_attn_for_SD_small_head_size, flash_attn_bwd, flash_fwd -from neuronxcc.nki.kernels.vision import resize_nearest_fixed_dma_kernel, select_and_scatter_kernel -from neuronxcc.nki.kernels.tutorial import add_kernel_nx8x128x512 -from neuronxcc.nki.kernels.allocated_attention import allocated_fused_self_attn_for_SD_small_head_size -from neuronxcc.nki.kernels.allocated_fused_linear import allocated_fused_rms_norm_qkv - -from neuronxcc.nki._private_kernels.legacy.attention import \ - (fused_self_attn_for_SD_small_head_size as _fused_self_attn_for_SD_small_head_size, - flash_attn_bwd as _flash_attn_bwd, flash_fwd as _flash_fwd) -from neuronxcc.nki._private_kernels.legacy.vision import ( - resize_nearest_fixed_dma_kernel as _resize_nearest_fixed_dma_kernel, - select_and_scatter_kernel as _select_and_scatter_kernel) -from neuronxcc.nki._private_kernels.legacy.tutorial import add_kernel_nx8x128x512 as _add_kernel_nx8x128x512 -from neuronxcc.nki._private_kernels.legacy.allocated_fused_linear import _allocated_fused_rms_norm_qkv - -fused_self_attn_for_SD_small_head_size._legacy_func = _fused_self_attn_for_SD_small_head_size -flash_attn_bwd._legacy_func = _flash_attn_bwd -flash_fwd._legacy_func = _flash_fwd -resize_nearest_fixed_dma_kernel._legacy_func = _resize_nearest_fixed_dma_kernel -select_and_scatter_kernel._legacy_func = _select_and_scatter_kernel -add_kernel_nx8x128x512._legacy_func = _add_kernel_nx8x128x512 -allocated_fused_rms_norm_qkv._legacy_func = _allocated_fused_rms_norm_qkv diff --git a/test/unit/README.md b/test/unit/README.md index e0835de..55dc937 100644 --- a/test/unit/README.md +++ b/test/unit/README.md @@ -1 +1,7 @@ -Tests under this folder are unit tests for the kernels in `neuronxcc.nki.kernels`, and they are part of the nki-samples Github Repo. Only public APIs can be used for tests in this folder. \ No newline at end of file +Tests under this folder are unit tests for the kernels in `src/nki_samples`. + +To execute the tests, we need to include `src/nki_samples` in the `PYTHONPATH`. + +For example, + +PYTHONPATH=$PYTHONPATH:/home/ubuntu/nki-samples/src/ pytest test_flash_attn_fwd.py \ No newline at end of file diff --git a/test/unit/__main__.py b/test/unit/__main__.py deleted file mode 100644 index 34fee3a..0000000 --- a/test/unit/__main__.py +++ /dev/null @@ -1,14 +0,0 @@ -import os -import sys - -# This file is basically a hack around the fact that pytest has a bug where it does not discover conftest.py correctly if you launch the test using --pyargs. -# https://github.com/pytest-dev/pytest/issues/1596 - - -# Todo: Using __file__ isn't the most robust. Figure out how to do this using importlib or similar. -test_root = os.path.dirname(__file__) - -if __name__ == "__main__": - import pytest - errcode = pytest.main([test_root] + sys.argv[1:]) - sys.exit(errcode) \ No newline at end of file diff --git a/test/unit/conftest.py b/test/unit/conftest.py new file mode 100644 index 0000000..cd663ae --- /dev/null +++ b/test/unit/conftest.py @@ -0,0 +1,28 @@ +import pytest + +def pytest_addoption(parser): + parser.addoption( + "--simulation-only", action="store_true", default=False, help="Run simulation only, it will run test with `simulation` marker in simulation mode" + ) + +def pytest_configure(config): + config.addinivalue_line( + "markers", "simulation: mark simulation test that can be executed without a NeuronDevice" + ) + +@pytest.fixture +def simulation_only(request): + return request.config.getoption("--simulation-only") + +def pytest_collection_modifyitems(session, config, items): + if config.getoption("--simulation-only"): + # Only run cases with `simulation marker` + result = [] + for item in items: + for marker in item.iter_markers(): + if marker.name == 'simulation': + result.append(item) + break + items.clear() + items.extend(result) + \ No newline at end of file diff --git a/test/unit/test_SD_attention_small_head.py b/test/unit/test_SD_attention_small_head.py index 32e6945..1a54a4b 100644 --- a/test/unit/test_SD_attention_small_head.py +++ b/test/unit/test_SD_attention_small_head.py @@ -3,14 +3,13 @@ """ import os import pytest -from neuronxcc.nki.kernels.attention import fused_self_attn_for_SD_small_head_size -from neuronxcc.nki import benchmark, baremetal +from nki_samples.reference.attention import fused_self_attn_for_SD_small_head_size +from neuronxcc.nki import benchmark, baremetal, simulate_kernel import neuronxcc.nki.language as nl import numpy as np from scipy.special import softmax test_trace_file_path='local_trace.ntff' -numeric_func = baremetal(fused_self_attn_for_SD_small_head_size) bench_func = benchmark(warmup=5, iters=20, save_trace_name=test_trace_file_path)(fused_self_attn_for_SD_small_head_size) def cpu_golden_attn(q, k, v): @@ -46,11 +45,12 @@ def test_attention_for_SD_perf(self, bs, seqlen, d, dtype, latency): assert p50 <= latency*1.05 # short running kernels are subjected to hardware fluctuation assert os.path.getsize(test_trace_file_path) > 0 + @pytest.mark.simulation @pytest.mark.parametrize("bs,seqlen,d,dtype", [ [1, 4096, 128, np.float32], [1, 4096, 128, nl.bfloat16] ]) - def test_attention_for_SD_numberic(self, bs, seqlen, d, dtype): + def test_attention_for_SD_numberic(self, simulation_only, bs, seqlen, d, dtype): q = np.random.random_sample((bs, d, seqlen)).astype(np.float32) k = np.random.random_sample((bs, seqlen, d)).astype(np.float32) v = np.random.random_sample((bs, seqlen, d)).astype(np.float32) @@ -59,7 +59,11 @@ def test_attention_for_SD_numberic(self, bs, seqlen, d, dtype): k_dev = nl.static_cast(k, dtype) v_dev = nl.static_cast(v, dtype) - out = numeric_func[bs](q_dev, k_dev, v_dev) + numeric_func = baremetal(fused_self_attn_for_SD_small_head_size) + if simulation_only: + out = simulate_kernel(numeric_func[bs], q_dev, k_dev, v_dev) + else: + out = numeric_func[bs](q_dev, k_dev, v_dev) out = nl.static_cast(out, np.float32) golden_result = cpu_golden_attn(q, k, v) assert np.allclose(out, golden_result, atol=1e-2) diff --git a/test/unit/test_allocated_SD_attention_small_head.py b/test/unit/test_allocated_SD_attention_small_head.py index ee0de86..712148f 100644 --- a/test/unit/test_allocated_SD_attention_small_head.py +++ b/test/unit/test_allocated_SD_attention_small_head.py @@ -3,15 +3,15 @@ """ import os import pytest -from neuronxcc.nki.kernels.allocated_attention import allocated_fused_self_attn_for_SD_small_head_size -from neuronxcc.nki import benchmark, baremetal +from nki_samples.reference.allocated_attention import allocated_fused_self_attn_for_SD_small_head_size +from neuronxcc.nki import benchmark, baremetal, simulate_kernel import neuronxcc.nki as nki import neuronxcc.nki.language as nl import numpy as np from scipy.special import softmax test_trace_file_path='local_trace.ntff' -numeric_func = baremetal(allocated_fused_self_attn_for_SD_small_head_size) + bench_func = benchmark(warmup=5, iters=20, save_trace_name=test_trace_file_path)(allocated_fused_self_attn_for_SD_small_head_size) def cpu_golden_attn(q, k, v): @@ -47,12 +47,13 @@ def test_allocated_attention_for_SD_perf(self, bs, seqlen, d, dtype, latency): assert p50 <= latency * 1.05 # short running kernels are subjected to hardware fluctuation assert os.path.getsize(test_trace_file_path) > 0 + @pytest.mark.simulation @pytest.mark.parametrize("bs,seqlen,d,dtype", [ [1, 4096, 128, np.float32], [1, 4096, 128, nl.bfloat16], [1, 5120, 128, nl.bfloat16] ]) - def test_allocated_attention_for_SD_numberic(self, bs, seqlen, d, dtype): + def test_allocated_attention_for_SD_numberic(self, simulation_only, bs, seqlen, d, dtype): q = np.random.random_sample((bs, d, seqlen)).astype(np.float32) k = np.random.random_sample((bs, d, seqlen)).astype(np.float32) v = np.random.random_sample((bs, seqlen, d)).astype(np.float32) @@ -61,7 +62,11 @@ def test_allocated_attention_for_SD_numberic(self, bs, seqlen, d, dtype): k_dev = nl.static_cast(k, dtype) v_dev = nl.static_cast(v, dtype) - out = numeric_func[bs](q_dev, k_dev, v_dev) + numeric_func = baremetal(allocated_fused_self_attn_for_SD_small_head_size) + if simulation_only: + out = simulate_kernel(numeric_func[bs], q_dev, k_dev, v_dev) + else: + out = numeric_func[bs](q_dev, k_dev, v_dev) out = nl.static_cast(out, np.float32) golden_result = cpu_golden_attn(q, k, v) assert np.allclose(out, golden_result, atol=1e-2) diff --git a/test/unit/test_flash_attn_bwd.py b/test/unit/test_flash_attn_bwd.py index 3aedab0..0f45f9f 100644 --- a/test/unit/test_flash_attn_bwd.py +++ b/test/unit/test_flash_attn_bwd.py @@ -2,14 +2,14 @@ Copyright (c) 2023, Amazon.com. All Rights Reserved """ import pytest -from neuronxcc.nki.kernels.attention import flash_attn_bwd -from neuronxcc.nki import benchmark, baremetal +from nki_samples.reference.attention import flash_attn_bwd +from neuronxcc.nki import benchmark, baremetal, simulate_kernel import neuronxcc.nki.language as nl import numpy as np -from TestDecorators import xfail +xfail = pytest.mark.arch_specific_xfail + -numeric_func = baremetal(flash_attn_bwd) bench_func = benchmark(warmup=5, iters=10)(flash_attn_bwd) def softmax(x: np.ndarray, dim: int, zero_max_mode=False, @@ -110,13 +110,14 @@ def test_flash_attn_bwd_perf(self, bs, nheads, seqlen, d, dtype, latency): bench_func_(q, k, v, o_proj, dy, lse, seed, use_causal_mask=True, mixed_precision=True) latency_res = bench_func_.benchmark_result.nc_latency - p99 = latency_res.get_latency_percentile(99) + p99 = latency_res.get_latency_percentile(50) assert p99 <= latency + @pytest.mark.simulation @pytest.mark.parametrize("bs, nheads, seqlen, d, dtype", [ [1, 4, 4096, 128, np.float32], ]) - def test_flash_attn_bwd_numerical(self, bs, nheads, seqlen, d, dtype): + def test_flash_attn_bwd_numerical(self, simulation_only, bs, nheads, seqlen, d, dtype): q = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2 k = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2 v = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2 @@ -135,7 +136,13 @@ def test_flash_attn_bwd_numerical(self, bs, nheads, seqlen, d, dtype): nl.tile_size.pmax).transpose(0, 1, 3, 2) lse = -1.0 * (cached_negative_max + np.log(cached_sum_reciprocal)) - out_dq, out_dk, out_dv = numeric_func[bs, nheads](q, k, v, o_proj, dy, lse, seed, + numeric_func = baremetal(flash_attn_bwd) + if simulation_only: + out_dq, out_dk, out_dv = simulate_kernel(numeric_func[bs, nheads], q, k, v, o_proj, dy, lse, seed, + use_causal_mask=True, + mixed_precision=True) + else: + out_dq, out_dk, out_dv = numeric_func[bs, nheads](q, k, v, o_proj, dy, lse, seed, use_causal_mask=True, mixed_precision=True) diff --git a/test/unit/test_flash_attn_fwd.py b/test/unit/test_flash_attn_fwd.py index fff4ac2..e52354d 100644 --- a/test/unit/test_flash_attn_fwd.py +++ b/test/unit/test_flash_attn_fwd.py @@ -2,12 +2,11 @@ Copyright (c) 2023, Amazon.com. All Rights Reserved """ import pytest -from neuronxcc.nki.kernels.attention import flash_fwd, FlashConfig -from neuronxcc.nki import benchmark, baremetal +from nki_samples.reference.attention import flash_fwd, FlashConfig +from neuronxcc.nki import benchmark, baremetal, simulate_kernel import neuronxcc.nki.language as nl import numpy as np - -numeric_func = baremetal(flash_fwd) + bench_func = benchmark(warmup=5, iters=10)(flash_fwd) def softmax(x: np.ndarray, dim: int, zero_max_mode=False, @@ -95,9 +94,10 @@ def test_flash_attn_fwd_perf(self, bs, nheads, seqlen_q, seqlen_k, d, dtype, use bench_func_(q, k, v, seed, use_causal_mask=use_causal_mask, mixed_precision=mixed_precision, config=config) latency_res = bench_func_.benchmark_result.nc_latency - p99 = latency_res.get_latency_percentile(99) + p99 = latency_res.get_latency_percentile(50) assert p99 <= latency - + + @pytest.mark.simulation @pytest.mark.parametrize("bs, nheads, seqlen_q, seqlen_k, d, dtype, use_causal_mask,\ training, tile_size, kv_heads, should_transpose_v", [ [1, 6, 4096, 4096, 128, np.float32, True, True, 2048, 3, False], @@ -105,7 +105,7 @@ def test_flash_attn_fwd_perf(self, bs, nheads, seqlen_q, seqlen_k, d, dtype, use [1, 1, 8192, 4096, 128, np.float32, True, False, 2048, None, False], [1, 1, 4096, 8192, 128, np.float32, True, False, 2048, None, False], ]) - def test_flash_attn_fwd_numerical(self, bs, nheads, seqlen_q, seqlen_k, d, dtype, use_causal_mask, + def test_flash_attn_fwd_numerical(self, simulation_only, bs, nheads, seqlen_q, seqlen_k, d, dtype, use_causal_mask, training, tile_size, kv_heads, should_transpose_v): q = (np.random.random_sample([bs, nheads, d, seqlen_q]) - 0.5) * 2 k = (np.random.random_sample([bs, kv_heads or nheads, d, seqlen_k]) - 0.5) * 2 @@ -132,7 +132,15 @@ def test_flash_attn_fwd_numerical(self, bs, nheads, seqlen_q, seqlen_k, d, dtype config = FlashConfig(**{'seq_tile_size':tile_size, 'training':training, 'should_transpose_v':should_transpose_v}) heads = nheads if kv_heads is None else kv_heads - results = numeric_func[bs, heads](q, k, v, seed, + + numeric_func = baremetal(flash_fwd) + if simulation_only: + results = simulate_kernel(numeric_func[bs, heads], q, k, v, seed, + use_causal_mask=use_causal_mask, + mixed_precision=True, + config=config) + else: + results = numeric_func[bs, heads](q, k, v, seed, use_causal_mask=use_causal_mask, mixed_precision=True, config=config) diff --git a/test/unit/test_neuron_profile.py b/test/unit/test_neuron_profile.py new file mode 100644 index 0000000..e607705 --- /dev/null +++ b/test/unit/test_neuron_profile.py @@ -0,0 +1,86 @@ +from neuronxcc.nki import benchmark +from neuronxcc.nki import profile +import neuronxcc.nki.language as nl +import numpy as np +import pytest +import os +import shutil +import tempfile + + +WORKING_DIRECTORY = tempfile.mkdtemp() +SAVE_NEFF_NAME = "cus_file123.neff" +SAVE_TRACE_NAME = "profile-custom.ntff" +NUM_EXECS = 20 +PROFILE_NTH = 10 +JSON_REPORTS = "json_reports" + +@profile(working_directory=WORKING_DIRECTORY, save_neff_name=SAVE_NEFF_NAME, overwrite=False , save_trace_name=SAVE_TRACE_NAME, num_execs=NUM_EXECS, profile_nth=PROFILE_NTH) +def nki_tensor_tensor_add(a_tensor, b_tensor): + c_output = nl.ndarray(a_tensor.shape, dtype=a_tensor.dtype, buffer=nl.shared_hbm) + + a = nl.load(a_tensor) + b = nl.load(b_tensor) + + c_tile = a + b + + nl.store(c_output, value=c_tile) + + return c_output + +class TestNeuronProfile: + def _get_ntff_path(self, trace_val): + """ + Prepares ntff file name based on execution trace number + """ + if trace_val == 1: + return os.path.join(WORKING_DIRECTORY, f"{os.path.splitext(os.path.basename(SAVE_TRACE_NAME))[0]}.ntff") + else: + return os.path.join(WORKING_DIRECTORY, f"{os.path.splitext(os.path.basename(SAVE_TRACE_NAME))[0]}_exec_{trace_val}.ntff") + + @pytest.fixture + def traces(self): + ret = [] + if NUM_EXECS < PROFILE_NTH: + ret.append(self._get_ntff_path(PROFILE_NTH)) + else: + curr = PROFILE_NTH + while curr <= NUM_EXECS: + ret.append(self._get_ntff_path(curr)) + curr += PROFILE_NTH + return ret + + @pytest.fixture + def num_reports(self): + if NUM_EXECS < PROFILE_NTH: + return 1 + else: + return NUM_EXECS // PROFILE_NTH + + def test_output_artifacts_created(self, traces, num_reports): + # delete artifact directory, only testing non-overwrite functionality + if os.path.exists(WORKING_DIRECTORY): + shutil.rmtree(WORKING_DIRECTORY) + + # creates dummy input to invoke profile kernel + a = np.zeros([128, 1024]).astype(np.float16) + b = np.random.random_sample([128, 1024]).astype(np.float16) + + output_nki = nki_tensor_tensor_add(a, b) + + # now asserting artifacts are correctly created + assert os.path.exists(os.path.join(WORKING_DIRECTORY, SAVE_NEFF_NAME)) # neff + + for trace in traces: + assert os.path.exists(trace) # trace + + # json reports + report_dir = os.path.join(WORKING_DIRECTORY, JSON_REPORTS) + + assert os.path.exists(report_dir) # actually exists + assert len(os.listdir(report_dir)) == num_reports # report all iterations queried + + # post condition cleanup + if os.path.exists(WORKING_DIRECTORY): + shutil.rmtree(WORKING_DIRECTORY) + diff --git a/test/unit/test_resize_nearest.py b/test/unit/test_resize_nearest.py index 2bbc601..72e7aef 100644 --- a/test/unit/test_resize_nearest.py +++ b/test/unit/test_resize_nearest.py @@ -3,12 +3,11 @@ """ import pytest -from neuronxcc.nki.kernels.vision import resize_nearest_fixed_dma_kernel -from neuronxcc.nki import benchmark, baremetal +from nki_samples.reference.vision import resize_nearest_fixed_dma_kernel +from neuronxcc.nki import benchmark, baremetal, simulate_kernel import neuronxcc.nki.language as nl import numpy as np -numeric_func = baremetal(resize_nearest_fixed_dma_kernel) bench_func = benchmark(warmup=5, iters=10)(resize_nearest_fixed_dma_kernel) @@ -49,20 +48,25 @@ def test_resize_nearest_for_perf(self, in_b, in_h, in_w, in_c, out_b, out_h, out bench_func_ = bench_func[in_b] bench_func_(input_dev, (out_b, out_h, out_w, out_c)) latency_res = bench_func_.benchmark_result.nc_latency - p99 = latency_res.get_latency_percentile(99) + p99 = latency_res.get_latency_percentile(50) assert p99 <= latency + @pytest.mark.simulation @pytest.mark.parametrize("in_b, in_h, in_w, in_c, out_b, out_h, out_w, out_c, dtype", [ [10, 30, 20, 1280, 10, 59, 38, 1280, np.float32], [1, 30, 20, 1280, 1, 59, 38, 1280, nl.float16], [1, 30, 20, 1280, 1, 59, 38, 1280, nl.bfloat16], ]) - def test_resize_nearest_for_numberic(self, in_b, in_h, in_w, in_c, out_b, out_h, out_w, out_c, dtype): + def test_resize_nearest_for_numberic(self, simulation_only, in_b, in_h, in_w, in_c, out_b, out_h, out_w, out_c, dtype): input_tensor = np.random.random_sample((in_b, in_h, in_w, in_c)).astype(np.float32) input_dev = nl.static_cast(input_tensor, dtype) - output_tensor = numeric_func[in_b](input_dev, (out_b, out_h, out_w, out_c)) + numeric_func = baremetal(resize_nearest_fixed_dma_kernel) + if simulation_only: + output_tensor = simulate_kernel(numeric_func[in_b], input_dev, (out_b, out_h, out_w, out_c)) + else: + output_tensor = numeric_func[in_b](input_dev, (out_b, out_h, out_w, out_c)) output_tensor = nl.static_cast(output_tensor, np.float32) golden_result = cpu_golden_result(input_tensor, output_tensor.shape) assert np.allclose(output_tensor, golden_result, atol=1e-2) diff --git a/test/unit/test_rmsnorm_qkv.py b/test/unit/test_rmsnorm_qkv.py index 24ad31c..28838d1 100644 --- a/test/unit/test_rmsnorm_qkv.py +++ b/test/unit/test_rmsnorm_qkv.py @@ -2,12 +2,11 @@ Copyright (c) 2024, Amazon.com. All Rights Reserved """ import pytest -from neuronxcc.nki.kernels.allocated_fused_linear import allocated_fused_rms_norm_qkv -from neuronxcc.nki import benchmark, baremetal +from nki_samples.reference.allocated_fused_linear import allocated_fused_rms_norm_qkv +from neuronxcc.nki import benchmark, baremetal, simulate_kernel import neuronxcc.nki.language as nl import numpy as np -numeric_func = baremetal(allocated_fused_rms_norm_qkv) bench_func = benchmark(warmup=5, iters=10)(allocated_fused_rms_norm_qkv) np.random.seed(0) @@ -31,7 +30,7 @@ class TestRMSNormQKV: [1, 128, 512, 512, np.float16, 25], [1, 512, 1024, 384, nl.bfloat16, 40], [1, 128, 1024, 512, nl.bfloat16, 28], - [1, 1024, 8192, 512, nl.bfloat16, 301 * 1.02] + # [1, 1024, 8192, 512, nl.bfloat16, 301 * 1.02], # FIXME: performance is flaky ]) def test_allocated_rmsnorm_qkv_perf(self, batch, seqlen, dim, d_head, dtype, latency): hidden = np.random.random_sample((batch, seqlen, dim)).astype(np.float32) @@ -42,23 +41,28 @@ def test_allocated_rmsnorm_qkv_perf(self, batch, seqlen, dim, d_head, dtype, lat bench_func(hidden, weights) latency_res = bench_func.benchmark_result.nc_latency - p99 = latency_res.get_latency_percentile(99) + p99 = latency_res.get_latency_percentile(50) assert p99 <= latency + @pytest.mark.simulation @pytest.mark.parametrize("batch, seqlen, dim, d_head, dtype", [ [1, 128, 512, 512, np.float16], [1, 512, 1024, 384, nl.bfloat16], [1, 128, 1024, 512, nl.bfloat16], [1, 1024, 8192, 512, nl.bfloat16] ]) - def test_allocated_rmsnorm_qkv_numeric(self, batch, seqlen, dim, d_head, dtype): + def test_allocated_rmsnorm_qkv_numeric(self, simulation_only, batch, seqlen, dim, d_head, dtype): hidden = np.random.random_sample((batch, seqlen, dim)) weights = np.random.random_sample((dim, d_head)) hidden_dev = nl.static_cast(hidden, dtype) weights_dev = nl.static_cast(weights, dtype) - out = numeric_func(hidden_dev, weights_dev) + numeric_func = baremetal(allocated_fused_rms_norm_qkv) + if simulation_only: + out = simulate_kernel(numeric_func, hidden_dev, weights_dev) + else: + out = numeric_func(hidden_dev, weights_dev) out = nl.static_cast(out, np.float32) golden_res = nl.static_cast(cpu_golden_result(hidden, None, weights, dtype, do_norm=True), np.float32) assert np.allclose(out, golden_res, atol=1e-2, rtol=1e-2) diff --git a/test/unit/test_select_and_scatter.py b/test/unit/test_select_and_scatter.py index fc99b37..08e787f 100644 --- a/test/unit/test_select_and_scatter.py +++ b/test/unit/test_select_and_scatter.py @@ -1,11 +1,10 @@ import pytest -from neuronxcc.nki.kernels.vision import select_and_scatter_kernel -from neuronxcc.nki import benchmark, baremetal +from nki_samples.reference.vision import select_and_scatter_kernel +from neuronxcc.nki import benchmark, baremetal, simulate_kernel import neuronxcc.nki.language as nl import numpy as np -numeric_func = baremetal(select_and_scatter_kernel) bench_func = benchmark(warmup=5, iters=10)(select_and_scatter_kernel) np.random.seed(0) @@ -51,14 +50,15 @@ def test_select_and_scatter_for_perf(self, n, c, operand_h, operand_w, source_h, bench_func(operand_dev, source_dev) latency_res = bench_func.benchmark_result.nc_latency - p99 = latency_res.get_latency_percentile(99) + p99 = latency_res.get_latency_percentile(50) assert p99 <= latency + @pytest.mark.simulation @pytest.mark.parametrize("n, c, operand_h, operand_w, source_h, source_w, dtype", [ [8, 64, 112, 112, 56, 56, np.float32], [8, 64, 112, 112, 56, 56, nl.bfloat16], ]) - def test_select_and_scatter_for_numeric(self, n, c, operand_h, operand_w, source_h, source_w, dtype): + def test_select_and_scatter_for_numeric(self,simulation_only, n, c, operand_h, operand_w, source_h, source_w, dtype): operand_dev = nl.static_cast(np.random.random_sample((n, c, operand_h, operand_w)), dtype) source_dev = nl.static_cast(np.random.random_sample((n, c, source_h, source_w)), dtype) @@ -66,7 +66,11 @@ def test_select_and_scatter_for_numeric(self, n, c, operand_h, operand_w, source operand_tensor = nl.static_cast(operand_dev, np.float32) source_tensor = nl.static_cast(source_dev, np.float32) - output_dev = numeric_func(operand_dev, source_dev) + numeric_func = baremetal(select_and_scatter_kernel) + if simulation_only: + output_dev = simulate_kernel(numeric_func, operand_dev, source_dev) + else: + output_dev = numeric_func(operand_dev, source_dev) golden_result = cpu_golden_result(operand_tensor, source_tensor) nki_result = nl.static_cast(output_dev, np.float32)