# Copyright 2024-2025 Olivier Romain, Francis Blais, Vincent Girouard, Marius Trudeau
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
This module solves the Diophantine equation :math:`\xi = t \cdot t^\dagger` for :math:`t \in \mathbb{D}[\omega]` where :math:`\xi
\in \mathbb{D}[\sqrt{2}]` is given. The solution :math:`t` is returned if it exists, or `None` otherwise. This
module is an implementation of the algorithm presented in Section 6 and Appendix C of :cite:`diophantine_ross`.
| **Input:** :math:`\xi \in \mathbb{D}[\sqrt{2}]`
| **Output:** :math:`t \in \mathbb{D}[\omega]`, the solution to the equation :math:`\xi = t \cdot t^\dagger`, or `None` if no solution exists for the specified :math:`\xi`
**Example:**
.. code-block:: python
>>> from qdecomp.rings import *
>>> from qdecomp.utils.diophantine import solve_xi_eq_ttdag_in_d
# Solve a Diophantine equation that has a solution
>>> xi = Dsqrt2(D(13, 1), D(4, 1)) # Input
>>> t = solve_xi_eq_ttdag_in_d(xi) # Compute the solution
>>> print(f"{xi = }")
xi = 13/2^1+2/2^0√2
>>> print(f"{t = }")
t = -2/2^0ω3 + 1/2^1ω2 + 0/2^0ω + 3/2^1
# Check the solution
>>> xi_calculated_in_Domega = t * t.complex_conjugate() # Calculate (t * t†)
>>> xi_calculated = Dsqrt2.from_ring(xi_calculated_in_Domega) # Convert the result from D[omega] to D[sqrt(2)]
>>> print(f"{xi_calculated = }")
xi_calculated = 13/2^1+2/2^0√2
>>> print(f"{xi == xi_calculated = }")
xi == xi_calculated = True
# Solve a Diophantine equation that doesn't have any solution
>>> xi = Dsqrt2(D(9, 1), D(3, 1)) # Input
>>> t = solve_xi_eq_ttdag_in_d(xi) # Compute the solution
>>> print(f"{xi = }")
xi = 9/2^1+3/2^1√2
>>> print(f"{t = }")
t = None
"""
from math import log, sqrt
from typing import Union
from qdecomp.rings import *
from qdecomp.utils.diophantine.tonelli_shanks import tonelli_shanks_algo
# ----------------------------- #
# Functions for Rings algebra #
# ----------------------------- #
[docs]
def gcd_Zomega(x: Zomega, y: Zomega) -> Zomega:
r"""
Find the greatest common divider (or :math:`gcd`) of :math:`x` and :math:`y` in the ring :math:`\mathbb{Z}[\omega]`. :math:`x` and :math:`y` are elements of
the ring :math:`\mathbb{Z}[\omega]`. The algorithm implemented is the Euler method extended to the ring
:math:`\mathbb{Z}[\omega]`.
Args:
x (Zomega): First number
y (Zomega): Second number
Returns:
Zomega: The greatest common divider of :math:`x` and :math:`y`
"""
a, b = x, y
while b != 0:
_, r = euclidean_div_Zomega(a, b)
a, b = b, r
return a
[docs]
def euclidean_div_Zomega(num: Zomega, div: Zomega) -> tuple[Zomega, Zomega]:
r"""
Compute the euclidean division of :math:`num` by :math:`div`, where :math:`num` and :math:`div` are elements of :math:`\mathbb{Z}[\omega]`. This
function return :math:`q` and :math:`r` such that :math:`num = q \cdot div + r`.
Args:
num (Zomega): Number to be divided
div (Zomega): Divider
Returns:
tuple: :math:`(q, r)` where :math:`q` is the result of the division and :math:`r` is the rest
"""
div_cc = div.complex_conjugate() # √2 conjugate of the divider
div_div_cc = div * div_cc # Product of the divider by its complex conjugate
# Convert the denominator into an integer
denom_D = div_div_cc * div_div_cc.sqrt2_conjugate() # Element of the ring D
denom = denom_D.d # Convert to an integer
# Apply the same multiplication on the numerator
numer = num * div_cc * div_div_cc.sqrt2_conjugate()
n = numer
a, b, c, d = n.a, n.b, n.c, n.d # Extract the coefficients of numer
# Divide the coefficients by the integer denominator and round them
a_, b_, c_, d_ = (
round(a / denom),
round(b / denom),
round(c / denom),
round(d / denom),
)
q = Zomega(a_, b_, c_, d_) # Construction of the divider with the new coefficients
r = num - q * div # Calculation of the rest of the division
return q, r
[docs]
def euclidean_div_Zsqrt2(num: Zsqrt2, div: Zsqrt2) -> tuple[Zsqrt2, Zsqrt2]:
r"""
Perform the euclidean division of num in :math:`\mathbb{Z}[\sqrt{2}]`. This function returns :math:`q` and :math:`r` such that
:math:`num = q \cdot div + r`.
Args:
num (Zsqrt2): Number to be divided
div (Zsqrt2): Divider
Returns:
tuple: :math:`(q, r)` where :math:`q` is the result of the division and :math:`r` is the rest
"""
num_ = num * div.sqrt2_conjugate()
den_ = (div * div.sqrt2_conjugate()).a
a_, b_ = num_.a, num_.b
a, b = round(a_ / den_), round(b_ / den_)
q = Zsqrt2(a, b)
r = num - q * div
return q, r
[docs]
def are_sim_Zsqrt2(x: Zsqrt2, y: Zsqrt2) -> bool:
r"""
Determine if :math:`x \sim y`. Equivalently, :math:`x \sim y` if there exists a unit :math:`u` such that :math:`x = u \cdot y`.
:math:`x`, :math:`y` and :math:`u` are elements of :math:`\mathbb{Z}[\sqrt{2}]`.
Args:
x (Zsqrt2): First number
y (Zsqrt2): Second number
Returns:
bool: `True` if :math:`x \sim y`, `False` otherwise
"""
# Test if y is a divider of x and y is a divider of x
_, r1 = euclidean_div_Zsqrt2(x, y)
_, r2 = euclidean_div_Zsqrt2(y, x)
return (r1 == 0) and (r2 == 0)
[docs]
def is_unit_Zsqrt2(x: Zsqrt2) -> bool:
r"""
Determine if :math:`x` is a unit in the ring :math:`\mathbb{Z}[\sqrt{2}]`.
Args:
x (Zsqrt2): The number to test
Returns:
bool: `True` if :math:`x` is a unit, `False` otherwise
"""
integer = x * x.sqrt2_conjugate()
return (integer == 1) or (integer == -1)
# ----------------------------- #
# Functions to solve the Diophantine equation
# ----------------------------- #
[docs]
def is_square(n: int) -> bool:
"""
Check if :math:`n` is a perfect square.
Args:
n (int): An integer
Returns:
bool: `True` if :math:`n` is a perfect square, `False` otherwise
"""
if n < 0:
return False
# A square must have a modulo 16 of 0, 1, 4 or 9
mod = n % 16
if mod != 0 and mod != 1 and mod != 4 and mod != 9:
return False
# Check if n is a square
return round(sqrt(n)) ** 2 == n
[docs]
def solve_usquare_eq_a_mod_p(a: int, p: int) -> int:
r"""
Solve the diophantine equation :math:`u^2 = -a\ (\text{mod p})` where :math:`a`, :math:`p` and :math:`u` are integers. This function
returns the first integer solution of the equation. :math:`p` is a prime. This problem is solved
using the Tonelli-Shanks algorithm.
Args:
a (int): An integer
p (int): A prime integer
Returns:
int: The first positive integer solution :math:`u` to the equation :math:`u^2 = -a\ (\text{mod p})`
"""
if p == 1 and a == 1: # Special case for p = 1
return 1
# Use the Tonelli-Shanks algorithm to find the square root of -a modulo p
return tonelli_shanks_algo(-a, p)
[docs]
def integer_fact(p: int) -> list[tuple[int, int]]:
"""
Find the factorization of an integer :math:`p`. This function returns a list of tuples :math:`(p_i, m_i)` where
:math:`p_i` is a prime factor of :math:`p` and :math:`m_i` is its power.
Args:
p (int): Number to factorize
Returns:
list of tuples: The prime factors of n and their powers. Each tuple is of the form
:math:`(p_i, m_i)` where :math:`p_i` is a prime factor of :math:`p` and :math:`m_i` is its power.
Raises:
ValueError: If the number is less than 2.
ValueError: If the number is not an integer.
"""
if p < 2:
raise ValueError("The number must be greater than 1.")
if int(p) != p:
raise ValueError(f"The number must be an integer. Got {p}.")
n = p
factors = [] # List of tuples (p_i, m_i)
counter = 0
while n % 2 == 0:
counter += 1
n = n // 2
if counter > 0:
factors.append((2, counter))
# n must be odd at this point, so a skip of 2 (i = i + 2) can be used
for i in range(3, int(sqrt(n)) + 1, 2):
counter = 0
# while i divides n, append i and divide n
while n % i == 0:
counter += 1
n = n // i
if counter > 0:
factors.append((i, counter))
if i > sqrt(n):
break
# If n != 1 at this point, n is a prime
if n != 1:
factors.append((n, 1))
return factors
[docs]
def xi_fact(xi: Zsqrt2) -> list[tuple[Zsqrt2, int]]:
r"""
Finds the factorization of :math:`\xi` (up to a prime) in the ring :math:`\mathbb{Z}[\sqrt{2}]` where :math:`\xi` is an
element of :math:`\mathbb{Z}[\sqrt{2}]`. This function returns a list of tuples :math:`(\xi_i, m_i)`, where :math:`\xi_i` is
a prime factor of :math:`\xi` in :math:`\mathbb{Z}[\sqrt{2}]` and :math:`m_i` is its power.
Args:
xi (Zsqrt2): An element of :math:`\mathbb{Z}[\sqrt{2}]`
Returns:
list of tuples: The prime factors of :math:`\xi` and their powers. Each tuple is of the form
:math:`(\xi_i, m_i)` where :math:`\xi_i` is a prime factor of :math:`\xi` and :math:`m_i` is its power.
"""
if xi == 0: # 0 cannot be factorized
return [
(Zsqrt2(0, 0), 1),
]
xi_fact_list = []
p = (xi * xi.sqrt2_conjugate()).a
if p == 1 or p == -1: # ξ is a unit, so it cannot be factorized
return [
(xi, 1),
]
if p < 0: # If p is negative, we factorize -p > 0 instead
p = -p
xi_fact_list.append((Zsqrt2(-1, 0), 1))
pi_list = integer_fact(p)
for pi, mi in pi_list:
# If pi = 2, ξ_i = sqrt(2)
if pi == 2:
xi_fact_list.append((Zsqrt2(0, 1), mi))
# If pi % 8 == 1 or 7, we can factorize pi into ξ_i where pi = ξ_i * ξ_i⋅
elif pi % 8 == 1 or pi % 8 == 7:
xi_i = pi_fact_into_xi(pi)
# Determine wether we need to add ξ_i or its conjugate to the factorization and how
# many times
xi_temp = xi
for i in range(mi + 1):
xi_temp, r = euclidean_div_Zsqrt2(xi_temp, xi_i)
if r != 0:
break
if i != 0:
xi_fact_list.append((xi_i, i))
if i != mi:
xi_fact_list.append((xi_i.sqrt2_conjugate(), mi - i))
# If pi % 8 == 3 or 5, pi is its own factorization in Z[√2]
# We need to append pi mi/2 times to the factorization of ξ since pi = ξ * ξ
else:
xi_fact_list.append((Zsqrt2(pi, 0), mi // 2))
return xi_fact_list
[docs]
def pi_fact_into_xi(pi: int) -> Union[Zsqrt2, None]:
r"""
Solve the equation :math:`p_i = \xi_i \cdot \xi_i^{\bullet} = a^2 - 2 \cdot b^2` where :math:`^{\bullet}` denotes
the :math:`\sqrt{2}` conjugate. :math:`p_i` is a prime integer and :math:`\xi_i = a + b \sqrt{2}` is an element of
:math:`\mathbb{Z}[\sqrt{2}]`. :math:`p_i` has a factorization only if :math:`p_i\ \%\ 8 = 1 \text{ or } 7` or if :math:`p_i = 2`.
In any other case, the function returns `None`.
Args:
pi (int): A prime integer
Returns:
Zsqrt2 or None: A number :math:`\xi_i` for which :math:`p_i = \xi_i \cdot \xi_i^{\bullet}`, or `None` if :math:`p_i\ \%\ 8 \neq 1 \text{ or } 7`
"""
if pi == 2:
return Zsqrt2(0, 1)
if not (pi % 8 == 1 or pi % 8 == 7):
return None
b = 1
while not is_square(pi + 2 * b**2):
b += 1
return Zsqrt2(int(sqrt(pi + 2 * b**2)), b)
[docs]
def xi_i_fact_into_ti(xi_i: Zsqrt2, check_prime: bool = False) -> Union[Zomega, None]:
r"""
Solve the equation :math:`\xi_i = t_i \cdot t_i^\dagger` where :math:`^\dagger` denotes the complex conjugate.
:math:`\xi_i` is a prime element in :math:`\mathbb{Z}[\sqrt{2}]` and :math:`t_i` is an element of :math:`\mathbb{Z}[\omega]`. :math:`\xi_i` has a
factorization only if :math:`p_i\ \%\ 8 = 1, 3 \text{ or } 5`, where :math:`p_i = \xi_i \cdot \xi_i^{\bullet}` or if :math:`p_i = 2`.
Note: this function assumes :math:`\xi_i` is a prime element in :math:`\mathbb{Z}[\sqrt{2}]`. No check is performed to
verify this assumption unless specified by the `check_prime` argument.
Args:
xi_i (Zsqrt2): A prime element in :math:`\mathbb{Z}[\sqrt{2}]`
check_prime (bool): If set to `True`, the function will check if :math:`\xi_i` is a prime in :math:`\mathbb{Z}[\sqrt{2}]`
Returns:
Zomega or None: A number :math:`t_i` for which :math:`\xi_i = t_i \cdot t_i^\dagger`, or `None` if :math:`\xi_i\ \%\ 8 = 7`
Raises:
ValueError: If the input argument is not a prime in :math:`\mathbb{Z}[\sqrt{2}]` (only if `check_prime` is `True`,
because this verification is computationally expensive)
"""
# Verify if ξ_i is a prime in Z[√2]
if check_prime:
factors = xi_fact(xi_i)
is_prime = True
if len(factors) >= 3: # The first factor might be a unit
is_prime = False
if len(factors) == 1: # Check if the factor is not a unit
if is_unit_Zsqrt2(factors[0][0]):
is_prime = False
if len(factors) == 2: # Check if a least one factor is a unit
if not (is_unit_Zsqrt2(factors[0][0]) or is_unit_Zsqrt2(factors[1][0])):
is_prime = False
for _, m in factors:
if m > 1:
is_prime = False
break
if not is_prime:
raise ValueError("The input argument must be a prime in Z[sqrt(2)].")
if xi_i == Zsqrt2(0, 1): # xi_i = √2
delta = Zomega(0, 0, 1, 1) # δ = 1 + ω
return delta
if xi_i.b == 0: # ξ_i is already a prime integer
pi = xi_i.a
else:
pi = (xi_i * xi_i.sqrt2_conjugate()).a
if pi % 4 == 1:
u = solve_usquare_eq_a_mod_p(1, pi)
xi_i_converted = Zomega.from_ring(xi_i)
ti = gcd_Zomega(xi_i_converted, Zomega(0, 1, 0, u)) # Second term: u + i
return ti
if pi % 8 == 3: # ξ_i = pi which is an integer in that case
u = solve_usquare_eq_a_mod_p(2, pi)
xi_i_converted = Zomega.from_ring(xi_i)
ti = gcd_Zomega(xi_i_converted, Zomega(1, 0, 1, u)) # Second term: u + i √2
return ti
if pi % 8 == 7:
return None
[docs]
def solve_xi_sim_ttdag_in_z(xi: Zsqrt2) -> Union[Zomega, None]:
r"""
Solve the equation :math:`\xi \sim t \cdot t^\dagger` for :math:`t` where :math:`^\dagger` denotes the complex conjugate.
:math:`\xi` is an element of :math:`\mathbb{Z}[\sqrt{2}]` and :math:`t` is an element of :math:`\mathbb{Z}[\omega]`. This function returns the
first solution of the equation. If no solution exists, the function returns `None`.
Args:
xi (Zsqrt2): A number
Returns:
Zomega or None: A number :math:`t` for which :math:`\xi = t \cdot t^\dagger`, or `None` if no solution exists
"""
xi_fact_list = xi_fact(xi)
t = Zomega(0, 0, 0, 1)
for xi_i, mi in xi_fact_list:
if xi_i == -1:
continue
if mi % 2 == 0: # For even exponents, ξ_i ** mi = ξ_i ** (mi // 2) * ξ_i ** (mi // 2)
factor = xi_i ** (mi // 2)
t *= Zomega.from_ring(factor)
else:
ti_i = xi_i_fact_into_ti(xi_i)
if ti_i is None:
return None
t *= ti_i**mi
return t
[docs]
def solve_xi_eq_ttdag_in_d(xi: Dsqrt2) -> Union[Domega, None]:
r"""
Solve the equation :math:`\xi = t \cdot t^\dagger` for :math:`t` where :math:`^\dagger` denotes the complex conjugate. :math:`\xi`
is an element of :math:`\mathbb{D}[\sqrt{2}]` and :math:`t` is an element of :math:`\mathbb{D}[\omega]`. This function returns the first
solution of the equation. If no solution exists, it returns `None`.
Args:
xi (Dsqrt2): A number
Returns:
Domega or None: A number :math:`t` for which :math:`\xi = t \cdot t^\dagger`, or `None` if no solution exists
"""
# The equation only has a solution if ξ is doubly positive, i.e. ξ >= 0 and ξ• >= 0.
if float(xi) < 0 or float(xi.sqrt2_conjugate()) < 0:
return None
# If ξ = 0, the solution is 0
if xi == 0:
return Domega((0, 0), (0, 0), (0, 0), (0, 0))
l = (xi * xi.sqrt2_conjugate()).a.denom # Greatest denominator power of 2
xi_prime_temp = Dsqrt2(D(0, 0), D(1, 0)) ** l * xi # ξ_prime is in Z[√2]
xi_prime = Zsqrt2.from_ring(xi_prime_temp) # Convert ξ_prime to Z[√2]
s = solve_xi_sim_ttdag_in_z(xi_prime) # Solve the equation ξ' ~ s * s†
if s is None: # If there is no solution to the equation ξ' ~ s * s†
return None
delta = Zomega(0, 0, 1, 1) # δ = 1 + ω
# δ**-1 = δ * λ**-1 * ω**-1 / √2
delta_inv = (
Domega.from_ring(delta)
* Domega((-1, 0), (0, 0), (1, 0), (-1, 0))
* Domega((0, 0), (-1, 1), (0, 0), (1, 1))
)
delta_inv_l = delta_inv**l # δ_l = δ ** l
t = delta_inv_l * Domega.from_ring(s) # t = δ**-l * s
tt = Dsqrt2.from_ring(t * t.complex_conjugate()) # tt = t * t†
# Find u such that ξ = u * t * t†
denom = (tt * tt.sqrt2_conjugate()).a # Element of ring D
u_temp = xi * tt.sqrt2_conjugate() * int(2**denom.denom)
u = Zsqrt2(u_temp.a.num // denom.num, u_temp.b.num // denom.num)
# u is of the form u = λ**2n => n = ln(u) / 2 ln(λ)
n = round(log(float(u)) / (2 * log(float(Zsqrt2(1, 1)))))
# v**2 = u => v = λ**n
if n > 0:
v = Domega((-1, 0), (0, 0), (1, 0), (1, 0)) ** n # λ**n
elif n == 0:
v = Domega((0, 0), (0, 0), (0, 0), (1, 0)) # 1
else:
v = Domega((-1, 0), (0, 0), (1, 0), (-1, 0)) ** -n # (λ**-1)**n
return t * v