diff --git a/tests/test_state_dict.py b/tests/test_state_dict.py index d85440d..ef14f86 100644 --- a/tests/test_state_dict.py +++ b/tests/test_state_dict.py @@ -186,7 +186,7 @@ def init(Block2, key): # check for some keys: assert "0.a.0.bias" in state_dict - assert "1.a.1." in state_dict + assert "1.a.1.weight" in state_dict z = tree_zeros_like(m_stacked)