#define Nb 4
#define Nk 4
#define Nr 11

uint8_t w[Nb * (Nr + 1) * 4];

uint8_t bentar [4][4];

uint8_t state[4][4] = {
    {0x1f, 0x2d, 0x3a, 0x5c},
    {0x44, 0x5d, 0x6a, 0x7f},
    {0x7d, 0x8a, 0x9b, 0x22},
    {0x98, 0x7c, 0x92, 0xf2}
    };

uint8_t key[Nk * 4] = {0x2b, 0x7e, 0x15, 0x16, 
                      0x28, 0xae, 0xd2, 0xa6,
                      0xab, 0xf7, 0x15, 0x88, 
                      0x09, 0xcf, 0x4f, 0x3c};


const uint8_t Sbox[256] ={
  0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, 0x30, 0x01, 0x67, 0x2B, 0xFE, 0xD7, 0xAB, 0x76,
  0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0, 0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0xA4, 0x72, 0xC0,
  0xB7, 0xFD, 0x93, 0x26, 0x36, 0x3F, 0xF7, 0xCC, 0x34, 0xA5, 0xE5, 0xF1, 0x71, 0xD8, 0x31, 0x15,
  0x04, 0xC7, 0x23, 0xC3, 0x18, 0x96, 0x05, 0x9A, 0x07, 0x12, 0x80, 0xE2, 0xEB, 0x27, 0xB2, 0x75,
  0x09, 0x83, 0x2C, 0x1A, 0x1B, 0x6E, 0x5A, 0xA0, 0x52, 0x3B, 0xD6, 0xB3, 0x29, 0xE3, 0x2F, 0x84,
  0x53, 0xD1, 0x00, 0xED, 0x20, 0xFC, 0xB1, 0x5B, 0x6A, 0xCB, 0xBE, 0x39, 0x4A, 0x4C, 0x58, 0xCF,
  0xD0, 0xEF, 0xAA, 0xFB, 0x43, 0x4D, 0x33, 0x85, 0x45, 0xF9, 0x02, 0x7F, 0x50, 0x3C, 0x9F, 0xA8,
  0x51, 0xA3, 0x40, 0x8F, 0x92, 0x9D, 0x38, 0xF5, 0xBC, 0xB6, 0xDA, 0x21, 0x10, 0xFF, 0xF3, 0xD2,
  0xCD, 0x0C, 0x13, 0xEC, 0x5F, 0x97, 0x44, 0x17, 0xC4, 0xA7, 0x7E, 0x3D, 0x64, 0x5D, 0x19, 0x73,
  0x60, 0x81, 0x4F, 0xDC, 0x22, 0x2A, 0x90, 0x88, 0x46, 0xEE, 0xB8, 0x14, 0xDE, 0x5E, 0x0B, 0xDB,
  0xE0, 0x32, 0x3A, 0x0A, 0x49, 0x06, 0x24, 0x5C, 0xC2, 0xD3, 0xAC, 0x62, 0x91, 0x95, 0xE4, 0x79,
  0xE7, 0xC8, 0x37, 0x6D, 0x8D, 0xD5, 0x4E, 0xA9, 0x6C, 0x56, 0xF4, 0xEA, 0x65, 0x7A, 0xAE, 0x08,
  0xBA, 0x78, 0x25, 0x2E, 0x1C, 0xA6, 0xB4, 0xC6, 0xE8, 0xDD, 0x74, 0x1F, 0x4B, 0xBD, 0x8B, 0x8A,
  0x70, 0x3E, 0xB5, 0x66, 0x48, 0x03, 0xF6, 0x0E, 0x61, 0x35, 0x57, 0xB9, 0x86, 0xC1, 0x1D, 0x9E,
  0xE1, 0xF8, 0x98, 0x11, 0x69, 0xD9, 0x8E, 0x94, 0x9B, 0x1E, 0x87, 0xE9, 0xCE, 0x55, 0x28, 0xDF,
  0x8C, 0xA1, 0x89, 0x0D, 0xBF, 0xE6, 0x42, 0x68, 0x41, 0x99, 0x2D, 0x0F, 0xB0, 0x54, 0xBB, 0x16
  };

const uint8_t Rcon[11] = {0x00, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36};

uint8_t getSBoxValue(uint8_t num) 
{
  return Sbox[num];
}

void KeyExpansion(uint8_t key[Nk * 4], uint8_t w[Nb * (Nr + 1) * 4]) {
  int i, j;
  uint8_t temp[4], k;
  for (i = 0; i < Nk; i++) {
    for (j = 0; j < 4; j++) {
    w[i * 4 + j] = key[i * 4 + j];
  }
}
    for (i = Nk; i < Nb * (Nr + 1); i++) {
      for (j = 0; j < 4; j++) {
      temp[j] = w[(i-1) * 4 + j]; 
  }

    if (i % Nk == 0) {
// rotasi 4 byte
      uint8_t t = temp[0];
        temp[0] = temp[1];
        temp[1] = temp[2];
        temp[2] = temp[3];
        temp[3] = t;
// Subtitusi dengan sbox
    for (j = 0; j < 4; j++) {
    temp[j] = getSBoxValue(temp[j]);
  }
    temp[0] = temp[0] ^ Rcon[i/Nk];
  }
    for (j = 0; j < 4; j++) {
    w[i * 4 + j] = w[(i - Nk) * 4 + j] ^ temp[j];
    }
  }
}

