-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added a github action for running the jax testsuite
- Loading branch information
Showing
2 changed files
with
189 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Installs prebuilt binaries and runs jax testsuite | ||
|
||
name: Run CPU Jax Testsuite | ||
|
||
on: | ||
workflow_dispatch: | ||
schedule: | ||
# Do the nightly dep roll at 2:30 PDT. | ||
- cron: '30 21 * * *' | ||
|
||
env: | ||
# This duplicates the variable from ci.yml. The variable needs to be in env | ||
# instead of the outputs of setup because it contains the run attempt and we | ||
# want that to be the current attempt, not whatever attempt the setup step | ||
# last ran in. It therefore can't be passed in via inputs because the env | ||
# context isn't available there. | ||
GCS_DIR: gs://iree-github-actions-jax-testsuite-artifacts/ | ||
|
||
|
||
concurrency: | ||
# A PR number if a pull request and otherwise the commit hash. This cancels | ||
# queued and in-progress runs for the same PR (presubmit) or commit | ||
# (postsubmit). | ||
group: run_jax_testsuite_${{ github.event.number || github.sha }} | ||
cancel-in-progress: true | ||
|
||
# Jobs are organized into groups and topologically sorted by dependencies | ||
jobs: | ||
build: | ||
runs-on: ubuntu-20.04-64core | ||
steps: | ||
- name: Get current date | ||
id: date | ||
# Sets up: ${{ steps.date.outputs.date }} | ||
run: echo "::set-output name=date::$(date +'%Y-%m-%d')" | ||
|
||
- name: "Checking out repository" | ||
uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0 | ||
|
||
- name: "Setting up Python" | ||
uses: actions/setup-python@75f3110429a8c05be0e1bf360334e4cced2b63fa # v2.3.3 | ||
with: | ||
python-version: "3.10" | ||
|
||
- name: Sync and install versions | ||
run: | | ||
# TODO: https://github.com/openxla/openxla-pjrt-plugin/issues/30 | ||
sudo apt install -y lld | ||
# Since only building the runtime, exclude compiler deps (expensive). | ||
python ./sync_deps.py --depth 1 --submodules-depth 1 --exclude-submodule "iree:third_party/(llvm|mlir-hlo)" | ||
pip install -r requirements.txt | ||
- name: Setup Bazelisk | ||
uses: bazelbuild/setup-bazelisk@v2 | ||
|
||
- name: "Configure" | ||
run: | | ||
python ./configure.py --cc=clang --cxx=clang++ | ||
- name: "Build CPU Plugin" | ||
run: | | ||
bazel build //iree/integrations/pjrt/cpu:pjrt_plugin_iree_cpu.so | ||
- name: "Run JAX Testsuite" | ||
run: | | ||
source .env.sh | ||
JAX_PLATFORMS=iree_cpu python test/test_jax.py external/jax/tests/nn_test.py \ | ||
--passing jaxsuite_passing.txt \ | ||
--failing jaxsuite_failing.txt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import argparse | ||
import multiprocessing | ||
import re | ||
import subprocess | ||
import sys | ||
|
||
parser = argparse.ArgumentParser( | ||
prog='test_jax.py', | ||
description='Run jax testsuite hermetically') | ||
parser.add_argument('testfiles', nargs="*") | ||
parser.add_argument('-t', '--timeout', default=20) | ||
parser.add_argument('-l', '--logdir', default="/tmp/jaxtest") | ||
parser.add_argument('-p', '--passing', default=None) | ||
parser.add_argument('-f', '--failing', default=None) | ||
parser.add_argument('-e', '--expected', default=None) | ||
|
||
args = parser.parse_args() | ||
|
||
PYTEST_CMD= ["pytest", "-p", "openxla_pjrt_artifacts", f"--openxla-pjrt-artifact-dir={args.logdir}"] | ||
|
||
def get_tests(tests): | ||
testlist = [] | ||
for test in sorted(tests): | ||
stdout = subprocess.run(PYTEST_CMD + ["--setup-only", test], capture_output=True) | ||
testlist += re.findall('::[^ ]*::[^ ]*', str(stdout)) | ||
testlist = [test + func for func in testlist] | ||
return testlist | ||
|
||
def generate_test_commands(tests, timeout=False): | ||
cmd = ["timeout", f"{args.timeout}"] if timeout else [] | ||
cmd += PYTEST_CMD | ||
cmds = [] | ||
for test in tests: | ||
test_cmd = cmd + [test] | ||
cmds.append(test_cmd) | ||
|
||
return cmds | ||
|
||
def exec_test(command): | ||
result = subprocess.run(command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) | ||
sys.stdout.write(".") | ||
sys.stdout.flush() | ||
return result.returncode | ||
|
||
def exec_testsuite(commands): | ||
returncodes = [] | ||
with multiprocessing.Pool() as p: | ||
returncodes = p.map(exec_test, commands) | ||
print("") | ||
passing = [] | ||
failing = [] | ||
for code, cmd in zip(returncodes, commands): | ||
testname = " ".join(cmd) | ||
testname = re.search("[^ ]*::[^ ]*::[^ ]*", testname)[0] | ||
|
||
if code == 0: | ||
passing.append(testname) | ||
else: | ||
failing.append(testname) | ||
return passing, failing | ||
|
||
def write_results(filename, results): | ||
if (filename is not None): | ||
with open(filename, 'w') as f: | ||
for line in results: | ||
f.write(line+"\n") | ||
|
||
def load_results(filename): | ||
if not filename: | ||
return [] | ||
expected = [] | ||
with open(filename, 'r') as f: | ||
for line in f: | ||
expected.append(line.strip()) | ||
return expected | ||
|
||
def compare_results(expected, passing): | ||
passing = set(passing) | ||
expected = set(expected) | ||
new_failures = expected - passing | ||
new_passing = passing - expected | ||
return new_passing, new_failures | ||
|
||
|
||
tests = get_tests(args.testfiles) | ||
tests = tests[:36] | ||
|
||
print("Generating test suite") | ||
test_commands = generate_test_commands(tests, timeout=True) | ||
|
||
print (f"Executing {len(test_commands)} tests") | ||
passing, failing = exec_testsuite(test_commands) | ||
expected = load_results(args.expected) | ||
|
||
write_results(args.passing, passing) | ||
write_results(args.failing, failing) | ||
|
||
print("Total:", len(test_commands)) | ||
print("Passing:", len(passing)) | ||
print("Failing:", len(failing)) | ||
|
||
if expected: | ||
new_passing, new_failures = compare_results(expected, passing) | ||
|
||
if new_passing: | ||
print ("Newly Passing Tests:") | ||
for test in new_passing: | ||
print(" ", test) | ||
|
||
if new_failures: | ||
print ("Newly Failing Tests:") | ||
for test in new_failures: | ||
print(" ", test) | ||
|
||
if len(expected) > len(passing): | ||
exit(1) |