Snore Signatures (crypto)

These signatures are a bore!

The following Python code implements the encryption scheme:

#!/usr/bin/env python3

from Crypto.Util.number import isPrime, getPrime, long_to_bytes, bytes_to_long
from Crypto.Random.random import getrandbits, randint
from Crypto.Hash import SHA512

LOOP_LIMIT = 2000


def hash(val, bits=1024):
    output = 0
    for i in range((bits//512) + 1):
        h = SHA512.new()
        h.update(long_to_bytes(val) + long_to_bytes(i))
        output = int(h.hexdigest(), 16) << (512 * i) ^ output
    return output

def gen_snore_group(N=512):
    q = getPrime(N)
    for _ in range(LOOP_LIMIT):
        X = getrandbits(2*N)
        p = X - X % (2 * q) + 1
        if isPrime(p):
            break
    else:
        raise Exception("Failed to generate group")

    r = (p - 1) // q

    for _ in range(LOOP_LIMIT):
        h = randint(2, p - 1)
        if pow(h, r, p) != 1:
            break
    else:
        raise Exception("Failed to generate group")

    g = pow(h, r, p)

    return (p, q, g)

def snore_gen(p, q, g, N=512):
    x = randint(1, q - 1)
    y = pow(g, -x, p)
    return (x, y)

def snore_sign(p, q, g, x, m):
    k = randint(1, q - 1)
    r = pow(g, k, p)
    e = hash((r + m) % p) % q
    s = (k + x * e) % q
    return (s, e)

def snore_verify(p, q, g, y, m, s, e):
    if not (0 < s < q):
        return False

    rv = (pow(g, s, p) * pow(y, e, p)) % p
    ev = hash((rv + m) % p) % q

    return ev == e


def main():
    p, q, g = gen_snore_group()
    print(f"p = {p}") # p is a prime
    print(f"q = {q}") # q is a prime
    print(f"g = {g}") # pow(h, r, p) where r = (p - 1) // q, h is a randint

    queries = []
    for _ in range(10):
        # x is priv key, y is pub key
        x, y = snore_gen(p, q, g) # y = pow(g, -x, p)
        print(f"y = {y}")

        print('you get one query to the oracle')

        m = int(input("m = "))
        queries.append(m)
        s, e = snore_sign(p, q, g, x, m) # we get the signature of m
        print(f"s = {s}")
        print(f"e = {e}")

        print('can you forge a signature?')
        # need to find an m for which s gives a valid signature
        m = int(input("m = "))
        s = int(input("s = "))
        # you can't change e >:)
        # since we can't change e, m' = m + ip
        if m in queries:
            print('nope')
            return
        if not snore_verify(p, q, g, y, m, s, e):
            print('invalid signature!')
            return
        queries.append(m)
        print('correct signature!')

    print('you win!')
    print(open('flag.txt').read())

if __name__ == "__main__":
    main()

We start with the main() function. We obtain p, q and g by invoking gen_snore_group(). In gen_snore_group(), we generate the primes p and q, then calculate r = (p - 1) // q. Finally, we obtain g by doing g = pow(h, r, p) (h is a random int). p, q and g are all given to us as public information.

This is where things get interesting. A for loop iterates 10 times. Within each iteration, we first obtain x and y using snore_gen(p, q, g). x is a random integer and y = pow(g, -x, p).

y=gx(mod p)y=g^{-x}(mod\space p)

We are given y, but not x. Then, we get one query to a signing oracle. The oracle takes in a message m of our choice, then calculates its signature (s, e) using s, e = snore_sign(p, q, g, x, m).

Let's try to understand the math behind snore_sign.

def snore_sign(p, q, g, x, m):
    k = randint(1, q - 1)
    r = pow(g, k, p)
    e = hash((r + m) % p) % q
    s = (k + x * e) % q
    return (s, e)

It first generates a random integer k, then calculates r. r, m, p, q generate e, and k, x, e, q generate s. Expressing this mathematically,

r=gk(mod p)r=g^k(mod \space p)
e=H((r+m) (mod p)) (mod q)e = H((r + m) \space (mod \space p)) \space (mod \space q)
s=(k+xe) (mod q)s=(k+xe)\space (mod \space q)

After we get the signature (s, e) from the query, we are asked to provide a message m (which we haven't used before), and a valid signature (s, e) for it (but our e must be the same as the one obtained from the query, so we can only change s).

If we can repeat this process of querying the signing oracle and forging a signature 10 times, we get the flag.

Next, let's look at the code for the signature verification:

def snore_verify(p, q, g, y, m, s, e):
    if not (0 < s < q):
        return False

    rv = (pow(g, s, p) * pow(y, e, p)) % p
    ev = hash((rv + m) % p) % q

    return ev == e

snore_verify takes the m and s we provide. Then, we calculate rv followed by ev. Then we compare ev and the expected signature e to check if they match.

Let's try to work out why the verification would work if our given m and s were indeed correct.

We are given

rv=(gs(mod p)×ye(mod p)) (mod p)r_v=(g^s(mod\space p)\times y^e (mod\space p)) \space (mod \space p)
=(gs×ye) (mod p)=(g^s\times y^e ) \space (mod \space p)

Combining this with the equations from snore_sign, we have

rv=g(k+xe)(mod q)×(gx(mod p))e(mod p)r_v=g^{(k + xe)(mod\space q)}\times (g^{-x}(mod\space p))^e(mod \space p)

which simplifies to

rv=g(k+xe)×gxe(mod p)r_v=g^{(k + xe)}\times g^{-xe}(mod \space p)
=gk(mod p)=g^k(mod\space p)

Then we have

ev=H((rv+m)(mod p))(mod q)e_v=H((r_v+m)(mod \space p))(mod \space q)

If our provided s and m were correct, this equation would be the same as our hash e.

So... How do we forge a signature? The key is in identifying a vulnerability in the computation of eve_v. When we compute the hash during verification, we should be concatenating rv to m instead of adding them together. To forge a signature we can follow the process below:

  1. Query the oracle using a message m, and obtain the signature (s, e).

  2. To forge a signature, send an m' such that m' = m + ip, where i is a positive integer.

Why does this work?

  • Since we are reusing s, we know that it has to be equivalent to k + xe even though we don't know k and x! This would give us rv=rr_v=r

  • Since m' = m + ip, (rv+m)(mod p)=(rv+m)(mod p)(r_v+m')(mod\space p) = (r_v+m)(mod \space p). As such, we would get ev=ee_v=e

Using pwntools, I wrote a script to repeat this procedure 10 times and obtain the flag from the server. Since it remembers the values of m from previous queries and prevents us from reusing them, I use the following values of m and m' for the 10 iterations:

m0=123,m0=123+p,m1=124,m1=124+p,...,m9=132,m9=132+pm_0=123,m_0'=123+p,m_1=124,m_1'=124+p,...,m_9=132,m_9'=132+p
from pwn import *

conn = remote("snore-signatures.chal.uiuc.tf", 1337, ssl=True)

context.log_level = "debug"

conn.recvuntil(b"p = ")
p = int(conn.recvuntil(b"\n")[:-1].decode())
conn.recvuntil(b"q = ")
q = conn.recvuntil(b"\n")[:-1]
conn.recvuntil(b"g = ")
g = conn.recvuntil(b"\n")[:-1]
conn.recvuntil(b"y = ")
y = conn.recvuntil(b"\n")[:-1]

curr_m = 123

for _ in range(10):
    # send query to oracle
    conn.sendlineafter(b"m = ", str(curr_m).encode())
    conn.recvuntil(b"s = ")
    s = conn.recvuntil(b"\n")[:-1]
    conn.recvuntil(b"e = ")
    e = conn.recvuntil(b"\n")[:-1]
    conn.sendlineafter(b"m = ", str(curr_m + p).encode())
    conn.sendlineafter(b"s = ", s)
    curr_m += 1

info(conn.recvall()) # uiuctf{add1ti0n_i5_n0t_c0nc4t3n4ti0n}

Last updated