//This code work correctly till the iteration i = 29, j = 0 or till the iteration 232, because at this iteration we got correct x0,z0 and x1 but we 
//got z1 incorrect.
#include <Arduino.h>
#include <avr/io.h>
#include <stdlib.h>

// Define constants and function prototypes
#define fe25519_add avrnacl_fe25519_add
#define fe25519_sub avrnacl_fe25519_sub
#define fe25519_red avrnacl_fe25519_red

typedef struct {unsigned char v[32];} fe25519;

extern "C"
{
  void fe25519_sub(fe25519 *r, const fe25519 *x, const fe25519 *y);
  void fe25519_add(fe25519 *r, const fe25519 *x, const fe25519 *y);
  void fe25519_red(fe25519 *r, unsigned char *C);
  char bigint_subp(unsigned char* r, const unsigned char* a);
  char bigint_square256(unsigned char* r, const unsigned char* a);
  char bigint_mul256(unsigned char* r, const unsigned char* a, const unsigned char* b);
  void bigint_mul121666(unsigned char *r, const unsigned char *x);
}

void fe25519_freeze(fe25519 *r);
void fe25519_unpack(fe25519 *r, const unsigned char x[32]);
void fe25519_pack(unsigned char r[32], const fe25519 *x);
void fe25519_cmov(fe25519 *r, const fe25519 *x, unsigned char b);
void fe25519_setone(fe25519 *r);
void fe25519_setzero(fe25519 *r);
void fe25519_mul(fe25519 *r, const fe25519 *x, const fe25519 *y);
void fe25519_square(fe25519 *r, const fe25519 *x);
void fe25519_invert(fe25519 *r, const fe25519 *x);
void work_cswap(fe25519 *work, char b);
void mladder(fe25519 *xr, fe25519 *zr, const unsigned char s[32]);
void fe25519_mul121666(fe25519 *r, const fe25519 *x);

int crypto_scalarmult_curve25519(unsigned char *r, const unsigned char *s, const unsigned char *p);

void setup()
{
    Serial.begin(9600); // baud rate: number of bytes / in second
    Serial.println();

    uint8_t a[32] = {0x88, 0x41, 0x08, 0xF2, 0xDC, 0xE3, 0x96, 0x93, 0xAA, 0x39, 0x21, 0xAF, 0xC1, 0x01, 0x5B, 0x32, 0xE7, 0xF9, 0x1A, 0xD4, 0xFD, 0x36, 0x85, 0x63, 0x0F, 0xA8, 0x66, 0x63, 0xB3, 0x7A, 0xBE, 0x37};
    uint8_t b[32] = {0x88, 0x41, 0x08, 0xF2, 0xDC, 0xE3, 0x96, 0x93, 0xAA, 0x39, 0x21, 0xAF, 0xC1, 0x01, 0x5B, 0x32, 0xE7, 0xF9, 0x1A, 0xD4, 0xFD, 0x36, 0x85, 0x63, 0x0F, 0xA8, 0x66, 0x63, 0xB3, 0x7A, 0xBE, 0x37};
    uint8_t result[64] = {0};
    fe25519 final_result;
    

    Serial.print(" a = ");
    for (int i = 0; i < 32; i++)
    {
      Serial.print(a[i], HEX);
      Serial.print(" ");
    }
    Serial.println();

    Serial.print(" b = ");
    for (int i = 0; i < 32; i++)
    {
      Serial.print(b[i], HEX);
      Serial.print(" ");
    }
    Serial.println();

    delay(10); // Add a small delay to ensure serial output is done

    // Call the assembly function to multiply the big integers
    bigint_mul256(result, a, b);

    // Print the result
    Serial.print(" result = ");
    for (int i = 0; i < 64; i++)
    {
      Serial.print(result[i], HEX);
      Serial.print(" ");
    }
    Serial.println();

    // Apply modular reduction to get the final result in fe25519 format
    fe25519_red(&final_result, result);
    // Print the final reduced and frozen result
    Serial.print(" result after reducation = ");
    for (int i = 0; i < 32; i++)
    {
      Serial.print(final_result.v[i], HEX);
      Serial.print(" ");
    }
    Serial.println();

    // Apply the freeze function to ensure the result is within the correct range
    fe25519_freeze(&final_result);

    // Print the final reduced and frozen result
    Serial.print(" final_result = ");
    for (int i = 0; i < 32; i++)
    {
      Serial.print(final_result.v[i], HEX);
      Serial.print(" ");
    }
    Serial.println();
}

void loop()
{
    // Empty loop
}

/* reduction modulo 2^255-19 */
void fe25519_freeze(fe25519 *r)
{
    unsigned char c;
    fe25519 rt;
    c = bigint_subp(rt.v, r->v);
    fe25519_cmov(r, &rt, 1 - c);
}

void fe25519_cmov(fe25519 *r, const fe25519 *x, unsigned char b)
{
    unsigned char i;
    unsigned long mask = b;
    mask = -mask;
    for (i = 0; i < 32; i++)
    {
        r->v[i] ^= mask & (x->v[i] ^ r->v[i]);
    }
}