#include <Arduino.h>
#include <avr/io.h>
#include <stdlib.h>

// Define constants and function prototypes
#define fe25519_add avrnacl_fe25519_add
#define fe25519_sub avrnacl_fe25519_sub
#define fe25519_red avrnacl_fe25519_red

typedef struct {unsigned char v[32];} fe25519;

extern "C"
{
  void fe25519_sub(fe25519 *r, const fe25519 *x, const fe25519 *y);
  void fe25519_add(fe25519 *r, const fe25519 *x, const fe25519 *y);
  void fe25519_red(fe25519 *r, unsigned char *C);
  char bigint_subp(unsigned char* r, const unsigned char* a);
  char bigint_square256(unsigned char* r, const unsigned char* a);
  char bigint_mul256(unsigned char* r, const unsigned char* a, const unsigned char* b);
  void bigint_mul121666(unsigned char *r, const unsigned char *x);
}

void fe25519_freeze(fe25519 *r);
void fe25519_unpack(fe25519 *r, const unsigned char x[32]);
void fe25519_pack(unsigned char r[32], const fe25519 *x);
void fe25519_cmov(fe25519 *r, const fe25519 *x, unsigned char b);
void fe25519_setone(fe25519 *r);
void fe25519_setzero(fe25519 *r);
void fe25519_mul(fe25519 *r, const fe25519 *x, const fe25519 *y);
void fe25519_square(fe25519 *r, const fe25519 *x);
void fe25519_invert(fe25519 *r, const fe25519 *x);
void work_cswap(fe25519 *work, char b);
void mladder(fe25519 *xr, fe25519 *zr, const unsigned char s[32]);
void fe25519_mul121666(fe25519 *r, const fe25519 *x);

static const fe25519 _121666 = {{0x42, 0xDB, 0x01, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}};

void fe25519_setzero(fe25519 *r) {
    for (int i = 0; i < 32; i++) {
        r->v[i] = 0;
    }
}

void fe25519_setone(fe25519 *r) {
    r->v[0] = 1;
    for (int i = 1; i < 32; i++) {
        r->v[i] = 0;
    }
}

void fe25519_freeze(fe25519 *r) {
    unsigned char c;
    fe25519 rt;
    c = bigint_subp(rt.v, r->v);
    fe25519_cmov(r, &rt, 1 - c);
    c = bigint_subp(rt.v, r->v);
    fe25519_cmov(r, &rt, 1 - c);
}

void fe25519_cmov(fe25519 *r, const fe25519 *x, unsigned char b) {
    unsigned long mask = -b;
    for (int i = 0; i < 32; i++) {
        r->v[i] ^= mask & (x->v[i] ^ r->v[i]);
    }
}

void fe25519_pack(unsigned char r[32], const fe25519 *x) {
    fe25519 y = *x;
    fe25519_freeze(&y);
    for (int i = 0; i < 32; i++) {
        r[i] = y.v[i];
    }
}

void fe25519_mul(fe25519 *r, const fe25519 *x, const fe25519 *y) {
    unsigned char t[64];
    bigint_mul256(t, x->v, y->v);
    fe25519_red(r, t);
    fe25519_freeze(r);
}

void fe25519_square(fe25519 *r, const fe25519 *x) {
    unsigned char t[64] = {0x00};
    bigint_square256(t, x->v);
    print_fe25519_64("zq_64", t);
    delay(10);
    fe25519_red(r, t);
    delay(10);
    print_fe25519("r = ", r);
    fe25519_freeze(r);
    print_fe25519("r = ", r);
}

void fe25519_mul121666(fe25519 *r, const fe25519 *x) {
    unsigned char t[64];
    bigint_mul121666(t, x->v);
    fe25519_red(r, t);
    fe25519_freeze(r);
}

void print_fe25519(const char* name, const fe25519 *f) {
    Serial.print(name);
    Serial.print(": ");
    for (int i = 0; i < 32; i++) {
        Serial.print(f->v[i], HEX);
        Serial.print(" ");
    }
    Serial.println();
}

void print_fe25519_64(const char* name, const fe25519 *f)
{
    Serial.print(name);
    Serial.print(": ");
    for (int i = 0; i < 64; i++)
    {
        Serial.print(f->v[i], HEX);
        Serial.print(" ");
    }
    Serial.println();
}

