Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FMHA PAXML test #830

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 52 additions & 6 deletions .github/container/test-pax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ usage() {
echo " --dtype Batch size, defaults to bfloat16."
echo " --enable-te If set, will run with env var ENABLE_TE=1."
echo " --enable-dropout If set, will set DROPOUT_PROB to 0.1."
echo " --disable-fused-attn Whether disable TE fused attention."
echo " --model-type One of 126M, 5B, LLaMA70BProxy. Defaults to 126M"
echo " --evaluate Whether to test evaluation rather than training."
echo " -s, --steps Number of steps to run, defaults to 500."
echo " --multiprocess Enable the multiprocess GPU mode."
echo " -o, --output NAME Name for the output folder, a temporary folder will be created if none specified."
echo " --save-hlo {0, 1} 1 to save the dumped hlo, 0 to remove the hlo dumped folder"
echo " --enable-fmha {0, 1} 1 to enable fmha testing, 0 to run test without fmha; default is 0"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: default doesn't match below

echo " --data-parallel Data parallelism to use. Defaults to 1."
echo " --fsdp Fully-sharded data parallelism to use. Defaults to 1."
echo " --tensor-parallel Tensor parallelism to use. Defaults to 1."
Expand All @@ -32,7 +33,8 @@ usage() {
exit $1
}

args=$(getopt -o a:b:s:o:n:h --long additional-args:,batch-per-gpu:,dtype:,enable-te,enable-dropout,disable-fused-attn,model-type:,evaluate,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- "$@")
args=$(getopt -o a:b:s:o:n:h --long additional-args:,batch-per-gpu:,dtype:,enable-te,enable-dropout,model-type:,enable-fmha:,evaluate,steps:,help,multiprocess,output:,save-hlo:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- "$@")

if [[ $? -ne 0 ]]; then
exit $1
fi
Expand All @@ -55,6 +57,8 @@ NVTE_FUSED_ATTN=1
DROPOUT=0
EVALUATE=0
ADDITIONAL_ARGS=""
ENABLE_FMHA=${ENABLE_FMHA:-0}
hmonishN marked this conversation as resolved.
Show resolved Hide resolved
SAVE_HLO=${SAVE_HLO:-0}

eval set -- "$args"
while [ : ]; do
Expand All @@ -75,14 +79,15 @@ while [ : ]; do
ENABLE_TE=1
shift 1
;;
--enable-fmha)
ENABLE_FMHA="$2"
NVTE_FUSED_ATTN="$2"
shift 2
;;
--enable-dropout)
DROPOUT='0.1'
shift 1
;;
--disable-fused-attn)
NVTE_FUSED_ATTN=0
shift 1
;;
--model-type)
MODEL_TYPE=$2
shift 2
Expand All @@ -103,6 +108,10 @@ while [ : ]; do
OUTPUT=$2
shift 2
;;
--save-hlo)
SAVE_HLO="$2"
shift 2
;;
--data-parallel)
DP="$2"
shift 2
Expand Down Expand Up @@ -136,6 +145,21 @@ while [ : ]; do
esac
done

# Set hlo dump folder after output folder is set.
HLO_DIR=${OUTPUT}/hlo
export BASE_XLA_FLAGS="${BASE_XLA_FLAGS:---xla_dump_hlo_as_text --xla_dump_to=${HLO_DIR}}"
Copy link
Contributor

@DwarKapex DwarKapex May 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please explain logic here: is BASE_XLA_FLAGS is set, than you always skip setting HLO_DIR?
If so, maybe you can add a warning message, that xla dump is not set?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dumping the hlo is enabled by default in BASE_XLA_FLAGS, and BASE_XLA_FLAGS are appended to XLA_FLAGS env var. if user wants to test fmha then BASE_XLA_FLAGS_FMHA is added and appended to XLA_FLAGS. The idea is to preserve the env var XLA_FLAGS before execution of this script.

Copy link
Contributor

