#include <Arduino.h>
#include <stdint.h>
#include <string.h>
const uint8_t mod25519_subtracter_32[32] = {0xED, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F};
uint8_t subtracter[64] = {0xED, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
void mul32_mod25519(const uint8_t* a, const uint8_t* b, uint8_t* result);
void mod25519(uint8_t* result, const uint8_t* tempResult);
void mod25519_multiply_32(uint8_t* result, const uint8_t* a, uint8_t b, uint16_t& carry);
void mod25519_add_32(uint8_t* result, const uint8_t* a, uint16_t& carry);
void mod25519_sub_32(uint8_t* r, const uint8_t* a, const uint8_t* b, uint8_t* borrow);
void mod25519_exchange_32(uint8_t* r, const uint8_t* original, uint8_t borrow);
void fe25519_invert(uint8_t *result, const uint8_t *input)
{
uint8_t x_exp_2[32], x_exp_9[32], x_exp_11[32], x2_exp_10[32], x2_exp_50[32], x2_exp_100[32], x2_exp_200[32], temp0[32], temp1[32];
unsigned char i;
/* Precomputation */
mul32_mod25519(input, input, x_exp_2); // x_exp_2 = input^2
mul32_mod25519(x_exp_2, x_exp_2, temp1); // temp1 = x_exp_2^2 = input^4
mul32_mod25519(temp1, temp1, temp0); // temp0 = temp1^2 = input^8
mul32_mod25519(temp0, input, x_exp_9); // x_exp_9 = temp0 * input = input^9
mul32_mod25519(x_exp_9, x_exp_2, x_exp_11); // x_exp_11 = x_exp_9 * x_exp_2 = input^11
/* Compute higher powers */
mul32_mod25519(x_exp_11, x_exp_11, temp0); // temp0 = x_exp_11^2 = input^22
mul32_mod25519(temp0, x_exp_9, x2_exp_10); // x2_exp_10 = temp0 * x_exp_9 = input^(2^10 - 1)
mul32_mod25519(x2_exp_10, x2_exp_10, temp0); // temp0 = x2_exp_10^2 = input^(2^11 - 2)
mul32_mod25519(temp0, temp0, temp1); // temp1 = temp0^2 = input^(2^12 - 4)
for (i = 2; i < 10; i += 2)
{
mul32_mod25519(temp1, temp1, temp0); // temp0 = temp1^2
mul32_mod25519(temp0, temp0, temp1); // temp1 = temp0^2
}
mul32_mod25519(temp1, x2_exp_10, x2_exp_50); // x2_exp_50 = temp1 * x2_exp_10 = input^(2^50 - 1)
/* Further combinations */
// Compute powers up to 2^100 - 1
for (i = 2; i < 50; i += 2)
{
mul32_mod25519(temp1, temp1, temp0);
mul32_mod25519(temp0, temp0, temp1);
}
mul32_mod25519(temp1, x2_exp_50, x2_exp_100); // x2_exp_100 = temp1 * x2_exp_50 = input^(2^100 - 1)
// Continue to 2^200 - 1
for (i = 2; i < 100; i += 2)
{
mul32_mod25519(temp1, temp1, temp0);
mul32_mod25519(temp0, temp0, temp1);
}
mul32_mod25519(temp1, x2_exp_100, x2_exp_200); // x2_exp_200 = temp1 * x2_exp_100 = input^(2^200 - 1)
// Final steps to reach 2^255 - 21
for (i = 2; i < 50; i += 2)
{
mul32_mod25519(temp1, temp1, temp0);
mul32_mod25519(temp0, temp0, temp1);
}
mul32_mod25519(temp1, x2_exp_50, temp0); // temp0 = temp1 * x2_exp_50 = input^(2^250 - 1)
// Final multiplications
mul32_mod25519(temp0, temp0, temp1);
mul32_mod25519(temp1, temp1, temp0);
mul32_mod25519(temp0, temp0, temp1);
mul32_mod25519(temp1, temp1, temp0);
mul32_mod25519(temp0, temp0, temp1);
mul32_mod25519(temp1, x_exp_11, result); // result = temp1 * x_exp_11 = input^(2^255 - 21)
}
// Multiplication of two 32-byte numbers followed by modulus
void mul32_mod25519(const uint8_t* a, const uint8_t* b, uint8_t* result)
{
int carry = 0;
uint8_t tempResult[64] = {0}; // Temporary result array to store 512-bit number after multiplication
for (int i = 0; i < 32; ++i)
{
carry = 0;
for (int j = 0; j < 32; ++j)
{
int index = i + j;
long product = (long)a[i] * (long)b[j] + (long)tempResult[index] + (long)carry;
tempResult[index] = product & 0xFF;
carry = (product >> 8) & 0xFF;
}
tempResult[i + 32] = (carry + tempResult[i + 32]) & 0xFF;
}
uint8_t Result_mod[32] = {0};
mod25519(Result_mod, tempResult);
memcpy(result, Result_mod, 32);
}
//-----------------------Modulo-mod25519 function------------
void mod25519(uint8_t* result, const uint8_t* tempResult)
{
uint8_t high[32];
uint8_t low[32];
uint8_t mod25519_originalResult[32];
uint8_t borrow = 0;
uint16_t carry = 0;
// Divide tempResult into high and low parts
memcpy(high, tempResult + 32, 32);
memcpy(low, tempResult, 32);
// Multiply high part by 38
mod25519_multiply_32(result, high, 38, carry);
// Propagate carry
for (int k = 0; k < 2; k++)
{
uint32_t carryProduct = carry * 38;
uint32_t tempSum = 0;
for (int i = 0; i < 32; i++)
{
tempSum += result[i] + (carryProduct & 0xFF);
result[i] = tempSum & 0xFF;
carryProduct >>= 8;
tempSum >>= 8;
}
carry = tempSum;
}
// Add low part
mod25519_add_32(result, low, carry);
// Final carry propagation
uint32_t carryProduct = carry * 38;
carry = 0;
for (int i = 0; i < 32; i++)
{
uint32_t sum = result[i] + (carryProduct & 0xFF) + carry;
result[i] = sum & 0xFF;
carryProduct >>= 8;
carry = sum >> 8;
}
// Save original result
memcpy(mod25519_originalResult, result, 32);
// Perform subtraction and exchange to ensure result is within range
mod25519_sub_32(result, result, mod25519_subtracter_32, &borrow);
mod25519_exchange_32(result, mod25519_originalResult, borrow);
memcpy(mod25519_originalResult, result, 32);
mod25519_sub_32(result, result, mod25519_subtracter_32, &borrow);
mod25519_exchange_32(result, mod25519_originalResult, borrow);
}
//------------------------------
// Multiply the high part by 38
void mod25519_multiply_32(uint8_t* result, const uint8_t* a, uint8_t b, uint16_t& carry)
{
for (int i = 0; i < 32; i++)
{
uint16_t mul = a[i] * b + carry;
result[i] = mul & 0xFF;
carry = mul >> 8;
}
}
//--------------------------
void mod25519_add_32(uint8_t* result, const uint8_t* a, uint16_t& carry)
{
for (int i = 0; i < 32; i++)
{
uint32_t sum = result[i] + a[i] + carry;
result[i] = sum & 0xFF;
carry = sum >> 8;
}
}
//--------------------------
void mod25519_sub_32(uint8_t* r, const uint8_t* a, const uint8_t* b, uint8_t* borrow)
{
*borrow = 0;
for (int i = 0; i < 32; i++)
{
uint16_t temp = (uint16_t)a[i] - b[i] - *borrow;
r[i] = temp & 0xFF;
*borrow = (temp >> 8) & 1;
}
}
//-------------------------
void mod25519_exchange_32(uint8_t* r, const uint8_t* original, uint8_t borrow)
{
uint8_t mask = -borrow;
uint8_t inv_mask = ~mask;
for (int i = 0; i < 32; i++)
{
r[i] = (r[i] & inv_mask) | (original[i] & mask);
}
}
void setup()
{
Serial.begin(9600);
uint8_t binaryArray[32] = {0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
uint8_t input[32], result[32];
memcpy(input, binaryArray, 32);
fe25519_invert(result, input);
Serial.print("Final result: ");
for (int i = 0; i < 32; ++i)
{
Serial.print(result[i], HEX);
}
Serial.println();
}
void loop()
{
// Nothing to do here
}