Cracking AES Without any one of its Operations

Bill Elim

Upon learning AES especially for CTF, one might start from an attack that doesn’t really requires the deep knowledge of its internals (e.g AES-ECB padding attack, AES-CBC bitflip attack). But the other day I stumbled across this forum and think that it will be an interesting topic to explore the inner workings of AES and try to see some new way of attacking a faulty AES.

In short, AES transforms the plaintext using these 4 operations:

1. SubBytes: Just like substitution cipher
2. ShiftRows: Shifting some rows in the matrix
3. MixColumns: Mixing some columns in the matrix
4. AddRoundKey: xor-ing the text with some values derived from the key (the round key is different each round)

If combined these operations creates a pretty secure encryption scheme, especially against things like chosen plaintext attack, this then raised the question, what if we remove one of the operation?

We’ll first take a quick look on how these operation works, including whether or not they’re linear and key-dependent. Linear means the transformation is, well… linear (can be expressed as Mx + B), while key-dependent means different key will have different result for the transformation, that is non key-dependent transformation doesn’t necessarily needs a key to perform the transformation.

1. SubBytes: non-linear, not key-dependent
2. ShiftRows: linear, not key-dependent
3. MixColumns: linear, not key-dependent
4. AddRoundKey: linear, key-dependent

We can see that there is one non-linear function and one key-dependent function, the key-dependent function is important for obvious reason. But why is a non-linear function so important? We’ll discuss it later but first

Nerfed AES 1: AES Without AddRoundKey

Consider the following challenge (From now on we will use Bo Zhu AES Implementation)

from aes import AES
import os

do_nothing = lambda *x: None

c = AES(os.urandom(16))
c._add_round_key = do_nothing
p = b"Very secret text"
ciphertext = c.encrypt(p)
print(ciphertext.hex())
# 11633a27deebd11d08c18fa6619b008e

Here we can see that we have changed the AES to AddRoundKey function to actually do nothing, essentially removing the only key-dependent function from the algorithm, this means that we don’t even need the key anymore, we can just reverse the entire transformation

from aes import AES

ciphertext = bytes.fromhex("11633a27deebd11d08c18fa6619b008e")
c = AES(b"blahdoesntmatter")
c._add_round_key = lambda *x: None
print(c.decrypt(ciphertext))

Nerfed AES 2: AES Without SubBytes

Well the first one was easy, but what about this one? In order to understand this, we first need to really understand how these 4 operations are used in AES

AES transforms our plaintext in 10 rounds:

Round 0: AddRoundKey
Round 1–9: SubBytes, ShiftRows, MixColumns, AddRoundKey
Round 10: SubBytes, ShiftRows, AddRoundKey

If SubBytes is gone, the resulting transformation looks like this:

Round 0: AddRoundKey
Round 1–9: ShiftRows, MixColumns, AddRoundKey
Round 10: ShiftRows, AddRoundKey

This makes the entire transformation linear, let’s see how we can exploit it.

Prerequisite: Transformation Matrix

Before we get into the attack, we must first understand about linear transformation, usually to operate a 16 bytes plaintext, it is converted to a matrix, functions like ShiftRows and MixColumns operates the same way as a regular transposition (moving around the elements in the matrix to different position)

Consider the following data:

[1, 2, 3, 4, 5]

Now lets say that we want to shuffle the data to something like [2, 5, 4, 1, 3], we can actually transform this to become some sort of equation by treating the data as a matrix and multiplying it by some transformation matrix, for example:

By understanding this, we can now model the transformation of ShiftRows and MixColumns as some sort of matrix transformation. You can check the details here, but we will just assume the ShiftRow transformation Matrix as S and the MixColumns transformation as M

After that, we can transform the entire AES into some equation.
Let’s treat our plaintext as a matrix P, We will then simulate the entire nerfed AES, first round we perform AddRoundKey

k_i is the i-th round’s key

And then for the next 9 rounds, we will perform ShiftRows, MixColumns, and AddRoundKey

Now for simplicity let’s create a new matrix A = MS, this gives us:

Finally adding to the last round we only use ShiftRows and AddRoundKey

Now let’s expand the equation

Let’s introduce a new variable K where

Now the whole equation becomes:

Notice that K is only there because of the AddRoundKey function, that is, without AddRoundKey, the entire ciphertext becomes SA⁹P.
Also notice that SA⁹ is not key-dependent, this means that we CAN precalculate this

