/*The code computes the multiplicative inverse of a number (x) in a finite field, based on Fermat's Little Theorem. According to this theorem, if (p) is a prime number, then for any integer (a) that is not divisible by (p):
--------------------------------------------------> a^{p-1} = 1 (mod p)
It follows that:
--------------------------------------------------> a * a^{p-2} = 1 (mod p)
Thus, (a^{p-2}) is the multiplicative inverse of (a) modulo (p). 

Note : The code employs an efficient algorithm known as the **exponentiation by squaring** method to find a^{p-2} .*/
#include <Arduino.h>

// Implement modular multiplication followed by reduction
void mul32_mod25519(const uint8_t* a, const uint8_t* b, uint8_t* result)
{
  uint8_t tempResult[64] = {0}; // Temporary result array to store 512-bit number after multiplication
  uint8_t Result_mod[64] = {0}; 

  for (int k = 0; k < 32; ++k)
  {
    int carry = 0;
    for (int j = 0; j < 32; ++j)
    {
      int index = k + j;
      long product = (long)a[k] * (long)b[j] + (long)tempResult[index] + (long)carry;
      tempResult[index] = product & 0xFF;  // Store the lowest byte
      carry = (product >> 8) & 0xFF; // Carry over the overflow
    }
    tempResult[k + 32] = (tempResult[k + 32] + carry) & 0xFF; // Store final carry, add and modulo 256
  }

  mod25519(Result_mod, tempResult);

  // Copy the first 32 bytes of the temporary result to the final result array
  memcpy(result, Result_mod, 32); // Correctly copying from tempResult to result

}

// Modular reduction by 25519
void mod25519(uint8_t *result, uint8_t *a)
{
  uint8_t borrow = 0;
  sub32(result, a, &borrow);
  exchange25519(result, a, borrow);
    
  borrow = 0;
  memcpy(a, result, 32);
  sub32(result, a, &borrow);
  exchange25519(result, a, borrow);
}

// Subtraction for modulus operation
void sub32(uint8_t *r, const uint8_t *a, uint8_t *borrow)
{
  uint8_t modulus[32] = {0}; // Define your modulus here correctly
  modulus[0] = 17; // Example modulus value

  uint16_t temp = 0;
  *borrow = 0;
  for (int j = 0; j < 32; j++)
  {
    temp = (uint16_t)a[j] - modulus[j] - *borrow;
    r[j] = temp & 0xFF; // Store the result
    *borrow = (temp >> 8) & 1; // Update borrow
  }
}

// Conditional exchange based on borrow
void exchange25519(uint8_t *r, const uint8_t *a, uint8_t borrow)
{
  uint8_t mask = -borrow;
  for (int j = 31; j >= 0; j--)
  {
    r[j] ^= mask & (a[j] ^ r[j]);
  }
}

