Snore Signatures (crypto)
These signatures are a bore!
The following Python code implements the encryption scheme:
#!/usr/bin/env python3
from Crypto.Util.number import isPrime, getPrime, long_to_bytes, bytes_to_long
from Crypto.Random.random import getrandbits, randint
from Crypto.Hash import SHA512
LOOP_LIMIT = 2000
def hash(val, bits=1024):
    output = 0
    for i in range((bits//512) + 1):
        h = SHA512.new()
        h.update(long_to_bytes(val) + long_to_bytes(i))
        output = int(h.hexdigest(), 16) << (512 * i) ^ output
    return output
def gen_snore_group(N=512):
    q = getPrime(N)
    for _ in range(LOOP_LIMIT):
        X = getrandbits(2*N)
        p = X - X % (2 * q) + 1
        if isPrime(p):
            break
    else:
        raise Exception("Failed to generate group")
    r = (p - 1) // q
    for _ in range(LOOP_LIMIT):
        h = randint(2, p - 1)
        if pow(h, r, p) != 1:
            break
    else:
        raise Exception("Failed to generate group")
    g = pow(h, r, p)
    return (p, q, g)
def snore_gen(p, q, g, N=512):
    x = randint(1, q - 1)
    y = pow(g, -x, p)
    return (x, y)
def snore_sign(p, q, g, x, m):
    k = randint(1, q - 1)
    r = pow(g, k, p)
    e = hash((r + m) % p) % q
    s = (k + x * e) % q
    return (s, e)
def snore_verify(p, q, g, y, m, s, e):
    if not (0 < s < q):
        return False
    rv = (pow(g, s, p) * pow(y, e, p)) % p
    ev = hash((rv + m) % p) % q
    return ev == e
def main():
    p, q, g = gen_snore_group()
    print(f"p = {p}") # p is a prime
    print(f"q = {q}") # q is a prime
    print(f"g = {g}") # pow(h, r, p) where r = (p - 1) // q, h is a randint
    queries = []
    for _ in range(10):
        # x is priv key, y is pub key
        x, y = snore_gen(p, q, g) # y = pow(g, -x, p)
        print(f"y = {y}")
        print('you get one query to the oracle')
        m = int(input("m = "))
        queries.append(m)
        s, e = snore_sign(p, q, g, x, m) # we get the signature of m
        print(f"s = {s}")
        print(f"e = {e}")
        print('can you forge a signature?')
        # need to find an m for which s gives a valid signature
        m = int(input("m = "))
        s = int(input("s = "))
        # you can't change e >:)
        # since we can't change e, m' = m + ip
        if m in queries:
            print('nope')
            return
        if not snore_verify(p, q, g, y, m, s, e):
            print('invalid signature!')
            return
        queries.append(m)
        print('correct signature!')
    print('you win!')
    print(open('flag.txt').read())
if __name__ == "__main__":
    main()We start with the main() function. We obtain p, q and g by invoking gen_snore_group(). In gen_snore_group(), we generate the primes p and q, then calculate r = (p - 1) // q. Finally, we obtain g by doing g = pow(h, r, p) (h is a random int). p, q and g are all given to us as public information.
This is where things get interesting. A for loop iterates 10 times. Within each iteration, we first obtain x and y using snore_gen(p, q, g). x is a random integer and y = pow(g, -x, p). 
We are given y, but not x. Then, we get one query to a signing oracle. The oracle takes in a message m of our choice, then calculates its signature (s, e) using s, e = snore_sign(p, q, g, x, m). 
Let's try to understand the math behind snore_sign. 
def snore_sign(p, q, g, x, m):
    k = randint(1, q - 1)
    r = pow(g, k, p)
    e = hash((r + m) % p) % q
    s = (k + x * e) % q
    return (s, e)It first generates a random integer k, then calculates r. r, m, p, q generate e, and k, x, e, q generate s. Expressing this mathematically,
After we get the signature (s, e) from the query, we are asked to provide a message m (which we haven't used before), and a valid signature (s, e) for it (but our e must be the same as the one obtained from the query, so we can only change s). 
If we can repeat this process of querying the signing oracle and forging a signature 10 times, we get the flag.
Next, let's look at the code for the signature verification:
def snore_verify(p, q, g, y, m, s, e):
    if not (0 < s < q):
        return False
    rv = (pow(g, s, p) * pow(y, e, p)) % p
    ev = hash((rv + m) % p) % q
    return ev == esnore_verify takes the m and s we provide. Then, we calculate rv followed by ev. Then we compare ev and the expected signature e to check if they match. 
Let's try to work out why the verification would work if our given m and s were indeed correct.
We are given
Combining this with the equations from snore_sign, we have
which simplifies to
Then we have
If our provided s and m were correct, this equation would be the same as our hash e. 
So... How do we forge a signature? The key is in identifying a vulnerability in the computation of . When we compute the hash during verification, we should be concatenating rv to m instead of adding them together. To forge a signature we can follow the process below:
- Query the oracle using a message - m, and obtain the signature- (s, e).
- To forge a signature, send an - m'such that- m' = m + ip, where- iis a positive integer.
Why does this work?
- Since we are reusing - s, we know that it has to be equivalent to- k + xeeven though we don't know- kand- x! This would give us
- Since - m' = m + ip, . As such, we would get
Using pwntools, I wrote a script to repeat this procedure 10 times and obtain the flag from the server. Since it remembers the values of m from previous queries and prevents us from reusing them, I use the following values of m and m' for the 10 iterations:
from pwn import *
conn = remote("snore-signatures.chal.uiuc.tf", 1337, ssl=True)
context.log_level = "debug"
conn.recvuntil(b"p = ")
p = int(conn.recvuntil(b"\n")[:-1].decode())
conn.recvuntil(b"q = ")
q = conn.recvuntil(b"\n")[:-1]
conn.recvuntil(b"g = ")
g = conn.recvuntil(b"\n")[:-1]
conn.recvuntil(b"y = ")
y = conn.recvuntil(b"\n")[:-1]
curr_m = 123
for _ in range(10):
    # send query to oracle
    conn.sendlineafter(b"m = ", str(curr_m).encode())
    conn.recvuntil(b"s = ")
    s = conn.recvuntil(b"\n")[:-1]
    conn.recvuntil(b"e = ")
    e = conn.recvuntil(b"\n")[:-1]
    conn.sendlineafter(b"m = ", str(curr_m + p).encode())
    conn.sendlineafter(b"s = ", s)
    curr_m += 1
info(conn.recvall()) # uiuctf{add1ti0n_i5_n0t_c0nc4t3n4ti0n}Last updated