


from Crypto.Util.number import *
from secret import flag

p = getPrime(512)
q = getPrime(512)
n = p * q
d = getPrime(299)
e = inverse(d,(p-1)*(q-1))
m = bytes_to_long(flag)
c = pow(m,e,n)
hint1 = p >> (512-70)
hint2 = q >> (512-70)

print(f"n = {n}")
print(f"e = {e}")
print(f"c = {c}")
print(f"hint1 = {hint1}")
print(f"hint2 = {hint2}")

n = 123789043095302886784777548580725867919630872720308267296330863659260260632444171595208750648710642616709290340791408935502415290984231140635423328808872594955139658822363033096014857287439409252367248420356169878044065798634016290690979979625051287064109800759113475629317869327100941592970373827299442569489
e = 112070481298571389221611833986644006256566240788306316765530852688390558290807060037831460397016038678699757261874520899143918664293504728402666398893964929840011110057060969775245481057773655679041350091817099143204028098431544760662690479779286160425059494739419234859710815966582837874194763305328789592245
c = 63662561509209168743977531923281040338804656992093161358503738280395090747786427812762995865224617853709000826994250614233562094619845247321880231488631212423212167167713869682181551433686816142488666533035193128298379649809096863305651271646535125466745409868274019550361728139482502448613835444108383177119
hint1 = 897446442156802074692
hint2 = 1069442646630079275131
# sage
from Crypto.Util.number import *
import itertools

Setting debug to true will display more informations
about the lattice, the bounds, the vectors...
debug = True

Setting strict to true will stop the algorithm (and
return (-1, -1)) if we don't have a correct
upperbound on the determinant. Note that this
doesn't necesseraly mean that no solutions
will be found since the theoretical upperbound is
usualy far away from actual results. That is why
you should probably use `strict = False`
strict = False

This is experimental, but has provided remarkable results
so far. It tries to reduce the lattice as much as it can
while keeping its efficiency. I see no reason not to use
this option, but if things don't work, you should try
disabling it
helpful_only = True
dimension_min = 7 # stop removing if lattice reaches that dimension

# Functions

# display stats on helpful vectors
def helpful_vectors(BB, modulus):
nothelpful = 0
for ii in range(BB.dimensions()[0]):
if BB[ii,ii] >= modulus:
nothelpful += 1

print (nothelpful, "/", BB.dimensions()[0], " vectors are not helpful")

# display matrix picture with 0 and X
def matrix_overview(BB, bound):
for ii in range(BB.dimensions()[0]):
a = ('%02d ' % ii)
for jj in range(BB.dimensions()[1]):
a += '0' if BB[ii,jj] == 0 else 'X'
if BB.dimensions()[0] < 60:
a += ' '
if BB[ii, ii] >= bound:
a += '~'
print (a)

# tries to remove unhelpful vectors
# we start at current = n-1 (last vector)
def remove_unhelpful(BB, monomials, bound, current):
# end of our recursive function
if current == -1 or BB.dimensions()[0] <= dimension_min:
return BB

# we start by checking from the end
for ii in range(current, -1, -1):
# if it is unhelpful:
if BB[ii, ii] >= bound:
affected_vectors = 0
affected_vector_index = 0
# let's check if it affects other vectors
for jj in range(ii + 1, BB.dimensions()[0]):
# if another vector is affected:
# we increase the count
if BB[jj, ii] != 0:
affected_vectors += 1
affected_vector_index = jj

# level:0
# if no other vectors end up affected
# we remove it
if affected_vectors == 0:
print ("* removing unhelpful vector", ii)
BB = BB.delete_columns([ii])
BB = BB.delete_rows([ii])
BB = remove_unhelpful(BB, monomials, bound, ii-1)
return BB

# level:1
# if just one was affected we check
# if it is affecting someone else
elif affected_vectors == 1:
affected_deeper = True
for kk in range(affected_vector_index + 1, BB.dimensions()[0]):
# if it is affecting even one vector
# we give up on this one
if BB[kk, affected_vector_index] != 0:
affected_deeper = False
# remove both it if no other vector was affected and
# this helpful vector is not helpful enough
# compared to our unhelpful one
if affected_deeper and abs(bound - BB[affected_vector_index, affected_vector_index]) < abs(bound - BB[ii, ii]):
print ("* removing unhelpful vectors", ii, "and", affected_vector_index)
BB = BB.delete_columns([affected_vector_index, ii])
BB = BB.delete_rows([affected_vector_index, ii])
BB = remove_unhelpful(BB, monomials, bound, ii-1)
return BB
# nothing happened
return BB

