#include <Arduino.h>
#include <stdint.h>
#include <string.h>

const uint8_t mod25519_subtracter_32[32] = {0xED, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F};
uint8_t subtracter[64] = {0xED, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};

void mul32_mod25519(const uint8_t* a, const uint8_t* b, uint8_t* result);
void mod25519(uint8_t* result, const uint8_t* tempResult);
void mod25519_multiply_32(uint8_t* result, const uint8_t* a, uint8_t b, uint16_t& carry);
void mod25519_add_32(uint8_t* result, const uint8_t* a, uint16_t& carry);
void mod25519_sub_32(uint8_t* r, const uint8_t* a, const uint8_t* b, uint8_t* borrow);
void mod25519_exchange_32(uint8_t* r, const uint8_t* original, uint8_t borrow);

void fe25519_invert(uint8_t *result, const uint8_t *input)
{
  uint8_t x_exp_2[32], x_exp_11[32], x2_exp_10[32], x2_exp_50[32], x2_exp_100[32], x2_exp_200[32], temp0[32], temp1[32];
  unsigned char i;

  Serial.print("input: ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(input[i], HEX);
    Serial.print(" ");
  }
  Serial.println();

  /* Precomputation */
  mul32_mod25519(input, input, x_exp_2);           // x_exp_2 = input^2
  Serial.print("1. x_exp_2 = input^2 ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(x_exp_2[i], HEX);
    Serial.print(" ");
  }
  Serial.println();
    
  mul32_mod25519(x_exp_2, x_exp_2, temp1);         // temp1 = x_exp_2^2 = input^4
  Serial.print("2. temp1 = x_exp_2^2 = ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(temp1[i], HEX);
    Serial.print(" ");
  }
  Serial.println();
    
  mul32_mod25519(temp1, temp1, temp0);             // temp0 = temp1^2 = input^8
  Serial.print("3. temp0 = temp1^2 = ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(temp0[i], HEX); Serial.print(" ");
  }
  Serial.println();
    
  mul32_mod25519(temp0, input, x2_exp_10);           // x2_exp_10 = temp0 * input
  Serial.print("4. x2_exp_10 = temp0 * input = ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(x2_exp_10[i], HEX);
    Serial.print(" ");
  }
  Serial.println();
    
  mul32_mod25519(x2_exp_10, x_exp_2, x_exp_11);      // x_exp_11 = x2_exp_10 * x_exp_2
  Serial.print("5. x_exp_11 = x2_exp_10 * x_exp_2 = ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(x_exp_11[i], HEX);
    Serial.print(" ");
  }
  Serial.println();

  /* Compute higher powers */
  mul32_mod25519(x_exp_11, x_exp_11, temp0);       // temp0 = x_exp_11 ^ 2
  Serial.print("6. temp0 = x_exp_11 ^ 2 = ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(temp0[i], HEX); Serial.print(" ");
  }
  Serial.println();
    
  mul32_mod25519(temp0, x2_exp_10, x2_exp_10);       // x2_exp_10 = temp0 * x2_exp_10
  Serial.print("7. x2_exp_10 = temp0 * x2_exp_10 = ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(x2_exp_10[i], HEX); Serial.print(" ");
  }
  Serial.println();
    
  mul32_mod25519(x2_exp_10, x2_exp_10, temp0);     // temp0 = x2_exp_10 ^ 2
  Serial.print("8. temp0 = x2_exp_10^2 = ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(temp0[i], HEX); Serial.print(" ");
  }
  Serial.println();
    
  mul32_mod25519(temp0, temp0, temp1);             // temp1 = temp0^2
  Serial.print("9. temp1 = temp0^2 = ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(temp1[i], HEX); Serial.print(" ");
  }
  Serial.println();

  mul32_mod25519(temp1, temp1, temp0);     // temp0 = temp1^2
  Serial.print("10. temp0 = temp1^2 = ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(temp0[i], HEX); Serial.print(" ");
  }
  Serial.println();
  //----------------------------------
  mul32_mod25519(temp0, temp0, temp1);     // temp1 = temp0^2
  Serial.print("11. temp1 = temp0^2 = ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(temp1[i], HEX); Serial.print(" ");
  }
  Serial.println();
  //----------------------------------
    mul32_mod25519(temp1, temp1, temp0);     // temp0 = temp1^2
  Serial.print("12. temp0 = temp1^2 = ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(temp0[i], HEX); Serial.print(" ");
  }
  Serial.println();
  //--------------------------------
  mul32_mod25519(temp0, x2_exp_10, x2_exp_10);  // x2_exp_10 = temp0 * x2_exp_10
  Serial.print("13. x2_exp_10: ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(x2_exp_10[i], HEX); Serial.print(" ");
  }
  Serial.println();
  //----------------------------------
  mul32_mod25519(x2_exp_10, x2_exp_10, temp0);     // temp0 = x2_exp_10^2
  Serial.print("14. temp0 (x2_exp_10^2): ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(temp0[i], HEX); Serial.print(" ");
  }
  Serial.println();
  //----------------------------------
    mul32_mod25519(temp0, temp0, temp1);     // temp1 = temp0^2
  Serial.print("15. temp1 (temp0^2): ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(temp1[i], HEX);
    Serial.print(" ");
  }
  Serial.println();

  Serial.println(" =======================================> Start of the first for loop : ");
  for (i = 2; i < 10; i += 2)
  {
    mul32_mod25519(temp1, temp1, temp0);         // temp0 = temp1^2
    Serial.print("16. temp0 (iteration ");
    Serial.print(i);
    Serial.print("): ");
    for (int j = 0; j < 32; j++)
    {
      Serial.print(temp0[j], HEX); Serial.print(" ");
    }
    Serial.println();
        
    mul32_mod25519(temp0, temp0, temp1);         // temp1 = temp0^2
    Serial.print("16. temp1 (iteration ");
    Serial.print(i);
    Serial.print("): ");
    for (int j = 0; j < 32; j++)
    {
      Serial.print(temp1[j], HEX); Serial.print(" ");
    }
    Serial.println();
  }
  Serial.println(" =======================================> End of the first for loop : ");

  mul32_mod25519(temp1, x2_exp_10, x2_exp_50);     // x2_exp_50 = temp1 * x2_exp_10
  Serial.print("17. x2_exp_50: ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(x2_exp_50[i], HEX);
    Serial.print(" ");
  }
  Serial.println();

  mul32_mod25519(x2_exp_50, x2_exp_50, temp0);     // temp0 = x2_exp_50^2
  Serial.print("18. temp0 = x2_exp_50^2 = ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(temp0[i], HEX);
    Serial.print(" ");
  }
  Serial.println();

  mul32_mod25519(temp0, temp0, temp1);     // temp1 = temp0^2
  Serial.print("19. temp1 (temp0^2): ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(temp1[i], HEX);
    Serial.print(" ");
  }
  Serial.println();
  
  Serial.println(" =======================================> Start of the Second for loop : ");
  for (i = 2; i < 20; i += 2)
  {
    mul32_mod25519(temp1, temp1, temp0);
    Serial.print("20. temp0 (iteration ");
    Serial.print(i);
    Serial.print("): ");
    for (int j = 0; j < 32; j++)
    {
      Serial.print(temp0[j], HEX); Serial.print(" ");
    }
    Serial.println();
        
    mul32_mod25519(temp0, temp0, temp1);
    Serial.print("20. temp1 (iteration ");
    Serial.print(i);
    Serial.print("): ");
    for (int j = 0; j < 32; j++)
    {
      Serial.print(temp1[j], HEX);
      Serial.print(" ");
    }
    Serial.println();
  }
  Serial.println(" =======================================> End of the Second for loop ========= ");
  
  mul32_mod25519(temp1, x2_exp_50, temp0);    // temp0 = temp1 * x2_exp_50
  Serial.print("21. x2_exp_100: ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(temp0[i], HEX);
    Serial.print(" ");
  }
  Serial.println();

  mul32_mod25519(temp0, temp0, temp1);     // temp1 = temp0^2
  Serial.print("22. temp1 = temp0^2 = ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(temp1[i], HEX);
    Serial.print(" ");
  }
  Serial.println();

  mul32_mod25519(temp1, temp1, temp0);     // temp1 = temp0^2
  Serial.print("23. temp0 = temp1^2 = ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(temp0[i], HEX);
    Serial.print(" ");
  }
  Serial.println();  

  Serial.println(" =======================================> Start of the third for loop : ");

  for (i = 2; i < 10; i += 2)
  {
    mul32_mod25519(temp0, temp0, temp1);
    Serial.print("24. temp1 (iteration ");
    Serial.print(i);
    Serial.print("): ");
    for (int j = 0; j < 32; j++)
    {
      Serial.print(temp1[j], HEX); Serial.print(" ");
    }
    Serial.println();

    mul32_mod25519(temp1, temp1, temp0);
    Serial.print("24. temp0 (iteration ");
    Serial.print(i);
    Serial.print("): ");
    for (int j = 0; j < 32; j++)
    {
      Serial.print(temp0[j], HEX); Serial.print(" ");
    }
    Serial.println();
  }
  Serial.println(" =======================================> End of the third for loop ========= ");

  mul32_mod25519(temp0, x2_exp_10, x2_exp_50);   // x2_exp_50 = temp0 * x2_exp_10
  Serial.print("25. x2_exp_50 = temp0 * x2_exp_10 = ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(x2_exp_50[i], HEX);
    Serial.print(" "); 
  }
  Serial.println();
  // --------------------------- we reached here -------------------
    mul32_mod25519(x2_exp_50, x2_exp_50, temp0);   // temp0 = x2_exp_50 ^ 2
  Serial.print("26. temp0 = x2_exp_10 ^ 2 = ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(temp0[i], HEX);
    Serial.print(" "); 
  }
  Serial.println();
  // --------------------------- we reached here --------------------
  // Final steps to reach 2^255 - 21
  for (i = 2; i < 50; i += 2)
  {
    mul32_mod25519(temp1, temp1, temp0);
    Serial.print("temp0 (iteration ");
    Serial.print(i);
    Serial.print("): ");
    for (int j = 0; j < 32; j++)
    {
      Serial.print(temp0[j], HEX); Serial.print(" ");
    }
    Serial.println();
        
    mul32_mod25519(temp0, temp0, temp1);
    Serial.print("temp1 (iteration ");
    Serial.print(i + 1);
    Serial.print("): ");
    for (int j = 0; j < 32; j++)
    {
      Serial.print(temp1[j], HEX); Serial.print(" ");
    }
    Serial.println();
  }
  mul32_mod25519(temp1, x2_exp_50, temp0);         // temp0 = temp1 * x2_exp_50 = input^(2^250 - 1)
  Serial.print("temp0 (input^(2^250 - 1)): ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(temp0[i], HEX); Serial.print(" ");
  }
  Serial.println();

  // Final multiplications
  mul32_mod25519(temp0, temp0, temp1);
  Serial.print("temp1 (final stage 1): ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(temp1[i], HEX); Serial.print(" ");
  }
  Serial.println();

  mul32_mod25519(temp1, temp1, temp0);
  Serial.print("temp0 (final stage 2): ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(temp0[i], HEX); Serial.print(" ");
  }
  Serial.println();

  mul32_mod25519(temp0, temp0, temp1);
  Serial.print("temp1 (final stage 3): ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(temp1[i], HEX); Serial.print(" ");
  }
  Serial.println();

  mul32_mod25519(temp1, temp1, temp0);
  Serial.print("temp0 (final stage 4): ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(temp0[i], HEX); Serial.print(" ");
  }
  Serial.println();

  mul32_mod25519(temp0, temp0, temp1);
  Serial.print("temp1 (final stage 5): ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(temp1[i], HEX); Serial.print(" ");
  }
  Serial.println();

  mul32_mod25519(temp1, x_exp_11, result);         // result = temp1 * x_exp_11 = input^(2^255 - 21)
  Serial.print("result: ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(result[i], HEX); Serial.print(" ");
  }
  Serial.println();
}

