diff --git a/flax/funcs.py b/flax/funcs.py index 0f4e3b2..cfccd4e 100644 --- a/flax/funcs.py +++ b/flax/funcs.py @@ -352,21 +352,23 @@ def iota(x): elif type2strn(x) == "mpc": return [[mpc(j[0], j[1]) for j in i] for i in iota([x.real, x.imag])] else: - res = cartesian_product(*(iota(a) for a in x)) - for e in x: - res = split(int(abs(e)) if type2strn(e) == "mpc" else int(e), res) + res = cartesian_product(*(iota(i) for i in x)) + for i in x: + res = split(int(abs(i)) if type2strn(i) != "lst" else len(i), res) return res[0] def iota1(x): """iota1: iota but 1 based""" if type2strn(x) in ["int", "dec"]: - return [i + 1 for i in range(int(x))] - - res = cartesian_product(*([i + 1 for i in range(int(a))] for a in x)) - for e in x: - res = split(int(e), res) - return res[0] + return list(range(1, int(x) + 1)) + elif type2strn(x) == "mpc": + return [[mpc(j[0], j[1]) for j in i] for i in iota1([x.real, x.imag])] + else: + res = cartesian_product(*(iota1(i) for i in x)) + for i in x: + res = split(int(abs(i)) if type2strn(i) != "lst" else len(i), res) + return res[0] def iterable(x, digits_=False, range_=False, copy_=False):