* 0,0 if it fails
* -1,-1 if `strict=true`, and determinant doesn't bound
* x0,y0 the solutions of `pol`
def boneh_durfee(pol, modulus, mm, tt, XX, YY):
Boneh and Durfee revisited by Herrmann and May

finds a solution if:
* d < N^delta
* |x| < e^delta
* |y| < e^0.5
whenever delta < 1 - sqrt(2)/2 ~ 0.292

# substitution (Herrman and May)
PR.<u, x, y> = PolynomialRing(ZZ)
Q = PR.quotient(x*y + 1 - u) # u = xy + 1
polZ = Q(pol).lift()

UU = XX*YY + 1

# x-shifts
gg = []
for kk in range(mm + 1):
for ii in range(mm - kk + 1):
xshift = x^ii * modulus^(mm - kk) * polZ(u, x, y)^kk

# x-shifts list of monomials
monomials = []
for polynomial in gg:
for monomial in polynomial.monomials():
if monomial not in monomials:

# y-shifts (selected by Herrman and May)
for jj in range(1, tt + 1):
for kk in range(floor(mm/tt) * jj, mm + 1):
yshift = y^jj * polZ(u, x, y)^kk * modulus^(mm - kk)
yshift = Q(yshift).lift()
gg.append(yshift) # substitution

# y-shifts list of monomials
for jj in range(1, tt + 1):
for kk in range(floor(mm/tt) * jj, mm + 1):
monomials.append(u^kk * y^jj)

# construct lattice B
nn = len(monomials)
BB = Matrix(ZZ, nn)
for ii in range(nn):
BB[ii, 0] = gg[ii](0, 0, 0)
for jj in range(1, ii + 1):
if monomials[jj] in gg[ii].monomials():
BB[ii, jj] = gg[ii].monomial_coefficient(monomials[jj]) * monomials[jj](UU,XX,YY)

# Prototype to reduce the lattice
if helpful_only:
# automatically remove
BB = remove_unhelpful(BB, monomials, modulus^mm, nn-1)
# reset dimension
nn = BB.dimensions()[0]
if nn == 0:
print ("failure")
return 0,0

# check if vectors are helpful
if debug:
helpful_vectors(BB, modulus^mm)

# check if determinant is correctly bounded
det = BB.det()
bound = modulus^(mm*nn)
if det >= bound:
print ("We do not have det < bound. Solutions might not be found.")
print ("Try with highers m and t.")
if debug:
diff = (log(det) - log(bound)) / log(2)
print ("size det(L) - size e^(m*n) = ", floor(diff))
if strict:
return -1, -1
print ("det(L) < e^(m*n) (good! If a solution exists < N^delta, it will be found)")

# display the lattice basis
if debug:
matrix_overview(BB, modulus^mm)

if debug:
print ("optimizing basis of the lattice via LLL, this can take a long time")


if debug:
print ("LLL is done!")

# transform vector i & j -> polynomials 1 & 2
if debug:
print ("looking for independent vectors in the lattice")
found_polynomials = False

for pol1_idx in range(nn - 1):
for pol2_idx in range(pol1_idx + 1, nn):
# for i and j, create the two polynomials
PR.<w,z> = PolynomialRing(ZZ)
pol1 = pol2 = 0
for jj in range(nn):
pol1 += monomials[jj](w*z+1,w,z) * BB[pol1_idx, jj] / monomials[jj](UU,XX,YY)
pol2 += monomials[jj](w*z+1,w,z) * BB[pol2_idx, jj] / monomials[jj](UU,XX,YY)

# resultant
PR.<q> = PolynomialRing(ZZ)
rr = pol1.resultant(pol2)

