#include <Arduino.h>

// Define substitution box dictionary and reverse substitution box dictionary as arrays of strings
String substitution_box_dict[] = {"1","0","8","C","A","B","4","5","9","2","3","D","6","7","E","F"};
String reverse_substitution_box_dict[] = {"0000","0001","0010","0011","0100","0101","0110","0111","1000","1001","1010","1011","1100","1101","1110","1111"};

// Split binary into chunks
String* split_binary_into_chunks(String binary_string, int chunk_size) {
    int num_chunks = binary_string.length() / chunk_size;
    String* chunks = new String[num_chunks];
    for (int i = 0; i < num_chunks; i++) {
        chunks[i] = binary_string.substring(i * chunk_size, (i + 1) * chunk_size);
    }
    return chunks;
}

// Merge chunks into binary
String merge_chunks_into_binary(String* chunk_list, int num_chunks) {
    String result = "";
    for (int i = 0; i < num_chunks; i++) {
        result += chunk_list[i];
    }
    return result;
}

// XOR binary strings
String xor_binary_strings(String binary_str1, String binary_str2) {
    String result = "";
    for (int i = 0; i < binary_str1.length(); i++) {
        result += String((binary_str1.charAt(i) != binary_str2.charAt(i)) ? '1' : '0');
    }
    return result;
}

// Encryption functions

long encrypt_string_to_decimal(String input_str) {
    long decimal_values_concatenated = 0;  // Initialize as integer
    for (int i = 0; i < input_str.length(); i++) {
        decimal_values_concatenated = decimal_values_concatenated * 1000 + input_str.charAt(i);  // Convert character to ASCII value and concatenate
    }
    return decimal_values_concatenated;
}

String decimal_to_multiple_128bit_binary(long decimal_number) {
    String binary_number = String(decimal_number, BIN);

    // Calculate the number of bits needed to make it a multiple of 128
    int padding_length = (128 - binary_number.length() % 128) % 128;

    // Add padding zeros to make it a multiple of 128 bits
    String padded_binary_number = "";
    for (int i = 0; i < padding_length; i++) {
        padded_binary_number += "0";
    }
    padded_binary_number += binary_number;

    return padded_binary_number;
}

String NOT_gate_32_bit(String input_bits) {
    if (input_bits.length() != 32) {
        // Throw an error or handle appropriately
    }

    String output_bits = "";
    for (int i = 0; i < input_bits.length(); i++) {
        output_bits += (input_bits.charAt(i) == '0') ? '1' : '0';
    }

    return output_bits;
}

String split_and_swap_32_to_64(String input_block) {
    if (input_block.length() != 32) {
        // Throw an error or handle appropriately
    }

    String left_half = input_block.substring(0, 16);
    String right_half = input_block.substring(16, 32);
    String middle_element = NOT_gate_32_bit(input_block);
    String element_0 = right_half + middle_element + left_half;

    return element_0;
}

String substitution_box(String input_bits) {
    String s_box_output = "";
    for (int i = 0; i < input_bits.length(); i += 4) {
        String sub_chunk = input_bits.substring(i, i + 4);
        for (int j = 0; j < 16; j++) {
            if (sub_chunk.equals(substitution_box_dict[j])) {
                s_box_output += String(j, HEX);
                break;
            }
        }
    }
    return s_box_output;
}

String permute(String input_text) {
    String output_text = "";
    output_text += String(input_text.charAt(7));
    output_text += String(input_text.charAt(0));
    output_text += String(input_text.charAt(6));
    output_text += String(input_text.charAt(4));
    output_text += String(input_text.charAt(5));
    output_text += String(input_text.charAt(3));
    output_text += String(input_text.charAt(2));
    output_text += String(input_text.charAt(1));
    return output_text;
}

String permutation_box(String input_text) {
    String to_return = "";
    for (int i = 0; i < input_text.length(); i += 8) {
        String chunk = input_text.substring(i, i + 8);
        String to_add = permute(chunk);
        to_return += to_add;
    }
    return to_return;
}

// Decryption functions

