#include <Arduino.h>
#include <avr/io.h>
#include <stdlib.h>

// Define constants and function prototypes
#define fe25519_add avrnacl_fe25519_add
#define fe25519_sub avrnacl_fe25519_sub
#define fe25519_red avrnacl_fe25519_red

typedef struct {unsigned char v[32];} fe25519;

extern "C"
{
  void fe25519_sub(fe25519 *r, const fe25519 *x, const fe25519 *y);
  void fe25519_add(fe25519 *r, const fe25519 *x, const fe25519 *y);
  void fe25519_red(fe25519 *r, unsigned char *C);
  char bigint_subp(unsigned char* r, const unsigned char* a);
  char bigint_square256(unsigned char* r, const unsigned char* a);
  char bigint_mul256(unsigned char* r, const unsigned char* a, const unsigned char* b);
  void bigint_mul121666(unsigned char *r, const unsigned char *x);
}

void fe25519_freeze(fe25519 *r);
void fe25519_unpack(fe25519 *r, const unsigned char x[32]);
void fe25519_pack(unsigned char r[32], const fe25519 *x);
void fe25519_cmov(fe25519 *r, const fe25519 *x, unsigned char b);
void fe25519_setone(fe25519 *r);
void fe25519_setzero(fe25519 *r);
void fe25519_mul(fe25519 *r, const fe25519 *x, const fe25519 *y);
void fe25519_square(fe25519 *r, const fe25519 *x);
void fe25519_invert(fe25519 *r, const fe25519 *x);
void work_cswap(fe25519 *work, char b);
void mladder(fe25519 *xr, fe25519 *zr, const unsigned char s[32]);
void fe25519_mul121666(fe25519 *r, const fe25519 *x);

int crypto_scalarmult_curve25519(unsigned char *r, const unsigned char *s, const unsigned char *p);

// Enable or disable debug printing
#define DEBUG 1 // Set to 0 to disable debug prints

void setup() {
  Serial.begin(9600);

  // Initialize zr with the given value
  fe25519 zr = {{0x41, 0xB2, 0xEA, 0x05, 0x32, 0xBD, 0x51, 0x95, 0xAF, 0x2F, 0xB3, 0xDC, 0x35, 0x6A, 0xE9, 0x32, 0x41, 0xB5, 0x9D, 0xEB, 0x6D, 0x4A, 0xF0, 0x8D, 0xFF, 0x84, 0xB2, 0x23, 0xD6, 0x43, 0xA0, 0x11}};
  
  Serial.print("zr: ");
  for (int i = 0; i < 32; i++) {
    Serial.print(zr.v[i], HEX);
    Serial.print(" ");
  }
  Serial.println();

  // Initialize a variable to hold the inverse of zr
  fe25519 zr_inv;
  
  // Compute the multiplicative inverse of zr
  fe25519_invert(&zr_inv, &zr);
  
  // Print the inverse
  Serial.println(" ===============> Multiplicative inverse of zr: ");
  for (int i = 0; i < 32; i++) {
    Serial.print(zr_inv.v[i], HEX);
    Serial.print(" ");
  }
  Serial.println();
}

void loop() {
  // Empty loop
}

void fe25519_invert(fe25519 *r, const fe25519 *x)
{
	fe25519 z2;
	fe25519 z11;
	fe25519 z2_10_0;
	fe25519 z2_50_0;
	fe25519 z2_100_0;
	fe25519 t0;
	fe25519 t1;
	unsigned char i;

  #if DEBUG
    Serial.println("Inside the fe25519_invert function.");
  #endif

	fe25519_square(&z2,x);
	fe25519_square(&t1,&z2);
	fe25519_square(&t0,&t1);
	fe25519_mul(&z2_10_0,&t0,x);
	fe25519_mul(&z11,&z2_10_0,&z2);
	fe25519_square(&t0,&z11);
	fe25519_mul(&z2_10_0,&t0,&z2_10_0);
	for (i = 2; i < 10; i += 2) { fe25519_square(&t0,&t1); fe25519_square(&t1,&t0); }
	fe25519_mul(&z2_50_0,&t1,&z2_10_0);
	for (i = 2; i < 20; i += 2) { fe25519_square(&t0,&t1); fe25519_square(&t1,&t0); }
	fe25519_mul(&t0,&t1,&z2_50_0);
	for (i = 2; i < 50; i += 2) { fe25519_square(&t0,&t1); fe25519_square(&t1,&t0); }
	fe25519_mul(&z2_100_0,&t1,&z2_50_0);
	for (i = 2; i < 100; i += 2) { fe25519_square(&t1,&t0); fe25519_square(&t0,&t1); }
	fe25519_mul(&t1,&t0,&z2_100_0);
	fe25519_mul(r,&t1,&z11);
}

void fe25519_square(fe25519 *r, const fe25519 *x)
{
  // Print the values of x before squaring
  #if DEBUG
    Serial.print("Before squaring, x: ");
    for (int i = 0; i < 32; i++)
    {
      Serial.print(x->v[i], HEX);
      Serial.print(" ");
    }
    Serial.println();
  #endif

  unsigned char t[64] = {0x00};
  bigint_square256(t,x->v);
  fe25519_red(r,t);
  fe25519_freeze(r);  // Apply freeze if needed
}

void fe25519_mul(fe25519 *r, const fe25519 *x, const fe25519 *y)
{
  // Print the values of x and y before multiplication
  #if DEBUG
    Serial.print("Before multiplication, x: ");
    for (int i = 0; i < 32; i++)
    {
      Serial.print(x->v[i], HEX);
      Serial.print(" ");
    }
    Serial.println();

    Serial.print("Before multiplication, y: ");
    for (int i = 0; i < 32; i++)
    {
      Serial.print(y->v[i], HEX);
      Serial.print(" ");
    }
    Serial.println();
  #endif

  unsigned char t[64] = {0x00};
  delay(10);  // Short delay between critical operations
  bigint_mul256(t,x->v,y->v);
  fe25519_red(r,t);
  fe25519_freeze(r);  // Apply freeze if needed
}

void fe25519_freeze(fe25519 *r)
{
  unsigned char c;
  fe25519 rt;
  c = bigint_subp(rt.v, r->v);
  fe25519_cmov(r,&rt,1-c);
  c = bigint_subp(rt.v, r->v);
  fe25519_cmov(r,&rt,1-c);
}

void fe25519_cmov(fe25519 *r, const fe25519 *x, unsigned char b)
{
  unsigned char i;
  unsigned long mask = b;
  mask = -mask;
  for(i=0;i<32;i++)
  {
    r->v[i] ^= mask & (x->v[i] ^ r->v[i]);
  }
}