#include <Arduino.h>
#include <avr/io.h>
#include <stdlib.h>

// Define constants and function prototypes
#define fe25519_add avrnacl_fe25519_add
#define fe25519_sub avrnacl_fe25519_sub
#define fe25519_red avrnacl_fe25519_red

typedef struct {unsigned char v[32];} fe25519;

extern "C"
{
  void fe25519_sub(fe25519 *r, const fe25519 *x, const fe25519 *y);
  void fe25519_add(fe25519 *r, const fe25519 *x, const fe25519 *y);
  void fe25519_red(fe25519 *r, unsigned char *C);
  char bigint_subp(unsigned char* r, const unsigned char* a);
  char bigint_square256(unsigned char* r, const unsigned char* a);
  char bigint_mul256(unsigned char* r, const unsigned char* a, const unsigned char* b);
  void bigint_mul121666(unsigned char *r, const unsigned char *x);
}

void fe25519_freeze(fe25519 *r);
void fe25519_unpack(fe25519 *r, const unsigned char x[32]);
void fe25519_pack(unsigned char r[32], const fe25519 *x);
void fe25519_cmov(fe25519 *r, const fe25519 *x, unsigned char b);
void fe25519_setone(fe25519 *r);
void fe25519_setzero(fe25519 *r);
void fe25519_mul(fe25519 *r, const fe25519 *x, const fe25519 *y);
void fe25519_square(fe25519 *r, const fe25519 *x);
void fe25519_invert(fe25519 *r, const fe25519 *x);
void work_cswap(fe25519 *work, char b);
void mladder(fe25519 *xr, fe25519 *zr, const unsigned char s[32]);
void fe25519_mul121666(fe25519 *r, const fe25519 *x);

static const fe25519 _121666 = {{0x42, 0xDB, 0x01, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}};

void fe25519_setzero(fe25519 *r) {
    for (int i = 0; i < 32; i++) {
        r->v[i] = 0;
    }
}

void fe25519_setone(fe25519 *r) {
    r->v[0] = 1;
    for (int i = 1; i < 32; i++) {
        r->v[i] = 0;
    }
}

void fe25519_freeze(fe25519 *r) {
    unsigned char c;
    fe25519 rt;
    c = bigint_subp(rt.v, r->v);
    fe25519_cmov(r, &rt, 1 - c);
    c = bigint_subp(rt.v, r->v);
    fe25519_cmov(r, &rt, 1 - c);
}

void fe25519_cmov(fe25519 *r, const fe25519 *x, unsigned char b) {
    unsigned long mask = -b;
    for (int i = 0; i < 32; i++) {
        r->v[i] ^= mask & (x->v[i] ^ r->v[i]);
    }
}

void fe25519_pack(unsigned char r[32], const fe25519 *x) {
    fe25519 y = *x;
    fe25519_freeze(&y);
    for (int i = 0; i < 32; i++) {
        r[i] = y.v[i];
    }
}

void fe25519_mul(fe25519 *r, const fe25519 *x, const fe25519 *y) {
    unsigned char t[64] = {0x00};
    for (long e = 0; e<500000; e++)
    {
      asm("");
    }
    bigint_mul256(t, x->v, y->v);

    for (long e = 0; e<500000; e++)
    {
      asm("");
    }
    
    print_fe25519_64("xq_64", t);
    fe25519_red(r, t);
    fe25519_freeze(r);
}

void fe25519_square(fe25519 *r, const fe25519 *x) {
    unsigned char t[64] = {0x00};
    for (long e = 0; e<500000; e++)
    {
      asm("");
    }
    bigint_square256(t, x->v);

    print_fe25519_64("zq_64 inside the (fe25519_square) function: ", t);
    for (long e = 0; e<500000; e++)
    {
      asm("");
    }
    fe25519_red(r, t);
    print_fe25519("r ", r);
    fe25519_freeze(r);
    print_fe25519("r ", r);
}

void fe25519_mul121666(fe25519 *r, const fe25519 *x) {
    unsigned char t[64];
    bigint_mul121666(t, x->v);
    fe25519_red(r, t);
    fe25519_freeze(r);
}

void print_fe25519(const char* name, const fe25519 *f) {
    Serial.print(name);
    Serial.print(": ");
    for (int i = 0; i < 32; i++) {
        Serial.print(f->v[i], HEX);
        Serial.print(" ");
    }
    Serial.println();
}

void print_fe25519_64(const char* name, const fe25519 *f)
{
    Serial.print(name);
    Serial.print(": ");
    for (int i = 0; i < 64; i++)
    {
        Serial.print(f->v[i], HEX);
        Serial.print(" ");
    }
    Serial.println();
}