@DwarKapex DwarKapex Jun 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, let me clarify my question:
line 150 literally means:

if [[ -z "$BASE_XLA_FLAGS"  ]]; then
      BASE_XLA_FLAGS = "--xla_dump_hlo_as_text --xla_dump_to=${HLO_DIR}}"
fi

Meaning, that if BASE_XLA_FLAGS is already set (by any previous scripts, or globally in the system, etc), ${HLO_DIR} will not have any effect at all.

Is that expected behaviour?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And why do you export it? You use it only locally.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mechanism was added as per the review comment of same PR for t5x: #442 (comment)
refer to the discussion for details of the implementation.
The implementation BASE_XLA_FLAGS="${BASE_XLA_FLAGS:---xla_dump_hlo_as_text --xla_dump_to=${HLO_DIR}}" means update the BASE_XLA_FLAGS with previous definition if any and append xla dump hlo flags to the env vars. This also gives us the flexibility of "zero out" the env var in this script without modifying code in this script by just doing BASE_XLA_FLAGS=""

export XLA_FLAGS="${BASE_XLA_FLAGS} ${XLA_FLAGS:-}"
echo "HLO will be dumped in ${HLO_DIR} dir."

## Setting the env variables for FMHA
if [[ "$ENABLE_FMHA" -eq "1" ]]; then
echo "Setting XLA FMHA Flags";
export BASE_XLA_FLAGS_FMHA="${BASE_XLA_FLAGS_FMHA:---xla_gpu_fused_attention_use_cudnn_rng=true --xla_gpu_enable_cudnn_fmha=true}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Save here as above

export XLA_FLAGS="${BASE_XLA_FLAGS_FMHA} ${XLA_FLAGS:-}"
fi

echo "XLA FLAGS: $XLA_FLAGS"

# # Set derived variables

GPUS_PER_NODE=$(nvidia-smi -L | grep -c '^GPU')
Expand All @@ -149,8 +173,10 @@ print_var NGPUS
print_var OUTPUT
print_var MULTIPROCESS
print_var ENABLE_TE
print_var ENABLE_FMHA
print_var NVTE_FUSED_ATTN
print_var EVALUATE
print_var SAVE_HLO
print_var DROPOUT
print_var DP
print_var FSDP
Expand Down Expand Up @@ -422,5 +448,25 @@ else
$([[ $MULTIPROCESS != 0 ]] && echo --multiprocess_gpu)
fi

echo "Checking for FMHA instructions in HLO!"

if [[ "$ENABLE_FMHA" -eq "1" ]]; then
## Check if fmha instructions are present in the HLO dumped file or not.
fmha_regex="fmha[-bmm]?[-scale]?[-bias]?[-mask]?[-softmax]?[-dropout]?[-bmm]?[-backward]?*"
result=$(grep -irlnE "$fmha_regex" "${HLO_DIR}/"*.txt)

if [ -z "$result" ]; then
echo "E: No FMHA instructions were found in the hlo files!"
exit 1
else
echo -e "Found FMHA instructions in the following HLO files: \n $result"
fi
fi

if [[ $SAVE_HLO -eq 0 ]]; then
rm -rf $HLO_DIR
echo "Removed dumped HLO directory!"
fi

set +x
echo "Output at ${OUTPUT}"
39 changes: 29 additions & 10 deletions .github/workflows/_test_upstream_pax.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,22 @@ on:

jobs:

single-process-multi-device:
pax-single-process-multi-device:
strategy:
matrix:
PARALLEL_CONFIG:
- [1, 8, 1, 1]
- [1, 1, 2, 4]
include:
- TEST_NAME: 8DP1FSDP1TP1PP
DwarKapex marked this conversation as resolved.
Show resolved Hide resolved
PARALLEL_CONFIG: [1, 8, 1, 1]
BATCH_SIZE: 4
ADDITIONAL_ARGS: ""
- TEST_NAME: 8DP2FSDP4TP1PP
PARALLEL_CONFIG: [1, 1, 2, 4]
BATCH_SIZE: 4
ADDITIONAL_ARGS: ""
- TEST_NAME: 8DP1FSDP1TP1PP_fmha
PARALLEL_CONFIG: [1, 8, 1, 1]
BATCH_SIZE: 4
ADDITIONAL_ARGS: "--enable-fmha 1 --save-hlo 1"
fail-fast: false

