#include <mbedtls/bignum.h>
// Define a struct for representing points
typedef struct {
mbedtls_mpi x;
mbedtls_mpi y;
} Point;
// Constants
mbedtls_mpi P, Order, SubOrder, BjA, BjD;
// Function prototypes
// void init_constants(void);
// void inv(mbedtls_mpi *result, const mbedtls_mpi *a, const mbedtls_mpi *n);
// void add_bj(Point *result, const Point *p1, const Point *p2);
// void multiply_bj(Point *result, const Point *pt, const mbedtls_mpi *n);
void setup() {
// put your setup code here, to run once:
Serial.begin(115200);
Serial.println("Hello, ESP32-S3!");
init_constants();
Point Base8, result;
mbedtls_mpi_init(&Base8.x); mbedtls_mpi_init(&Base8.y);
mbedtls_mpi_init(&result.x); mbedtls_mpi_init(&result.y);
mbedtls_mpi_read_string(&Base8.x, 10, "5299619240641551281634865583518297030282874472190772894086521144482721001553");
mbedtls_mpi_read_string(&Base8.y, 10, "16950150798460657717958625567821834550301663161624707787222815936182638968203");
mbedtls_mpi n;
mbedtls_mpi_init(&n);
mbedtls_mpi_read_string(&n, 10, "5439300000449022559275371417944541395116345860728577057593212065023951936589");
unsigned int startTime = millis();
multiply_bj(&result, &Base8, &n);
Serial.println(millis() - startTime);
// add_bj(&result, &Base8, &Base8);
// mbedtls_mpi_lset(&result.x, 0);
// mbedtls_mpi_lset(&result.y, 1);
char x_str[200], y_str[200];
size_t olen;
mbedtls_mpi_write_string(&result.x, 10, x_str, sizeof(x_str), &olen);
mbedtls_mpi_write_string(&result.y, 10, y_str, sizeof(y_str), &olen);
Serial.println("Result:");
Serial.println(x_str);
Serial.println(y_str);
// Cleanup
mbedtls_mpi_free(&Base8.x); mbedtls_mpi_free(&Base8.y);
mbedtls_mpi_free(&result.x); mbedtls_mpi_free(&result.y);
mbedtls_mpi_free(&n);
mbedtls_mpi_free(&P); mbedtls_mpi_free(&Order);
mbedtls_mpi_free(&SubOrder); mbedtls_mpi_free(&BjA); mbedtls_mpi_free(&BjD);
}
void loop() {
// put your main code here, to run repeatedly:
delay(10); // this speeds up the simulation
}
// Initialize constants
void init_constants(void) {
mbedtls_mpi_init(&P);
mbedtls_mpi_init(&Order);
mbedtls_mpi_init(&SubOrder);
mbedtls_mpi_init(&BjA);
mbedtls_mpi_init(&BjD);
mbedtls_mpi_read_string(&P, 10, "21888242871839275222246405745257275088548364400416034343698204186575808495617");
mbedtls_mpi_read_string(&Order, 10, "21888242871839275222246405745257275088614511777268538073601725287587578984328");
mbedtls_mpi_copy(&SubOrder, &Order);
mbedtls_mpi_shift_r(&SubOrder, 3);
mbedtls_mpi_lset(&BjA, 168700);
mbedtls_mpi_lset(&BjD, 168696);
}
// 7509ms for multiply
void add_bj(Point *result, const Point *p1, const Point *p2) {
mbedtls_mpi temp1, temp2, temp3, temp4, inv_result;
mbedtls_mpi_init(&temp1); mbedtls_mpi_init(&temp2);
mbedtls_mpi_init(&temp3); mbedtls_mpi_init(&temp4);
mbedtls_mpi_init(&inv_result);
// x3 = ((x1y2 + y1x2) % P) * inv((1 + BjDx1x2y1y2) % P, P)
mbedtls_mpi_mul_mpi(&temp1, &p1->x, &p2->y);
mbedtls_mpi_mul_mpi(&temp2, &p1->y, &p2->x);
mbedtls_mpi_add_mpi(&temp3, &temp1, &temp2);
mbedtls_mpi_mod_mpi(&temp3, &temp3, &P);
mbedtls_mpi_mul_mpi(&temp1, &p1->x, &p2->x);
mbedtls_mpi_mul_mpi(&temp2, &temp1, &p1->y);
mbedtls_mpi_mul_mpi(&temp1, &temp2, &p2->y);
mbedtls_mpi_mul_mpi(&temp2, &temp1, &BjD);
mbedtls_mpi_add_int(&temp1, &temp2, 1);
mbedtls_mpi_mod_mpi(&temp1, &temp1, &P);
mbedtls_mpi_inv_mod(&inv_result, &temp1, &P);
mbedtls_mpi_mul_mpi(&result->x, &temp3, &inv_result);
mbedtls_mpi_mod_mpi(&result->x, &result->x, &P);
// y3 = ((y1y2 - BjAx1*x2) % P) * inv((P + 1 - BjDx1x2y1y2) % P, P)
mbedtls_mpi_mul_mpi(&temp1, &p1->y, &p2->y);
mbedtls_mpi_mul_mpi(&temp2, &p1->x, &p2->x);
mbedtls_mpi_mul_mpi(&temp3, &temp2, &BjA);
mbedtls_mpi_sub_mpi(&temp4, &temp1, &temp3);
mbedtls_mpi_mod_mpi(&temp4, &temp4, &P);
mbedtls_mpi_mul_mpi(&temp1, &p1->x, &p2->x);
mbedtls_mpi_mul_mpi(&temp2, &temp1, &p1->y);
mbedtls_mpi_mul_mpi(&temp1, &temp2, &p2->y);
mbedtls_mpi_mul_mpi(&temp2, &temp1, &BjD);
mbedtls_mpi_sub_mpi(&temp1, &P, &temp2);
mbedtls_mpi_add_int(&temp2, &temp1, 1);
mbedtls_mpi_mod_mpi(&temp2, &temp2, &P);
mbedtls_mpi_inv_mod(&inv_result, &temp2, &P);
mbedtls_mpi_mul_mpi(&result->y, &temp4, &inv_result);
mbedtls_mpi_mod_mpi(&result->y, &result->y, &P);
mbedtls_mpi_free(&temp1); mbedtls_mpi_free(&temp2);
mbedtls_mpi_free(&temp3); mbedtls_mpi_free(&temp4);
mbedtls_mpi_free(&inv_result);
}
// Addition on Baby Jubjub curve (optimized)
// 18034ms - so its slower
void add_bj_slow(Point *result, const Point *p1, const Point *p2) {
mbedtls_mpi A, B, C, D, E, F, G, H;
mbedtls_mpi_init(&A); mbedtls_mpi_init(&B); mbedtls_mpi_init(&C); mbedtls_mpi_init(&D);
mbedtls_mpi_init(&E); mbedtls_mpi_init(&F); mbedtls_mpi_init(&G); mbedtls_mpi_init(&H);
// A = x1 * x2, B = y1 * y2
mbedtls_mpi_mul_mpi(&A, &p1->x, &p2->x);
mbedtls_mpi_mul_mpi(&B, &p1->y, &p2->y);
// C = BjD * A * B
mbedtls_mpi_mul_mpi(&C, &A, &B);
mbedtls_mpi_mul_mpi(&C, &C, &BjD);
// D = 1 + C
mbedtls_mpi_add_int(&D, &C, 1);
// E = (x1 + y1) * (x2 + y2) - A - B
mbedtls_mpi_add_mpi(&E, &p1->x, &p1->y);
mbedtls_mpi_add_mpi(&F, &p2->x, &p2->y);
mbedtls_mpi_mul_mpi(&E, &E, &F);
mbedtls_mpi_sub_mpi(&E, &E, &A);
mbedtls_mpi_sub_mpi(&E, &E, &B);
// F = B - BjA * A
mbedtls_mpi_mul_mpi(&F, &A, &BjA);
mbedtls_mpi_sub_mpi(&F, &B, &F);
// G = P + 1 - C
mbedtls_mpi_sub_mpi(&G, &P, &C);
mbedtls_mpi_add_int(&G, &G, 1);
// H = inv(D * G)
mbedtls_mpi_mul_mpi(&H, &D, &G);
mbedtls_mpi_inv_mod(&H, &H, &P);
// x3 = E * G * H mod P
mbedtls_mpi_mul_mpi(&result->x, &E, &G);
mbedtls_mpi_mul_mpi(&result->x, &result->x, &H);
mbedtls_mpi_mod_mpi(&result->x, &result->x, &P);
// y3 = F * D * H mod P
mbedtls_mpi_mul_mpi(&result->y, &F, &D);
mbedtls_mpi_mul_mpi(&result->y, &result->y, &H);
mbedtls_mpi_mod_mpi(&result->y, &result->y, &P);
mbedtls_mpi_free(&A); mbedtls_mpi_free(&B); mbedtls_mpi_free(&C); mbedtls_mpi_free(&D);
mbedtls_mpi_free(&E); mbedtls_mpi_free(&F); mbedtls_mpi_free(&G); mbedtls_mpi_free(&H);
}
// Scalar multiplication on Baby Jubjub curve (corrected iterative version)
void multiply_bj(Point *result, const Point *pt, const mbedtls_mpi *n) {
Point R, T;
mbedtls_mpi_init(&R.x); mbedtls_mpi_init(&R.y);
mbedtls_mpi_init(&T.x); mbedtls_mpi_init(&T.y);
// Initialize R as the input point
mbedtls_mpi_copy(&R.x, &pt->x);
mbedtls_mpi_copy(&R.y, &pt->y);
// Initialize T as the point at infinity (0, 1)
mbedtls_mpi_lset(&T.x, 0);
mbedtls_mpi_lset(&T.y, 1);
// Temporary variables
Point temp_point;
mbedtls_mpi_init(&temp_point.x); mbedtls_mpi_init(&temp_point.y);
mbedtls_mpi temp;
mbedtls_mpi_init(&temp);
// Iterate through each bit of n, from least significant to most
for (int i = 0; i < mbedtls_mpi_bitlen(n); i++) {
// If the current bit of n is 1, add R to T
if (mbedtls_mpi_get_bit(n, i)) {
add_bj(&temp_point, &T, &R);
mbedtls_mpi_copy(&T.x, &temp_point.x);
mbedtls_mpi_copy(&T.y, &temp_point.y);
}
// Double R
add_bj(&temp_point, &R, &R);
mbedtls_mpi_copy(&R.x, &temp_point.x);
mbedtls_mpi_copy(&R.y, &temp_point.y);
}
// Copy the result to the output
mbedtls_mpi_copy(&result->x, &T.x);
mbedtls_mpi_copy(&result->y, &T.y);
// Clean up
mbedtls_mpi_free(&R.x); mbedtls_mpi_free(&R.y);
mbedtls_mpi_free(&T.x); mbedtls_mpi_free(&T.y);
mbedtls_mpi_free(&temp_point.x); mbedtls_mpi_free(&temp_point.y);
mbedtls_mpi_free(&temp);
}