From 4c0b90e4aaf6decb83f0c4254b4586ea5954e2d3 Mon Sep 17 00:00:00 2001 From: Alex Rogozhnikov Date: Fri, 20 Dec 2024 17:51:36 -0800 Subject: [PATCH] add tests for different failure modes --- einops/tests/test_layers.py | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/einops/tests/test_layers.py b/einops/tests/test_layers.py index 6d72e635..a70a5d67 100644 --- a/einops/tests/test_layers.py +++ b/einops/tests/test_layers.py @@ -4,7 +4,7 @@ import numpy import pytest -from einops import rearrange, reduce +from einops import rearrange, reduce, EinopsError from einops.tests import collect_test_backends, is_backend_tested, FLOAT_REDUCTIONS as REDUCTIONS __author__ = "Alex Rogozhnikov" @@ -433,3 +433,37 @@ def test_einmix_decomposition(): assert mixin7.einsum_pattern == "a...bc,cdb->a...db" assert mixin7.saved_weight_shape == [3, 4, 2] assert mixin7.saved_bias_shape == [1, 4, 2] # (a) d b, ellipsis does not participate + + +def test_einmix_restrictions(): + """ + Testing different cases + """ + from einops.layers._einmix import _EinmixDebugger + + with pytest.raises(EinopsError): + _EinmixDebugger( + "a b c d e -> e d c b a", + weight_shape="d a b", + d=2, a=3, # missing b + ) # fmt: off + + with pytest.raises(EinopsError): + _EinmixDebugger( + "a b c d e -> e d c b a", + weight_shape="w a b", + d=2, a=3, b=1 # missing d + ) # fmt: off + + with pytest.raises(EinopsError): + _EinmixDebugger( + "(...) a -> ... a", + weight_shape="a", a=1, # ellipsis on the left + ) # fmt: off + + with pytest.raises(EinopsError): + _EinmixDebugger( + "(...) a -> a ...", + weight_shape="a", a=1, # ellipsis on the right side after bias axis + bias_shape='a', + ) # fmt: off