#include <Arduino.h>

// Substitution box dictionaries
const char* substitution_box_dict[] = {
    "1", "0", "8", "C",
    "A", "B", "4", "5",
    "9", "2", "3", "D",
    "6", "7", "E", "F"
};

const char* reverse_substitution_box_dict[] = {
    "0000", "0001", "0010", "0011",
    "0100", "0101", "0110", "0111",
    "1000", "1001", "1010", "1011",
    "1100", "1101", "1110", "1111"
};

// Common Functions
void split_binary_into_chunks(const char* binary_string, int chunk_size, char* result[]) {
    // Split binary_string into chunks of chunk_size and store in result array
    for (int i = 0; i < strlen(binary_string); i += chunk_size) {
        strncpy(result[i / chunk_size], binary_string + i, chunk_size);
    }
}

char* merge_chunks_into_binary(char* chunk_list[], int num_chunks, int chunk_size) {
    char* result = (char*)malloc(num_chunks * chunk_size + 1); // +1 for null terminator
    result[0] = '\0'; // Initialize as empty string
    for (int i = 0; i < num_chunks; i++) {
        strcat(result, chunk_list[i]);
    }
    return result;
}

char* xor_binary_strings(const char* binary_str1, const char* binary_str2) {
    // Perform XOR operation on binary strings
    char* result = (char*)malloc(strlen(binary_str1) + 1);
    for (int i = 0; i < strlen(binary_str1); i++) {
        result[i] = (binary_str1[i] == binary_str2[i]) ? '0' : '1';
    }
    result[strlen(binary_str1)] = '\0';
    return result;
}

// Encryption functions

unsigned long encrypt_string_to_decimal(const char* input_str) {
    unsigned long decimal_values_concatenated = 0;  // Initialize as integer
    for (int i = 0; i < strlen(input_str); i++) {
        decimal_values_concatenated = decimal_values_concatenated * 1000 + input_str[i];  // Convert character to ASCII value and concatenate
    }
    return decimal_values_concatenated;
}

char* decimal_to_multiple_128bit_binary(unsigned long decimal_number) {
    // Convert decimal number to binary
    char binary_number[33];
    itoa(decimal_number, binary_number, 2);

    // Calculate the number of bits needed to make it a multiple of 128
    int padding_length = (128 - strlen(binary_number) % 128) % 128;

    // Add padding zeros to make it a multiple of 128 bits
    char* padded_binary_number = (char*)malloc(strlen(binary_number) + padding_length + 1);
    strcpy(padded_binary_number, "0");
    strcat(padded_binary_number, binary_number);
    for (int i = 0; i < padding_length; i++) {
        strcat(padded_binary_number, "0");
    }
    return padded_binary_number;
}

void NOT_gate_32_bit(const char* input_bits, char* output_bits) {
    // Ensure the input is exactly 32 bits long
    if (strlen(input_bits) != 32) {
        return;
    }

    // Perform the NOT operation on each bit of the input
    for (int i = 0; i < 32; i++) {
        output_bits[i] = (input_bits[i] == '0') ? '1' : '0';
    }
    output_bits[32] = '\0';
}

char* split_and_swap_32_to_64(const char* input_block) {
    // Ensure the input block is exactly 32 bits
    if (strlen(input_block) != 32) {
        return NULL;
    }

    // Split the input block into left and right halves (16 bits each)
    char left_half[17];
    char right_half[17];
    strncpy(left_half, input_block, 16);
    left_half[16] = '\0';
    strncpy(right_half, input_block + 16, 16);
    right_half[16] = '\0';

    //Generate the middle part of 32 bits compliment using function
    char middle_element[33];
    NOT_gate_32_bit(input_block, middle_element);
    char* element_0 = (char*)malloc(65);
    strcpy(element_0, right_half);
    strcat(element_0, middle_element);
    strcat(element_0, left_half);
    return element_0;
}

char* substitution_box(const char* input_bits) {
    char* s_box_output = (char*)malloc(strlen(input_bits) / 4 + 1);
    s_box_output[0] = '\0';
    for (int i = 0; i < strlen(input_bits); i += 4) {
        char chunk[5];
        strncpy(chunk, input_bits + i, 4);
        chunk[4] = '\0';
        strcat(s_box_output, substitution_box_dict[strtol(chunk, NULL, 2)]);
    }
    return s_box_output;
}

