diff --git a/flax/funcs.py b/flax/funcs.py index 9d93511..4805ad7 100644 --- a/flax/funcs.py +++ b/flax/funcs.py @@ -724,7 +724,10 @@ def type2strn(x): def transpose(x, filler=None): """transpose: transpose x""" - return [[j for j in i if j is not None] for i in itertools.zip_longest(*[iterable(i) for i in x], fillvalue=filler)] + return [ + [j for j in i if j is not None] + for i in itertools.zip_longest(*[iterable(i) for i in x], fillvalue=filler) + ] def trim(w, x): diff --git a/test/test_funcs.py b/test/test_funcs.py index 4f913c8..3c405c9 100644 --- a/test/test_funcs.py +++ b/test/test_funcs.py @@ -64,23 +64,40 @@ def test_depth(): def test_diagonal_leading(): - assert diagonal_leading([[1,0,2],[2,3,4],[5,6,7]]) == [1,3,7] + assert diagonal_leading([[1, 0, 2], [2, 3, 4], [5, 6, 7]]) == [1, 3, 7] + def test_diagonal_trailing(): - assert diagonal_trailing([[1,0,2],[2,3,4],[5,6,7]]) == [2,3,5] + assert diagonal_trailing([[1, 0, 2], [2, 3, 4], [5, 6, 7]]) == [2, 3, 5] + def test_diagonals(): - assert diagonals([[1,2,3],[4,5,6],[7,8,9]]) == [[1], [4, 2], [7, 5, 3], [8, 6], [9]] - assert diagonals([[1,2,3],[4,5,6],[7,8,9]], antidiagonals=True) == [[7], [4, 8], [1, 5, 9], [2, 6], [3]] + assert diagonals([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) == [ + [1], + [4, 2], + [7, 5, 3], + [8, 6], + [9], + ] + assert diagonals([[1, 2, 3], [4, 5, 6], [7, 8, 9]], antidiagonals=True) == [ + [7], + [4, 8], + [1, 5, 9], + [2, 6], + [3], + ] + def test_digits(): assert digits(123) == [1, 2, 3] - assert digits(3.1415) == [3,1,4,1,5] - assert digits(mpc(123,456)) == [mpc(1,4), mpc(2,5), mpc(3,6)] + assert digits(3.1415) == [3, 1, 4, 1, 5] + assert digits(mpc(123, 456)) == [mpc(1, 4), mpc(2, 5), mpc(3, 6)] + def test_digits_i(): - assert digits_i([3,1,4,1,5]) == 31415 - assert digits_i([mpc(1,4), mpc(2,5), mpc(3,6)]) == mpc(123,456) + assert digits_i([3, 1, 4, 1, 5]) == 31415 + assert digits_i([mpc(1, 4), mpc(2, 5), mpc(3, 6)]) == mpc(123, 456) + "test_enumerate_md" "test_ensure_square"