Skip to content

Commit

Permalink
Added a github action for running the jax testsuite
Browse files Browse the repository at this point in the history
  • Loading branch information
rsuderman committed May 6, 2023
1 parent aecbf3b commit ca0218a
Show file tree
Hide file tree
Showing 2 changed files with 189 additions and 0 deletions.
69 changes: 69 additions & 0 deletions .github/workflows/run_jaxtests_cpu.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Installs prebuilt binaries and runs jax testsuite

name: Run CPU Jax Testsuite

on:
workflow_dispatch:
schedule:
# Do the nightly dep roll at 2:30 PDT.
- cron: '30 21 * * *'

env:
# This duplicates the variable from ci.yml. The variable needs to be in env
# instead of the outputs of setup because it contains the run attempt and we
# want that to be the current attempt, not whatever attempt the setup step
# last ran in. It therefore can't be passed in via inputs because the env
# context isn't available there.
GCS_DIR: gs://iree-github-actions-jax-testsuite-artifacts/


concurrency:
# A PR number if a pull request and otherwise the commit hash. This cancels
# queued and in-progress runs for the same PR (presubmit) or commit
# (postsubmit).
group: run_jax_testsuite_${{ github.event.number || github.sha }}
cancel-in-progress: true

# Jobs are organized into groups and topologically sorted by dependencies
jobs:
build:
runs-on: ubuntu-20.04-64core
steps:
- name: Get current date
id: date
# Sets up: ${{ steps.date.outputs.date }}
run: echo "::set-output name=date::$(date +'%Y-%m-%d')"

- name: "Checking out repository"
uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0

- name: "Setting up Python"
uses: actions/setup-python@75f3110429a8c05be0e1bf360334e4cced2b63fa # v2.3.3
with:
python-version: "3.10"

- name: Sync and install versions
run: |
# TODO: https://github.com/openxla/openxla-pjrt-plugin/issues/30
sudo apt install -y lld
# Since only building the runtime, exclude compiler deps (expensive).
python ./sync_deps.py --depth 1 --submodules-depth 1 --exclude-submodule "iree:third_party/(llvm|mlir-hlo)"
pip install -r requirements.txt
- name: Setup Bazelisk
uses: bazelbuild/setup-bazelisk@v2

- name: "Configure"
run: |
python ./configure.py --cc=clang --cxx=clang++
- name: "Build CPU Plugin"
run: |
bazel build //iree/integrations/pjrt/cpu:pjrt_plugin_iree_cpu.so
- name: "Run JAX Testsuite"
run: |
source .env.sh
JAX_PLATFORMS=iree_cpu python test/test_jax.py external/jax/tests/nn_test.py \
--passing jaxsuite_passing.txt \
--failing jaxsuite_failing.txt
120 changes: 120 additions & 0 deletions test/test_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import argparse
import multiprocessing
import re
import subprocess
import sys

parser = argparse.ArgumentParser(
prog='test_jax.py',
description='Run jax testsuite hermetically')
parser.add_argument('testfiles', nargs="*")
parser.add_argument('-t', '--timeout', default=20)
parser.add_argument('-l', '--logdir', default="/tmp/jaxtest")
parser.add_argument('-p', '--passing', default=None)
parser.add_argument('-f', '--failing', default=None)
parser.add_argument('-e', '--expected', default=None)

args = parser.parse_args()

PYTEST_CMD= ["pytest", "-p", "openxla_pjrt_artifacts", f"--openxla-pjrt-artifact-dir={args.logdir}"]

def get_tests(tests):
testlist = []
for test in sorted(tests):
stdout = subprocess.run(PYTEST_CMD + ["--setup-only", test], capture_output=True)
testlist += re.findall('::[^ ]*::[^ ]*', str(stdout))
testlist = [test + func for func in testlist]
return testlist

def generate_test_commands(tests, timeout=False):
cmd = ["timeout", f"{args.timeout}"] if timeout else []
cmd += PYTEST_CMD
cmds = []
for test in tests:
test_cmd = cmd + [test]
cmds.append(test_cmd)

return cmds

def exec_test(command):
result = subprocess.run(command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
sys.stdout.write(".")
sys.stdout.flush()
return result.returncode

def exec_testsuite(commands):
returncodes = []
with multiprocessing.Pool() as p:
returncodes = p.map(exec_test, commands)
print("")
passing = []
failing = []
for code, cmd in zip(returncodes, commands):
testname = " ".join(cmd)
testname = re.search("[^ ]*::[^ ]*::[^ ]*", testname)[0]

if code == 0:
passing.append(testname)
else:
failing.append(testname)
return passing, failing

def write_results(filename, results):
if (filename is not None):
with open(filename, 'w') as f:
for line in results:
f.write(line+"\n")

def load_results(filename):
if not filename:
return []
expected = []
with open(filename, 'r') as f:
for line in f:
expected.append(line.strip())
return expected

def compare_results(expected, passing):
passing = set(passing)
expected = set(expected)
new_failures = expected - passing
new_passing = passing - expected
return new_passing, new_failures


tests = get_tests(args.testfiles)
tests = tests[:36]

print("Generating test suite")
test_commands = generate_test_commands(tests, timeout=True)

print (f"Executing {len(test_commands)} tests")
passing, failing = exec_testsuite(test_commands)
expected = load_results(args.expected)

write_results(args.passing, passing)
write_results(args.failing, failing)

print("Total:", len(test_commands))
print("Passing:", len(passing))
print("Failing:", len(failing))

if expected:
new_passing, new_failures = compare_results(expected, passing)

if new_passing:
print ("Newly Passing Tests:")
for test in new_passing:
print(" ", test)

if new_failures:
print ("Newly Failing Tests:")
for test in new_failures:
print(" ", test)

if len(expected) > len(passing):
exit(1)

0 comments on commit ca0218a

Please sign in to comment.