From ca0218a7a4d80d7f6f6465847fe420f1fecd05cb Mon Sep 17 00:00:00 2001 From: Robert Suderman Date: Sat, 6 May 2023 00:09:08 +0000 Subject: [PATCH] Added a github action for running the jax testsuite --- .github/workflows/run_jaxtests_cpu.yml | 69 ++++++++++++++ test/test_jax.py | 120 +++++++++++++++++++++++++ 2 files changed, 189 insertions(+) create mode 100644 .github/workflows/run_jaxtests_cpu.yml create mode 100644 test/test_jax.py diff --git a/.github/workflows/run_jaxtests_cpu.yml b/.github/workflows/run_jaxtests_cpu.yml new file mode 100644 index 00000000..0d9c05eb --- /dev/null +++ b/.github/workflows/run_jaxtests_cpu.yml @@ -0,0 +1,69 @@ +# Installs prebuilt binaries and runs jax testsuite + +name: Run CPU Jax Testsuite + +on: + workflow_dispatch: + schedule: + # Do the nightly dep roll at 2:30 PDT. + - cron: '30 21 * * *' + +env: + # This duplicates the variable from ci.yml. The variable needs to be in env + # instead of the outputs of setup because it contains the run attempt and we + # want that to be the current attempt, not whatever attempt the setup step + # last ran in. It therefore can't be passed in via inputs because the env + # context isn't available there. + GCS_DIR: gs://iree-github-actions-jax-testsuite-artifacts/ + + +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). + group: run_jax_testsuite_${{ github.event.number || github.sha }} + cancel-in-progress: true + +# Jobs are organized into groups and topologically sorted by dependencies +jobs: + build: + runs-on: ubuntu-20.04-64core + steps: + - name: Get current date + id: date + # Sets up: ${{ steps.date.outputs.date }} + run: echo "::set-output name=date::$(date +'%Y-%m-%d')" + + - name: "Checking out repository" + uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0 + + - name: "Setting up Python" + uses: actions/setup-python@75f3110429a8c05be0e1bf360334e4cced2b63fa # v2.3.3 + with: + python-version: "3.10" + + - name: Sync and install versions + run: | + # TODO: https://github.com/openxla/openxla-pjrt-plugin/issues/30 + sudo apt install -y lld + # Since only building the runtime, exclude compiler deps (expensive). + python ./sync_deps.py --depth 1 --submodules-depth 1 --exclude-submodule "iree:third_party/(llvm|mlir-hlo)" + pip install -r requirements.txt + + - name: Setup Bazelisk + uses: bazelbuild/setup-bazelisk@v2 + + - name: "Configure" + run: | + python ./configure.py --cc=clang --cxx=clang++ + + - name: "Build CPU Plugin" + run: | + bazel build //iree/integrations/pjrt/cpu:pjrt_plugin_iree_cpu.so + + - name: "Run JAX Testsuite" + run: | + source .env.sh + JAX_PLATFORMS=iree_cpu python test/test_jax.py external/jax/tests/nn_test.py \ + --passing jaxsuite_passing.txt \ + --failing jaxsuite_failing.txt diff --git a/test/test_jax.py b/test/test_jax.py new file mode 100644 index 00000000..0f5a6f12 --- /dev/null +++ b/test/test_jax.py @@ -0,0 +1,120 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import argparse +import multiprocessing +import re +import subprocess +import sys + +parser = argparse.ArgumentParser( + prog='test_jax.py', + description='Run jax testsuite hermetically') +parser.add_argument('testfiles', nargs="*") +parser.add_argument('-t', '--timeout', default=20) +parser.add_argument('-l', '--logdir', default="/tmp/jaxtest") +parser.add_argument('-p', '--passing', default=None) +parser.add_argument('-f', '--failing', default=None) +parser.add_argument('-e', '--expected', default=None) + +args = parser.parse_args() + +PYTEST_CMD= ["pytest", "-p", "openxla_pjrt_artifacts", f"--openxla-pjrt-artifact-dir={args.logdir}"] + +def get_tests(tests): + testlist = [] + for test in sorted(tests): + stdout = subprocess.run(PYTEST_CMD + ["--setup-only", test], capture_output=True) + testlist += re.findall('::[^ ]*::[^ ]*', str(stdout)) + testlist = [test + func for func in testlist] + return testlist + +def generate_test_commands(tests, timeout=False): + cmd = ["timeout", f"{args.timeout}"] if timeout else [] + cmd += PYTEST_CMD + cmds = [] + for test in tests: + test_cmd = cmd + [test] + cmds.append(test_cmd) + + return cmds + +def exec_test(command): + result = subprocess.run(command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + sys.stdout.write(".") + sys.stdout.flush() + return result.returncode + +def exec_testsuite(commands): + returncodes = [] + with multiprocessing.Pool() as p: + returncodes = p.map(exec_test, commands) + print("") + passing = [] + failing = [] + for code, cmd in zip(returncodes, commands): + testname = " ".join(cmd) + testname = re.search("[^ ]*::[^ ]*::[^ ]*", testname)[0] + + if code == 0: + passing.append(testname) + else: + failing.append(testname) + return passing, failing + +def write_results(filename, results): + if (filename is not None): + with open(filename, 'w') as f: + for line in results: + f.write(line+"\n") + +def load_results(filename): + if not filename: + return [] + expected = [] + with open(filename, 'r') as f: + for line in f: + expected.append(line.strip()) + return expected + +def compare_results(expected, passing): + passing = set(passing) + expected = set(expected) + new_failures = expected - passing + new_passing = passing - expected + return new_passing, new_failures + + +tests = get_tests(args.testfiles) +tests = tests[:36] + +print("Generating test suite") +test_commands = generate_test_commands(tests, timeout=True) + +print (f"Executing {len(test_commands)} tests") +passing, failing = exec_testsuite(test_commands) +expected = load_results(args.expected) + +write_results(args.passing, passing) +write_results(args.failing, failing) + +print("Total:", len(test_commands)) +print("Passing:", len(passing)) +print("Failing:", len(failing)) + +if expected: + new_passing, new_failures = compare_results(expected, passing) + + if new_passing: + print ("Newly Passing Tests:") + for test in new_passing: + print(" ", test) + + if new_failures: + print ("Newly Failing Tests:") + for test in new_failures: + print(" ", test) + + if len(expected) > len(passing): + exit(1)