# are these good polynomials?
if rr.is_zero() or rr.monomials() == [1]:
print ("found them, using vectors", pol1_idx, "and", pol2_idx)
found_polynomials = True
if found_polynomials:

if not found_polynomials:
print ("no independant vectors could be found. This should very rarely happen...")
return 0, 0

rr = rr(q, q)

# solutions
soly = rr.roots()

if len(soly) == 0:
print ("Your prediction (delta) is too small")
return 0, 0

soly = soly[0][0]
ss = pol1(q, soly)
solx = ss.roots()[0][0]

return solx, soly

nbit = 1024
n = 123789043095302886784777548580725867919630872720308267296330863659260260632444171595208750648710642616709290340791408935502415290984231140635423328808872594955139658822363033096014857287439409252367248420356169878044065798634016290690979979625051287064109800759113475629317869327100941592970373827299442569489
e = 112070481298571389221611833986644006256566240788306316765530852688390558290807060037831460397016038678699757261874520899143918664293504728402666398893964929840011110057060969775245481057773655679041350091817099143204028098431544760662690479779286160425059494739419234859710815966582837874194763305328789592245
c = 63662561509209168743977531923281040338804656992093161358503738280395090747786427812762995865224617853709000826994250614233562094619845247321880231488631212423212167167713869682181551433686816142488666533035193128298379649809096863305651271646535125466745409868274019550361728139482502448613835444108383177119
hint1 = 897446442156802074692
hint2 = 1069442646630079275131
ph = int(hint1<<442)
qh = int(hint2<<442)

#part1 boneh and durfee to get d
m = 8
delta = 0.30
t = int((1-2*delta) * m) # optimization from Herrmann and May
X = 2*floor(n^delta) # this _might_ be too much
Y = 2^442 # correct if p, q are ~ same size

A = (n+1)//2 - (ph+qh)//2
PR.<x,y> = PolynomialRing(ZZ)
pol = 1+x*(A+y)
# res = boneh_durfee(pol, e, m, t, X, Y)
# print(res)

k2,pl_ql = 1476216873354030897123807900312494643299166545025397520836636477078919248983223765331825148, -5248359230286975998007421528138879480670687275342665043065128666287838030017492213260091350908248072270884523636416440683845503716271
d = (1+k2*(A+pl_ql)) // e
# b'wdflag{f2cd2e39-7009-4f8f-89f3-285d50c6ca0f}'


# coding: utf-8
#!/usr/bin/env python2

import gmpy2
import random
import binascii
from hashlib import sha256
from sympy import nextprime
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad
from Crypto.Util.number import long_to_bytes
from FLAG import flag
#flag = 'wdflag{123}'

def victory_encrypt(plaintext, key):
key = key.upper()
key_length = len(key)
plaintext = plaintext.upper()
ciphertext = ''

for i, char in enumerate(plaintext):
if char.isalpha():
shift = ord(key[i % key_length]) - ord('A')
encrypted_char = chr((ord(char) - ord('A') + shift) % 26 + ord('A'))
ciphertext += encrypted_char
ciphertext += char

return ciphertext

victory_key = "WANGDINGCUP"
victory_encrypted_flag = victory_encrypt(flag, victory_key)

p = 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f
a = 0
b = 7
xG = 0x79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798
yG = 0x483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8
G = (xG, yG)
n = 0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141
h = 1
zero = (0,0)

dA = nextprime(random.randint(0, n))

if dA > n:

def addition(t1, t2):
if t1 == zero:
return t2
if t2 == zero:
return t2
(m1, n1) = t1
(m2, n2) = t2
if m1 == m2:
if n1 == 0 or n1 != n2:
return zero
k = (3 * m1 * m1 + a) % p * gmpy2.invert(2 * n1 , p) % p
k = (n2 - n1 + p) % p * gmpy2.invert((m2 - m1 + p) % p, p) % p
m3 = (k * k % p - m1 - m2 + p * 2) % p
n3 = (k * (m1 - m3) % p - n1 + p) % p
return (int(m3),int(n3))

def multiplication(x, k):
ans = zero
t = 1
while(t <= k):
if (k &t )>0:
ans = addition(ans, x)
x = addition(x, x)
t <<= 1
return ans