// Exponentiation function
void exponent(uint8_t* Z0, uint8_t* exp)
{
  uint8_t Inv_temp, Inv_carry = 0, current_bit[32] = {0};
  uint8_t complement_current_bit[32] = {0}, Inv_complement_current_bit [32] = {0};

  uint8_t result[32] = {1}; // Initial result
  uint8_t temp[32];
  uint8_t tempZ0[32]; 

  for (int i = 255; i >=0 ; i--)
  {
    current_bit[0] = exp[0] & 0x01; 
    // -------------- print current bit
    Serial.print("current_bit = ");
    for (int j = 31; j >= 0; j--)
    {
      Serial.print(current_bit[j], BIN);
      Serial.print(" ");
    }
    Serial.println(" ");
    // --------------------calculate complement_current_bit and Inv_complement_current_bit
    for (int j = 31; j >= 0; j--)
    {
      complement_current_bit [j] = -current_bit[0];
      Inv_complement_current_bit [j] = ~complement_current_bit [j];
    }
    //--------- print complement_current_bit
    Serial.print("complement_current_bit =");
    for (int j = 31; j >= 0; j--)
    {
      Serial.print(complement_current_bit[j], BIN);
      Serial.print(" ");
    }
    Serial.println(" ");
    //--------- print Inv_complement_current_bit
    Serial.print("Inv_complement_current_bit =");
    for (int j = 31; j >= 0; j--)
    {
      Serial.print(Inv_complement_current_bit[j], BIN);
      Serial.print(" ");
    }
    Serial.println(" ");
    //------------------shift the exponent to get current bit in the next iteration
    Inv_carry = 0;
    for (int j = 31; j >= 0; j--)
    {
      Inv_temp = exp[j];
      exp[j] = (Inv_carry << 7) | (exp[j] >> 1);
      Inv_carry = Inv_temp;
    }
    //------------------print the shifted exponent
    Serial.print("At the [");
    Serial.print(i);
    Serial.print("] iteration, Shifted exp = ");
    for (int j = 31; j >= 0; j--)
    {
      Serial.print(exp[j], BIN);
      Serial.print(" ");
    }
    Serial.println();
    Serial.println("------------------------------");
    Serial.println();
      
    // ------------------------ result = result * (Z0 & complement_current_bit) | (result & Inv_complement_current_bit)
    //-------------------------------------------------------------------------------------------------------------------------
  
    //------------------------- temp = Z0 & complement_current_bit --------------------------

    //------------------------- Print Z0 --------------------------
    Serial.print("-------------> At the [");
    Serial.print(i);
    Serial.print("] Iteration, Z0 = ");
    for (int j = 31; j >= 0; j--)
    {
      Serial.print(Z0[j], BIN);
      Serial.print(" ");
    }
    Serial.println();
    //------------------------- Print Result --------------------------
    Serial.print("result[j] = ");
    for (int j = 31; j >= 0; j--)
    {
      Serial.print(result[j], HEX);
      Serial.print(" ");
    }
    Serial.println();

    //------------------------- Calculate temp[j] --------------------------
    for (int j = 31; j >=0; j--)
    {
      temp[j] = Z0[j] & complement_current_bit[j];
    }
    //------------------------- print temp = Z0[j] & complement_current_bit[j] --------------------------
    Serial.print("temp = Z0 & complement_current_bit = ");
    for (int j = 31; j >= 0; j--)
    {
      Serial.print(temp[j], HEX);
      Serial.print(" ");
    }
    Serial.println();

    //------------------------- temp = result * temp = result * (Z0 & complement_current_bit) --------------------------
    mul32_mod25519(result, temp, temp); // temp = Z0 * temp
    Serial.print("temp = result * temp = result * (Z0 & complement_current_bit) = ");
    for (int j = 31; j >= 0; j--)
    {
      Serial.print(temp[j], HEX);
      Serial.print(" ");
    }
    Serial.println();
    // ------------------------------------- result = Result & Inv_complement_current_bit---------------
    for (int j = 31; j >=0; j--)
    {
      result[j] = result[j] & Inv_complement_current_bit[j];
    }
    Serial.print("result = Result & Inv_complement_current_bit = ");
    for (int j = 31; j >= 0; j--)
    {
      Serial.print(result[j], HEX);
      Serial.print(" ");
    }
    Serial.println();
    //--------------------------------------------------------------
    for (int j = 31; j >=0; j--)
    {
    result[j] |= temp[j];
    }
    //--------------------------------------------------------------
    Serial.print("At the [");
    Serial.print(i);
    Serial.print("] Iteration, result[j] |= temp[j] =  result * (Z0 & complement_current_bit) | (result & Inv_complement_current_bit) =");
    for (int j = 31; j >= 0; j--)
    {
      Serial.print(result[j], HEX);
      Serial.print(" ");
    }
    Serial.println();

    mul32_mod25519(Z0, Z0, tempZ0); // tempZ0 = Z0 * Z0
    memcpy(Z0, tempZ0, 32); // Copy tempZ0 to Z0

    Serial.print("Z0^2 = ");
    for (int j = 31; j >= 0; j--)
    {
      Serial.print(Z0[j], HEX);
      Serial.print(" ");
    }
    Serial.println();
    Serial.println("===================================================================================================================");
  }
  //-------------------------------------------------------I think the that the Final Result shouldn't be here ----------------------------------------------
  // Output the result
  Serial.print("Final Result of Z0^(p): ");
  for (int j = 31; j >= 0; j--)
  {
    Serial.print(result[j], HEX);
    Serial.print(" ");
  }
  Serial.println();
}

void setup()
{
  Serial.begin(9600);

  uint8_t Z0[32] = {0};
  Z0[0] = 2;
  // ----------------------------- print Z0-------------------
  Serial.print("Z0 = ");
  for (int j = 31; j >= 0; j--)
  {
    Serial.print(Z0[j], HEX);
    Serial.print(" ");
  }
  Serial.println();

  //-------------------------- print exp ---------------------
  uint8_t exp[32] = {0};
  exp[0] = 5;
  Serial.print("exp = ");
  for (int j = 31; j >= 0; j--)
  {
    Serial.print(exp[j], BIN);
    Serial.print(" ");
  }
  Serial.println(" ");
  Serial.println("==============================================>");
  //--------------------------- Call exp function--------------------
  exponent(Z0, exp);
}

void loop()
{
  // Nothing to do here
}