#include <Arduino.h>
#include <stdint.h>
#include <string.h>
const uint8_t mod25519_subtracter_32[32] = {0xED, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F};
uint8_t subtracter[64] = {0xED, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
void mul32_mod25519(const uint8_t* a, const uint8_t* b, uint8_t* result);
void mod25519(uint8_t* result, const uint8_t* tempResult);
void mod25519_multiply_32(uint8_t* result, const uint8_t* a, uint8_t b, uint16_t& carry);
void mod25519_add_32(uint8_t* result, const uint8_t* a, uint16_t& carry);
void mod25519_sub_32(uint8_t* r, const uint8_t* a, const uint8_t* b, uint8_t* borrow);
void mod25519_exchange_32(uint8_t* r, const uint8_t* original, uint8_t borrow);
void fe25519_invert(uint8_t *result, const uint8_t *input)
{
uint8_t x_exp_2[32], x_exp_11[32], x2_exp_10[32], x2_exp_50[32], x2_exp_100[32], x2_exp_200[32], temp0[32], temp1[32];
unsigned char i;
Serial.print("input: ");
for (int i = 0; i < 32; i++)
{
Serial.print(input[i], HEX);
Serial.print(" ");
}
Serial.println();
/* Precomputation */
mul32_mod25519(input, input, x_exp_2); // x_exp_2 = input^2
Serial.print("1. x_exp_2 = input^2 ");
for (int i = 0; i < 32; i++)
{
Serial.print(x_exp_2[i], HEX);
Serial.print(" ");
}
Serial.println();
mul32_mod25519(x_exp_2, x_exp_2, temp1); // temp1 = x_exp_2^2 = input^4
Serial.print("2. temp1 = x_exp_2^2 = ");
for (int i = 0; i < 32; i++)
{
Serial.print(temp1[i], HEX);
Serial.print(" ");
}
Serial.println();
mul32_mod25519(temp1, temp1, temp0); // temp0 = temp1^2 = input^8
Serial.print("3. temp0 = temp1^2 = ");
for (int i = 0; i < 32; i++)
{
Serial.print(temp0[i], HEX); Serial.print(" ");
}
Serial.println();
mul32_mod25519(temp0, input, x2_exp_10); // x2_exp_10 = temp0 * input
Serial.print("4. x2_exp_10 = temp0 * input = ");
for (int i = 0; i < 32; i++)
{
Serial.print(x2_exp_10[i], HEX);
Serial.print(" ");
}
Serial.println();
mul32_mod25519(x2_exp_10, x_exp_2, x_exp_11); // x_exp_11 = x2_exp_10 * x_exp_2
Serial.print("5. x_exp_11 = x2_exp_10 * x_exp_2 = ");
for (int i = 0; i < 32; i++)
{
Serial.print(x_exp_11[i], HEX);
Serial.print(" ");
}
Serial.println();
/* Compute higher powers */
mul32_mod25519(x_exp_11, x_exp_11, temp0); // temp0 = x_exp_11 ^ 2
Serial.print("6. temp0 = x_exp_11 ^ 2 = ");
for (int i = 0; i < 32; i++)
{
Serial.print(temp0[i], HEX); Serial.print(" ");
}
Serial.println();
mul32_mod25519(temp0, x2_exp_10, x2_exp_10); // x2_exp_10 = temp0 * x2_exp_10
Serial.print("7. x2_exp_10 = temp0 * x2_exp_10 = ");
for (int i = 0; i < 32; i++)
{
Serial.print(x2_exp_10[i], HEX); Serial.print(" ");
}
Serial.println();
mul32_mod25519(x2_exp_10, x2_exp_10, temp0); // temp0 = x2_exp_10 ^ 2
Serial.print("8. temp0 = x2_exp_10^2 = ");
for (int i = 0; i < 32; i++)
{
Serial.print(temp0[i], HEX); Serial.print(" ");
}
Serial.println();
mul32_mod25519(temp0, temp0, temp1); // temp1 = temp0^2
Serial.print("9. temp1 = temp0^2 = ");
for (int i = 0; i < 32; i++)
{
Serial.print(temp1[i], HEX); Serial.print(" ");
}
Serial.println();
mul32_mod25519(temp1, temp1, temp0); // temp0 = temp1^2
Serial.print("10. temp0 = temp1^2 = ");
for (int i = 0; i < 32; i++)
{
Serial.print(temp0[i], HEX); Serial.print(" ");
}
Serial.println();
//----------------------------------
mul32_mod25519(temp0, temp0, temp1); // temp1 = temp0^2
Serial.print("11. temp1 = temp0^2 = ");
for (int i = 0; i < 32; i++)
{
Serial.print(temp1[i], HEX); Serial.print(" ");
}
Serial.println();
//----------------------------------
mul32_mod25519(temp1, temp1, temp0); // temp0 = temp1^2
Serial.print("12. temp0 = temp1^2 = ");
for (int i = 0; i < 32; i++)
{
Serial.print(temp0[i], HEX); Serial.print(" ");
}
Serial.println();
//--------------------------------
mul32_mod25519(temp0, x2_exp_10, x2_exp_10); // x2_exp_10 = temp0 * x2_exp_10
Serial.print("13. x2_exp_10: ");
for (int i = 0; i < 32; i++)
{
Serial.print(x2_exp_10[i], HEX); Serial.print(" ");
}
Serial.println();
//----------------------------------
mul32_mod25519(x2_exp_10, x2_exp_10, temp0); // temp0 = x2_exp_10^2
Serial.print("14. temp0 (x2_exp_10^2): ");
for (int i = 0; i < 32; i++)
{
Serial.print(temp0[i], HEX); Serial.print(" ");
}
Serial.println();
//----------------------------------
mul32_mod25519(temp0, temp0, temp1); // temp1 = temp0^2
Serial.print("15. temp1 (temp0^2): ");
for (int i = 0; i < 32; i++)
{
Serial.print(temp1[i], HEX);
Serial.print(" ");
}
Serial.println();
Serial.println(" =======================================> Start of the first for loop : ");
for (i = 2; i < 10; i += 2)
{
mul32_mod25519(temp1, temp1, temp0); // temp0 = temp1^2
Serial.print("16. temp0 (iteration ");
Serial.print(i);
Serial.print("): ");
for (int j = 0; j < 32; j++)
{
Serial.print(temp0[j], HEX); Serial.print(" ");
}
Serial.println();
mul32_mod25519(temp0, temp0, temp1); // temp1 = temp0^2
Serial.print("16. temp1 (iteration ");
Serial.print(i);
Serial.print("): ");
for (int j = 0; j < 32; j++)
{
Serial.print(temp1[j], HEX); Serial.print(" ");
}
Serial.println();
}
Serial.println(" =======================================> End of the first for loop : ");
mul32_mod25519(temp1, x2_exp_10, x2_exp_50); // x2_exp_50 = temp1 * x2_exp_10
Serial.print("17. x2_exp_50: ");
for (int i = 0; i < 32; i++)
{
Serial.print(x2_exp_50[i], HEX);
Serial.print(" ");
}
Serial.println();
mul32_mod25519(x2_exp_50, x2_exp_50, temp0); // temp0 = x2_exp_50^2
Serial.print("18. temp0 = x2_exp_50^2 = ");
for (int i = 0; i < 32; i++)
{
Serial.print(temp0[i], HEX);
Serial.print(" ");
}
Serial.println();
mul32_mod25519(temp0, temp0, temp1); // temp1 = temp0^2
Serial.print("19. temp1 (temp0^2): ");
for (int i = 0; i < 32; i++)
{
Serial.print(temp1[i], HEX);
Serial.print(" ");
}
Serial.println();
Serial.println(" =======================================> Start of the Second for loop : ");
for (i = 2; i < 20; i += 2)
{
mul32_mod25519(temp1, temp1, temp0);
Serial.print("20. temp0 (iteration ");
Serial.print(i);
Serial.print("): ");
for (int j = 0; j < 32; j++)
{
Serial.print(temp0[j], HEX); Serial.print(" ");
}
Serial.println();
mul32_mod25519(temp0, temp0, temp1);
Serial.print("20. temp1 (iteration ");
Serial.print(i);
Serial.print("): ");
for (int j = 0; j < 32; j++)
{
Serial.print(temp1[j], HEX);
Serial.print(" ");
}
Serial.println();
}
Serial.println(" =======================================> End of the Second for loop ========= ");
mul32_mod25519(temp1, x2_exp_50, temp0); // temp0 = temp1 * x2_exp_50
Serial.print("21. x2_exp_100: ");
for (int i = 0; i < 32; i++)
{
Serial.print(temp0[i], HEX);
Serial.print(" ");
}
Serial.println();
mul32_mod25519(temp0, temp0, temp1); // temp1 = temp0^2
Serial.print("22. temp1 = temp0^2 = ");
for (int i = 0; i < 32; i++)
{
Serial.print(temp1[i], HEX);
Serial.print(" ");
}
Serial.println();
mul32_mod25519(temp1, temp1, temp0); // temp1 = temp0^2
Serial.print("23. temp0 = temp1^2 = ");
for (int i = 0; i < 32; i++)
{
Serial.print(temp0[i], HEX);
Serial.print(" ");
}
Serial.println();
Serial.println(" =======================================> Start of the third for loop : ");
for (i = 2; i < 10; i += 2)
{
mul32_mod25519(temp0, temp0, temp1);
Serial.print("24. temp1 (iteration ");
Serial.print(i);
Serial.print("): ");
for (int j = 0; j < 32; j++)
{
Serial.print(temp1[j], HEX); Serial.print(" ");
}
Serial.println();
mul32_mod25519(temp1, temp1, temp0);
Serial.print("24. temp0 (iteration ");
Serial.print(i);
Serial.print("): ");
for (int j = 0; j < 32; j++)
{
Serial.print(temp0[j], HEX); Serial.print(" ");
}
Serial.println();
}
Serial.println(" =======================================> End of the third for loop ========= ");
mul32_mod25519(temp0, x2_exp_10, x2_exp_50); // x2_exp_50 = temp0 * x2_exp_10
Serial.print("25. x2_exp_50 = temp0 * x2_exp_10 = ");
for (int i = 0; i < 32; i++)
{
Serial.print(x2_exp_50[i], HEX);
Serial.print(" ");
}
Serial.println();
// --------------------------- we reached here -------------------
mul32_mod25519(x2_exp_50, x2_exp_50, temp0); // temp0 = x2_exp_50 ^ 2
Serial.print("26. temp0 = x2_exp_10 ^ 2 = ");
for (int i = 0; i < 32; i++)
{
Serial.print(temp0[i], HEX);
Serial.print(" ");
}
Serial.println();
// --------------------------- we reached here --------------------
// Final steps to reach 2^255 - 21
for (i = 2; i < 50; i += 2)
{
mul32_mod25519(temp1, temp1, temp0);
Serial.print("temp0 (iteration ");
Serial.print(i);
Serial.print("): ");
for (int j = 0; j < 32; j++)
{
Serial.print(temp0[j], HEX); Serial.print(" ");
}
Serial.println();
mul32_mod25519(temp0, temp0, temp1);
Serial.print("temp1 (iteration ");
Serial.print(i + 1);
Serial.print("): ");
for (int j = 0; j < 32; j++)
{
Serial.print(temp1[j], HEX); Serial.print(" ");
}
Serial.println();
}
mul32_mod25519(temp1, x2_exp_50, temp0); // temp0 = temp1 * x2_exp_50 = input^(2^250 - 1)
Serial.print("temp0 (input^(2^250 - 1)): ");
for (int i = 0; i < 32; i++)
{
Serial.print(temp0[i], HEX); Serial.print(" ");
}
Serial.println();
// Final multiplications
mul32_mod25519(temp0, temp0, temp1);
Serial.print("temp1 (final stage 1): ");
for (int i = 0; i < 32; i++)
{
Serial.print(temp1[i], HEX); Serial.print(" ");
}
Serial.println();
mul32_mod25519(temp1, temp1, temp0);
Serial.print("temp0 (final stage 2): ");
for (int i = 0; i < 32; i++)
{
Serial.print(temp0[i], HEX); Serial.print(" ");
}
Serial.println();
mul32_mod25519(temp0, temp0, temp1);
Serial.print("temp1 (final stage 3): ");
for (int i = 0; i < 32; i++)
{
Serial.print(temp1[i], HEX); Serial.print(" ");
}
Serial.println();
mul32_mod25519(temp1, temp1, temp0);
Serial.print("temp0 (final stage 4): ");
for (int i = 0; i < 32; i++)
{
Serial.print(temp0[i], HEX); Serial.print(" ");
}
Serial.println();
mul32_mod25519(temp0, temp0, temp1);
Serial.print("temp1 (final stage 5): ");
for (int i = 0; i < 32; i++)
{
Serial.print(temp1[i], HEX); Serial.print(" ");
}
Serial.println();
mul32_mod25519(temp1, x_exp_11, result); // result = temp1 * x_exp_11 = input^(2^255 - 21)
Serial.print("result: ");
for (int i = 0; i < 32; i++)
{
Serial.print(result[i], HEX); Serial.print(" ");
}
Serial.println();
}
// Multiplication of two 32-byte numbers followed by modulus
void mul32_mod25519(const uint8_t* a, const uint8_t* b, uint8_t* result)
{
int carry = 0;
uint8_t tempResult[64] = {0}; // Temporary result array to store 512-bit number after multiplication
for (int i = 0; i < 32; ++i)
{
carry = 0;
for (int j = 0; j < 32; ++j) {
int index = i + j;
long product = (long)a[i] * (long)b[j] + (long)tempResult[index] + (long)carry;
tempResult[index] = product & 0xFF;
carry = (product >> 8) & 0xFF;
}
tempResult[i + 32] = (carry + tempResult[i + 32]) & 0xFF;
}
uint8_t Result_mod[32] = {0};
mod25519(Result_mod, tempResult);
memcpy(result, Result_mod, 32);
}
//-----------------------Modulo-mod25519 function------------
void mod25519(uint8_t* result, const uint8_t* tempResult)
{
uint8_t high[32];
uint8_t low[32];
uint8_t mod25519_originalResult[32];
uint8_t borrow = 0;
uint16_t carry = 0;
// Divide tempResult into high and low parts
memcpy(high, tempResult + 32, 32);
memcpy(low, tempResult, 32);
// Multiply high part by 38
mod25519_multiply_32(result, high, 38, carry);
// Propagate carry
for (int k = 0; k < 2; k++)
{
uint32_t carryProduct = carry * 38;
uint32_t tempSum = 0;
for (int i = 0; i < 32; i++)
{
tempSum += result[i] + (carryProduct & 0xFF);
result[i] = tempSum & 0xFF;
carryProduct >>= 8;
tempSum >>= 8;
}
carry = tempSum;
}
// Add low part
mod25519_add_32(result, low, carry);
// Final carry propagation
uint32_t carryProduct = carry * 38;
carry = 0;
for (int i = 0; i < 32; i++) {
uint32_t sum = result[i] + (carryProduct & 0xFF) + carry;
result[i] = sum & 0xFF;
carryProduct >>= 8;
carry = sum >> 8;
}
// Save original result
memcpy(mod25519_originalResult, result, 32);
// Perform subtraction and exchange to ensure result is within range
mod25519_sub_32(result, result, mod25519_subtracter_32, &borrow);
mod25519_exchange_32(result, mod25519_originalResult, borrow);
memcpy(mod25519_originalResult, result, 32);
mod25519_sub_32(result, result, mod25519_subtracter_32, &borrow);
mod25519_exchange_32(result, mod25519_originalResult, borrow);
}
//------------------------------
// Multiply the high part by 38
void mod25519_multiply_32(uint8_t* result, const uint8_t* a, uint8_t b, uint16_t& carry)
{
for (int i = 0; i < 32; i++)
{
uint16_t mul = a[i] * b + carry;
result[i] = mul & 0xFF;
carry = mul >> 8;
}
}
//--------------------------
void mod25519_add_32(uint8_t* result, const uint8_t* a, uint16_t& carry) {
for (int i = 0; i < 32; i++) {
uint32_t sum = result[i] + a[i] + carry;
result[i] = sum & 0xFF;
carry = sum >> 8;
}
}
//--------------------------
void mod25519_sub_32(uint8_t* r, const uint8_t* a, const uint8_t* b, uint8_t* borrow)
{
*borrow = 0;
for (int i = 0; i < 32; i++)
{
uint16_t temp = (uint16_t)a[i] - b[i] - *borrow;
r[i] = temp & 0xFF;
*borrow = (temp >> 8) & 1;
}
}
//-------------------------
void mod25519_exchange_32(uint8_t* r, const uint8_t* original, uint8_t borrow)
{
uint8_t mask = -borrow;
uint8_t inv_mask = ~mask;
for (int i = 0; i < 32; i++) {
r[i] = (r[i] & inv_mask) | (original[i] & mask);
}
}
void setup()
{
Serial.begin(9600);
uint8_t binaryArray[32] = {0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
/*uint8_t binaryArray[32] = {
0x11, 0xa0, 0x43, 0xd6, 0x23, 0xb2, 0x84, 0xff,
0x8d, 0xf0, 0x4a, 0x6d, 0xeb, 0x9d, 0xb5, 0x41,
0x32, 0xe9, 0x6a, 0x35, 0xdc, 0xb3, 0x2f, 0xaf,
0x95, 0x51, 0xbd, 0x32, 0x05, 0xea, 0xb2, 0x41
};*/
uint8_t input[32], result[32];
memcpy(input, binaryArray, 32);
fe25519_invert(result, input);
Serial.print("Final result: ");
for (int i = 0; i < 32; i++)
{
Serial.print(result[i], HEX);
Serial.print(" ");
}
Serial.println();
}
void loop()
{
// Nothing to do here
}