From 106d3e1f4ff0302aff4fb287f4502647e8b89fdb Mon Sep 17 00:00:00 2001 From: caetano melone Date: Thu, 7 Mar 2024 20:53:20 -0800 Subject: [PATCH] update tests to use spec strings instead of dict payloads --- gantry/tests/defs/prediction.py | 50 +++++++------------------------- gantry/tests/test_prediction.py | 51 ++++++++++++++++----------------- 2 files changed, 36 insertions(+), 65 deletions(-) diff --git a/gantry/tests/defs/prediction.py b/gantry/tests/defs/prediction.py index 9da2ed4..90408a1 100644 --- a/gantry/tests/defs/prediction.py +++ b/gantry/tests/defs/prediction.py @@ -1,62 +1,34 @@ # flake8: noqa # fmt: off -NORMAL_BUILD = { - "hash": "testing", - "package": { - "name": "py-torch", - "version": "2.2.1", - "variants": "~caffe2+cuda+cudnn~debug+distributed+fbgemm+gloo+kineto~magma~metal+mkldnn+mpi~nccl+nnpack+numa+numpy+onnx_ml+openmp+qnnpack~rocm+tensorpipe~test+valgrind+xnnpack build_system=python_pip cuda_arch=80", - }, - "compiler": { - "name": "gcc", - "version": "11.4.0", - }, -} +from gantry.util.spec import parse_alloc_spec + +NORMAL_BUILD = parse_alloc_spec( + "py-torch@2.2.1 ~caffe2+cuda+cudnn~debug+distributed+fbgemm+gloo+kineto~magma~metal+mkldnn+mpi~nccl+nnpack+numa+numpy+onnx_ml+openmp+qnnpack~rocm+tensorpipe~test+valgrind+xnnpack build_system=python_pip cuda_arch=80%gcc@11.4.0" +) # everything in NORMAL_BUILD["package"]["variants"] except removing build_system=python_pip # in order to test the expensive variants filter -EXPENSIVE_VARIANT_BUILD = { - "hash": "testing", - "package": { - "name": "py-torch", - "version": "2.2.1", - "variants": "~caffe2+cuda+cudnn~debug+distributed+fbgemm+gloo+kineto~magma~metal+mkldnn+mpi~nccl+nnpack+numa+numpy+onnx_ml+openmp+qnnpack~rocm+tensorpipe~test+valgrind+xnnpack cuda_arch=80", - }, - "compiler": { - "name": "gcc", - "version": "11.4.0", - }, -} +EXPENSIVE_VARIANT_BUILD = parse_alloc_spec( + "py-torch@2.2.1 ~caffe2+cuda+cudnn~debug+distributed+fbgemm+gloo+kineto~magma~metal+mkldnn+mpi~nccl+nnpack+numa+numpy+onnx_ml+openmp+qnnpack~rocm+tensorpipe~test+valgrind+xnnpack cuda_arch=80%gcc@11.4.0" +) # no variants should match this, so we expect the default prediction -BAD_VARIANT_BUILD = { - "hash": "testing", - "package": { - "name": "py-torch", - "version": "2.2.1", - "variants": "+no~expensive~variants+match", - }, - "compiler": { - "name": "gcc", - "version": "11.4.0", - }, -} +BAD_VARIANT_BUILD = parse_alloc_spec( + "py-torch@2.2.1 +no~expensive~variants+match%gcc@11.4.0" +) # calculated by running the baseline prediction algorithm on the sample data in gantry/tests/sql/insert_prediction.sql NORMAL_PREDICTION = { - "hash": "testing", "variables": { "KUBERNETES_CPU_REQUEST": "12", "KUBERNETES_MEMORY_REQUEST": "9576M", }, } - # this is what will get returned when there are no samples in the database # that match what the client wants DEFAULT_PREDICTION = { - "hash": "testing", "variables": { "KUBERNETES_CPU_REQUEST": "1", "KUBERNETES_MEMORY_REQUEST": "2000M", diff --git a/gantry/tests/test_prediction.py b/gantry/tests/test_prediction.py index ac56eb8..8f93055 100644 --- a/gantry/tests/test_prediction.py +++ b/gantry/tests/test_prediction.py @@ -2,7 +2,7 @@ from gantry.routes.prediction import prediction from gantry.tests.defs import prediction as defs -from gantry.util.prediction import validate_payload +from gantry.util.spec import parse_alloc_spec @pytest.fixture @@ -57,7 +57,7 @@ async def test_partial_match(db_conn_inserted): # same as NORMAL_BUILD, but with a different compiler name to test partial matching diff_compiler_build = defs.NORMAL_BUILD.copy() - diff_compiler_build["compiler"]["name"] = "gcc-different" + diff_compiler_build["compiler_name"] = "gcc-different" assert ( await prediction.predict_single(db_conn_inserted, diff_compiler_build) @@ -75,37 +75,36 @@ async def test_empty_sample(db_conn): # Test validate_payload +def test_valid_spec(): + """Tests that a valid spec is parsed correctly.""" + assert parse_alloc_spec("emacs@29.2 +json+native+treesitter%gcc@12.3.0") == { + "pkg_name": "emacs", + "pkg_version": "29.2", + "pkg_variants": '{"json": true, "native": true, "treesitter": true}', + "pkg_variants_dict": {"json": True, "native": True, "treesitter": True}, + "compiler_name": "gcc", + "compiler_version": "12.3.0", + } -def test_valid_payload(): - """Tests that a valid payload returns True""" - assert validate_payload(defs.NORMAL_BUILD) is True +def test_invalid_specs(): + """Test a series of invalid specs""" + # not a spec + assert parse_alloc_spec("hi") == {} -def test_invalid_payloads(): - """Test a series of invalid payloads""" - - # non dict - assert validate_payload("hi") is False - - build = defs.NORMAL_BUILD.copy() # missing package - del build["package"] - assert validate_payload(build) is False + assert parse_alloc_spec("@29.2 +json+native+treesitter%gcc@12.3.0") == {} - build = defs.NORMAL_BUILD.copy() # missing compiler - del build["compiler"] - assert validate_payload(build) is False + assert parse_alloc_spec("emacs@29.2 +json+native+treesitter") == {} + + # variants not spaced correctly + assert parse_alloc_spec("emacs@29.2+json+native+treesitter%gcc@12.3.0") == {} - # name and version are strings in the package and compiler - for key in ["name", "version"]: - for field in ["package", "compiler"]: - build = defs.NORMAL_BUILD.copy() - build[field][key] = 123 - assert validate_payload(build) is False + # missing versions + assert parse_alloc_spec("emacs@29.2 +json+native+treesitter%gcc@") == {} + assert parse_alloc_spec("emacs@ +json+native+treesitter%gcc@12.3.0") == {} # invalid variants - build = defs.NORMAL_BUILD.copy() - build["package"]["variants"] = "+++++" - assert validate_payload(build) is False + assert parse_alloc_spec("emacs@29.2 this_is_not_a_thing%gcc@12.3.0") == {}