String decrypt_decimal_to_string(long decimal_value) {
    String decrypted_string = "";
    String decimal_str = String(decimal_value);

    int i = 0;
    while (i < decimal_str.length()) {
        int ascii_value = decimal_str.substring(i, i + 3).toInt();  // Take 3 characters at a time (assuming ASCII values are 3 digits)
        decrypted_string += (char)ascii_value;  // Convert ASCII value to character
        i += 3;
    }
    return decrypted_string;
}

long multiple_128bit_binary_to_decimal(String binary_number) {
    binary_number.trim(); // Remove leading zeros added during padding

    // Convert binary number to decimal
    long decimal_number = strtol(binary_number.c_str(), NULL, 2);

    return decimal_number;
}

String reverse_expansion(String input_64) {
    String a = input_64.substring(16, 48);
    String b = NOT_gate_32_bit(a);
    return b;
}

String reverse_substitution_box(String input_bits) {
    String s_box_output = "";
    for (int i = 0; i < input_bits.length(); i++) {
        String sub_chunk = input_bits.substring(i, i + 1);
        s_box_output += reverse_substitution_box_dict[sub_chunk.toInt()];
    }
    return s_box_output;
}

String reverse_permute(String input_text) {
    String output_text = "";
    output_text += String(input_text.charAt(1));
    output_text += String(input_text.charAt(7));
    output_text += String(input_text.charAt(6));
    output_text += String(input_text.charAt(5));
    output_text += String(input_text.charAt(3));
    output_text += String(input_text.charAt(4));
    output_text += String(input_text.charAt(2));
    output_text += String(input_text.charAt(0));
    return output_text;
}

String reverse_permutation_box(String input_text) {
    String to_return = "";
    for (int i = 0; i < input_text.length(); i += 8) {
        String chunk = input_text.substring(i, i + 8);
        String to_add = reverse_permute(chunk);
        to_return += to_add;
    }
    return to_return;
}

// Encryption

String encryption(String input_text, String k1) {
    long a = encrypt_string_to_decimal(input_text);
    String b = decimal_to_multiple_128bit_binary(a);
    String* arr = split_binary_into_chunks(b, 32);
    String encrypted_text = "";
    for (int i = 0; i < sizeof(arr) / sizeof(arr[0]); i++) {
        String temp = split_and_swap_32_to_64(arr[i]);
        temp = xor_binary_strings(temp, k1);
        encrypted_text += temp;
    }
    encrypted_text = substitution_box(encrypted_text);
    encrypted_text = permutation_box(encrypted_text);
    return encrypted_text;
}

// Decryption

String decryption(String decrypted_text, String k1) {
    decrypted_text = reverse_permutation_box(decrypted_text);
    decrypted_text = reverse_substitution_box(decrypted_text);
    String* arr = split_binary_into_chunks(decrypted_text, 64);
    String decrypted_text2 = "";
    for (int i = 0; i < sizeof(arr) / sizeof(arr[0]); i++) {
        String temp = xor_binary_strings(arr[i], k1);
        temp = reverse_expansion(temp);
        decrypted_text2 += temp;
    }
    return decrypt_decimal_to_string(multiple_128bit_binary_to_decimal(decrypted_text2));
}

void setup() {
    Serial.begin(9600);

    // Define plain_text with predefined value
    String plain_text = "Your predefined text goes here";
    
    String key_1 = "1010101010101010101010101010101010101010101010101010101010101010";
    String encrypted_text = encryption(plain_text, key_1);
    Serial.println("Original text is:    " + plain_text);
    Serial.println("Original text is of " + String(plain_text.length()) + " bits");
    Serial.println("Key for encryption:  " + key_1);
    Serial.println("Key is of " + String(key_1.length()) + " bits");
    Serial.println("Encrypted text is:   " + encrypted_text);
    Serial.println("Encrypted text is of " + String(encrypted_text.length()) + " bits");
    String decrypted_text = decryption(encrypted_text, key_1);
    Serial.println("Decrypted text is:   " + decrypted_text);
    if (plain_text.equals(decrypted_text)) {
        Serial.println("Decryption successful and matched with input text");
    } else {
        Serial.println("Decrypted text does not match input text");
    }
}

void loop() {
    // Code in loop() is not needed for this purpose, you can leave it empty.
}