So here’s the plan:

  1. Find a plaintext-ciphertext pair that is encrypted without SubBytes
  2. Compute SA⁹P (Just AES encryption without SubBytes and AddRoundKey)
  3. Recover K = C - SA⁹P
  4. From now on to recover any plaintext given C, simply calculate
    SA⁹P = C - K and use your favorite matrix equation solving technique to find P

There is just one problem, pre-computing SA⁹ is not that simple, looking at our previous example, here is one such matrix that is also the valid transformation matrix

So there is multiple solution when trying to infer SA⁹ from SA⁹P, if you use something like solve_left in SageMath, this won’t work, because it only gives you one solution.

Upon doing some research I found out about Daniel’s answer that explains the calculation in great detail, so I decided to implement it.

Here is a quick PoC

from sage.all import *
from aes import AES
import os
do_nothing = lambda *x: None

def bytes2mat(b):
a = []
for i in b:
tmp = bin(i)[2:].zfill(8)
for j in tmp:
a.append(int(j))
return Matrix(GF(2), a)

def mat2bytes(m):
a = ""
for i in range(128):
a += str(m[0, i])
a = [a[i:i+8] for i in range(0, 128, 8)]
a = [int(i, 2) for i in a]
return bytes(a)

I = identity_matrix(GF(2), 8)
X = Matrix(GF(2), 8, 8)
for i in range(7):
X[i, i+1] = 1
X[3, 0] = 1
X[4, 0] = 1
X[6, 0] = 1
X[7, 0] = 1

C = block_matrix([
[X, X+I, I, I],
[I, X, X+I, I],
[I, I, X, X+I],
[X+I, I, I, X]
])

zeros = Matrix(GF(2), 8, 8)
zeros2 = Matrix(GF(2), 32, 32)
o0 = block_matrix([
[I, zeros, zeros, zeros],
[zeros, zeros, zeros, zeros],
[zeros, zeros, zeros, zeros],
[zeros, zeros, zeros, zeros]
])

o1 = block_matrix([
[zeros, zeros, zeros, zeros],
[zeros, I, zeros, zeros],
[zeros, zeros, zeros, zeros],
[zeros, zeros, zeros, zeros]
])

o2 = block_matrix([
[zeros, zeros, zeros, zeros],
[zeros, zeros, zeros, zeros],
[zeros, zeros, I, zeros],
[zeros, zeros, zeros, zeros]
])

o3 = block_matrix([
[zeros, zeros, zeros, zeros],
[zeros, zeros, zeros, zeros],
[zeros, zeros, zeros, zeros],
[zeros, zeros, zeros, I]
])

S = block_matrix([
[o0, o1, o2, o3],
[o3, o0, o1, o2],
[o2, o3, o0, o1],
[o1, o2, o3, o0]
])

M = block_matrix([
[C, zeros2, zeros2, zeros2],
[zeros2, C, zeros2, zeros2],
[zeros2, zeros2, C, zeros2],
[zeros2, zeros2, zeros2, C]
])

R = M*S
A = S*(R**9) # sorry for the inconsistency in the variable name, this is supposed to be SA^9 that I talked about

c = AES(os.urandom(16))
c._sub_bytes = do_nothing
p = b"Very secret text"
ct = c.encrypt(p)
p2 = b"Known plaintextt"
ct2 = c.encrypt(p2)

p2 = bytes2mat(p2).transpose()
ct2 = bytes2mat(ct2).transpose()

K = ct2 - A*p2
recovered_plaintext = mat2bytes((A.inverse() * (bytes2mat(ct).transpose() - K)).transpose())
print(recovered_plaintext)

As you can see, with only a single plaintext-ciphertext pair (p2-ct2), we are able to recover another plaintext given only their ciphertext.

Nerfed AES 3: AES Without ShiftRows

Take a look at this image (from Wikipedia)

As you can see, ShiftRows… well, shift rows, while MixColumns mixes columns, so if there is no ShiftRows, then each column will not interact with each other, that is a change in one column will not affect the output of the other column, so we can treat a single block of text as a 4 separate blocks!

Here is a quick PoC

from aes import AES
import os
do_nothing = lambda *x: None
c = AES(os.urandom(16))
c._shift_rows = do_nothing
p1 = b"A"*4 + b"B"*12
c1 = c.encrypt(p1)
print(c1.hex())

