Without a Trace (crypto)

Gone with the wind, can you find my flag?

We are given the following Python code being run on the remote server:

import numpy as np
from Crypto.Util.number import bytes_to_long
from itertools import permutations
from SECRET import FLAG

def inputs():
    print("[WAT] Define diag(u1, u2, u3. u4, u5)")
    M = [
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
    ]
    for i in range(5):
        try:
            M[i][i] = int(input(f"[WAT] u{i + 1} = "))
        except:
            return None
    return M

def handler(signum, frame):
    raise Exception("[WAT] You're trying too hard, try something simpler")

def check(M):
    def sign(sigma):
        l = 0
        for i in range(5):
            for j in range(i + 1, 5):
                if sigma[i] > sigma[j]:
                    l += 1
        return (-1)**l

    res = 0
    for sigma in permutations([0,1,2,3,4]):
        curr = 1
        for i in range(5):
            curr *= M[sigma[i]][i]
        res += sign(sigma) * curr
    return res

def fun(M):
    f = [bytes_to_long(bytes(FLAG[5*i:5*(i+1)], 'utf-8')) for i in range(5)]
    F = [
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
    ]
    for i in range(5):
        F[i][i] = f[i]

    try:
        R = np.matmul(F, M)
        return np.trace(R) # return sum along the diagonals

    except:
        print("[WAT] You're trying too hard, try something simpler")
        return None

def main():
    print("[WAT] Welcome")
    M = inputs() # generates a diagonal matrix with entries of our choice
    if M is None:
        print("[WAT] You tried something weird...")
        return
    elif check(M) == 0:
        print("[WAT] It's not going to be that easy...")
        return

    res = fun(M)
    if res == None:
        print("[WAT] You tried something weird...")
        return
    print(f"[WAT] Have fun: {res}")

if __name__ == "__main__":
    main()

main() first calls inputs(), which prompts us for inputs. Our input is then used as diagonal entries for a matrix M: M[i][i] = int(input(f"[WAT] u{i + 1} = "))

Then our matrix is checked using check(M). If check(M) == 0, our inputs are rejected. Looking at check(), it is using a permutation algorithm to calculate the determinant of the matrix and returns it. Since our matrix is diagonal, its determinant is the product of the diagonal entries. Hence to pass the check, none of our inputs can be 0.

After the check, the server prints fun(M). The line

f = [bytes_to_long(bytes(FLAG[5*i:5*(i+1)], 'utf-8')) for i in range(5)]

first breaks the flag into 5 groups of 5 bytes each. Then, each group is converted to a long, giving us an array of 5 longs. Following this, we have:

F = [
    [0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0],
]
for i in range(5):
    F[i][i] = f[i]

which takes the longs in f and uses them as diagonal entries for F.

fun() then does

R = np.matmul(F, M)
return np.trace(R) # return sum along the diagonals

So it multiplies the 5x5 diagonal matrix F (representing the flag) with our given 5x5 diagonal matrix M, obtaining another 5x5 diagonal matrix R. Then we are given the sum of R's diagonal entries.

R5×5=[f000000f100000f200000f300000f4][m000000m100000m200000m300000m4]=[f0m000000f1m100000f2m200000f3m300000f4m4]R_{5\times5}=\begin{bmatrix} f_0 & 0 & 0 & 0 & 0\\ 0 & f_1 & 0 & 0 & 0\\ 0 & 0 & f_2 & 0 & 0 \\ 0 & 0 & 0 & f_3 & 0 \\ 0 & 0 & 0 & 0 & f_4 \end{bmatrix}\begin{bmatrix} m_0 & 0 & 0 & 0 & 0\\ 0 & m_1 & 0 & 0 & 0\\ 0 & 0 & m_2 & 0 & 0 \\ 0 & 0 & 0 & m_3 & 0 \\ 0 & 0 & 0 & 0 & m_4 \end{bmatrix}=\begin{bmatrix} f_0m_0 & 0 & 0 & 0 & 0\\ 0 & f_1m_1 & 0 & 0 & 0\\ 0 & 0 & f_2m_2 & 0 & 0 \\ 0 & 0 & 0 & f_3m_3 & 0 \\ 0 & 0 & 0 & 0 & f_4m_4 \end{bmatrix}

Therefore,

res=f0m0+f1m1+f2m2+f3m3+f4m4res=f_0m_0+f_1m_1+f_2m_2+f_3m_3+f_4m_4

Even if none of our mm's can be 0, we can still obtain the different values of ff! In one connection we can set m0=2,m1=m2=m3=m4=1m_0=2,m_1=m_2=m_3=m_4=1, which would give us res0=2f0+f1+f2+f3+f4res_0=2f_0+f_1+f_2+f_3+f_4 and in another connection we can set m0=m1=m2=m3=m4=1m_0=m_1=m_2=m_3=m_4=1, giving us res1=f0+f1+f2+f3+f4res_1=f_0+f_1+f_2+f_3+f_4. We can obtain f0f_0 by calculating res0res1res_0 - res_1! And repeat this process to obtain f1,f2,f3f_1,f_2,f_3 and f4f_4. Then, we can use long_to_bytes from Crypto.Util.number to retrieve the bytes within each group.

from pwn import *
from Crypto.Util.number import long_to_bytes

context.log_level = "debug"

arr = [1,1,1,1,1]

elems = []

for i in range(5):
    p = remote("without-a-trace.chal.uiuc.tf", 1337, ssl=True)
    new_arr = [1,1,1,1,1]
    new_arr[i] += 1
    for j in range(5):
        p.sendlineafter(b"= ", str(new_arr[j]).encode())
    res = p.recvline()
    res = int(res[16:-1].decode())
    p.close()
    p = remote("without-a-trace.chal.uiuc.tf", 1337, ssl=True)
    for j in range(5):
        p.sendlineafter(b"= ", str(arr[j]).encode())
    res2 = p.recvline()
    res2 = int(res2[16:-1].decode())
    p.close()
    elems.append(res - res2)

print(elems)
flag = b""
for x in elems:
    flag += long_to_bytes(x)

print(flag) # uiuctf{tr4c1ng_&&_mult5!}

Last updated