#include <Arduino.h>
#include <stdint.h>
#include <string.h>

const uint8_t mod25519_subtracter_32[32] = {0xED, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F};
uint8_t subtracter[64] = {0xED, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};

void mul32_mod25519(const uint8_t* a, const uint8_t* b, uint8_t* result);
void mod25519(uint8_t* result, const uint8_t* tempResult);
void mod25519_multiply_32(uint8_t* result, const uint8_t* a, uint8_t b, uint16_t& carry);
void mod25519_add_32(uint8_t* result, const uint8_t* a, uint16_t& carry);
void mod25519_sub_32(uint8_t* r, const uint8_t* a, const uint8_t* b, uint8_t* borrow);
void mod25519_exchange_32(uint8_t* r, const uint8_t* original, uint8_t borrow);

void fe25519_invert(uint8_t *result, const uint8_t *input)
{
  uint8_t x_exp_2[32], x_exp_9[32], x_exp_11[32], x2_exp_10[32], x2_exp_50[32], x2_exp_100[32], x2_exp_200[32], temp0[32], temp1[32];
  unsigned char i;

  /* Precomputation */
  mul32_mod25519(input, input, x_exp_2);           // x_exp_2 = input^2
  mul32_mod25519(x_exp_2, x_exp_2, temp1);         // temp1 = x_exp_2^2 = input^4
  mul32_mod25519(temp1, temp1, temp0);             // temp0 = temp1^2 = input^8
  mul32_mod25519(temp0, input, x_exp_9);           // x_exp_9 = temp0 * input = input^9
  mul32_mod25519(x_exp_9, x_exp_2, x_exp_11);      // x_exp_11 = x_exp_9 * x_exp_2 = input^11

  /* Compute higher powers */
  mul32_mod25519(x_exp_11, x_exp_11, temp0);       // temp0 = x_exp_11^2 = input^22
  mul32_mod25519(temp0, x_exp_9, x2_exp_10);       // x2_exp_10 = temp0 * x_exp_9 = input^(2^10 - 1)
  mul32_mod25519(x2_exp_10, x2_exp_10, temp0);     // temp0 = x2_exp_10^2 = input^(2^11 - 2)
  mul32_mod25519(temp0, temp0, temp1);             // temp1 = temp0^2 = input^(2^12 - 4)

  for (i = 2; i < 10; i += 2)
  {
    mul32_mod25519(temp1, temp1, temp0);         // temp0 = temp1^2
    mul32_mod25519(temp0, temp0, temp1);         // temp1 = temp0^2
  }
  mul32_mod25519(temp1, x2_exp_10, x2_exp_50);     // x2_exp_50 = temp1 * x2_exp_10 = input^(2^50 - 1)

  /* Further combinations */
  // Compute powers up to 2^100 - 1
  for (i = 2; i < 50; i += 2)
  {
    mul32_mod25519(temp1, temp1, temp0);
    mul32_mod25519(temp0, temp0, temp1);
  }
  mul32_mod25519(temp1, x2_exp_50, x2_exp_100);    // x2_exp_100 = temp1 * x2_exp_50 = input^(2^100 - 1)

  // Continue to 2^200 - 1
  for (i = 2; i < 100; i += 2)
  {
    mul32_mod25519(temp1, temp1, temp0);
    mul32_mod25519(temp0, temp0, temp1);
  }
  mul32_mod25519(temp1, x2_exp_100, x2_exp_200);   // x2_exp_200 = temp1 * x2_exp_100 = input^(2^200 - 1)

  // Final steps to reach 2^255 - 21
  for (i = 2; i < 50; i += 2)
    {
      mul32_mod25519(temp1, temp1, temp0);
      mul32_mod25519(temp0, temp0, temp1);
    }
  mul32_mod25519(temp1, x2_exp_50, temp0);         // temp0 = temp1 * x2_exp_50 = input^(2^250 - 1)

  // Final multiplications
  mul32_mod25519(temp0, temp0, temp1);
  mul32_mod25519(temp1, temp1, temp0);
  mul32_mod25519(temp0, temp0, temp1);
  mul32_mod25519(temp1, temp1, temp0);
  mul32_mod25519(temp0, temp0, temp1);
  mul32_mod25519(temp1, x_exp_11, result);         // result = temp1 * x_exp_11 = input^(2^255 - 21)
}