void print_fe25519_64(const char* name, const unsigned char *f)
{
    Serial.print(name);
    Serial.print(": ");
    for (int i = 0; i < 64; i++)
    {
        Serial.print(*(f + i), HEX);
        Serial.print(" ");
    }
    Serial.println();
}

void ladderstep(fe25519 *work) {
    fe25519 t1, t2;
    fe25519_setzero(&t1);
    fe25519_setzero(&t2);

    fe25519 *x0 = &work[0];
    fe25519 *xp = &work[1];
    fe25519 *zp = &work[2];
    fe25519 *xq = &work[3];
    fe25519 *zq = &work[4];
    
    Serial.println(" --------------- Initialization: ");
    print_fe25519("t1", &t1);
    print_fe25519("t2", &t2);
    print_fe25519("x0", x0);
    print_fe25519("xp", xp);
    print_fe25519("zp", zp);
    print_fe25519("xq", xq);
    print_fe25519("zq", zq);

    // 1. t1 = xq + zq
    Serial.println(" ---------------> // 1. t1 = xq + zq  ");
    print_fe25519("t1", &t1);
    print_fe25519("xq", xq);
    print_fe25519("zq", zq);

    fe25519_add(&t1, xq, zq);
    fe25519_freeze(&t1);

    print_fe25519("t1", &t1);

    // 2. xq = xq - zq
    Serial.println(" ---------------> // 2. xq = xq - zq  ");
    print_fe25519("xq", xq);
    print_fe25519("zq", zq);

    fe25519_sub(xq, xq, zq);
    fe25519_freeze(xq);

    print_fe25519("xq", xq);

    // 3. zq = xp + zp
    Serial.println(" ---------------> // 3. zq = xp + zp  ");
    print_fe25519("xp", xp);
    print_fe25519("zp", zp);
    print_fe25519("The result zq", zq);
    /*for (long e = 0; e<500000; e++)
    {
      asm("");
    }*/

    fe25519_add(zq, xp, zp);
    for (long e = 0; e<500000; e++)
    {
      asm("");
    }
    print_fe25519("zq before Freeze", zq);

    fe25519_freeze(zq);
    
    print_fe25519("zq after Freeze", zq);

    // 4. xp = xp - zp
    Serial.println(" ---------------> // 4. xp = xp - zp  ");
    print_fe25519("xp", xp);
    print_fe25519("zp", zp);
    for (long e = 0; e<500000; e++)
    {
      asm("");
    }
    fe25519_sub(xp, xp, zp);
    fe25519_freeze(xp);

    print_fe25519("xp", xp);

    // 5. t1 = t1 * xp
    Serial.println(" ---------------> // 5. t1 = t1 * xp");
    print_fe25519("xp", xp);
    print_fe25519("t1", &t1);
    for (long e = 0; e<500000; e++)
    {
      asm("");
    }

    fe25519_mul(&t1, &t1, xp);

    print_fe25519_64("t1_64", &t1);

    fe25519_freeze(&t1);

    print_fe25519("t1", &t1);

    // 6. xq = xq * zq
    Serial.println(" ---------------> // 6. xq = xq * zq");
    print_fe25519("xq", xq);
    print_fe25519("zq", zq);
    for (long e = 0; e<500000; e++)
    {
      asm("");
    }
    fe25519_mul(xq, xq, zq);

    fe25519_freeze(xq);
    print_fe25519("xq", xq);

    // 7. zq = zq^2
    Serial.println(" ---------------> // 7. zq = zq^2");
    print_fe25519("zq", zq);
    for (long e = 0; e<500000; e++)
    {
      asm("");
    }
    fe25519_square(zq, zq);

    fe25519_freeze(zq);
    print_fe25519("zq", zq);

    // 8. xp = xp^2
    Serial.println(" ---------------> // 8. xp = xp^2");
    print_fe25519("xp", xp);

    fe25519_square(xp, xp);

    print_fe25519_64("xp_64", xp);

    fe25519_freeze(xp);
    print_fe25519("xp", xp);

    // 9. t2 = zq - xp
    Serial.println(" ---------------> // 9. t2 = zq - xp");
    print_fe25519("xp", xp);
    print_fe25519("zq", zq);
    print_fe25519("t2", &t2);

    fe25519_sub(&t2, zq, xp);
    fe25519_freeze(&t2);
    
    print_fe25519("Result _ t2", &t2);

    // 10. zp = 121666 * t2
    Serial.println(" ---------------> // 10. zp = 121666 * t2");
    print_fe25519("t2", &t2);
    fe25519_mul121666(zp, &t2);
    print_fe25519_64("zp_64", zp);
    fe25519_freeze(zp);
    print_fe25519("zp", zp);

    // 11. zp = zp + xp
    Serial.println(" ---------------> // 11. zp = zp + xp");
    print_fe25519("zp", zp);
    print_fe25519("xp", xp);
    fe25519_add(zp, zp, xp);
    fe25519_freeze(zp);
    print_fe25519("zp", zp);

    // 12. zp = zp * t2
    Serial.println(" ---------------> // 12. zp = zp * t2");
    print_fe25519("t2", &t2);
    print_fe25519("zp", zp);

    fe25519_mul(zp, zp, &t2);

    print_fe25519_64("zp_64", zp);

    fe25519_freeze(zp);
    
    print_fe25519("zp", zp);

    // 13. xp = zq * xp
    Serial.println(" ---------------> // 13. xp = zq * xp");
    print_fe25519("zq", zq);
    print_fe25519("xp", xp);

    fe25519_mul(xp, zq, xp);
    
    print_fe25519_64("xp_64", xp);
    
    fe25519_freeze(xp);
    
    print_fe25519("xp", xp);

    // 14. zq = xq - t1
    Serial.println(" ---------------> // 14. zq = xq - t1");
    print_fe25519("t1", &t1);
    print_fe25519("xq", xq);
    /*for (long e = 0; e<500000; e++)
    {
      asm("");
    }*/
    fe25519_sub(zq, xq, &t1);
    /*for (long e = 0; e<500000; e++)
    {
      asm("");
    }*/
    print_fe25519("zq before freeze", zq);

    /*for (long e = 0; e<500000; e++)
    {
      asm("");
    }*/
    fe25519_freeze(zq);
    for (long e = 0; e<500000; e++)
    {
      asm("");
    }
    print_fe25519("zq after freeze", zq);

    // 15. zq = zq^2
    Serial.println(" ---------------> // 15. zq = zq^2");
    print_fe25519("zq", zq);
    for (long e = 0; e<500000; e++)
    {
      asm("");
    }
    fe25519_square(zq, zq);
    fe25519_freeze(zq);
    print_fe25519("zq", zq);

    // 16. zq = zq * x0
    Serial.println(" ---------------> // 16. zq = zq * x0");
    print_fe25519("x0", x0);
    print_fe25519("zq", zq);
    /*for (long e = 0; e<500000; e++)
    {
      asm("");
    }*/

    fe25519_mul(zq, zq, x0);

    print_fe25519_64("zq_64", zq);

    fe25519_freeze(zq);

    print_fe25519("zq", zq);

    // 17. xq = t1 + xq
    Serial.println(" ---------------> // 17. xq = t1 + xq");
    print_fe25519("t1 = ", &t1);
    print_fe25519("xq = ", xq);

    /*for (long e = 0; e<500000; e++)
    {
      asm("");
    }*/

    fe25519_add(xq, &t1, xq);
    
    for (long e = 0; e<500000; e++)
    {
      asm("");
    }

    print_fe25519("xq before freeze = ", xq);

    fe25519_freeze(xq);

    print_fe25519("xq after freeze = ", xq);

    // 18. xq = xq^2
    Serial.println(" ---------------> // 18. xq = xq^2");
    print_fe25519("xq", xq);

    fe25519_square(xq, xq);
    
    print_fe25519_64("xq_64", xq);

    fe25519_freeze(xq);
    
    print_fe25519("xq", xq);
    
    for (long e = 0; e<1000000; e++)
    {
      asm("");
    }
    // Print final values
    Serial.println(" ---------------> // Print final values");
    print_fe25519("x0", x0);
    print_fe25519("xp", xp);
    print_fe25519("zp", zp);
    print_fe25519("xq", xq);
    print_fe25519("zq", zq);
}

void setup()
{
  Serial.begin(9600);

  fe25519 work[5] = {
    {{0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}}, // x0
    {{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0xF8, 0xD0, 0x7B, 0x22, 0x44, 0x58, 0xEF, 0xA0, 0x76, 0x1C, 0xF0, 0xB2, 0xE8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}}, // xp
    {{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x90, 0x35, 0xC0, 0x08, 0x02, 0x8D, 0x55, 0x9A, 0x3D, 0x17, 0x67, 0x85, 0x06, 0x20, 0x97, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}}, // zp
    {{0x00, 0x00, 0x00, 0x00, 0x00, 0x99, 0x1A, 0xF4, 0x14, 0xCD, 0x94, 0xA4, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}}, // xq
    {{0x00, 0x00, 0x00, 0x00, 0x20, 0x57, 0xE3, 0x87, 0x6A, 0x01, 0xAF, 0x7B, 0xDF, 0x3D, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}}  // zq
  };
  ladderstep(work);
}

void loop() {
    // Empty loop
}