#include <Arduino.h>
#include <avr/io.h>
#include <stdlib.h>

// Define constants and function prototypes
#define fe25519_add avrnacl_fe25519_add
#define fe25519_sub avrnacl_fe25519_sub
#define fe25519_red avrnacl_fe25519_red

typedef struct {
  unsigned char v[32];
} fe25519;

extern "C" {
  void fe25519_sub(fe25519 *r, const fe25519 *x, const fe25519 *y);
  void fe25519_add(fe25519 *r, const fe25519 *x, const fe25519 *y);
  void fe25519_red(fe25519 *r, unsigned char *C);
  char bigint_subp(unsigned char *r, const unsigned char *a);
  char bigint_square256(unsigned char *r, const unsigned char *a);
  char bigint_mul256(unsigned char *r, const unsigned char *a, const unsigned char *b);
  void bigint_mul121666(unsigned char *r, const unsigned char *x);
}

void fe25519_freeze(fe25519 *r);
void fe25519_unpack(fe25519 *r, const unsigned char x[32]);
void fe25519_pack(unsigned char r[32], const fe25519 *x);
void fe25519_cmov(fe25519 *r, const fe25519 *x, unsigned char b);
void fe25519_setone(fe25519 *r);
void fe25519_setzero(fe25519 *r);
void fe25519_mul(fe25519 *r, const fe25519 *x, const fe25519 *y);
void fe25519_square(fe25519 *r, const fe25519 *x);
void fe25519_invert(fe25519 *r, const fe25519 *x);
void work_cswap(fe25519 *work, char b);
void mladder(fe25519 *xr, fe25519 *zr, const unsigned char s[32]);
void fe25519_mul121666(fe25519 *r, const fe25519 *x);

int crypto_scalarmult_curve25519(unsigned char *r, const unsigned char *s, const unsigned char *p);

static const fe25519 _121666 = { { 0x42, 0xDB, 0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } };

void setup() {
  Serial.begin(500000);
  // n represent the scalar multiplcation value.
  //unsigned char n[32] = {0x56,0x2c,0x1e,0xb5,0xfd,0xb2,0x81,0x29, 0xbd,0x37,0x49,0x58,0x35,0xd4,0xb1,0x30, 0x7d,0xdb,0x57,0x38,0x80,0x12,0x17,0x42, 0xf7,0x13,0xf1,0x05,0x67,0x69,0xd5,0xbf};
  unsigned char n[32] = { 0x05 };
  // p represent the x-coordinate of the base point (or the generator).
  unsigned char p[32] = { 0x09 };
  // q represent the result that we want to get.
  unsigned char q[32];

  crypto_scalarmult_curve25519(q, n, p);

  Serial.print("Result: ");
  for (int i = 0; i < 32; i++) {
    Serial.print(q[i], HEX);
    Serial.print(" ");
  }
  Serial.println();
}

void loop() {
  // Empty loop
}

//---------------------------------------Define your functions here---------------------------------------
int crypto_scalarmult_curve25519(unsigned char *r, const unsigned char *s, const unsigned char *p) {
  unsigned char e[32];
  unsigned char i;
  for (i = 0; i < 32; i++) {
    e[i] = s[i];
  }
  Serial.println();
  //e[0] &= 248;
  //e[31] &= 127;
  //e[31] |= 64;

  fe25519 t;
  fe25519 z;
  fe25519_unpack(&t, p);
  mladder(&t, &z, e);
  fe25519_invert(&z, &z);
  fe25519_mul(&t, &t, &z);
  fe25519_pack(r, &t);
  return 0;
}

void fe25519_unpack(fe25519 *r, const unsigned char x[32]) {
  unsigned char i;
  for (i = 0; i < 32; i++) {
    r->v[i] = x[i];
  }
  r->v[31] &= 127;
}

// Implement mladder function here
void mladder(fe25519 *xr, fe25519 *zr, const unsigned char s[32]) {
  Serial.println();
  fe25519 work[5];
  unsigned char bit, prevbit = 0;
  unsigned char swap;
  signed char j = 6;
  signed char i;

  work[0] = *xr;
  fe25519_setone(work + 1);
  fe25519_setzero(work + 2);
  work[3] = *xr;
  fe25519_setone(work + 4);
  Serial.println(" -----------> Start of the for loop:  ");
  for (i = 31; i >= 0; i--) {
    while (j >= 0) {
      bit = 1 & (s[i] >> j);
      swap = bit ^ prevbit;
      prevbit = bit;
      work_cswap(work, swap);
      ladderstep(work);
      j -= 1;
    }
    j = 7;
  }
  *xr = work[3];
  *zr = work[4];
}

void fe25519_setone(fe25519 *r) {
  unsigned char i;
  r->v[0] = 1;
  for (i = 1; i < 64; i++) {
    r->v[i] = 0;
  }
}

