#include <SPI.h>

// switch between the simple and the fast versions
// 63.7ms (15.7 FPS) vs 2.87ms (348 FPS)
#define OPTIMISED 1

#define DIN 11
#define CS  12
#define CLK 13
#define X_SEGMENTS   4
#define Y_SEGMENTS   4
#define NUM_SEGMENTS (X_SEGMENTS * Y_SEGMENTS)

// a framebuffer to hold the state of the entire matrix of LEDs
// laid out in raster order, with (0, 0) at the top-left
byte fb[8 * NUM_SEGMENTS];

void setup() {
  Serial.begin(115200);
  pinMode(CLK, OUTPUT);
  pinMode(DIN, OUTPUT);
  pinMode(CS, OUTPUT);
  SPI.beginTransaction(SPISettings(16000000, MSBFIRST, SPI_MODE0));

  // configure each MAX7219
  shiftAll(0x0f, 0x00); // display test register - test mode off
  shiftAll(0x0b, 0x07); // scan limit register - display digits 0 thru 7
  shiftAll(0x0c, 0x01); // shutdown register - normal operation
  shiftAll(0x0a, 0x0f); // intensity register - max brightness
  shiftAll(0x09, 0x00); // decode mode register - No decode
}

void loop() {
  if (OPTIMISED) {
    // generate three different frequencies of sine/cosine waves
    static int16_t sx1 = 20 << 8, sx2 = sx1, sx3, sy1, sy2, sy3 = 127 << 8;
    sx1 -= sy1 >> 6, sy1 += sx1 >> 6;
    sx2 -= sy2 >> 5, sy2 += sx2 >> 5;
    sx3 -= sy3 >> 7, sy3 += sx3 >> 7;
    // move the origin in a Lissajous curve, and to-and-fro on a sine
    tunnel((sx1 >> 8) - X_SEGMENTS * 4, (sx2 >> 8) - Y_SEGMENTS * 4, sx3 >> 8);
  } else {
    int8_t sx1 = 20 * cosf(millis() / 1024.f);
    int8_t sx2 = 20 * cosf(millis() / 512.f);
    int8_t sx3 = 127 * sinf(millis() / 2048.f);
    tunnel_float(sx1 - X_SEGMENTS * 4, sx2 - Y_SEGMENTS * 4, sx3);
  }

  // cap the refresh rate to 60Hz
  uint32_t fps_goal_us = 1000000 / 60;
  static uint64_t next_frame_us = 0;
  next_frame_us += fps_goal_us;
  while (uint64_t delay_us = next_frame_us - micros() < fps_goal_us)
    if (delay_us >= 8192)
      delay(8);
    else if (delay_us >= 3)
      delayMicroseconds(delay_us);

  show();

  if (0) { // show frame timings
    static uint32_t fps_ms;
    static uint16_t frame;
    if (++frame == 128) {
      uint32_t time_ms = millis() - fps_ms;
      Serial.print(time_ms / float(frame));
      Serial.print("ms\t");
      Serial.print(frame * 1000.f / time_ms);
      Serial.println("FPS");
      fps_ms = millis();
      frame = 0;
    }
  }
}


// write data to the config registers of each MAX7219
void shiftAll(byte send_to_address, byte send_this_data) {
  digitalWrite(CS, LOW);
  for (int i = 0; i < NUM_SEGMENTS; i++) {
    SPI.transfer(send_to_address);
    SPI.transfer(send_this_data);
  }
  digitalWrite(CS, HIGH);
}


// benchmarking 4,096 frames of show()
// level   O0    Os    O1    O2    O3
// millis  1.41  0.63  0.63  0.47  0.47
// FPS     710   1576  1575  2137  2137
#pragma GCC push_options
#pragma GCC optimize "-O2"

// send the raster order framebuffer in the correct order
// for the boustrophedon layout of daisy-chained MAX7219s
void show() {
  for (byte row = 0; row < 8; row++) {
    digitalWrite(CS, LOW);
    byte segment = NUM_SEGMENTS;
    while (segment--) {
      byte x = segment % X_SEGMENTS;
      byte y = segment / X_SEGMENTS * 8;
      byte addr = (row + y) * X_SEGMENTS;

      if (segment & X_SEGMENTS) { // odd rows of segments
        SPI.transfer(8 - row);
        byte c = fb[addr + x];
        // reverse the byte (LSB to MSB)
        c = ((c >> 1) & 0x55) | ((c << 1) & 0xAA);
        c = ((c >> 2) & 0x33) | ((c << 2) & 0xCC);
        c = (c >> 4) | (c << 4);
        SPI.transfer(c);
      } else { // even rows of segments
        SPI.transfer(1 + row);
        SPI.transfer(fb[addr - x + X_SEGMENTS - 1]);
      }
    }
    digitalWrite(CS, HIGH);
  }
}
#pragma GCC pop_options