char* permute(const char* input_text) {
    char* output_text = (char*)malloc(strlen(input_text) + 1);
    output_text[0] = '\0';
    strcat(output_text, input_text + 7);
    strncat(output_text, input_text, 7);
    return output_text;
}

char* permutation_box(const char* input_text) {
    char* to_return = (char*)malloc(strlen(input_text) + 1);
    to_return[0] = '\0';
    for (int i = 0; i < strlen(input_text); i += 8) {
        char chunk[9];
        strncpy(chunk, input_text + i, 8);
        chunk[8] = '\0';
        char* to_add = permute(chunk);
        strcat(to_return, to_add);
        free(to_add);
    }
    return to_return;
}

// Decryption functions

char* reverse_expansion(const char* input_64) {
    char a[33];
    strncpy(a, input_64 + 16, 32);
    a[32] = '\0';
    char b[33];
    NOT_gate_32_bit(a, b);
    char* result = (char*)malloc(33);
    strcpy(result, b);
    return result;
}

char* reverse_substitution_box(const char* input_bits) {
    char* s_box_output = (char*)malloc(strlen(input_bits) / 4 + 1);
    s_box_output[0] = '\0';
    for (int i = 0; i < strlen(input_bits); i++) {
        char chunk[2];
        strncpy(chunk, input_bits + i, 1);
        chunk[1] = '\0';
        strcat(s_box_output, reverse_substitution_box_dict[chunk[0] - '0']);
    }
    return s_box_output;
}

char* reverse_permute(const char* input_text) {
    char* output_text = (char*)malloc(strlen(input_text) + 1);
    output_text[0] = '\0';
    strcat(output_text, input_text + 1);
    strncat(output_text, input_text + 7, 1);
    strncat(output_text, input_text + 6, 1);
    strncat(output_text, input_text + 5, 1);
    strncat(output_text, input_text + 3, 1);
    strncat(output_text, input_text + 4, 1);
    strncat(output_text, input_text + 2, 1);
    strncat(output_text, input_text, 1);
    return output_text;
}

char* reverse_permutation_box(const char* input_text) {
    char* to_return = (char*)malloc(strlen(input_text) + 1);
    to_return[0] = '\0';
    for (int i = 0; i < strlen(input_text); i += 8) {
        char chunk[9];
        strncpy(chunk, input_text + i, 8);
        chunk[8] = '\0';
        char* to_add = reverse_permute(chunk);
        strcat(to_return, to_add);
        free(to_add);
    }
    return to_return;
}

// Encryption

char* encryption(const char* input_text, const char* k1) {
    unsigned long a = encrypt_string_to_decimal(input_text);
    char* b = decimal_to_multiple_128bit_binary(a);
    char* arr[2];
    arr[0] = split_and_swap_32_to_64(b);
    char* k1_xor = xor_binary_strings(arr[0], k1);
    arr[0] = k1_xor;
    char* to_return = merge_chunks_into_binary(arr, 1, 64);
    free(b);
    free(k1_xor);
    char* s_box_output = substitution_box(to_return);
    free(to_return);
    char* p_box_output = permutation_box(s_box_output);
    free(s_box_output);
    return p_box_output;
}

// Decryption

char* decryption(const char* decrypted_text, const char* k1) {
    char* reverse_p_box_output = reverse_permutation_box(decrypted_text);
    char* reverse_s_box_output = reverse_substitution_box(reverse_p_box_output);
    char* arr[2];
    arr[0] = xor_binary_strings(reverse_s_box_output, k1);
    char* arr2[2];
    arr2[0] = reverse_expansion(arr[0]);
    free(arr[0]);
    char* result = merge_chunks_into_binary(arr2, 1, 32);
    free(arr2[0]);
    return result;
}

void setup() {
    Serial.begin(9600);
}

void loop() {
    // Example usage
    const char* plain_text = "HelloWorld";
    const char* key_1 = "1010101010101010101010101010101010101010101010101010101010101010";

    Serial.print("Original text is: ");
    Serial.println(plain_text);
    Serial.print("Key for encryption: ");
    Serial.println(key_1);

    char* encrypted_text = encryption(plain_text, key_1);
    Serial.print("Encrypted text is: ");
    Serial.println(encrypted_text);

    char* decrypted_text = decryption(encrypted_text, key_1);
    Serial.print("Decrypted text is: ");
    Serial.println(decrypted_text);

    free(encrypted_text);
    free(decrypted_text);

    delay(5000); // Delay for 5 seconds
}