void Subbyte(uint8_t state[4][4])
{
for (int i = 0; i < 4; i++) {	//ubah nilai i ke sbox j=0, looping hingga i=4 
  for (int j = 0; j < 4; j++) {	//ubah nilai i ke sbox j=0, looping hingga i=4 
    state[i][j] = Sbox[state[i][j]];		//write Ptx(i,j)=sbox_out 
    Serial.print(state[i][j], HEX);			//sbox_out
    Serial.print(" ");	//end (menampilkan hasil subbytes)
      }
    Serial.println();
    }
}

void Shiftrow(uint8_t state[4][4])
{
  for (int i = 0; i < 4; i++) {
    for (int j = 0; j < 4; j++) {
      bentar[i][j] = state[i][(j+i) % 4];
    }
      for (int k = 0; k < 4; k++) {
      state[i][k] = bentar[i][k];
      Serial.print(state[i][k], HEX);
      Serial.print(" ");
    }
  Serial.println();
  }
}


uint8_t mix_matrix[4][4] = {{2, 3, 1, 1}, //define row and column
                            {1, 2, 3, 1},
                            {1, 1, 2, 3},
                            {3, 1, 1, 2}};

void mixColumns(uint8_t state[4][4]) {        //Perhitungan
uint8_t temp[4];

for (int i = 0; i < 4; i++) {
  for (int j = 0; j < 4; j++) { 
  temp[j] = 0;
    for (int k = 0; k < 4; k++) {
    temp[j] ^= GF2(mix_matrix[j][k], state[k][i]);
    }
  }
      for (int j = 0; j < 4; j++) {
      state[j][i] = temp[j];
      }
    }
  }

uint8_t GF2(uint8_t a, uint8_t b) {             //modulus 2^8
uint8_t result = 0;
  while (b > 0) {
    if (b & 1) {
      result ^= a;
  }

a = (a << 1) ^ ((a & 0x80) ? 0x1B : 0x00);
b >>= 1;
}

return result;
}


void AddRoundKey(uint8_t state[4][4], uint8_t key[4][4])
{
    for (int i = 0; i < 4; i++)
    {
        for (int j = 0; j < 4; j++)
        {
            state[i][j] = state[i][j] ^ key[i][j];
        }
    }
}

void Keyss(uint8_t keys[4][4])
{
    for (int i = 0; i < 4; i++)
    {
        for (int j = 0; j < 4; j++)
        {
            Serial.print(keys[i][j], HEX);
            Serial.print(" ");
        }
        Serial.print("\n");
    }
    Serial.print("\n");
}

void setup() {
  Serial.begin(9600);

    Serial.println("-----AES Encryption-----");
    Serial.println("START OF ROUND");
    for (int i = 0; i < 4; i++)
    {
        for (int j = 0; j < 4; j++)
        {
            Serial.print(state[i][j], HEX);
            Serial.print(" ");
        }
        Serial.println("");
    }

    int JumlahRonde = 9;

    KeyExpansion(key, w);

    uint8_t keys[11][4][4];

    for (int i = 0; i < 11; i++)
    {
        for (int j = 0; j < 4; j++)
        {
            for (int k = 0; k < 4; k++)
            {
                keys[i][k][j] = w[i * 16 + j * 4 + k];
            }
        }
    }

    

    Serial.println("\nROUND KEY VALUE");
    Keyss(keys[0]);
    Serial.println("\n----- Round 1 -----");
    Serial.println("\nSTART OF ROUND");
    AddRoundKey(state, keys[0]);
    for (int i = 0; i < 4; i++)
    {
        for (int j = 0; j < 4; j++)
        {
            Serial.print(state[i][j], HEX);
            Serial.print(" ");
        }
        Serial.println("");
    }

    for (int i = 0; i <= JumlahRonde; i++)
    {
        Serial.println("\nAFTER SUBBYTES:");
        Subbyte(state);
        Serial.println("\nAFTER SHIFTROWS:");
        Shiftrow(state);
        if (i < JumlahRonde)
        {
            Serial.println("\nAFTER MIXCOLUMN");
            mixColumns(state);
            Serial.println("\nROUND KEY VALUE");
            Keyss(keys[i + 1]);
            Serial.print("\n======= RONDE ");
            Serial.print(i + 2);
            Serial.print(" =======");
            Serial.println("\nSTART OF ROUND");
        }
        AddRoundKey(state, keys[i + 1]);
    }
    Serial.println("ROUND KEY VALUE");
    Keyss(keys[10]);
    Serial.println("CIPHERTEXT");
    for (int i = 0; i < 4; i++)
    {
        for (int j = 0; j < 4; j++)
        {
            Serial.print(state[i][j], HEX);
            Serial.print(" ");
        }
        Serial.println("");
    }
}
  // put your setup code here, to run once:



void loop() {
  // put your main code here, to run repeatedly:

}