// integer square root
uint8_t isqrt16(uint16_t x) {
  uint8_t res = 0;
  uint8_t add = 0x80;
  do {
    uint8_t t = res | add;
    uint16_t t2 = t * t;
    if (x >= t2) res = t;
  } while (add >>= 1);
  return res;
}


// benchmarking 4,096 frames of tunnel()
// level   O0     Os     O1     O2     O3     Ofast
// millis  7.23   2.42   2.18   2.57   2.51   2.51
// FPS     138.2  413.4  459.7  389.4  398.8  398.8
#pragma GCC push_options
#pragma GCC optimize "-O1"

inline void __attribute__((always_inline)) emit_pixel(byte* &dst, byte radius_pos, uint8_t xroot, uint8_t screenx) {
  static byte out = 0;
  out <<= 1;
  if ((xroot + radius_pos) & 8)
    out |= 1;
  if (!(screenx & 7))
    *dst++ = out;
}


void tunnel(int8_t x_pos, int8_t y_pos, uint8_t radius_pos) {
  byte* dst = fb;
  uint8_t  screenx, screeny, xroot, yroot;
  uint16_t xsumsquares, ysumsquares, xnextsquare, ynextsquare;
  int8_t   x, y;

  // offset the origin in screen space
  x = x_pos;
  y = y_pos;
  ysumsquares = x * x + y * y;
  yroot = isqrt16(ysumsquares);
  ynextsquare = yroot * yroot;

  // Quadrant II (top-left)
  screeny = Y_SEGMENTS * 8;
  while (y < 0 && screeny) {
    screeny--;
    x = x_pos;
    screenx = X_SEGMENTS * 8;
    xsumsquares = ysumsquares;
    xroot = yroot;
    if (x < 0) {
      xnextsquare = xroot * xroot;
      while (x < 0 && screenx) {
        screenx--;
        emit_pixel(dst, radius_pos, xroot, screenx);
        xsumsquares += 2 * x++ + 1;
        if (xsumsquares < xnextsquare)
          xnextsquare -= 2 * xroot-- - 1;
      }
    }
    // Quadrant I (top-right)
    if (screenx) {
      xnextsquare = (xroot + 1) * (xroot + 1);
      while (screenx) {
        screenx--;
        emit_pixel(dst, radius_pos, xroot, screenx);
        xsumsquares += 2 * x++ + 1;
        if (xsumsquares >= xnextsquare)
          xnextsquare += 2 * ++xroot + 1;
      }
    }
    ysumsquares += 2 * y++ + 1;
    if (ysumsquares < ynextsquare)
      ynextsquare -= 2 * yroot-- - 1;
  }
  // Quadrant III (bottom-left)
  ynextsquare = (yroot + 1) * (yroot + 1);
  while (screeny) {
    screeny--;
    x = x_pos;
    screenx = X_SEGMENTS * 8;
    xsumsquares = ysumsquares;
    xroot = yroot;
    if (x < 0) {
      xnextsquare = xroot * xroot;
      while (x < 0 && screenx) {
        screenx--;
        emit_pixel(dst, radius_pos, xroot, screenx);
        xsumsquares += 2 * x++ + 1;
        if (xsumsquares < xnextsquare)
          xnextsquare -= 2 * xroot-- - 1;
      }
    }
    // Quadrant IV (bottom-right)
    if (screenx) {
      xnextsquare = (xroot + 1) * (xroot + 1);
      while (screenx--) {
        emit_pixel(dst, radius_pos, xroot, screenx);
        xsumsquares += 2 * x++ + 1;
        if (xsumsquares >= xnextsquare)
          xnextsquare += 2 * ++xroot + 1;
      }
    }
    ysumsquares += 2 * y++ + 1;
    if (ysumsquares >= ynextsquare)
      ynextsquare += 2 * ++yroot + 1;
  }
}
#pragma GCC pop_options


// benchmarking 512 frames of tunnel_float()
// level   O0     Os     O1     O2     O3     Ofast
// millis  70.99  66.74  66.72  62.99  62.99  62.99
// FPS     14.09  14.98  14.99  15.87  15.88  15.88
#pragma GCC push_options
#pragma GCC optimize "-Ofast"
void tunnel_float(int8_t x_pos, int8_t y_pos, uint8_t radius_pos) {
  byte* dst = fb;
  int8_t y = y_pos;
  uint8_t screeny = Y_SEGMENTS * 8;
  while (y++, screeny--) {
    int8_t x = x_pos;
    uint8_t screenx = X_SEGMENTS * 8;
    while (x++, screenx--)
      emit_pixel(dst, radius_pos, hypotf(x, y), screenx);
  }
}
#pragma GCC pop_options