369 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			369 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# MIT License
 | 
						|
#
 | 
						|
# Copyright (c) 2015 Brian Warner and other contributors
 | 
						|
 | 
						|
# Permission is hereby granted, free of charge, to any person obtaining a copy
 | 
						|
# of this software and associated documentation files (the "Software"), to deal
 | 
						|
# in the Software without restriction, including without limitation the rights
 | 
						|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 | 
						|
# copies of the Software, and to permit persons to whom the Software is
 | 
						|
# furnished to do so, subject to the following conditions:
 | 
						|
#
 | 
						|
# The above copyright notice and this permission notice shall be included in all
 | 
						|
# copies or substantial portions of the Software.
 | 
						|
#
 | 
						|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 | 
						|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 | 
						|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 | 
						|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 | 
						|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 | 
						|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 | 
						|
# SOFTWARE.
 | 
						|
 | 
						|
import binascii, hashlib, itertools
 | 
						|
 | 
						|
Q = 2**255 - 19
 | 
						|
L = 2**252 + 27742317777372353535851937790883648493
 | 
						|
 | 
						|
def inv(x):
 | 
						|
    return pow(x, Q-2, Q)
 | 
						|
 | 
						|
d = -121665 * inv(121666)
 | 
						|
I = pow(2,(Q-1)//4,Q)
 | 
						|
 | 
						|
def xrecover(y):
 | 
						|
    xx = (y*y-1) * inv(d*y*y+1)
 | 
						|
    x = pow(xx,(Q+3)//8,Q)
 | 
						|
    if (x*x - xx) % Q != 0: x = (x*I) % Q
 | 
						|
    if x % 2 != 0: x = Q-x
 | 
						|
    return x
 | 
						|
 | 
						|
By = 4 * inv(5)
 | 
						|
Bx = xrecover(By)
 | 
						|
B = [Bx % Q,By % Q]
 | 
						|
 | 
						|
# Extended Coordinates: x=X/Z, y=Y/Z, x*y=T/Z
 | 
						|
# http://www.hyperelliptic.org/EFD/g1p/auto-twisted-extended-1.html
 | 
						|
 | 
						|
def xform_affine_to_extended(pt):
 | 
						|
    (x, y) = pt
 | 
						|
    return (x%Q, y%Q, 1, (x*y)%Q) # (X,Y,Z,T)
 | 
						|
 | 
						|
def xform_extended_to_affine(pt):
 | 
						|
    (x, y, z, _) = pt
 | 
						|
    return ((x*inv(z))%Q, (y*inv(z))%Q)
 | 
						|
 | 
						|
def double_element(pt): # extended->extended
 | 
						|
    # dbl-2008-hwcd
 | 
						|
    (X1, Y1, Z1, _) = pt
 | 
						|
    A = (X1*X1)
 | 
						|
    B = (Y1*Y1)
 | 
						|
    C = (2*Z1*Z1)
 | 
						|
    D = (-A) % Q
 | 
						|
    J = (X1+Y1) % Q
 | 
						|
    E = (J*J-A-B) % Q
 | 
						|
    G = (D+B) % Q
 | 
						|
    F = (G-C) % Q
 | 
						|
    H = (D-B) % Q
 | 
						|
    X3 = (E*F) % Q
 | 
						|
    Y3 = (G*H) % Q
 | 
						|
    Z3 = (F*G) % Q
 | 
						|
    T3 = (E*H) % Q
 | 
						|
    return (X3, Y3, Z3, T3)
 | 
						|
 | 
						|
def add_elements(pt1, pt2): # extended->extended
 | 
						|
    # add-2008-hwcd-3 . Slightly slower than add-2008-hwcd-4, but -3 is
 | 
						|
    # unified, so it's safe for general-purpose addition
 | 
						|
    (X1, Y1, Z1, T1) = pt1
 | 
						|
    (X2, Y2, Z2, T2) = pt2
 | 
						|
    A = ((Y1-X1)*(Y2-X2)) % Q
 | 
						|
    B = ((Y1+X1)*(Y2+X2)) % Q
 | 
						|
    C = T1*(2*d)*T2 % Q
 | 
						|
    D = Z1*2*Z2 % Q
 | 
						|
    E = (B-A) % Q
 | 
						|
    F = (D-C) % Q
 | 
						|
    G = (D+C) % Q
 | 
						|
    H = (B+A) % Q
 | 
						|
    X3 = (E*F) % Q
 | 
						|
    Y3 = (G*H) % Q
 | 
						|
    T3 = (E*H) % Q
 | 
						|
    Z3 = (F*G) % Q
 | 
						|
    return (X3, Y3, Z3, T3)
 | 
						|
 | 
						|
def scalarmult_element_safe_slow(pt, n):
 | 
						|
    # this form is slightly slower, but tolerates arbitrary points, including
 | 
						|
    # those which are not in the main 1*L subgroup. This includes points of
 | 
						|
    # order 1 (the neutral element Zero), 2, 4, and 8.
 | 
						|
    assert n >= 0
 | 
						|
    if n==0:
 | 
						|
        return xform_affine_to_extended((0,1))
 | 
						|
    _ = double_element(scalarmult_element_safe_slow(pt, n>>1))
 | 
						|
    return add_elements(_, pt) if n&1 else _
 | 
						|
 | 
						|
def _add_elements_nonunfied(pt1, pt2): # extended->extended
 | 
						|
    # add-2008-hwcd-4 : NOT unified, only for pt1!=pt2. About 10% faster than
 | 
						|
    # the (unified) add-2008-hwcd-3, and safe to use inside scalarmult if you
 | 
						|
    # aren't using points of order 1/2/4/8
 | 
						|
    (X1, Y1, Z1, T1) = pt1
 | 
						|
    (X2, Y2, Z2, T2) = pt2
 | 
						|
    A = ((Y1-X1)*(Y2+X2)) % Q
 | 
						|
    B = ((Y1+X1)*(Y2-X2)) % Q
 | 
						|
    C = (Z1*2*T2) % Q
 | 
						|
    D = (T1*2*Z2) % Q
 | 
						|
    E = (D+C) % Q
 | 
						|
    F = (B-A) % Q
 | 
						|
    G = (B+A) % Q
 | 
						|
    H = (D-C) % Q
 | 
						|
    X3 = (E*F) % Q
 | 
						|
    Y3 = (G*H) % Q
 | 
						|
    Z3 = (F*G) % Q
 | 
						|
    T3 = (E*H) % Q
 | 
						|
    return (X3, Y3, Z3, T3)
 | 
						|
 | 
						|
def scalarmult_element(pt, n): # extended->extended
 | 
						|
    # This form only works properly when given points that are a member of
 | 
						|
    # the main 1*L subgroup. It will give incorrect answers when called with
 | 
						|
    # the points of order 1/2/4/8, including point Zero. (it will also work
 | 
						|
    # properly when given points of order 2*L/4*L/8*L)
 | 
						|
    assert n >= 0
 | 
						|
    if n==0:
 | 
						|
        return xform_affine_to_extended((0,1))
 | 
						|
    _ = double_element(scalarmult_element(pt, n>>1))
 | 
						|
    return _add_elements_nonunfied(_, pt) if n&1 else _
 | 
						|
 | 
						|
# points are encoded as 32-bytes little-endian, b255 is sign, b2b1b0 are 0
 | 
						|
 | 
						|
def encodepoint(P):
 | 
						|
    x = P[0]
 | 
						|
    y = P[1]
 | 
						|
    # MSB of output equals x.b0 (=x&1)
 | 
						|
    # rest of output is little-endian y
 | 
						|
    assert 0 <= y < (1<<255) # always < 0x7fff..ff
 | 
						|
    if x & 1:
 | 
						|
        y += 1<<255
 | 
						|
    return binascii.unhexlify("%064x" % y)[::-1]
 | 
						|
 | 
						|
def isoncurve(P):
 | 
						|
    x = P[0]
 | 
						|
    y = P[1]
 | 
						|
    return (-x*x + y*y - 1 - d*x*x*y*y) % Q == 0
 | 
						|
 | 
						|
class NotOnCurve(Exception):
 | 
						|
    pass
 | 
						|
 | 
						|
def decodepoint(s):
 | 
						|
    unclamped = int(binascii.hexlify(s[:32][::-1]), 16)
 | 
						|
    clamp = (1 << 255) - 1
 | 
						|
    y = unclamped & clamp # clear MSB
 | 
						|
    x = xrecover(y)
 | 
						|
    if bool(x & 1) != bool(unclamped & (1<<255)): x = Q-x
 | 
						|
    P = [x,y]
 | 
						|
    if not isoncurve(P): raise NotOnCurve("decoding point that is not on curve")
 | 
						|
    return P
 | 
						|
 | 
						|
# scalars are encoded as 32-bytes little-endian
 | 
						|
 | 
						|
def bytes_to_scalar(s):
 | 
						|
    assert len(s) == 32, len(s)
 | 
						|
    return int(binascii.hexlify(s[::-1]), 16)
 | 
						|
 | 
						|
def bytes_to_clamped_scalar(s):
 | 
						|
    # Ed25519 private keys clamp the scalar to ensure two things:
 | 
						|
    #   1: integer value is in L/2 .. L, to avoid small-logarithm
 | 
						|
    #      non-wraparaound
 | 
						|
    #   2: low-order 3 bits are zero, so a small-subgroup attack won't learn
 | 
						|
    #      any information
 | 
						|
    # set the top two bits to 01, and the bottom three to 000
 | 
						|
    a_unclamped = bytes_to_scalar(s)
 | 
						|
    AND_CLAMP = (1<<254) - 1 - 7
 | 
						|
    OR_CLAMP = (1<<254)
 | 
						|
    a_clamped = (a_unclamped & AND_CLAMP) | OR_CLAMP
 | 
						|
    return a_clamped
 | 
						|
 | 
						|
def random_scalar(entropy_f): # 0..L-1 inclusive
 | 
						|
    # reduce the bias to a safe level by generating 256 extra bits
 | 
						|
    oversized = int(binascii.hexlify(entropy_f(32+32)), 16)
 | 
						|
    return oversized % L
 | 
						|
 | 
						|
def password_to_scalar(pw):
 | 
						|
    oversized = hashlib.sha512(pw).digest()
 | 
						|
    return int(binascii.hexlify(oversized), 16) % L
 | 
						|
 | 
						|
def scalar_to_bytes(y):
 | 
						|
    y = y % L
 | 
						|
    assert 0 <= y < 2**256
 | 
						|
    return binascii.unhexlify("%064x" % y)[::-1]
 | 
						|
 | 
						|
# Elements, of various orders
 | 
						|
 | 
						|
def is_extended_zero(XYTZ):
 | 
						|
    # catch Zero
 | 
						|
    (X, Y, Z, T) = XYTZ
 | 
						|
    Y = Y % Q
 | 
						|
    Z = Z % Q
 | 
						|
    if X==0 and Y==Z and Y!=0:
 | 
						|
        return True
 | 
						|
    return False
 | 
						|
 | 
						|
class ElementOfUnknownGroup:
 | 
						|
    # This is used for points of order 2,4,8,2*L,4*L,8*L
 | 
						|
    def __init__(self, XYTZ):
 | 
						|
        assert isinstance(XYTZ, tuple)
 | 
						|
        assert len(XYTZ) == 4
 | 
						|
        self.XYTZ = XYTZ
 | 
						|
 | 
						|
    def add(self, other):
 | 
						|
        if not isinstance(other, ElementOfUnknownGroup):
 | 
						|
            raise TypeError("elements can only be added to other elements")
 | 
						|
        sum_XYTZ = add_elements(self.XYTZ, other.XYTZ)
 | 
						|
        if is_extended_zero(sum_XYTZ):
 | 
						|
            return Zero
 | 
						|
        return ElementOfUnknownGroup(sum_XYTZ)
 | 
						|
 | 
						|
    def scalarmult(self, s):
 | 
						|
        if isinstance(s, ElementOfUnknownGroup):
 | 
						|
            raise TypeError("elements cannot be multiplied together")
 | 
						|
        assert s >= 0
 | 
						|
        product = scalarmult_element_safe_slow(self.XYTZ, s)
 | 
						|
        return ElementOfUnknownGroup(product)
 | 
						|
 | 
						|
    def to_bytes(self):
 | 
						|
        return encodepoint(xform_extended_to_affine(self.XYTZ))
 | 
						|
    def __eq__(self, other):
 | 
						|
        return self.to_bytes() == other.to_bytes()
 | 
						|
    def __ne__(self, other):
 | 
						|
        return not self == other
 | 
						|
 | 
						|
class Element(ElementOfUnknownGroup):
 | 
						|
    # this only holds elements in the main 1*L subgroup. It never holds Zero,
 | 
						|
    # or elements of order 1/2/4/8, or 2*L/4*L/8*L.
 | 
						|
 | 
						|
    def add(self, other):
 | 
						|
        if not isinstance(other, ElementOfUnknownGroup):
 | 
						|
            raise TypeError("elements can only be added to other elements")
 | 
						|
        sum_element = ElementOfUnknownGroup.add(self, other)
 | 
						|
        if sum_element is Zero:
 | 
						|
            return sum_element
 | 
						|
        if isinstance(other, Element):
 | 
						|
            # adding two subgroup elements results in another subgroup
 | 
						|
            # element, or Zero, and we've already excluded Zero
 | 
						|
            return Element(sum_element.XYTZ)
 | 
						|
        # not necessarily a subgroup member, so assume not
 | 
						|
        return sum_element
 | 
						|
 | 
						|
    def scalarmult(self, s):
 | 
						|
        if isinstance(s, ElementOfUnknownGroup):
 | 
						|
            raise TypeError("elements cannot be multiplied together")
 | 
						|
        # scalarmult of subgroup members can be done modulo the subgroup
 | 
						|
        # order, and using the faster non-unified function.
 | 
						|
        s = s % L
 | 
						|
        # scalarmult(s=0) gets you Zero
 | 
						|
        if s == 0:
 | 
						|
            return Zero
 | 
						|
        # scalarmult(s=1) gets you self, which is a subgroup member
 | 
						|
        # scalarmult(s<grouporder) gets you a different subgroup member
 | 
						|
        return Element(scalarmult_element(self.XYTZ, s))
 | 
						|
 | 
						|
    # negation and subtraction only make sense for the main subgroup
 | 
						|
    def negate(self):
 | 
						|
        # slow. Prefer e.scalarmult(-pw) to e.scalarmult(pw).negate()
 | 
						|
        return Element(scalarmult_element(self.XYTZ, L-2))
 | 
						|
    def subtract(self, other):
 | 
						|
        return self.add(other.negate())
 | 
						|
 | 
						|
class _ZeroElement(ElementOfUnknownGroup):
 | 
						|
    def add(self, other):
 | 
						|
        return other # zero+anything = anything
 | 
						|
    def scalarmult(self, s):
 | 
						|
        return self # zero*anything = zero
 | 
						|
    def negate(self):
 | 
						|
        return self # -zero = zero
 | 
						|
    def subtract(self, other):
 | 
						|
        return self.add(other.negate())
 | 
						|
 | 
						|
 | 
						|
Base = Element(xform_affine_to_extended(B))
 | 
						|
Zero = _ZeroElement(xform_affine_to_extended((0,1))) # the neutral (identity) element
 | 
						|
 | 
						|
_zero_bytes = Zero.to_bytes()
 | 
						|
 | 
						|
 | 
						|
def arbitrary_element(seed): # unknown DL
 | 
						|
    # TODO: if we don't need uniformity, maybe use just sha256 here?
 | 
						|
    hseed = hashlib.sha512(seed).digest()
 | 
						|
    y = int(binascii.hexlify(hseed), 16) % Q
 | 
						|
 | 
						|
    # we try successive Y values until we find a valid point
 | 
						|
    for plus in itertools.count(0):
 | 
						|
        y_plus = (y + plus) % Q
 | 
						|
        x = xrecover(y_plus)
 | 
						|
        Pa = [x,y_plus] # no attempt to use both "positive" and "negative" X
 | 
						|
 | 
						|
        # only about 50% of Y coordinates map to valid curve points (I think
 | 
						|
        # the other half give you points on the "twist").
 | 
						|
        if not isoncurve(Pa):
 | 
						|
            continue
 | 
						|
 | 
						|
        P = ElementOfUnknownGroup(xform_affine_to_extended(Pa))
 | 
						|
        # even if the point is on our curve, it may not be in our particular
 | 
						|
        # (order=L) subgroup. The curve has order 8*L, so an arbitrary point
 | 
						|
        # could have order 1,2,4,8,1*L,2*L,4*L,8*L (everything which divides
 | 
						|
        # the group order).
 | 
						|
 | 
						|
        # [I MAY BE COMPLETELY WRONG ABOUT THIS, but my brief statistical
 | 
						|
        # tests suggest it's not too far off] There are phi(x) points with
 | 
						|
        # order x, so:
 | 
						|
        #  1 element of order 1: [(x=0,y=1)=Zero]
 | 
						|
        #  1 element of order 2 [(x=0,y=-1)]
 | 
						|
        #  2 elements of order 4
 | 
						|
        #  4 elements of order 8
 | 
						|
        #  L-1 elements of order L (including Base)
 | 
						|
        #  L-1 elements of order 2*L
 | 
						|
        #  2*(L-1) elements of order 4*L
 | 
						|
        #  4*(L-1) elements of order 8*L
 | 
						|
 | 
						|
        # So 50% of random points will have order 8*L, 25% will have order
 | 
						|
        # 4*L, 13% order 2*L, and 13% will have our desired order 1*L (and a
 | 
						|
        # vanishingly small fraction will have 1/2/4/8). If we multiply any
 | 
						|
        # of the 8*L points by 2, we're sure to get an 4*L point (and
 | 
						|
        # multiplying a 4*L point by 2 gives us a 2*L point, and so on).
 | 
						|
        # Multiplying a 1*L point by 2 gives us a different 1*L point. So
 | 
						|
        # multiplying by 8 gets us from almost any point into a uniform point
 | 
						|
        # on the correct 1*L subgroup.
 | 
						|
 | 
						|
        P8 = P.scalarmult(8)
 | 
						|
 | 
						|
        # if we got really unlucky and picked one of the 8 low-order points,
 | 
						|
        # multiplying by 8 will get us to the identity (Zero), which we check
 | 
						|
        # for explicitly.
 | 
						|
        if is_extended_zero(P8.XYTZ):
 | 
						|
            continue
 | 
						|
 | 
						|
        # Test that we're finally in the right group. We want to scalarmult
 | 
						|
        # by L, and we want to *not* use the trick in Group.scalarmult()
 | 
						|
        # which does x%L, because that would bypass the check we care about.
 | 
						|
        # P is still an _ElementOfUnknownGroup, which doesn't use x%L because
 | 
						|
        # that's not correct for points outside the main group.
 | 
						|
        assert is_extended_zero(P8.scalarmult(L).XYTZ)
 | 
						|
 | 
						|
        return Element(P8.XYTZ)
 | 
						|
    # never reached
 | 
						|
 | 
						|
def bytes_to_unknown_group_element(bytes):
 | 
						|
    # this accepts all elements, including Zero and wrong-subgroup ones
 | 
						|
    if bytes == _zero_bytes:
 | 
						|
        return Zero
 | 
						|
    XYTZ = xform_affine_to_extended(decodepoint(bytes))
 | 
						|
    return ElementOfUnknownGroup(XYTZ)
 | 
						|
 | 
						|
def bytes_to_element(bytes):
 | 
						|
    # this strictly only accepts elements in the right subgroup
 | 
						|
    P = bytes_to_unknown_group_element(bytes)
 | 
						|
    if P is Zero:
 | 
						|
        raise ValueError("element was Zero")
 | 
						|
    if not is_extended_zero(P.scalarmult(L).XYTZ):
 | 
						|
        raise ValueError("element is not in the right group")
 | 
						|
    # the point is in the expected 1*L subgroup, not in the 2/4/8 groups,
 | 
						|
    # or in the 2*L/4*L/8*L groups. Promote it to a correct-group Element.
 | 
						|
    return Element(P.XYTZ)
 |