Naptime (crypto)

I'm pretty tired. Don't leak my flag while I'm asleep.

The given sage file implements the encryption scheme:

from random import randint
from Crypto.Util.number import getPrime, bytes_to_long, long_to_bytes
import numpy as np

def get_b(n):
    b = []
    b.append(randint(2**(n-1), 2**n))
    for i in range(n - 1):
        lb = sum(b)
        found = False
        while not found:
            num = randint(max(2**(n + i), lb + 1), 2**(n + i + 1))
            if num > lb:
                found = True
                b.append(num)

    return b

def get_MW(b):
    lb = sum(b)
    M = randint(lb + 1, 2*lb)
    W = getPrime(int(1.5*len(b)))

    return M, W

def get_a(b, M, W):
    a_ = []
    for num in b:
        a_.append(num * W % M)
    pi = np.random.permutation(list(i for i in range(len(b)))).tolist()
    a = [a_[pi[i]] for i in range(len(b))]
    return a, pi


def enc(flag, a, n):
    bitstrings = []
    for c in flag:
        # c -> int -> 8-bit binary string
        bitstrings.append(bin(ord(c))[2:].zfill(8))
    ct = []
    for bits in bitstrings:
        curr = 0
        for i, b in enumerate(bits):
            if b == "1":
                curr += a[i]
        ct.append(curr)

    return ct

def dec(ct, a, b, pi, M, W, n):
    # construct inverse permuation to pi
    pii = np.argsort(pi).tolist()
    m = ""
    U = pow(W, -1, M)
    ct = [c * U % M for c in ct]
    for c in ct:
        # find b_pi(j)
        diff = 0
        bits = ["0" for _ in range(n)]
        for i in reversed(range(n)):
            if c - diff > sum(b[:i]):
                diff += b[i]
                bits[pii[i]] = "1"
        # convert bits to character
        m += chr(int("".join(bits), base=2))

    return m


def main():
    flag = 'uiuctf{I_DID_NOT_LEAVE_THE_FLAG_THIS_TIME}'

    # generate cryptosystem
    n = 8
    b = get_b(n)
    M, W = get_MW(b)
    a, pi = get_a(b, M, W)

    # encrypt
    ct = enc(flag, a, n)

    # public information
    print(f"{a =  }")
    print(f"{ct = }")

    # decrypt
    res = dec(ct, a, b, pi, M, W, n)

if __name__ == "__main__":
    main()

Let's look at the main() function first. It generates the cryptosystem starting from n = 8. Using n, we generate b using get_b(n). Looking at the function, it just seems to be generating an array of 8 random integers. Then we generate M and W using get_MW(). M is a random integer, whereas W is a random 12-bit prime. Then, a and pi are generated using get_a(b, M, W).

def get_a(b, M, W):
    a_ = []
    for num in b:
        a_.append(num * W % M)
    pi = np.random.permutation(list(i for i in range(len(b)))).tolist()
    a = [a_[pi[i]] for i in range(len(b))]
    return a, pi

get_a makes a_, which is an array containing each element in b multiplied by W and modulo M. pi is a random permutation of [0,1,2,...,7]. Then, a is a_ but shuffled according to the indexes given by pi.

Once the cryptosystem generation is done, the ciphertext ct is obtained using enc(flag, a, n).

def enc(flag, a, n):
    bitstrings = []
    for c in flag:
        # c -> int -> 8-bit binary string
        bitstrings.append(bin(ord(c))[2:].zfill(8))
    ct = []
    for bits in bitstrings:
        curr = 0
        for i, b in enumerate(bits):
            if b == "1":
                curr += a[i]
        ct.append(curr)

    return ct

The encryption algorithm iterates through each character c in the flag. Each character is converted into an 8-bit bitstring and added to the array bitstrings. Then, it iterates through each bitstring bits in the array. For each bitstring, we maintain a sum curr. When the bit at an index i is 1, we add a[i] to curr. So for example if we had a bitstring 01011101, we would have curr = a[1] + a[3] + a[4] + a[5] + a[7]. After the 8 bits have been iterated through, the sum curr is appended to the array ct.

We are given the array a and the ciphertext ct.

Thinking about this problem, every 8-bit character is represented as a sum of the elements in a. This is essentially a 0/1 knapsack problem, where we want to find out which elements of a to add to the sum for each ciphertext integer in ct.

To solve this problem, we can iterate through all 2^8=256 possible characters, and calculate the sum which each character would give us. Then, for each integer in ct, we just need to match its sum with the character that produces that sum. This allows us to work backwards from integers in ct to characters in the plaintext!

The script below uses a dictionary whose keys are the sum each character gives, and its values are the respective character. After the dictionary is filled with all 256 key-value pairs, I iterate through each integer in ct, obtain the corresponding plaintext character using the dictionary, and append it to flag.

a =  [66128, 61158, 36912, 65196, 15611, 45292, 84119, 65338]
ct = [273896, 179019, 273896, 247527, 208558, 227481, 328334, 179019, 336714, 292819, 102108, 208558, 336714, 312723, 158973, 208700, 208700, 163266, 244215, 336714, 312723, 102108, 336714, 142107, 336714, 167446, 251565, 227481, 296857, 336714, 208558, 113681, 251565, 336714, 227481, 158973, 147400, 292819, 289507]

# generate all 2^8 permutations; for each element in a we either take it or we don't
mapping = dict()
for i in range(0, 256):
    bitstring = bin(i)[2:].zfill(8)
    curr = 0
    for j, b in enumerate(bitstring):
        if b == "1":
            curr += a[j]
    mapping[curr] = chr(i)
# print(mapping)
flag = ""

for i in ct:
    flag += mapping[i]

print(flag) # uiuctf{i_g0t_sleepy_s0_I_13f7_th3_fl4g}

Last updated