/*The code computes the multiplicative inverse of a number (x) in a finite field, based on Fermat's Little Theorem. According to this theorem, if (p) is a prime number, then for any integer (a) that is not divisible by (p):
--------------------------------------------------> a^{p-1} = 1 (mod p)
It follows that:
--------------------------------------------------> a * a^{p-2} = 1 (mod p)
Thus, (a^{p-2}) is the multiplicative inverse of (a) modulo (p).
Note : The code employs an efficient algorithm known as the **exponentiation by squaring** method to find a^{p-2} .*/
#include <Arduino.h>
const uint8_t mod25519_subtracter_32[32] = {0xED, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F};
// Multiply the high part by 38
void mod25519_multiply_32(uint8_t* result, const uint8_t* a, uint8_t b, uint16_t& carry)
{
for (int i = 0; i < 32; i++)
{
uint16_t mul = a[i] * b + carry;
result[i] = mul & 0xFF;
carry = mul >> 8;
}
}
void mod25519_add_32(uint8_t* result, const uint8_t* a, uint16_t& carry) {
for (int i = 0; i < 32; i++) {
uint32_t sum = result[i] + a[i] + carry;
result[i] = sum & 0xFF;
carry = sum >> 8;
}
}
void mod25519_sub_32(uint8_t* r, const uint8_t* a, const uint8_t* b, uint8_t* borrow)
{
*borrow = 0;
for (int i = 0; i < 32; i++)
{
uint16_t temp = (uint16_t)a[i] - b[i] - *borrow;
r[i] = temp & 0xFF;
*borrow = (temp >> 8) & 1;
}
}
// Multiplication of two 32-byte numbers followed by modulus
void mul32_mod25519(const uint8_t* a, const uint8_t* b, uint8_t* result)
{
int carry = 0;
uint8_t tempResult[64] = {0}; // Temporary result array to store 512-bit number after multiplication
for (int i = 0; i < 32; ++i)
{
int carry = 0;
for (int j = 0; j < 32; ++j) {
int index = i + j;
long product = (long)a[i] * (long)b[j] + (long)tempResult[index] + (long)carry;
tempResult[index] = product & 0xFF;
carry = (product >> 8) & 0xFF;
}
tempResult[i + 32] = (carry + tempResult[i + 32]) & 0xFF;
}
uint8_t Result_mod[32] = {0};
mod25519(Result_mod, tempResult);
memcpy(result, Result_mod, 32);
}
void mod25519_exchange_32(uint8_t* r, const uint8_t* original, uint8_t borrow)
{
uint8_t mask = -borrow;
uint8_t inv_mask = ~mask;
for (int i = 0; i < 32; i++) {
r[i] = (r[i] & inv_mask) | (original[i] & mask);
}
}
//-------------------------------------------------------Modulo-mod25519 function---------------------------------------------------------------------------------------------------
void mod25519(uint8_t* result, const uint8_t* tempResult)
{
uint8_t high[32];
uint8_t low[32];
uint8_t mod25519_originalResult[32];
uint8_t borrow = 0;
uint16_t carry = 0;
// Divide tempResult into high and low parts
memcpy(high, tempResult + 32, 32);
memcpy(low, tempResult, 32);
// Multiply high part by 38
mod25519_multiply_32(result, high, 38, carry);
// Propagate carry
for (int k = 0; k < 2; k++)
{
uint32_t carryProduct = carry * 38;
uint32_t tempSum = 0;
for (int i = 0; i < 32; i++)
{
tempSum += result[i] + (carryProduct & 0xFF);
result[i] = tempSum & 0xFF;
carryProduct >>= 8;
tempSum >>= 8;
}
carry = tempSum;
}
// Add low part
mod25519_add_32(result, low, carry);
// Final carry propagation
uint32_t carryProduct = carry * 38;
carry = 0;
for (int i = 0; i < 32; i++) {
uint32_t sum = result[i] + (carryProduct & 0xFF) + carry;
result[i] = sum & 0xFF;
carryProduct >>= 8;
carry = sum >> 8;
}
// Save original result
memcpy(mod25519_originalResult, result, 32);
// Perform subtraction and exchange to ensure result is within range
mod25519_sub_32(result, result, mod25519_subtracter_32, &borrow);
mod25519_exchange_32(result, mod25519_originalResult, borrow);
memcpy(mod25519_originalResult, result, 32);
mod25519_sub_32(result, result, mod25519_subtracter_32, &borrow);
mod25519_exchange_32(result, mod25519_originalResult, borrow);
}
//------------------------------------------------------------------------------------------------------------------------------------------------------------
// Conditional exchange based on borrow
void exchange25519(uint8_t *r, const uint8_t *a, uint8_t borrow)
{
uint8_t mask = -borrow;
for (int j = 31; j >= 0; j--)
{
r[j] ^= mask & (a[j] ^ r[j]);
}
}
// Exponentiation function
void exponent(uint8_t* Z0, uint8_t* expon)
{
uint8_t temp[32];
uint8_t Inv_temp;
uint8_t Inv_carry = 0;
uint8_t current_bit[32] = {0};
uint8_t complement_current_bit[32] = {0};
uint8_t complement_Z0[32] = {0};
uint8_t Inv_complement_current_bit [32] = {0};
uint8_t Constant_Result[32] = {0};
uint8_t result[32] = {0}; // Initial result
result[0] = {0x01}; // Initial result
uint8_t tempZ0[32];
//----------------------------------------------------------------------------------------------------
current_bit[0] = expon[0] & 0x01;
// print current bit
Serial.print("current_bit = ");
for (int j = 31; j >= 0; j--)
{
Serial.print(current_bit[j], HEX);
Serial.print(" ");
}
Serial.println(" ");
// -------------------- Calculate complement_current_bit and Inv_complement_current_bit
for (int j = 31; j >= 0; j--)
{
complement_current_bit [j] = -current_bit[0];
Inv_complement_current_bit [j] = ~complement_current_bit [j];
}
//--------- print complement_current_bit
Serial.print("complement_current_bit =");
for (int j = 31; j >= 0; j--)
{
Serial.print(complement_current_bit[j], HEX);
Serial.print(" ");
}
Serial.println(" ");
//--------- print Inv_complement_current_bit
Serial.print("Inv_complement_current_bit =");
for (int j = 31; j >= 0; j--)
{
Serial.print(Inv_complement_current_bit[j], HEX);
Serial.print(" ");
}
Serial.println(" ");
//------------------shift the exponent to get current bit in the next iteration
Inv_carry = 0;
for (int j = 31; j >= 0; j--)
{
Inv_temp = expon[j];
expon[j] = (Inv_carry << 7) | (expon[j] >> 1);
Inv_carry = Inv_temp;
}
//------------------print the shifted exponent
Serial.print("Before the for loop, Shifted expon = ");
for (int j = 31; j >= 0; j--)
{
Serial.print(expon[j], HEX);
Serial.print(" ");
}
Serial.println();
Serial.print("Before for loop, Z0 = ");
for (int j = 31; j >= 0; j--)
{
Serial.print(Z0[j], HEX);
Serial.print(" ");
}
Serial.println();
//------------------------- Print Result --------------------------
Serial.print("Before the for loop, result[j] = ");
for (int j = 31; j >= 0; j--)
{
Serial.print(result[j], HEX);
Serial.print(" ");
}
Serial.println();
//------------------------- Calculate temp[j] --------------------------
for (int j = 31; j >=0; j--)
{
result[j] = Z0[j] & complement_current_bit[j];
}
//------------------------- result = Z0 & complement_current_bit --------------------------
Serial.print("Before the for loop, result = Z0 & complement_current_bit = ");
for (int j = 31; j >= 0; j--)
{
Serial.print(result[j], HEX);
Serial.print(" ");
}
Serial.println();
//--------------------------------------------------Start for loop here--------------------------------------------------
Serial.println("-----------------------Start for loop here------------------- ");
for (int i = 0; i <= 254 ; i++)
{
current_bit[0] = expon[0] & 0x01;
// -------------- print current bit
Serial.print("Inside the for loop, current_bit = ");
for (int j = 31; j >= 0; j--)
{
Serial.print(current_bit[j], HEX);
Serial.print(" ");
}
Serial.println(" ");
// -------------------- Calculate complement_current_bit and Inv_complement_current_bit
for (int j = 31; j >= 0; j--)
{
complement_current_bit [j] = -current_bit[0];
Inv_complement_current_bit [j] = ~complement_current_bit [j];
}
//--------- print complement_current_bit
Serial.print("complement_current_bit =");
for (int j = 31; j >= 0; j--)
{
Serial.print(complement_current_bit[j], HEX);
Serial.print(" ");
}
Serial.println(" ");
//--------- print Inv_complement_current_bit
Serial.print("Inv_complement_current_bit =");
for (int j = 31; j >= 0; j--)
{
Serial.print(Inv_complement_current_bit[j], HEX);
Serial.print(" ");
}
Serial.println(" ");
//------------------shift the exponent to get current bit in the next iteration
Inv_carry = 0;
for (int j = 31; j >= 0; j--)
{
Inv_temp = expon[j];
expon[j] = (Inv_carry << 7) | (expon[j] >> 1);
Inv_carry = Inv_temp;
}
//------------------print the shifted exponent
Serial.print("Inside the for loop, At the [");
Serial.print(i);
Serial.print("] iteration, Shifted expon = ");
for (int j = 31; j >= 0; j--)
{
Serial.print(expon[j], HEX);
Serial.print(" ");
}
Serial.println();
Serial.println("------------------------------");
Serial.println();
//-------------------------------------------------------------------------------------------------------------------------
// ------------------------ Z0 = Z0 * Z0
//------------------------- result = result * (Z0 & complement_current_bit) | (result & Inv_complement_current_bit) // when the current bit = 1, then complement_current_bit = 1111...1 and then we want to multiply the result by Z0 (after multiplying it by itself in the previous step) so we use --> result * (Z0 & complement_current_bit). But if the current bit = 0, then we need to keep the value of the reslut, so we will use --> (result & Inv_complement_current_bit)
// ------------------------ complement_Z0 = Z0 & complement_current_bit
// ------------------------ Constant_Result = Z0 & Inv_complement_current_bit
// ------------------------ result = complement_Z0 | Constant_Result
//------------------------- Print Z0 --------------------------
Serial.print("-------------> At the [");
Serial.print(i);
Serial.print("] Iteration, First of all, Z0 = ");
for (int j = 31; j >= 0; j--)
{
Serial.print(Z0[j], HEX);
Serial.print(" ");
}
Serial.println();
//------------------------- Print Result --------------------------
Serial.print("First of all in the for loop, result[j] = ");
for (int j = 31; j >= 0; j--)
{
Serial.print(result[j], HEX);
Serial.print(" ");
}
Serial.println();
//------------------------- Z0 = Z0 * Z0 (mod p) --------------------------
mul32_mod25519(Z0, Z0, Z0);
Serial.print("Z0 = Z0 * Z0 (mod p) = ");
for (int j = 31; j >= 0; j--)
{
Serial.print(Z0[j], HEX);
Serial.print(" ");
}
Serial.println(" ");
//------------------------- result = result * (Z0 & complement_current_bit) | (result & Inv_complement_current_bit)
//------------------------- result = (result * complement_Z0) | (result & Inv_complement_current_bit)
//------------------------- complement_Z0[j] = Z0[j] & complement_current_bit[j] --------------------------
for (int j = 31; j >=0; j--)
{
complement_Z0[j] = Z0[j] & complement_current_bit[j];
}
//------------------------- print complement_Z0 = Z0[j] & complement_current_bit[j] --------------------------
Serial.print("complement_Z0 = Z0 & complement_current_bit = ");
for (int j = 31; j >= 0; j--)
{
Serial.print(complement_Z0[j], HEX);
Serial.print(" ");
}
Serial.println();
//------------------------- result = result * (Z0 & complement_current_bit) | (result & Inv_complement_current_bit)
//------------------------- result = (result * complement_Z0) | (result & Inv_complement_current_bit)
//------------------------- complement_Z0[j] = result * complement_Z0[j] --------------------------
mul32_mod25519(result, complement_Z0, complement_Z0);
//------------------------- print complement_Z0 = Z0[j] & complement_current_bit[j] --------------------------
Serial.print("complement_Z0[j] = result * complement_Z0[j] = ");
for (int j = 31; j >= 0; j--)
{
Serial.print(complement_Z0[j], HEX);
Serial.print(" ");
}
Serial.println();
//------------------------- result = result * (Z0 & complement_current_bit) | (result & Inv_complement_current_bit)
//------------------------- result = result * (Z0 & complement_current_bit) | Constant_Result
//------------------------- Constant_Result[j] = result[j] & Inv_complement_current_bit[j] --------------------------
for (int j = 31; j >=0; j--)
{
Constant_Result[j] = result[j] & Inv_complement_current_bit[j];
}
//------------------------- print temp = Z0[j] & complement_current_bit[j] --------------------------
Serial.print("Constant_Result = result & Inv_complement_current_bit = ");
for (int j = 31; j >= 0; j--)
{
Serial.print(Constant_Result[j], HEX);
Serial.print(" ");
}
Serial.println();
//-------------- result = complement_Z0 | Constant_Result -----------------------
for (int j = 31; j >=0; j--)
{
result[j] = complement_Z0 [j] | Constant_Result [j];
}
Serial.print("At the [");
Serial.print(i);
Serial.print("] Iteration, result = complement_Z0 | Constant_Result =");
for (int j = 31; j >= 0; j--)
{
Serial.print(result[j], HEX);
Serial.print(" ");
}
Serial.println();
Serial.println("===================================================================================================================");
}
}
void setup()
{
Serial.begin(9600);
uint8_t Z0[32] = {0};
Z0[0] = 2;
// ----------------------------- print Z0-------------------
Serial.print("Z0 = ");
for (int j = 31; j >= 0; j--)
{
Serial.print(Z0[j], HEX);
Serial.print(" ");
}
Serial.println();
//-------------------------- print exp ---------------------
uint8_t expon[32] = {0};
expon[0] = 8;
Serial.print("expon = ");
for (int j = 31; j >= 0; j--)
{
Serial.print(expon[j], HEX);
Serial.print(" ");
}
Serial.println(" ");
Serial.println("==============================================>");
//--------------------------- Call expon function--------------------
exponent(Z0, expon);
}
void loop()
{
// Nothing to do here
}