#include <Arduino.h>
#include <avr/io.h>
#include <stdlib.h>

// Define constants and function prototypes
#define fe25519_add avrnacl_fe25519_add
#define fe25519_sub avrnacl_fe25519_sub
#define fe25519_red avrnacl_fe25519_red

// Define constants and function prototypes
typedef struct {unsigned char v[32];} fe25519;

extern "C"
{
  void fe25519_sub(fe25519 *r, const fe25519 *x, const fe25519 *y);
  void fe25519_add(fe25519 *r, const fe25519 *x, const fe25519 *y);
  void fe25519_red(fe25519 *r, unsigned char *C);
  char bigint_subp(unsigned char* r, const unsigned char* a);
  char bigint_square256(unsigned char* r, const unsigned char* a);
  char bigint_mul256(unsigned char* r, const unsigned char* a, const unsigned char* b);
  void bigint_mul121666(unsigned char *r, const unsigned char *x);
}

void fe25519_freeze(fe25519 *r);
void fe25519_unpack(fe25519 *r, const unsigned char x[32]);
void fe25519_pack(unsigned char r[32], const fe25519 *x);
void fe25519_cmov(fe25519 *r, const fe25519 *x, unsigned char b);
void fe25519_setone(fe25519 *r);
void fe25519_setzero(fe25519 *r);
void fe25519_mul(fe25519 *r, const fe25519 *x, const fe25519 *y);
void fe25519_square(fe25519 *r, const fe25519 *x);
void fe25519_invert(fe25519 *r, const fe25519 *x);
void work_cswap(fe25519 *work, char b);
void mladder(fe25519 *xr, fe25519 *zr, const unsigned char s[32]);
void fe25519_mul121666(fe25519 *r, const fe25519 *x);

int crypto_scalarmult_curve25519(unsigned char *r, const unsigned char *s, const unsigned char *p);

void setup() {
  Serial.begin(9600);

  // Initialize zr with the given value
  fe25519 zr = {{0x41, 0xB2, 0xEA, 0x05, 0x32, 0xBD, 0x51, 0x95, 0xAF, 0x2F, 0xB3, 0xDC, 0x35, 0x6A, 0xE9, 0x32, 0x41, 0xB5, 0x9D, 0xEB, 0x6D, 0x4A, 0xF0, 0x8D, 0xFF, 0x84, 0xB2, 0x23, 0xD6, 0x43, 0xA0, 0x11}};
  Serial.print("zr: ");
  for (int i = 0; i < 32; i++) {
    Serial.print(zr.v[i], HEX);
    Serial.print(" ");
  }
  Serial.println();

  // Initialize a variable to hold the inverse of zr
  fe25519 zr_inv;
  
  // Compute the multiplicative inverse of zr
  fe25519_invert(&zr_inv, &zr);
  
  // Print the inverse
  Serial.println(" ===============> ===============> ===============> ===============> Multiplicative inverse of zr: ");
  for (int i = 0; i < 32; i++) {
    Serial.print(zr_inv.v[i], HEX);
    Serial.print(" ");
  }
  Serial.println();
}

void loop() {
  // Empty loop
}

