#include <stdint.h>
#include <string.h>
static const uint8_t SAES_SBOX[16] = {
0x9, 0x4, 0xA, 0xB,
0xD, 0x1, 0x8, 0x5,
0x6, 0x2, 0x0, 0x3,
0xC, 0xE, 0xF, 0x7
};
static const uint8_t SAES_RSBOX[16] = {
0xA, 0x5, 0x9, 0xB,
0x1, 0x7, 0x8, 0xF,
0x6, 0x0, 0x2, 0x3,
0xC, 0x4, 0xD, 0xE
};
static const uint8_t SAES_RCON[3] = { 0x0, 0x1, 0x2 };
static uint8_t gf16mul(uint8_t a, uint8_t b) {
uint8_t p = 0;
for (uint8_t i = 0; i < 4; i++) {
if (b & 1) p ^= a;
uint8_t carry = a & 0x8;
a = (a << 1) & 0xF;
if (carry) a ^= 0x3;
b >>= 1;
}
return p;
}
static uint8_t getNib(uint16_t block, uint8_t pos) {
return (block >> (12 - pos * 4)) & 0xF;
}
static uint16_t setNib(uint16_t block, uint8_t pos, uint8_t val) {
uint8_t shift = 12 - pos * 4;
block &= ~((uint16_t)0xF << shift);
block |= ((uint16_t)(val & 0xF) << shift);
return block;
}
class MiniAES16 {
public:
void setKey(uint8_t key) {
uint8_t w[6];
w[0] = (key >> 4) & 0xF;
w[1] = key & 0xF;
for (uint8_t i = 2; i < 6; i += 2) {
uint8_t r = i / 2;
w[i] = w[i-2] ^ SAES_SBOX[w[i-1]] ^ SAES_RCON[r];
w[i+1] = w[i-1] ^ SAES_SBOX[w[i] ];
}
roundKey[0] = pack(w[0], w[1]);
roundKey[1] = pack(w[2], w[3]);
roundKey[2] = pack(w[4], w[5]);
}
uint16_t encrypt(uint16_t block) {
block ^= roundKey[0];
block = subNib(block);
block = swapRow(block);
block = mixCols(block);
block ^= roundKey[1];
block = subNib(block);
block = swapRow(block);
block ^= roundKey[2];
return block;
}
uint16_t decrypt(uint16_t block) {
block ^= roundKey[2];
block = swapRow(block);
block = invSubNib(block);
block ^= roundKey[1];
block = mixCols(block);
block = swapRow(block);
block = invSubNib(block);
block ^= roundKey[0];
return block;
}
uint8_t encryptPayload(const uint8_t* in, uint8_t len, uint8_t* out) {
uint8_t n = 0;
for (uint8_t i = 0; i < len; i += 2) {
uint8_t hi = in[i];
uint8_t lo = (i + 1 < len) ? in[i + 1] : 0x00;
uint16_t b = ((uint16_t)hi << 8) | lo;
b = encrypt(b);
out[n++] = (uint8_t)(b >> 8);
out[n++] = (uint8_t)(b & 0xFF);
}
return n;
}
uint8_t decryptPayload(const uint8_t* in, uint8_t encLen, uint8_t* out) {
uint8_t n = 0;
for (uint8_t i = 0; i < encLen; i += 2) {
uint16_t b = ((uint16_t)in[i] << 8) | in[i + 1];
b = decrypt(b);
out[n++] = (uint8_t)(b >> 8);
out[n++] = (uint8_t)(b & 0xFF);
}
return n;
}
private:
uint16_t roundKey[3];
uint16_t pack(uint8_t hi, uint8_t lo) {
return ((uint16_t)hi << 12) | ((uint16_t)lo << 8)
| ((uint16_t)hi << 4) | (uint16_t)lo;
}
uint16_t subNib(uint16_t b) {
for (uint8_t i = 0; i < 4; i++)
b = setNib(b, i, SAES_SBOX[getNib(b, i)]);
return b;
}
uint16_t invSubNib(uint16_t b) {
for (uint8_t i = 0; i < 4; i++)
b = setNib(b, i, SAES_RSBOX[getNib(b, i)]);
return b;
}
uint16_t swapRow(uint16_t b) {
uint8_t n1 = getNib(b, 1);
uint8_t n3 = getNib(b, 3);
b = setNib(b, 1, n3);
b = setNib(b, 3, n1);
return b;
}
uint16_t mixCols(uint16_t b) {
for (uint8_t c = 0; c < 2; c++) {
uint8_t a0 = getNib(b, c * 2);
uint8_t a1 = getNib(b, c * 2 + 1);
b = setNib(b, c * 2, gf16mul(2, a0) ^ gf16mul(3, a1));
b = setNib(b, c * 2 + 1, gf16mul(3, a0) ^ gf16mul(2, a1));
}
return b;
}
};
MiniAES16 aes;
const uint8_t payload[15] = {
0x48, 0x65, 0x6C, 0x6C, 0x6F,
0x20, 0x57, 0x6F, 0x72, 0x6C,
0x64, 0x21, 0x20, 0x41, 0x42
};
void printHexBuf(const uint8_t* buf, uint8_t len) {
for (uint8_t i = 0; i < len; i++) {
if (buf[i] < 0x10) Serial.print('0');
Serial.print(buf[i], HEX);
Serial.print(' ');
}
Serial.println();
}
void setup() {
Serial.begin(9600);
while (!Serial);
uint8_t key = 0xA7;
aes.setKey(key);
Serial.print("Key: 0x"); Serial.println(key, HEX);
Serial.print("Payload len: "); Serial.print(sizeof(payload)); Serial.println(" bytes (padded to 16)");
Serial.print("Plaintext: "); printHexBuf(payload, sizeof(payload));
uint8_t encrypted[16];
uint8_t encLen = aes.encryptPayload(payload, sizeof(payload), encrypted);
Serial.print("Encrypted: "); printHexBuf(encrypted, encLen);
uint8_t decrypted[16];
aes.decryptPayload(encrypted, encLen, decrypted);
Serial.print("Decrypted: "); printHexBuf(decrypted, sizeof(payload));
bool ok = (memcmp(payload, decrypted, sizeof(payload)) == 0);
Serial.println(ok ? "OK: round-trip passed!" : "ERROR: mismatch!");
Serial.println();
Serial.println("-- Single block test --");
uint16_t pt = 0x1234;
uint16_t ct = aes.encrypt(pt);
uint16_t rt = aes.decrypt(ct);
Serial.print("Plaintext: 0x"); Serial.println(pt, HEX);
Serial.print("Encrypted: 0x"); Serial.println(ct, HEX);
Serial.print("Decrypted: 0x"); Serial.println(rt, HEX);
Serial.println(rt == pt ? "OK: round-trip passed!" : "ERROR: mismatch!");
}
void loop() {}