runs-on: ubuntu-22.04
Expand Down Expand Up @@ -67,7 +77,7 @@ jobs:
shell: bash -x -e {0}
run: |
IMAGE="$(echo ${{inputs.PAX_IMAGE}} | sed 's/\//#/')"
TEST_CASE_NAME=${{ matrix.PARALLEL_CONFIG[1] }}DP${{ matrix.PARALLEL_CONFIG[2] }}FSDP${{ matrix.PARALLEL_CONFIG[3] }}TP${{ matrix.PARALLEL_CONFIG[0] }}PP_single_process
TEST_CASE_NAME=${{ matrix.TEST_NAME }}_single_process
MAX_GPUS_PER_NODE=8
NODES=1
GPUS_PER_NODE=8
Expand Down Expand Up @@ -112,13 +122,14 @@ jobs:
test-pax.sh \
--output /output/${{ steps.meta.outputs.TEST_CASE_NAME }} \
--dtype bfloat16 \
--batch-per-gpu 4 \
--batch-per-gpu ${{ matrix.BATCH_SIZE }} \
--steps 500 \
--pipeline-parallel ${{ matrix.PARALLEL_CONFIG[0] }} \
--data-parallel ${{ matrix.PARALLEL_CONFIG[1] }} \
--fsdp ${{ matrix.PARALLEL_CONFIG[2] }} \
--tensor-parallel ${{ matrix.PARALLEL_CONFIG[3] }} \
--nodes ${{ steps.meta.outputs.NODES }}
--nodes ${{ steps.meta.outputs.NODES }} \
${{ matrix.ADDITIONAL_ARGS }}
EOF
)

Expand Down Expand Up @@ -210,6 +221,14 @@ jobs:
BATCH_SIZE: 4
EVALUATE: true
ADDITIONAL_ARGS: "--model-type LLaMA70BProxy --evaluate"
- TEST_NAME: 2DP1FSDP1TP4PP_fmha
PARALLEL_CONFIG: [4, 2, 1, 1]
hmonishN marked this conversation as resolved.
Show resolved Hide resolved
BATCH_SIZE: 4
ADDITIONAL_ARGS: "--enable-fmha 1 --save-hlo 1"
- TEST_NAME: 16DP1FSDP1TP1PP_fmha
PARALLEL_CONFIG: [1, 16, 1, 1]
BATCH_SIZE: 4
ADDITIONAL_ARGS: "--enable-fmha 1 --save-hlo 1"
fail-fast: false

runs-on: ubuntu-22.04
Expand Down Expand Up @@ -354,7 +373,7 @@ jobs:
path: |
output/*

single-process-evaluation:
pax-single-process-evaluation:
strategy:
matrix:
PARALLEL_CONFIG:
Expand Down Expand Up @@ -503,7 +522,7 @@ jobs:

metrics:
name: test-upstream-pax-metrics
needs: [single-process-multi-device, pax-multi-node, single-process-evaluation]
needs: [pax-single-process-multi-device, pax-multi-node, pax-single-process-evaluation]
runs-on: ubuntu-22.04

steps:
Expand Down Expand Up @@ -549,7 +568,7 @@ jobs:
summary:
name: test-upstream-pax-summary
runs-on: ubuntu-22.04
needs: [single-process-multi-device, pax-multi-node, single-process-evaluation]
needs: [pax-single-process-multi-device, pax-multi-node, pax-single-process-evaluation]
if: "!cancelled()"
steps:
- name: Generate TensorBoard query URL
Expand Down
Loading