void fe25519_invert(fe25519 *r, const fe25519 *x)
{
	fe25519 z2;
	fe25519 z11;
	fe25519 z2_10_0;
	fe25519 z2_50_0;
	fe25519 z2_100_0;
	fe25519 t0;
	fe25519 t1;
	unsigned char i;
  Serial.println(" 1 -------------> Inside the fe25519_invert function, at the (/* 2 */ fe25519_square(&z2,x);) step ");
	/* 2 */ fe25519_square(&z2,x);
  Serial.println(" 2 -------------> Inside the fe25519_invert function, at the (/* 4 */ fe25519_square(&t1,&z2);) step ");
	/* 4 */ fe25519_square(&t1,&z2);
  Serial.println(" 3 -------------> Inside the fe25519_invert function, at the (/* 8 */ fe25519_square(&t0,&t1);) step ");
	/* 8 */ fe25519_square(&t0,&t1);
  Serial.println(" 4 -------------> Inside the fe25519_invert function, at the (/* 9 */ fe25519_mul(&z2_10_0,&t0,x);) step ");
	/* 9 */ fe25519_mul(&z2_10_0,&t0,x);
  Serial.println(" 5 -------------> Inside the fe25519_invert function, at the (/* 11 */ fe25519_mul(&z11,&z2_10_0,&z2);) step ");
	/* 11 */ fe25519_mul(&z11,&z2_10_0,&z2);
  Serial.println(" 6 -------------> Inside the fe25519_invert function, at the (/* 22 */ fe25519_square(&t0,&z11);) step ");
	/* 22 */ fe25519_square(&t0,&z11);
  Serial.println(" 7 -------------> Inside the fe25519_invert function, at the (/* 2^5 - 2^0 = 31 */ fe25519_mul(&z2_10_0,&t0,&z2_10_0);) step ");
	/* 2^5 - 2^0 = 31 */ fe25519_mul(&z2_10_0,&t0,&z2_10_0);
  
  Serial.println(" 8 -------------> Inside the fe25519_invert function, at the (/* 2^6 - 2^1 */ fe25519_square(&t0,&z2_10_0);) step ");
	/* 2^6 - 2^1 */ fe25519_square(&t0,&z2_10_0);
  Serial.println(" 9 -------------> Inside the fe25519_invert function, at the (/* 2^7 - 2^2 */ fe25519_square(&t1,&t0);) step ");
	/* 2^7 - 2^2 */ fe25519_square(&t1,&t0);
  Serial.println(" 10 -------------> Inside the fe25519_invert function, at the (/* 2^8 - 2^3 */ fe25519_square(&t0,&t1);) step ");
	/* 2^8 - 2^3 */ fe25519_square(&t0,&t1);
  Serial.println(" 11 -------------> Inside the fe25519_invert function, at the (/* 2^9 - 2^4 */ fe25519_square(&t1,&t0);) step ");
	/* 2^9 - 2^4 */ fe25519_square(&t1,&t0);
  Serial.println(" 12 -------------> Inside the fe25519_invert function, at the (/* 2^10 - 2^5 */ fe25519_square(&t0,&t1);) step ");
	/* 2^10 - 2^5 */ fe25519_square(&t0,&t1);
  Serial.println(" 13 -------------> Inside the fe25519_invert function, at the (/* 2^10 - 2^0 */ fe25519_mul(&z2_10_0,&t0,&z2_10_0);) step ");
	/* 2^10 - 2^0 */ fe25519_mul(&z2_10_0,&t0,&z2_10_0);
  
  Serial.println(" 14 -------------> Inside the fe25519_invert function, at the (/* 2^11 - 2^1 */ fe25519_square(&t0,&z2_10_0);) step ");
	/* 2^11 - 2^1 */ fe25519_square(&t0,&z2_10_0);
  Serial.println(" 15 -------------> Inside the fe25519_invert function, at the (/* 2^12 - 2^2 */ fe25519_square(&t1,&t0);) step ");
	/* 2^12 - 2^2 */ fe25519_square(&t1,&t0);
  Serial.println(" 16 -------------> Inside the fe25519_invert function, at the (/* 2^20 - 2^10 */ for (i = 2;i < 10;i += 2){ fe25519_square(&t0,&t1); fe25519_square(&t1,&t0); }) step ");
	/* 2^20 - 2^10 */ for (i = 2;i < 10;i += 2){ fe25519_square(&t0,&t1); fe25519_square(&t1,&t0); }
  Serial.println(" 17 -------------> Inside the fe25519_invert function, at the (/* 2^20 - 2^0 */ fe25519_mul(&z2_50_0,&t1,&z2_10_0);) step ");
	/* 2^20 - 2^0 */ fe25519_mul(&z2_50_0,&t1,&z2_10_0);
  
  Serial.println(" 18 -------------> Inside the fe25519_invert function, at the (/* 2^21 - 2^1 */ fe25519_square(&t0,&z2_50_0);) step ");
	/* 2^21 - 2^1 */ fe25519_square(&t0,&z2_50_0);
  Serial.println(" 19 -------------> Inside the fe25519_invert function, at the (/* 2^22 - 2^2 */ fe25519_square(&t1,&t0);) step ");
	/* 2^22 - 2^2 */ fe25519_square(&t1,&t0);
  Serial.println(" 20 -------------> Inside the fe25519_invert function, at the (/* 2^40 - 2^20 */ for (i = 2;i < 20;i += 2) { fe25519_square(&t0,&t1); fe25519_square(&t1,&t0); }) step ");
	/* 2^40 - 2^20 */ for (i = 2;i < 20;i += 2) { fe25519_square(&t0,&t1); fe25519_square(&t1,&t0); }
  Serial.println(" 21 -------------> Inside the fe25519_invert function, at the (/* 2^40 - 2^0 */ fe25519_mul(&t0,&t1,&z2_50_0);) step ");
	/* 2^40 - 2^0 */ fe25519_mul(&t0,&t1,&z2_50_0);
  
  Serial.println(" 22 -------------> Inside the fe25519_invert function, at the (/* 2^41 - 2^1 */ fe25519_square(&t1,&t0);) step ");
	/* 2^41 - 2^1 */ fe25519_square(&t1,&t0);
  Serial.println(" 23 -------------> Inside the fe25519_invert function, at the (/* 2^42 - 2^2 */ fe25519_square(&t0,&t1);) step ");
	/* 2^42 - 2^2 */ fe25519_square(&t0,&t1);
  Serial.println(" 24 -------------> Inside the fe25519_invert function, at the (/* 2^50 - 2^10 */ for (i = 2;i < 10;i += 2) { fe25519_square(&t1,&t0); fe25519_square(&t0,&t1); }) step ");
	/* 2^50 - 2^10 */ for (i = 2;i < 10;i += 2) { fe25519_square(&t1,&t0); fe25519_square(&t0,&t1); }
  Serial.println(" 25 -------------> Inside the fe25519_invert function, at the (/* 2^50 - 2^0 */ fe25519_mul(&z2_50_0,&t0,&z2_10_0);) step ");
	/* 2^50 - 2^0 */ fe25519_mul(&z2_50_0,&t0,&z2_10_0);
  
  Serial.println(" 26 -------------> Inside the fe25519_invert function, at the (/* 2^51 - 2^1 */ fe25519_square(&t0,&z2_50_0);) step ");
	/* 2^51 - 2^1 */ fe25519_square(&t0,&z2_50_0);
  Serial.println(" 27 -------------> Inside the fe25519_invert function, at the (/* 2^52 - 2^2 */ fe25519_square(&t1,&t0);) step ");
	/* 2^52 - 2^2 */ fe25519_square(&t1,&t0);
  Serial.println(" 28 -------------> Inside the fe25519_invert function, at the (/* 2^100 - 2^50 */ for (i = 2;i < 50;i += 2) { fe25519_square(&t0,&t1); fe25519_square(&t1,&t0); }) step ");
	/* 2^100 - 2^50 */ for (i = 2;i < 50;i += 2) { fe25519_square(&t0,&t1); fe25519_square(&t1,&t0); }
  Serial.println(" 29 -------------> Inside the fe25519_invert function, at the (/* 2^100 - 2^0 */ fe25519_mul(&z2_100_0,&t1,&z2_50_0);) step ");
	/* 2^100 - 2^0 */ fe25519_mul(&z2_100_0,&t1,&z2_50_0);
  
  Serial.println(" 30 -------------> Inside the fe25519_invert function, at the (/* 2^101 - 2^1 */ fe25519_square(&t1,&z2_100_0);) step ");
	/* 2^101 - 2^1 */ fe25519_square(&t1,&z2_100_0);
  Serial.println(" 31 -------------> Inside the fe25519_invert function, at the (/* 2^102 - 2^2 */ fe25519_square(&t0,&t1); ");
	/* 2^102 - 2^2 */ fe25519_square(&t0,&t1);
  Serial.println(" 32 -------------> Inside the fe25519_invert function, at the (/* 2^200 - 2^100 */ for (i = 2;i < 100;i += 2) { fe25519_square(&t1,&t0); fe25519_square(&t0,&t1); }) step ");
	/* 2^200 - 2^100 */ for (i = 2;i < 100;i += 2) { fe25519_square(&t1,&t0); fe25519_square(&t0,&t1); }
  Serial.println(" 33 -------------> Inside the fe25519_invert function, at the (/* 2^200 - 2^0 */ fe25519_mul(&t1,&t0,&z2_100_0);) step ");
	/* 2^200 - 2^0 */ fe25519_mul(&t1,&t0,&z2_100_0);
  
  Serial.println(" 34 -------------> Inside the fe25519_invert function, at the (/* 2^201 - 2^1 */ fe25519_square(&t0,&t1);) step ");
	/* 2^201 - 2^1 */ fe25519_square(&t0,&t1);
  Serial.println(" 35 -------------> Inside the fe25519_invert function, at the (/* 2^202 - 2^2 */ fe25519_square(&t1,&t0);) step ");
	/* 2^202 - 2^2 */ fe25519_square(&t1,&t0);
  Serial.println(" 36 -------------> Inside the fe25519_invert function, at the (/* 2^250 - 2^50 */ for (i = 2;i < 50;i += 2) { fe25519_square(&t0,&t1); fe25519_square(&t1,&t0); }) step ");
	/* 2^250 - 2^50 */ for (i = 2;i < 50;i += 2) { fe25519_square(&t0,&t1); fe25519_square(&t1,&t0); }
  Serial.println(" 37 -------------> Inside the fe25519_invert function, at the (/* 2^250 - 2^0 */ fe25519_mul(&t0,&t1,&z2_50_0);) step ");
	/* 2^250 - 2^0 */ fe25519_mul(&t0,&t1,&z2_50_0);
  
  Serial.println(" 38 -------------> Inside the fe25519_invert function, at the (/* 2^251 - 2^1 */ fe25519_square(&t1,&t0);) step ");
	/* 2^251 - 2^1 */ fe25519_square(&t1,&t0);
  Serial.println(" 39 -------------> Inside the fe25519_invert function, at the (/* 2^252 - 2^2 */ fe25519_square(&t0,&t1);) step ");
	/* 2^252 - 2^2 */ fe25519_square(&t0,&t1);
  Serial.println(" 40 -------------> Inside the fe25519_invert function, at the (/* 2^253 - 2^3 */ fe25519_square(&t1,&t0);) step ");
	/* 2^253 - 2^3 */ fe25519_square(&t1,&t0);
  Serial.println(" 41 -------------> Inside the fe25519_invert function, at the (/* 2^254 - 2^4 */ fe25519_square(&t0,&t1);) step ");
	/* 2^254 - 2^4 */ fe25519_square(&t0,&t1);
  Serial.println(" 42 -------------> Inside the fe25519_invert function, at the (/* 2^255 - 2^5 */ fe25519_square(&t1,&t0);) step ");
	/* 2^255 - 2^5 */ fe25519_square(&t1,&t0);
  Serial.println(" 43 -------------> Inside the fe25519_invert function, at the (/* 2^255 - 21 */ fe25519_mul(r,&t1,&z11);) step ");
	/* 2^255 - 21 */ fe25519_mul(r,&t1,&z11);
}

