Skip to content

Commit

Permalink
Add support for fmpz_mod and nmod interop
Browse files Browse the repository at this point in the history
  • Loading branch information
Jake Moss committed Aug 16, 2024
1 parent 4526d2b commit 043ca9d
Show file tree
Hide file tree
Showing 4 changed files with 458 additions and 152 deletions.
32 changes: 9 additions & 23 deletions src/flint/test/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -2755,13 +2755,13 @@ def _all_mpolys():
(
flint.fmpz_mod_mpoly,
lambda *args, **kwargs: flint.fmpz_mod_mpoly_ctx.get_context(*args, **kwargs, modulus=101),
flint.fmpz,
lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(101)),
True,
),
(
flint.fmpz_mod_mpoly,
lambda *args, **kwargs: flint.fmpz_mod_mpoly_ctx.get_context(*args, **kwargs, modulus=100),
flint.fmpz,
lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(100)),
False,
),
(
Expand Down Expand Up @@ -2981,17 +2981,13 @@ def quick_poly():
assert +quick_poly() \
== quick_poly()

assert -quick_poly() \
== mpoly(
{(0, 0): -1, (0, 1): -2, (1, 0): -3, (2, 2): -4} if P is not flint.nmod_mpoly
else {k: ctx.modulus() + v for k, v in {(0, 0): -1, (0, 1): -2, (1, 0): -3, (2, 2): -4}.items()}
)
assert -quick_poly() == mpoly({(0, 0): -1, (0, 1): -2, (1, 0): -3, (2, 2): -4})

assert quick_poly() \
+ mpoly({(0, 0): 5, (0, 1): 6, (1, 0): 7, (2, 2): 8}) \
== mpoly({(0, 0): 6, (0, 1): 8, (1, 0): 10, (2, 2): 12})

for T in [int, S, lambda x: P(x, ctx=ctx)]:
for T in [int, S, flint.fmpz, lambda x: P(x, ctx=ctx)]:
p = quick_poly()
p += T(1)
q = quick_poly()
Expand All @@ -3008,22 +3004,15 @@ def quick_poly():
assert raises(lambda: quick_poly().iadd(None), NotImplementedError)

assert quick_poly() - mpoly({(0, 0): 5, (0, 1): 6, (1, 0): 7, (2, 2): 8}) \
== mpoly(
{(0, 0): -4, (0, 1): -4, (1, 0): -4, (2, 2): -4} if P is not flint.nmod_mpoly
else {k: ctx.modulus() + v for k, v in {(0, 0): -4, (0, 1): -4, (1, 0): -4, (2, 2): -4}.items()}
)
== mpoly({(0, 0): -4, (0, 1): -4, (1, 0): -4, (2, 2): -4})

for T in [int, S, int, lambda x: P(x, ctx=ctx)]:
for T in [int, S, flint.fmpz, lambda x: P(x, ctx=ctx)]:
p = quick_poly()
p -= T(1)
q = quick_poly()
assert q.isub(T(1)) is None
assert quick_poly() - T(1) == p == q == mpoly({(0, 1): 2, (1, 0): 3, (2, 2): 4})
assert T(1) - quick_poly() == \
mpoly(
{(0, 1): -2, (1, 0): -3, (2, 2): -4} if P is not flint.nmod_mpoly
else {k: ctx.modulus() + v for k, v in {(0, 1): -2, (1, 0): -3, (2, 2): -4}.items()}
)
assert T(1) - quick_poly() == mpoly({(0, 1): -2, (1, 0): -3, (2, 2): -4})

assert raises(lambda: quick_poly() - None, TypeError)
assert raises(lambda: None - quick_poly(), TypeError)
Expand All @@ -3042,7 +3031,7 @@ def quick_poly():
(0, 1): 6
})

for T in [int, S, int, lambda x: P(x, ctx=ctx)]:
for T in [int, S, flint.fmpz, lambda x: P(x, ctx=ctx)]:
p = quick_poly()
p *= T(2)
q = quick_poly()
Expand Down Expand Up @@ -3083,10 +3072,7 @@ def quick_poly():
assert divmod(quick_poly(), S(1)) == (quick_poly(), P(ctx=ctx))

if is_field:
if (P is flint.fmpz_mod_mpoly or P is flint.nmod_mpoly):
assert quick_poly() / 3 == mpoly({(0, 0): S(34), (0, 1): S(68), (1, 0): S(1), (2, 2): S(35)})
else:
assert quick_poly() / 3 == mpoly({(0, 0): S(1, 3), (0, 1): S(2, 3), (1, 0): S(1), (2, 2): S(4, 3)})
assert quick_poly() / 3 == mpoly({(0, 0): S(1) / 3, (0, 1): S(2) / 3, (1, 0): S(1), (2, 2): S(4) / 3})
else:
assert raises(lambda: quick_poly() / 3, DomainError)

Expand Down
Loading

0 comments on commit 043ca9d

Please sign in to comment.