void print_fe25519_64(const char* name, const unsigned char *f)
{
    Serial.print(name);
    Serial.print(": ");
    for (int i = 0; i < 64; i++)
    {
        Serial.print(*(f + i), HEX);
        Serial.print(" ");
    }
    Serial.println();
}

void ladderstep(fe25519 *work) {
    fe25519 t1, t2;
    fe25519_setzero(&t1);
    fe25519_setzero(&t2);

    fe25519 *x0 = &work[0];
    fe25519 *xp = &work[1];
    fe25519 *zp = &work[2];
    fe25519 *xq = &work[3];
    fe25519 *zq = &work[4];
    
    Serial.println(" --------------- Initialization: ");
    print_fe25519("t1", &t1);
    print_fe25519("t2", &t2);
    print_fe25519("x0", x0);
    print_fe25519("xp", xp);
    print_fe25519("zp", zp);
    print_fe25519("xq", xq);
    print_fe25519("zq", zq);

    // 1. t1 = xq + zq
    Serial.println(" ---------------> // 1. t1 = xq + zq  ");
    print_fe25519("t1", &t1);
    print_fe25519("xq", xq);
    print_fe25519("zq", zq);

    fe25519_add(&t1, xq, zq);
    fe25519_freeze(&t1);

    print_fe25519("t1", &t1);

    // 2. xq = xq - zq
    Serial.println(" ---------------> // 2. xq = xq - zq  ");
    print_fe25519("xq", xq);
    print_fe25519("zq", zq);

    fe25519_sub(xq, xq, zq);
    fe25519_freeze(xq);

    print_fe25519("xq", xq);

    // 3. zq = xp + zp
    Serial.println(" ---------------> // 3. zq = xp + zp  ");
    print_fe25519("xp", xp);
    print_fe25519("zp", zp);
    print_fe25519("The result zq", zq);

    fe25519_add(zq, xp, zp);
    print_fe25519("zq", zq);

    // 4. xp = xp - zp
    Serial.println(" ---------------> // 4. xp = xp - zp  ");
    print_fe25519("xp", xp);
    print_fe25519("zp", zp);

    fe25519_sub(xp, xp, zp);
    fe25519_freeze(xp);

    print_fe25519("xp", xp);

    // 5. t1 = t1 * xp
    Serial.println(" ---------------> // 5. t1 = t1 * xp");
    print_fe25519("xp", xp);
    print_fe25519("t1", &t1);
    delay(10);

    fe25519_mul(&t1, &t1, xp);

    print_fe25519_64("t1_64", &t1);

    fe25519_freeze(&t1);

    print_fe25519("t1", &t1);

    // 6. xq = xq * zq
    Serial.println(" ---------------> // 6. xq = xq * zq");
    print_fe25519("xq", xq);
    print_fe25519("zq", zq);
    delay(10);
    fe25519_mul(xq, xq, zq);
    print_fe25519_64("xq_64", xq);

    fe25519_freeze(xq);
    print_fe25519("xq", xq);

    // 7. zq = zq^2
    Serial.println(" ---------------> // 7. zq = zq^2");
    print_fe25519("zq", zq);

    fe25519_square(zq, zq);

    print_fe25519_64("zq_64", zq);

    fe25519_freeze(zq);
    print_fe25519("zq", zq);

    // 8. xp = xp^2
    Serial.println(" ---------------> // 8. xp = xp^2");
    print_fe25519("xp", xp);

    fe25519_square(xp, xp);

    print_fe25519_64("xp_64", xp);

    fe25519_freeze(xp);
    print_fe25519("xp", xp);

    // 9. t2 = zq - xp
    Serial.println(" ---------------> // 9. t2 = zq - xp");
    print_fe25519("xp", xp);
    print_fe25519("zq", zq);
    print_fe25519("t2", &t2);

    fe25519_sub(&t2, zq, xp);
    fe25519_freeze(&t2);
    
    print_fe25519("Result _ t2", &t2);

    // 10. zp = 121666 * t2
    Serial.println(" ---------------> // 10. zp = 121666 * t2");
    print_fe25519("t2", &t2);
    fe25519_mul121666(zp, &t2);
    print_fe25519_64("zp_64", zp);
    fe25519_freeze(zp);
    print_fe25519("zp", zp);

    // 11. zp = zp + xp
    Serial.println(" ---------------> // 11. zp = zp + xp");
    print_fe25519("zp", zp);
    print_fe25519("xp", xp);
    fe25519_add(zp, zp, xp);
    fe25519_freeze(zp);
    print_fe25519("zp", zp);

    // 12. zp = zp * t2
    Serial.println(" ---------------> // 12. zp = zp * t2");
    print_fe25519("t2", &t2);
    print_fe25519("zp", zp);

    fe25519_mul(zp, zp, &t2);

    print_fe25519_64("zp_64", zp);

    fe25519_freeze(zp);
    
    print_fe25519("zp", zp);

    // 13. xp = zq * xp
    Serial.println(" ---------------> // 13. xp = zq * xp");
    print_fe25519("zq", zq);
    print_fe25519("xp", xp);

    fe25519_mul(xp, zq, xp);
    
    print_fe25519_64("xp_64", xp);
    
    fe25519_freeze(xp);
    
    print_fe25519("xp", xp);

    // 14. zq = xq - t1
    Serial.println(" ---------------> // 14. zq = xq - t1");
    print_fe25519("t1", &t1);
    print_fe25519("xq", xq);

    fe25519_sub(zq, xq, &t1);
    fe25519_freeze(zq);
    
    print_fe25519("zq", zq);

    // 15. zq = zq^2
    Serial.println(" ---------------> // 15. zq = zq^2");
    print_fe25519("zq", zq);
    delay(10);
    fe25519_square(zq, zq);
    fe25519_freeze(zq);
    print_fe25519("zq", zq);

    // 16. zq = zq * x0
    Serial.println(" ---------------> // 16. zq = zq * x0");
    print_fe25519("x0", x0);
    print_fe25519("zq", zq);

    fe25519_mul(zq, zq, x0);

    print_fe25519_64("zq_64", zq);

    fe25519_freeze(zq);

    print_fe25519("zq", zq);

    // 17. xq = t1 + xq
    Serial.println(" ---------------> // 17. xq = t1 + xq");
    print_fe25519("x0", x0);
    print_fe25519("zq", zq);
    
    fe25519_add(xq, &t1, xq);
    fe25519_freeze(xq);

    print_fe25519("zq", zq);

    // 18. xq = xq^2
    Serial.println(" ---------------> // 18. xq = xq^2");
    print_fe25519("xq", xq);

    fe25519_square(xq, xq);
    
    print_fe25519_64("xq_64", xq);

    fe25519_freeze(xq);
    
    print_fe25519("xq", xq);
    
    delay(100);
    // Print final values
    Serial.println(" ---------------> // Print final values");
    print_fe25519("x0", x0);
    print_fe25519("xp", xp);
    print_fe25519("zp", zp);
    print_fe25519("xq", xq);
    print_fe25519("zq", zq);
}

void setup()
{
  Serial.begin(9600);

  fe25519 work[5] = {
  {{0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}}, // x0
  {{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}}, // xp
  {{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}}, // zp
  {{0x28, 0x79, 0xAD, 0xD7, 0x04, 0x26, 0x4D, 0x7B, 0x89, 0xE2, 0x8C, 0x60, 0xDB, 0x3E, 0xE6, 0x6D, 0x05, 0xE6, 0x63, 0xCB, 0x3C, 0xDC, 0x10, 0x27, 0x01, 0x82, 0xA9, 0xE2, 0xCE, 0x8C, 0x4B, 0x0B}}, // xq
  {{0x37, 0x7F, 0xA1, 0xC2, 0x8E, 0x59, 0x7A, 0x7F, 0x9D, 0x35, 0xF3, 0x98, 0xA6, 0xB1, 0xE0, 0xB6, 0xC7, 0xA7, 0xEE, 0x6B, 0xB1, 0xA6, 0x73, 0xCB, 0xAA, 0x9C, 0x84, 0x6E, 0x33, 0xF3, 0xEB, 0x2B}}  // zq
  };

  ladderstep(work);
}

void loop() {
    // Empty loop
}