void fe25519_setzero(fe25519 *r) {
  unsigned char i;
  for (i = 0; i < 64; i++) {
    r->v[i] = 0;
  }
}

// Implement work_cswap and ladderstep functions here
void work_cswap(fe25519 *work, unsigned char b) {
  fe25519 t = {{0x00}};
  fe25519_cmov(&t, work + 1, b);
  fe25519_cmov(work + 1, work + 3, b);
  fe25519_cmov(work + 3, &t, b);
  fe25519_cmov(&t, work + 2, b);
  fe25519_cmov(work + 2, work + 4, b);
  fe25519_cmov(work + 4, &t, b);
}

void fe25519_cmov(fe25519 *r, const fe25519 *x, unsigned char b) {
  unsigned char i;
  unsigned long mask = b;
  mask = -mask;
  for (i = 0; i < 32; i++) {
    r->v[i] ^= mask & (x->v[i] ^ r->v[i]);
  }
}

void fe25519_pack(unsigned char r[32], const fe25519 *x) {
  unsigned char i;
  fe25519 y = *x;
  fe25519_freeze(&y);
  for (i = 0; i < 32; i++) {
    r[i] = y.v[i];
  }
}

/* reduction modulo 2^255-19 */
void fe25519_freeze(fe25519 *r) {
  unsigned char c;
  fe25519 rt;
  c = bigint_subp(rt.v, r->v);
  fe25519_cmov(r, &rt, 1 - c);
  c = bigint_subp(rt.v, r->v);
  fe25519_cmov(r, &rt, 1 - c);
}

void fe25519_mul(fe25519 *r, const fe25519 *x, const fe25519 *y) {
  unsigned char t[64] = { 0x00 };
  cli();  // Disable interrupts
  bigint_mul256(t, x->v, y->v);
  sei();  // Enable interrupts
  fe25519_red(r, t);
  fe25519_freeze(r);  // Apply freeze if needed
  fe25519_freeze(r);  // Apply freeze if needed
}

void fe25519_square(fe25519 *r, const fe25519 *x) {
  unsigned char t[64];
  cli();  // Disable interrupts
  bigint_square256(t, x->v);
  sei();  // Enable interrupts
  fe25519_red(r, t);
  fe25519_freeze(r);  // Apply freeze if needed
}

void ladderstep(fe25519 *work) {
  fe25519 t[2];
  fe25519 *t1 = &t[0];
  fe25519 *t2 = &t[1];

  // Initialize t1 and t2 to zero
  fe25519_setzero(t1);
  fe25519_setzero(t2);

  fe25519 *x0 = work;
  fe25519 *xp = work + 1;
  fe25519 *zp = work + 2;
  fe25519 *xq = work + 3;
  fe25519 *zq = work + 4;

  fe25519_add(t1, xq, zq);
  fe25519_freeze(t1);  // Apply freeze if needed

  fe25519_sub(xq, xq, zq);
  fe25519_freeze(xq);  // Apply freeze if needed

  fe25519_add(zq, xp, zp);
  fe25519_freeze(zq);  // Apply freeze if needed

  fe25519_sub(xp, xp, zp);

  fe25519_mul(t1, t1, xp);

  fe25519_mul(xq, xq, zq);
  fe25519_freeze(xq);

  fe25519_square(zq, zq);

  fe25519_square(xp, xp);

  fe25519_sub(t2, zq, xp);
  fe25519_red(t2, t2->v);
  fe25519_freeze(t2);  // Apply freeze if needed

  fe25519_mul121666(zp, t2);

  fe25519_add(zp, zp, xp);
  fe25519_freeze(zp);  // Apply freeze if needed

  fe25519_mul(zp, zp, t2);

  fe25519_mul(xp, zq, xp);

  fe25519_sub(zq, xq, t1);
  fe25519_freeze(zq);  // Apply freeze if needed

  fe25519_square(zq, zq);
  fe25519_freeze(zq);  // Apply freeze if needed

  fe25519_mul(zq, zq, x0);

  fe25519_add(xq, t1, xq);

  fe25519_square(xq, xq);
  fe25519_freeze(xq);  // Apply freeze if needed
}

void print_fe25519(const char *name, const fe25519 *f) {
  Serial.print(name);
  Serial.print("= ");
  for (int i = 0; i < 32; i++) {
    Serial.print(f->v[i], HEX);
    Serial.print(" ");
  }
  Serial.println();
}

void print_fe25519_64(const char *name, const fe25519 *f) {
  Serial.print(name);
  Serial.print(": ");
  for (int i = 0; i < 64; i++) {
    Serial.print(f->v[i], HEX);
    Serial.print(" ");
  }
  Serial.println();
}

void print_fe25519_64(const char *name, const unsigned char *f) {
  Serial.print(name);
  Serial.print(": ");
  for (int i = 0; i < 64; i++) {
    Serial.print(*(f + i), HEX);
    Serial.print(" ");
  }
  Serial.println();
}

