import machine
M_BASE = 0x3ff02000
Z_BASE = 0x3ff02200
Y_BASE = 0x3ff02400
X_BASE = 0x3ff02600
RSA_MPRIME_REG = 0x3FF02800
RSA_MODEXP_MODE_REG = 0x3FF02804
RSA_MODEXP_START_REG = 0x3FF02808
RSA_MULT_MODE_REG = 0x3FF0280C
RSA_MULT_START_REG = 0x3FF02810
RSA_INTERRUPT = 0x3FF02814
RSA_CLEAN_REG = 0x3FF02818
DPORT_PERI_CLK_EN_REG = 0x3ff0001C
DPORT_PERI_RST_EN_REG = 0x3ff00020
def enableRSA():
machine.mem32[DPORT_PERI_CLK_EN_REG] = 1 << 2
machine.mem32[DPORT_PERI_RST_EN_REG] = 0
machine.mem32[0x3ff00490] = 0
while not machine.mem32[RSA_CLEAN_REG]:
pass
def store(offset, value, bytecnt):
while bytecnt > 0:
machine.mem32[offset] = value & 0xffff_ffff
value >>= 32
offset += 4
bytecnt -= 4
def load(offset, bytecnt):
result = 0
offset += bytecnt - 4
while bytecnt > 0:
result <<= 32
result |= machine.mem32[offset] & 0xffffffff
offset -= 4
bytecnt -= 4
return result
def rsaModExp(mprime, mode, x, y, z, m):
store(X_BASE, x, 512)
store(Y_BASE, y, 512)
store(Z_BASE, z, 512)
store(M_BASE, m, 512)
machine.mem32[RSA_MPRIME_REG] = mprime
machine.mem32[RSA_MODEXP_MODE_REG] = mode
machine.mem32[RSA_MODEXP_START_REG] = 1
while not machine.mem32[RSA_INTERRUPT]:
pass
return load(Z_BASE, 512)
def rsaMult(mprime, mode, x, y, z, m):
store(X_BASE, x, 512)
store(Y_BASE, y, 512)
store(Z_BASE, z, 512)
store(M_BASE, m, 512)
machine.mem32[RSA_MPRIME_REG] = mprime
machine.mem32[RSA_MULT_MODE_REG] = mode
machine.mem32[RSA_MULT_START_REG] = 1
while not machine.mem32[RSA_INTERRUPT]:
pass
return load(Z_BASE, 512)
print("Testing ESP32 RSA accelerator...")
enableRSA()
# Bunch of tests, verified on physical hardware
print("MULT_MODE=4")
assert(rsaMult(mode=4, mprime=1, x=4294967295, y=0, z=43426255979245076895123449803164665624385884479292600189996695592603436457649156942855188464627747359038288402395476695924548608045140655973904177400980393802437373871830731776326113897248680817652967420242138482600661515541378621241941876306491015093298641094192634547473554842522770138752893254223325695280868679982829553147269404487988102224823493947277611314629630550067688537416142431066941975037992922604403652393165931592147010374540056444980250986384073310916849643180812963502674291521712269891886914397047506507407497134585167231513152618097962973651722598023941011285581605019048671703679777886565772920922239401984, m=433300210274926779301235722995130529126851924312253566276831366547097655953268278873591999368453759375352459120642109169274240110620843517970102311224458427564046616227161115801662152359395451958893835480905010246703239443701925453743683726938789767940577524555724362397039358587448248740993955708382317083531851133758531532844694304737183093418571190053292997716413815714417740434254315670267335553209606545389900958653541702218767662872949109999460596999125719155768466778443229719919179735890978537026576050107500869340314188887101095263339865901080787145771915722368720844521851040375582291274824574444462419841916186875058374404329003250698574832235855871168399140082282900718430116575617097058691735266917705341386171618452673063526784039252631042036988932549246975) == 186514349175155804054315361892166426356348128298209822551546593728302403490213781653885218376637439256592071341066372064930596059191653001082764815401088892317689212065200514755237914443128074593713353729640505763630767754615355397446332641027314306037069037167500379811286042926024173038746288613515304487364466819716074092326986410830035119404733643651187794882056218250473582304448715546494307738450885985027379910007196157696478688390675243098644144907411085179270235681894011449756014727123049981585867483173783685010614875430849506651454683857074376077955095278925258270480798898610421896980496577175664214561757639469681638113280)
print("MULT_MODE=9")
assert(rsaMult(mode=9, mprime=0, x=5771399392846532744, y=0, z=2568893366<<512, m=0) == 14826109612819885822163376304)
assert(rsaMult(mode=9, mprime=0, x=5771399392846532744, y=0, z=486459899<<512, m=0) == 2807554365732785641146432856)
print("All tested passed!")