void fe25519_square(fe25519 *r, const fe25519 *x)
{
  // Print the values of x before squaring
  Serial.print("Before squaring, x: ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(x->v[i], HEX);
    Serial.print(" ");
  }
  Serial.println();
  unsigned char t[64] = {0x00};
  bigint_square256(t,x->v);
  print_fe25519_64("square intermidate Result: ", t);
  /*for (long e = 0; e<500000; e++)
  {
    asm("");
  }*/
  fe25519_red(r,t);
  print_fe25519("square intermidate Result, after reduction: ", r);

  fe25519_freeze(r);  // Apply freeze if needed
  print_fe25519("After freeze: ", r);
}

void fe25519_mul(fe25519 *r, const fe25519 *x, const fe25519 *y)
{
  // Print the values of x and y before multiplication
  Serial.print("Before multiplication, x: ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(x->v[i], HEX);
    Serial.print(" ");
  }
  Serial.println();

  Serial.print("Before multiplication, y: ");
  for (int i = 0; i < 32; i++)
  {
    Serial.print(y->v[i], HEX);
    Serial.print(" ");
  }
  Serial.println();
  unsigned char t[64] = {0x00};
  delay(10);
  bigint_mul256(t,x->v,y->v);
  print_fe25519_64("Intermidiate multiplication 64 result", t);
  fe25519_red(r,t);
  print_fe25519("Intermidiate multiplication 64 result, after reduction: ", r);

  fe25519_freeze(r);  // Apply freeze if needed
  print_fe25519("Intermidiate multiplication 32 result, After freeze: ", r);
}