void fe25519_mul121666(fe25519 *r, const fe25519 *x) {
  unsigned char t[64];
  cli();  // Disable interrupts
  bigint_mul256(t, x->v, _121666.v);
  sei();  // Enable interrupts
  //bigint_mul121666(t,x->v);
  //bigint_mul121666(t,_121666.v);
  fe25519_red(r, t);
}

void fe25519_invert(fe25519 *r, const fe25519 *x) {
  fe25519 z2;
  fe25519 z11;
  fe25519 z2_10_0;
  fe25519 z2_50_0;
  fe25519 z2_100_0;
  fe25519 t0;
  fe25519 t1;
  unsigned char i;

  /* 2 */ fe25519_square(&z2, x);
  /* 4 */ fe25519_square(&t1, &z2);
  /* 8 */ fe25519_square(&t0, &t1);
  /* 9 */ fe25519_mul(&z2_10_0, &t0, x);
  /* 11 */ fe25519_mul(&z11, &z2_10_0, &z2);
  /* 22 */ fe25519_square(&t0, &z11);
  /* 2^5 - 2^0 = 31 */ fe25519_mul(&z2_10_0, &t0, &z2_10_0);

  /* 2^6 - 2^1 */ fe25519_square(&t0, &z2_10_0);
  /* 2^7 - 2^2 */ fe25519_square(&t1, &t0);
  /* 2^8 - 2^3 */ fe25519_square(&t0, &t1);
  /* 2^9 - 2^4 */ fe25519_square(&t1, &t0);
  /* 2^10 - 2^5 */ fe25519_square(&t0, &t1);
  /* 2^10 - 2^0 */ fe25519_mul(&z2_10_0, &t0, &z2_10_0);

  /* 2^11 - 2^1 */ fe25519_square(&t0, &z2_10_0);
  /* 2^12 - 2^2 */ fe25519_square(&t1, &t0);
  /* 2^20 - 2^10 */ for (i = 2; i < 10; i += 2) {
    fe25519_square(&t0, &t1);
    fe25519_square(&t1, &t0);
  }
  /* 2^20 - 2^0 */ fe25519_mul(&z2_50_0, &t1, &z2_10_0);

  /* 2^21 - 2^1 */ fe25519_square(&t0, &z2_50_0);
  /* 2^22 - 2^2 */ fe25519_square(&t1, &t0);
  /* 2^40 - 2^20 */ for (i = 2; i < 20; i += 2) {
    fe25519_square(&t0, &t1);
    fe25519_square(&t1, &t0);
  }
  /* 2^40 - 2^0 */ fe25519_mul(&t0, &t1, &z2_50_0);

  /* 2^41 - 2^1 */ fe25519_square(&t1, &t0);
  /* 2^42 - 2^2 */ fe25519_square(&t0, &t1);
  /* 2^50 - 2^10 */ for (i = 2; i < 10; i += 2) {
    fe25519_square(&t1, &t0);
    fe25519_square(&t0, &t1);
  }
  /* 2^50 - 2^0 */ fe25519_mul(&z2_50_0, &t0, &z2_10_0);

  /* 2^51 - 2^1 */ fe25519_square(&t0, &z2_50_0);
  /* 2^52 - 2^2 */ fe25519_square(&t1, &t0);
  /* 2^100 - 2^50 */ for (i = 2; i < 50; i += 2) {
    fe25519_square(&t0, &t1);
    fe25519_square(&t1, &t0);
  }
  /* 2^100 - 2^0 */ fe25519_mul(&z2_100_0, &t1, &z2_50_0);

  /* 2^101 - 2^1 */ fe25519_square(&t1, &z2_100_0);
  /* 2^102 - 2^2 */ fe25519_square(&t0, &t1);
  /* 2^200 - 2^100 */ for (i = 2; i < 100; i += 2) {
    fe25519_square(&t1, &t0);
    fe25519_square(&t0, &t1);
  }
  /* 2^200 - 2^0 */ fe25519_mul(&t1, &t0, &z2_100_0);

  /* 2^201 - 2^1 */ fe25519_square(&t0, &t1);
  /* 2^202 - 2^2 */ fe25519_square(&t1, &t0);
  /* 2^250 - 2^50 */ for (i = 2; i < 50; i += 2) {
    fe25519_square(&t0, &t1);
    fe25519_square(&t1, &t0);
  }
  /* 2^250 - 2^0 */ fe25519_mul(&t0, &t1, &z2_50_0);

  /* 2^251 - 2^1 */ fe25519_square(&t1, &t0);
  /* 2^252 - 2^2 */ fe25519_square(&t0, &t1);
  /* 2^253 - 2^3 */ fe25519_square(&t1, &t0);
  /* 2^254 - 2^4 */ fe25519_square(&t0, &t1);
  /* 2^255 - 2^5 */ fe25519_square(&t1, &t0);
  /* 2^255 - 21 */ fe25519_mul(r, &t1, &z11);
}