p2 = b"A"*4 + b"C"*12
c2 = c.encrypt(p2)
print(c2.hex())
# 8badc0d82d94b3be32287566848507c0
# 8badc0d8aeb319f02b18448aa827d433

As you can see, the first 4 bytes were 8badc0d8 for both ciphertext, even though the rest of the block has been changed.

Although we can treat this as 4 block of 4 bytes each, we still need around 256⁴ ciphertext pair to recover the plaintext, however, it’s still worth noting that there can be a situation where this attack is feasible, such as if the plaintext only contains a repeated bytes (e.g AAAABBBBCCCCDDDD), we can create some sort of lookup table by encrypting AAAAAAAAAAAAAAAA, BBBBBBBBBBBBBBBB, CCCCCCCCCCCCCCCC, and so on, this only requires 256 plaintext-ciphertext pair

Nerfed AES 4: AES Without MixColumns

A good thing about MixColumns is that it combines multiple bytes to create a new set of bytes, so it’s not only transposition, but also contains some sort of combinations between the bytes.

When MixColumns is gone, the entire plaintext will now be treated separately, this is even worse than ShiftRows, because we can now divide a single block of 16 bytes into 16 SEPARATE BLOCK, that is every byte is independent and are not affected by any other byte.

Figuring out the relation of the plaintext and the ciphertext due to ShiftRows is not that hard, it has quite a pattern actually, here is the diagram that I draw to showcase the pattern:

As you can see, one byte of plaintext is responsible for another byte of the ciphertext, this means changing any other byte of the plaintext will not affect the outcome of the other byte of the ciphertext

For example if we have this:

Changing one byte will not affect the other

Here is some quick PoC

from aes import AES
import os
relation = [0, 9, 2, 11, 4, 13, 6, 15, 8, 1, 10, 3, 12, 5, 14, 7]
c = AES(os.urandom(16))
c._mix_columns = lambda *x: None
pt = b"Very secret text"
ct = c.encrypt(pt)

lookup = [{} for _ in range(16)]
for i in range(256):
ptest = bytes([i]*16)
ctest = c.encrypt(ptest)
for j in range(16):
lookup[j][ctest[j]] = i

recovered = ["??" for _ in range(16)]
for i in range(16):
recovered[i] = lookup[relation[i]][ct[relation[i]]]

print("".join([chr(i) for i in recovered]))

And there you have it, We can now exploit AES without any one of it’s operations!

As a bonus, I was playing Hology, an Indonesian CTF the other day and I stumbled upon this challenge:

import os
from aes import AES

# A function that does nothing
no_op = lambda *x: None

def main():
k = os.urandom(16)
c = AES(k)
s = b''.join(k[i:i+1]*4 for i in range(16))

flag = os.environ.get('FLAG', 'Hology6{***********************MISSING*************************}').encode()
assert len(flag) == 64

flag = b''.join([c.encrypt(flag[i:i+16]) for i in range(0, 64, 16)])
print(f'Here is encrypted flag: {flag.hex()}.')

opts = ['sb', 'sr', 'mc', 'ark']
sopts = ['data', 'secret']

for _ in range(128):
[opt, suboption, *more] = input('> ').split(' ')
if opt not in opts: raise Exception('invalid option!')
if suboption not in sopts: raise Exception('invalid suboption!')

if suboption == 'secret':
opts.remove(opt)
msg = s
else:
msg = bytes.fromhex(more[0])
if len(msg) != 16: raise Exception('invalid length!')
msg = msg * 4

if opt == 'sb':
c = AES(k)
c._sub_bytes = no_op
ct = c.encrypt(msg[0:16])

elif opt == 'sr':
c = AES(k)
c._shift_rows = no_op
ct = c.encrypt(msg[16:32])

elif opt == 'mc':
c = AES(k)
c._mix_columns = no_op
ct = c.encrypt(msg[32:48])

elif opt == 'ark':
c = AES(k)
c._add_round_key = no_op
ct = c.encrypt(msg[48:64])

print(ct.hex())

if __name__ == '__main__':
main()

Can you figure this out? I will probably make a separate post to talk about this challenge.

References and further reading:

Sign up to discover human stories that deepen your understanding of the world.

Responses (2)

Write a response

Noice stuff 🧨🧨