def getrs(z, k):
(xp, yp) = P
r = xp
s = (z + r * dA % n) % n * gmpy2.invert(k, n) % n
return r,s

z1 = random.randint(0, p)
z2 = random.randint(0, p)
k = random.randint(0, n)
P = multiplication(G, k)
hA = multiplication(G, dA)
r1, s1 = getrs(z1, k)
r2, s2 = getrs(z2, k)

print("r1 = {}".format(r1))
print("r2 = {}".format(r2))
print("s1 = {}".format(s1))
print("s2 = {}".format(s2))
print("z1 = {}".format(z1))
print("z2 = {}".format(z2))

key = sha256(long_to_bytes(dA)).digest()
cipher = AES.new(key, AES.MODE_CBC)
iv = cipher.iv
encrypted_flag = cipher.encrypt(pad(victory_encrypted_flag.encode(), AES.block_size))
encrypted_flag_hex = binascii.hexlify(iv + encrypted_flag).decode('utf-8')

print("Encrypted flag (AES in CBC mode, hex):", encrypted_flag_hex)

k = (z1-z2)*inverse(s1-s2, n) % n得到k

然后根据s = (z + r * dA % n) % n * gmpy2.invert(k, n) % n
解出dAdA = (s1*k-z1)*inverse(r1, n) % n



from hashlib import sha256
from Crypto.Cipher import AES
from Crypto.Util.Padding import unpad
from Crypto.Util.number import *

