-
Notifications
You must be signed in to change notification settings - Fork 310
/
Copy pathmultivariable_crt.py
57 lines (44 loc) · 1.46 KB
/
multivariable_crt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
mat_sub = lambda A, B: [[i - j for i, j in zip(*row)] for row in zip(A, B)]
mat_mul = lambda A, B: [[sum(i * j for i, j in zip(row, col)) for col in zip(*B)] for row in A]
def gcd(x, y):
"""greatest common divisor of x and y"""
while y:
x, y = y, x % y
return x
def extended_gcd(a, b):
"""returns gcd(a, b), s, r s.t. a * s + b * r == gcd(a, b)"""
s, old_s = 0, 1
r, old_r = b, a
while r:
q = old_r // r
old_r, r = r, old_r - q * r
old_s, s = s, old_s - q * s
return old_r, old_s, (old_r - old_s * a) // b if b else 0
def modinv(a, m):
"""returns the modular inverse of a w.r.t. to m"""
amodm = a % m
g, x, _ = extended_gcd(amodm, m)
return x % m if g == 1 else None
def pivot(A, m):
"""returns the pivot of A and m"""
result = [0] * len(A)
for i, Ai in enumerate(A):
for j, Aij in enumerate(Ai):
if gcd(Aij, m[i]) == 1:
result[i] = j
return result
def is_sol(A, x, b, m):
"""checks if Ax = b mod m"""
ax_b = mat_sub(mat_mul(A, x), b)
return not any(ax_b[i] % mod for i, mod in enumerate(m))
def mcrt(A, b, m):
"""returns x s.t. Ax = b mod m"""
piv = pivot(A, m)
x = [0] * len(A)
m_prod = 1
for i, Ai in enumerate(A):
tot = sum(Ai[k] * x[k] for k in range(len(A)))
tmp = (modinv(m_prod * Ai[piv[i]], m[i]) * (b[i] - tot)) % m[i]
x[piv[i]] += tmp * m_prod
m_prod *= m[i]
return x