// Multiplication of two 32-byte numbers followed by modulus
void mul32_mod25519(const uint8_t* a, const uint8_t* b, uint8_t* result)
{
    int carry = 0;
    uint8_t tempResult[64] = {0}; // Temporary result array to store 512-bit number after multiplication

    for (int i = 0; i < 32; ++i)
    {
        carry = 0;
        for (int j = 0; j < 32; ++j) {
            int index = i + j;
            long product = (long)a[i] * (long)b[j] + (long)tempResult[index] + (long)carry;
            tempResult[index] = product & 0xFF;
            carry = (product >> 8) & 0xFF;
        }
        tempResult[i + 32] = (carry + tempResult[i + 32]) & 0xFF;
    }
    uint8_t Result_mod[32] = {0};
    mod25519(Result_mod, tempResult);

    memcpy(result, Result_mod, 32);
}

//-----------------------Modulo-mod25519 function------------
void mod25519(uint8_t* result, const uint8_t* tempResult)
{
    uint8_t high[32];
    uint8_t low[32];
    uint8_t mod25519_originalResult[32];
    uint8_t borrow = 0;
    uint16_t carry = 0;

    // Divide tempResult into high and low parts
    memcpy(high, tempResult + 32, 32);
    memcpy(low, tempResult, 32);

    // Multiply high part by 38
    mod25519_multiply_32(result, high, 38, carry);

    // Propagate carry
    for (int k = 0; k < 2; k++)
    {
        uint32_t carryProduct = carry * 38;
        uint32_t tempSum = 0;

        for (int i = 0; i < 32; i++)
        {
            tempSum += result[i] + (carryProduct & 0xFF);
            result[i] = tempSum & 0xFF;
            carryProduct >>= 8;
            tempSum >>= 8;
        }
        carry = tempSum;
    }

    // Add low part
    mod25519_add_32(result, low, carry);

    // Final carry propagation
    uint32_t carryProduct = carry * 38;
    carry = 0;
    for (int i = 0; i < 32; i++) {
        uint32_t sum = result[i] + (carryProduct & 0xFF) + carry;
        result[i] = sum & 0xFF;
        carryProduct >>= 8;
        carry = sum >> 8;
    }

    // Save original result
    memcpy(mod25519_originalResult, result, 32);
    // Perform subtraction and exchange to ensure result is within range
    mod25519_sub_32(result, result, mod25519_subtracter_32, &borrow);
    mod25519_exchange_32(result, mod25519_originalResult, borrow);

    memcpy(mod25519_originalResult, result, 32);
    mod25519_sub_32(result, result, mod25519_subtracter_32, &borrow);
    mod25519_exchange_32(result, mod25519_originalResult, borrow);
}