# 给定的值
r1 = 11455446275324978121918764201975007376236576456980676251003353868423934779741
r2 = 11455446275324978121918764201975007376236576456980676251003353868423934779741
s1 = 40939314972385973234538799073526528418921185412931089406906904671430814294251
s2 = 20521265832237322311837491925934175279870547563516618597235100170259508682349
z1 = 86373733089658748931377346977504497606857910789811212034370600969224209571891
z2 = 114906325375159287808320541183977807561041047925384600130919891200501605749979
n = 0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141
k = (z1-z2)*inverse(s1-s2, n) % n
dA = (s1*k-z1)*inverse(r1, n) % n
key = sha256(long_to_bytes(dA)).digest()
enc_flag = long_to_bytes(
iv = enc_flag[:16]
dec_flag = enc_flag[16:]
aes = AES.new(key, AES.MODE_CBC, iv)
dec_flag = unpad(aes.decrypt(dec_flag),
# Vigenère解密

def victory_decrypt(ciphertext, key):
key = key.upper()
key_length = len(key)
ciphertext = ciphertext.upper()
plaintext = ''

for i, char in enumerate(ciphertext):
if char.isalpha():
shift = ord(key[i % key_length]) - ord('A')
decrypted_char = chr(
(ord(char) - ord('A') - shift) % 26 + ord('A'))
plaintext += decrypted_char
plaintext += char

return plaintext

victory_key = "WANGDINGCUP"
flag = victory_decrypt(dec_flag, victory_key)
print("Recovered flag:", flag.lower())
# Recovered flag: wdflag{8ed62e3b409bb214e6fdb0a78f63569c}
















from Crypto.Util.number import long_to_bytes, bytes_to_long, getPrime
import random, gmpy2

class RSAEncryptor:
def __init__(self):
self.g = self.a = self.b = 0
self.e = 65537

def factorGen(self):
while True:
self.g = getPrime(500)
while not gmpy2.is_prime(2*self.g*self.a+1):
self.a = random.randint(2**523, 2**524)
while not gmpy2.is_prime(2*self.g*self.b+1):
self.b = random.randint(2**523, 2**524)
self.h = 2*self.g*self.a*self.b+self.a+self.b
if gmpy2.is_prime(self.h):
self.N = 2*self.h*self.g+1

def encrypt(self, msg):
return gmpy2.powmod(msg, self.e, self.N)

def product(self):
with open('/flag', 'rb') as f:
self.flag = f.read()
self.enc = self.encrypt(self.flag)

def show(self):





# sage
from sage.groups.generic import bsgs

g =
N =
e =
enc =
nbits = 2048
gamma = 0.2446
cbits = ceil(nbits * (0.5 - 2 * gamma))

M = (N - 1) // (2 * g)
u = M // (2 * g)
v = M - 2 * g * u
GF = Zmod(N)
x = GF.random_element()
y = x ^ (2 * g)
# c的范围大概与N^(0.5-2*gamma)很接近
c = bsgs(y, y ^ u, (Integer(2**(cbits-1)), Integer(2**(cbits+1))))
ab = u - c
apb = v + 2 * g * c
P.<x> = ZZ[]
f = x ^ 2 - apb * x + ab
a = f.roots()
if a:
a, b = a[0][0], a[1][0]
p = 2 * g * a + 1
q = 2 * g * b + 1
assert p * q == N

# 这个库的导入不能写在最上面
from Crypto.Util.number import *
d = inverse(e,(p-1)*(q-1))


from Crypto.Util.number import *
from secrets import flag
from math import ceil
import sys

class RSA():
def __init__(self, privatekey, publickey):
self.p, self.q, self.d = privatekey
self.n, self.e = publickey

def encrypt(self, plaintext):
if isinstance(plaintext, bytes):
plaintext = bytes_to_long(plaintext)
ciphertext = pow(plaintext, self.e, self.n)
return ciphertext

def decrypt(self, ciphertext):
if isinstance(ciphertext, bytes):
ciphertext = bytes_to_long(ciphertext)
plaintext = pow(ciphertext, self.d, self.n)
return plaintext

def get_keypair(nbits, e = 65537):
p = getPrime(nbits//2)
q = getPrime(nbits//2)
n = p * q
d = inverse(e, n - p - q + 1)
return (p, q, d), (n, e)

if __name__ == '__main__':
pt = './output.txt'
fout = open(pt, 'w')
sys.stdout = fout

block_size = ceil(len(flag)/3)
flag = [flag[i:i+block_size] for i in range(0, len(flag), block_size)]
e = 65537

print(f'[+] Welcome to my apbq game')
# stage 1
print(f'┃ stage 1: p + q')
prikey1, pubkey1 = get_keypair(1024)
RSA1 = RSA(prikey1, pubkey1)
enc1 = RSA1.encrypt(flag[0])
print(f'┃ hints = {prikey1[0] + prikey1[1]}')
print(f'┃ public key = {pubkey1}')
print(f'┃ enc1 = {enc1}')

# stage 2
print(f'┃ stage 2: ai*p + bi*q')
prikey2, pubkey2 = get_keypair(1024)
RSA2 = RSA(prikey2, pubkey2)
enc2 = RSA2.encrypt(flag[1])
kbits = 180
a = [getRandomNBitInteger(kbits) for i in range(100)]
b = [getRandomNBitInteger(kbits) for i in range(100)]
c = [a[i]*prikey2[0] + b[i]*prikey2[1] for i in range(100)]
print(f'┃ hints = {c}')
print(f'┃ public key = {pubkey2}')
print(f'┃ enc2 = {enc2}')

# stage 3
print(f'┃ stage 3: a*p + q, p + bq')
prikey3, pubkey3 = get_keypair(1024)
RSA3 = RSA(prikey3, pubkey3)
enc3 = RSA2.encrypt(flag[2])
kbits = 512
a = getRandomNBitInteger(kbits)
b = getRandomNBitInteger(kbits)
c1 = a*prikey3[0] + prikey3[1]
c2 = prikey3[0] + b*prikey3[1]
print(f'┃ hints = {c1, c2}')
print(f'┃ public key = {pubkey3}')
print(f'┃ enc3 = {enc3}')
flag1: 解方程





# sage
import itertools
from Crypto.Util.number import *
from sympy import symbols,solve

p_q = 18978581186415161964839647137704633944599150543420658500585655372831779670338724440572792208984183863860898382564328183868786589851370156024615630835636170
n, e = (89839084450618055007900277736741312641844770591346432583302975236097465068572445589385798822593889266430563039645335037061240101688433078717811590377686465973797658355984717210228739793741484666628342039127345855467748247485016133560729063901396973783754780048949709195334690395217112330585431653872523325589, 65537)
enc1 = 23664702267463524872340419776983638860234156620934868573173546937679196743146691156369928738109129704387312263842088573122121751421709842579634121187349747424486233111885687289480494785285701709040663052248336541918235910988178207506008430080621354232140617853327942136965075461701008744432418773880574136247
p, q = symbols('p q')
res = solve([p*q-n, p+q-p_q], [p, q])
p, q = res[0]
p = int(p)
q = int(q)
d = inverse(e, (p-1)*(q-1))
flag1 = long_to_bytes(int(pow(enc1, d, n))).decode()

n = 73566307488763122580179867626252642940955298748752818919017828624963832700766915409125057515624347299603944790342215380220728964393071261454143348878369192979087090394858108255421841966688982884778999786076287493231499536762158941790933738200959195185310223268630105090119593363464568858268074382723204344819
e = 65537
hints = [18167664006612887319059224902765270796893002676833140278828762753019422055112981842474960489363321381703961075777458001649580900014422118323835566872616431879801196022002065870575408411392402196289546586784096, 16949724497872153018185454805056817009306460834363366674503445555601166063612534131218872220623085757598803471712484993846679917940676468400619280027766392891909311628455506176580754986432394780968152799110962, 17047826385266266053284093678595321710571075374778544212380847321745757838236659172906205102740667602435787521984776486971187349204170431714654733175622835939702945991530565925393793706654282009524471957119991, 25276634064427324410040718861523090738559926416024529567298785602258493027431468948039474136925591721164931318119534505838854361600391921633689344957912535216611716210525197658061038020595741600369400188538567]
c = 30332590230153809507216298771130058954523332140754441956121305005101434036857592445870499808003492282406658682811671092885592290410570348283122359319554197485624784590315564056341976355615543224373344781813890901916269854242660708815123152440620383035798542275833361820196294814385622613621016771854846491244
V = hints
k = 2^800
M = Matrix.column([k * v for v in V]).augment(Matrix.identity(len(V)))
B = [b[1:] for b in M.LLL()]
M = (k * Matrix(B[:len(V)-2])).T.augment(Matrix.identity(len(V)))
B = [b[-len(V):] for b in M.LLL() if set(b[:len(V)-2]) == {0}]

for s, t in itertools.product(range(4), repeat=2):
T = s*B[0] + t*B[1]
a1, a2, a3 ,a4 = T
kq = gcd(a1 * hints[1] - a2 * hints[0], n)
if 1 < kq < n:
print('find!!', kq, s, t)
for i in range(2**16, 1, -1):
if kq % i == 0:
kq //= i
q = int(kq)
p = int(n // kq)
d = inverse(e,(p-1)*(q-1))
flag2 = long_to_bytes(int(pow(c,d,n))).decode()

enc3 = 17737974772490835017139672507261082238806983528533357501033270577311227414618940490226102450232473366793815933753927943027643033829459416623683596533955075569578787574561297243060958714055785089716571943663350360324047532058597960949979894090400134473940587235634842078030727691627400903239810993936770281755
flag3 = long_to_bytes(int(pow(enc3,d,n))).decode()

flag = flag1 + flag2 + flag3

如果flag3不使用RSA2加密,已知h1=a*p+q h2=p+b*q,有

# sage
from sage.all import *
import fpylll
import operator
from functools import reduce
import warnings
from Crypto.Util.number import *

def solve_linear_mod(equations, bounds, guesses=None):
''' Solve an arbitrary system of modular linear equations over different moduli.

equations: A sequence of (lhs == rhs, M) pairs, where lhs and rhs are expressions and M is the modulus.
bounds: A dictionary of {var: M} entries, where var is a variable and M is the maximum of that variable (the bound).
All variables used in the equations must be bounded.
guesses: An *optional* dictionary containing an approximation (guess) for each variable.
For e.g. uniformly distributed numbers, just use the mean (expected value).
Variables for which a guess is not specified are assumed to lie uniformly in [0, bound),
i.e. they will have a guess of bound/2.

NOTE: Bounds are *soft*. This function may return solutions above the bounds. If this happens, and the result
is incorrect, make some bounds smaller and try again.

>>> k = var('k')
>>> # solve CRT
>>> solve_linear_mod([(k == 2, 3), (k == 4, 5), (k == 3, 7)], {k: 3*5*7})
{k: 59}

>>> x,y = var('x,y')
>>> solve_linear_mod([(2*x + 3*y == 7, 11), (3*x + 5*y == 3, 13), (2*x + 5*y == 6, 143)], {x: 143, y: 143})
{x: 62, y: 5}

# The general idea is to set up an integer matrix equation Ax=y by introducing extra variables for the quotients,
# then use LLL to solve the equation. We introduce extra axes in the lattice to observe the actual solution x,
# which works so long as the solutions are known to be bounded (which is of course the case for modular equations).

vars = list(bounds)
if guesses is None:
guesses = {}

NR = len(equations)
NV = len(vars)
B = fpylll.IntegerMatrix(NR+NV, NR+NV)
Y = [None] * (NR + NV)

# B format (columns are the basis for the lattice):
# [ eqns:NRxNV mods:NRxNR
# vars:NVxNV 0 ]
# eqns correspond to equation axes, fi(...) = yi mod mi
# vars correspond to variable axes, which effectively "observe" elements of the solution vector (x in Ax=y)

# Compute scale such that the variable axes can't interfere with the equation axes
nS = 1
for var in vars:
nS = max(nS, int(bounds[var]).bit_length())
# NR + NV is a fudge to make CVP return correct results despite the 2^(n/2) error bound
S = (1 << (nS + (NR + NV + 1)))

# Compute per-variable scale such that the variable axes are scaled roughly equally
scales = {}
for vi, var in enumerate(vars):
scale = S >> (int(bounds[var]).bit_length())
scales[var] = scale
# Fill in vars block of B
# Fill in vars block of B
# Fill in "guess" for variable axis - try reducing bounds if the result is wrong
# Fill in "guess" for variable axis - try reducing bounds if the result is wrong

# Extract coefficients from equations
for ri, (rel, m) in enumerate(equations):
op = rel.operator()
if op is not operator.eq:
raise TypeError('relation %s: not an equality relation' % rel)

expr = (rel - rel.rhs()).lhs().expand()
for var in expr.variables():
if var not in bounds:
raise ValueError('relation %s: variable %s is not bounded' % (rel, var))

# Fill in eqns block of B
coeffs = []
for vi, var in enumerate(vars):
if expr.degree(var) >= 2:
raise ValueError('relation %s: equation is not linear in %s' % (rel, var))
coeff = expr.coefficient(var)
if not coeff.is_constant():
raise ValueError('relation %s: coefficient of %s is not constant (equation is not linear)' % (rel, var))
if not coeff.is_integer():
raise ValueError('relation %s: coefficient of %s is not an integer' % (rel, var))

B[ri, vi] = (int(coeff) % m) * S

# Fill in mods block of B
B[ri, NV + ri] = m * S

const = expr.subs({var: 0 for var in vars})
if not const.is_constant():
raise ValueError('relation %s: failed to extract constant' % rel)
if not const.is_integer():
raise ValueError('relation %s: constant is not integer' % rel)

# Fill in corresponding equation axes of target Y
Y[ri] = (int(-const) % m) * S

# Note that CVP requires LLL to be run first, and that LLL/CVP use the rows as the basis
Bt = B.transpose()
lll = fpylll.LLL.reduction(Bt)
result = fpylll.CVP.closest_vector(Bt, Y)

# Check result for sanity
if list(map(int, result[:NR])) != list(map(int, Y[:NR])):
raise ValueError("CVP returned an incorrect result: input %s, output %s (try increasing your bounds?)" % (Y, result))

res = {}
for vi, var in enumerate(vars):
aa = result[NR + vi] // scales[var]
bb = result[NR + vi] % scales[var]
if bb:
warnings.warn("CVP returned suspicious result: %s=%d is not scaled correctly (try adjusting your bounds?)" % (var, result[NR + vi]))
res[var] = aa

return res
if __name__ == '__main__':
c =
n =
e =
h1 =
h2 =
var('p q')
bounds = {p: 2**512, q: 2**512}
eqs = [(q*h1 + p*h2 - h1*h2==0, n)]
sol = solve_linear_mod(eqs, bounds)
p = sol[p]
q = sol[q]
d = inverse(e, (p-1)*(q-1))
