#include <Arduino.h>

// Function to compute the greatest common divisor using the Euclidean algorithm
int gcd(int a, int b) {
    while (b != 0) {
        int temp = b;
        b = a % b;
        a = temp;
    }
    return a;
}

// Function to compute the modular multiplicative inverse
int modInverse(int e, int phi) {
    for (int x = 1; x < phi; x++) {
        if ((e * x) % phi == 1) {
            return x;
        }
    }
    return -1; // No modular inverse found
}

// Function to generate RSA public and private keys
void generateKeys(int p, int q, int &e, int &d, int &n) {
    n = p * q;
    int phi = (p - 1) * (q - 1);

    // Choose e such that e and phi(n) are coprime
    e = 2;
    while (e < phi) {
        if (gcd(e, phi) == 1) {
            break;
        }
        e++;
    }

    // Compute d such that d*e ≡ 1 (mod phi)
    d = modInverse(e, phi);
}

// Function to encrypt a message using the public key
void encrypt(int e, int n, const char *plaintext, int *ciphertext, int length) {
    for (int i = 0; i < length; i++) {
        ciphertext[i] = 1;
        int base = (int)plaintext[i];
        for (int j = 0; j < e; j++) {
            ciphertext[i] = (ciphertext[i] * base) % n;
        }
    }
}

// Function to decrypt a message using the private key
void decrypt(int d, int n, const int *ciphertext, char *plaintext, int length) {
    for (int i = 0; i < length; i++) {
        int decryptedValue = 1;
        for (int j = 0; j < d; j++) {
            decryptedValue = (decryptedValue * ciphertext[i]) % n;
        }
        plaintext[i] = (char)decryptedValue;
    }
    plaintext[length] = '\0'; // Null-terminate the string
}

// Function to measure execution time
unsigned long execuTime(void (*func)(int, int, const char*, int*, int), int e, int n, const char* plaintext, int* ciphertext, int length) {
    unsigned long startTime = micros();
    func(e, n, plaintext, ciphertext, length);
    unsigned long endTime = micros();
    return endTime - startTime;
}

// Overloaded function to measure execution time for decryption
unsigned long execuTime(void (*func)(int, int, const int*, char*, int), int d, int n, const int* ciphertext, char* plaintext, int length) {
    unsigned long startTime = micros();
    func(d, n, ciphertext, plaintext, length);
    unsigned long endTime = micros();
    return endTime - startTime;
}

void setup() {
    Serial.begin(115200);
    Serial.println();
    Serial.println();
    Serial.println();

    int p = 61; // Example prime number
    int q = 53; // Example prime number
    int e, d, n;
    
    // Generate RSA keys
    generateKeys(p, q, e, d, n);

    // Print the public and private keys
    Serial.print("Public Key (e, n): (");
    Serial.print(e);
    Serial.print(", ");
    Serial.print(n);
    Serial.println(")");

    Serial.print("Private Key (d, n): (");
    Serial.print(d);
    Serial.print(", ");
    Serial.print(n);
    Serial.println(")");
    Serial.println();

    const char *message = "hello GTA world";
    int messageLength = strlen(message);
    int ciphertext[messageLength];

    // Measure execution time for encryption
    unsigned long encryptionTime = execuTime(encrypt, e, n, message, ciphertext, messageLength);
    Serial.print("Encrypted message: ");
    for (int i = 0; i < messageLength; i++) {
        Serial.print(ciphertext[i]);
        Serial.print(" ");
    }
    Serial.println();
    Serial.print("Encryption time: ");
    Serial.print(encryptionTime);
    Serial.println(" µs");
    Serial.println();

    // Measure execution time for decryption
    char decryptedMessage[messageLength + 1];
    unsigned long decryptionTime = execuTime(decrypt, d, n, ciphertext, decryptedMessage, messageLength);
    Serial.print("Decrypted message: ");
    Serial.println(decryptedMessage);
    Serial.print("Decryption time: ");
    Serial.print(decryptionTime);
    Serial.println(" µs");
}

void loop() {
    // No repeated actions required
}