/* reduction modulo 2^255-19 */
void fe25519_freeze(fe25519 *r)
{
  unsigned char c;
  fe25519 rt;
  c = bigint_subp(rt.v, r->v);
  fe25519_cmov(r,&rt,1-c);
  c = bigint_subp(rt.v, r->v);
  fe25519_cmov(r,&rt,1-c);
}

void print_fe25519_64(const char* name, const fe25519 *f)
{
    Serial.print(name);
    Serial.print(": ");
    for (int i = 0; i < 64; i++)
    {
        Serial.print(f->v[i], HEX);
        Serial.print(" ");
    }
    Serial.println();
}

void print_fe25519_64(const char* name, const unsigned char *f)
{
    Serial.print(name);
    Serial.print(": ");
    for (int i = 0; i < 64; i++)
    {
        Serial.print(*(f + i), HEX);
        Serial.print(" ");
    }
    Serial.println();
}

void print_fe25519(const char* name, const fe25519 *f)
{
    Serial.print(name);
    Serial.print(": ");
    for (int i = 0; i < 32; i++)
    {
        Serial.print(f->v[i], HEX);
        Serial.print(" ");
    }
    Serial.println();
}

void fe25519_cmov(fe25519 *r, const fe25519 *x, unsigned char b)
{
  unsigned char i;
  unsigned long mask = b;
  mask = -mask;
  for(i=0;i<32;i++)
  {
    r->v[i] ^= mask & (x->v[i] ^ r->v[i]);
  }
}