// Multiplication of two 32-byte numbers followed by modulus
void mul32_mod25519(const uint8_t* a, const uint8_t* b, uint8_t* result)
{
  int carry = 0;
  uint8_t tempResult[64] = {0}; // Temporary result array to store 512-bit number after multiplication

  for (int i = 0; i < 32; ++i)
  {
    carry = 0;
    for (int j = 0; j < 32; ++j)
    {
      int index = i + j;
      long product = (long)a[i] * (long)b[j] + (long)tempResult[index] + (long)carry;
      tempResult[index] = product & 0xFF;
      carry = (product >> 8) & 0xFF;
    }
    tempResult[i + 32] = (carry + tempResult[i + 32]) & 0xFF;
  }
  uint8_t Result_mod[32] = {0};
  mod25519(Result_mod, tempResult);

  memcpy(result, Result_mod, 32);
}

//-----------------------Modulo-mod25519 function------------
void mod25519(uint8_t* result, const uint8_t* tempResult)
{
  uint8_t high[32];
  uint8_t low[32];
  uint8_t mod25519_originalResult[32];
  uint8_t borrow = 0;
  uint16_t carry = 0;

  // Divide tempResult into high and low parts
  memcpy(high, tempResult + 32, 32);
  memcpy(low, tempResult, 32);

  // Multiply high part by 38
  mod25519_multiply_32(result, high, 38, carry);

  // Propagate carry
  for (int k = 0; k < 2; k++)
  {
    uint32_t carryProduct = carry * 38;
    uint32_t tempSum = 0;

    for (int i = 0; i < 32; i++)
    {
      tempSum += result[i] + (carryProduct & 0xFF);
      result[i] = tempSum & 0xFF;
      carryProduct >>= 8;
      tempSum >>= 8;
    }
    carry = tempSum;
  }

  // Add low part
  mod25519_add_32(result, low, carry);

  // Final carry propagation
  uint32_t carryProduct = carry * 38;
  carry = 0;
  for (int i = 0; i < 32; i++)
    {
      uint32_t sum = result[i] + (carryProduct & 0xFF) + carry;
      result[i] = sum & 0xFF;
      carryProduct >>= 8;
      carry = sum >> 8;
    }

  // Save original result
  memcpy(mod25519_originalResult, result, 32);
  // Perform subtraction and exchange to ensure result is within range
  mod25519_sub_32(result, result, mod25519_subtracter_32, &borrow);
  mod25519_exchange_32(result, mod25519_originalResult, borrow);

  memcpy(mod25519_originalResult, result, 32);
  mod25519_sub_32(result, result, mod25519_subtracter_32, &borrow);
  mod25519_exchange_32(result, mod25519_originalResult, borrow);
}

//------------------------------
// Multiply the high part by 38
void mod25519_multiply_32(uint8_t* result, const uint8_t* a, uint8_t b, uint16_t& carry)
{
  for (int i = 0; i < 32; i++)
  {
    uint16_t mul = a[i] * b + carry;
    result[i] = mul & 0xFF;
    carry = mul >> 8;
  }
}

//--------------------------
void mod25519_add_32(uint8_t* result, const uint8_t* a, uint16_t& carry)
{
  for (int i = 0; i < 32; i++)
  {
    uint32_t sum = result[i] + a[i] + carry;
    result[i] = sum & 0xFF;
    carry = sum >> 8;
  }
}

//--------------------------
void mod25519_sub_32(uint8_t* r, const uint8_t* a, const uint8_t* b, uint8_t* borrow)
{
  *borrow = 0;
  for (int i = 0; i < 32; i++)
  {
    uint16_t temp = (uint16_t)a[i] - b[i] - *borrow;
    r[i] = temp & 0xFF;
    *borrow = (temp >> 8) & 1;
  }
}

//-------------------------
void mod25519_exchange_32(uint8_t* r, const uint8_t* original, uint8_t borrow)
{
    uint8_t mask = -borrow;
    uint8_t inv_mask = ~mask;
    for (int i = 0; i < 32; i++)
    {
      r[i] = (r[i] & inv_mask) | (original[i] & mask);
    }
}

void setup()
{
  Serial.begin(9600);

  uint8_t binaryArray[32] = {0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
                             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};

  uint8_t input[32], result[32];
  memcpy(input, binaryArray, 32);

  fe25519_invert(result, input);

  Serial.print("Final result: ");
  for (int i = 0; i < 32; ++i)
  {
    Serial.print(result[i], HEX);
  }
  Serial.println();
}

void loop()
{
  // Nothing to do here
}