//------------------------------
// Multiply the high part by 38
void mod25519_multiply_32(uint8_t* result, const uint8_t* a, uint8_t b, uint16_t& carry)
{
    for (int i = 0; i < 32; i++)
    {
        uint16_t mul = a[i] * b + carry;
        result[i] = mul & 0xFF;
        carry = mul >> 8;
    }
}

//--------------------------
void mod25519_add_32(uint8_t* result, const uint8_t* a, uint16_t& carry) {
    for (int i = 0; i < 32; i++) {
        uint32_t sum = result[i] + a[i] + carry;
        result[i] = sum & 0xFF;
        carry = sum >> 8;
    }
}

//--------------------------
void mod25519_sub_32(uint8_t* r, const uint8_t* a, const uint8_t* b, uint8_t* borrow)
{
    *borrow = 0;
    for (int i = 0; i < 32; i++)
    {
        uint16_t temp = (uint16_t)a[i] - b[i] - *borrow;
        r[i] = temp & 0xFF;
        *borrow = (temp >> 8) & 1;
    }
}

//-------------------------
void mod25519_exchange_32(uint8_t* r, const uint8_t* original, uint8_t borrow)
{
    uint8_t mask = -borrow;
    uint8_t inv_mask = ~mask;
    for (int i = 0; i < 32; i++) {
        r[i] = (r[i] & inv_mask) | (original[i] & mask);
    }
}

void setup()
{
  Serial.begin(9600);
  uint8_t binaryArray[32] = {0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
                             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
/*uint8_t binaryArray[32] = {
    0x11, 0xa0, 0x43, 0xd6, 0x23, 0xb2, 0x84, 0xff,
    0x8d, 0xf0, 0x4a, 0x6d, 0xeb, 0x9d, 0xb5, 0x41,
    0x32, 0xe9, 0x6a, 0x35, 0xdc, 0xb3, 0x2f, 0xaf,
    0x95, 0x51, 0xbd, 0x32, 0x05, 0xea, 0xb2, 0x41
};*/


  uint8_t input[32], result[32];
  memcpy(input, binaryArray, 32);

  fe25519_invert(result, input);

  Serial.print("Final result: ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(result[i], HEX);
    Serial.print(" ");
  }
  Serial.println();
}

void loop()
{
  // Nothing to do here
}