// Credit to xdd (MIT MArch Year 3) for the setup of the pieces, over which I developed code.
// Thanks man, the LED Dots worked better than the LCD Screen!


#include <math.h> // For sqrtf

#define CLK 13
#define DIN 11
#define CS  10
#define X_SEGMENTS   4
#define Y_SEGMENTS   4
#define NUM_SEGMENTS (X_SEGMENTS * Y_SEGMENTS)

#define GRID_WIDTH 32    // 32x32 grid for 4 LED matrices (8x8 each)
#define GRID_HEIGHT 32

byte fb[8 * NUM_SEGMENTS];  // Framebuffer to hold LED matrix data

// Grids to store cell states for Game of Life
int grid[GRID_WIDTH][GRID_HEIGHT];
int nextGrid[GRID_WIDTH][GRID_HEIGHT];

// Function to shift data to all LED matrices
void shiftAll(byte send_to_address, byte send_this_data) {
  digitalWrite(CS, LOW);
  for (int i = 0; i < NUM_SEGMENTS; i++) {
    shiftOut(DIN, CLK, MSBFIRST, send_to_address);
    shiftOut(DIN, CLK, MSBFIRST, send_this_data);
  }
  digitalWrite(CS, HIGH);
}

void setup() {
  Serial.begin(115200);
  pinMode(CLK, OUTPUT);
  pinMode(DIN, OUTPUT);
  pinMode(CS, OUTPUT);

  // Setup each MAX7219
  shiftAll(0x0f, 0x00); // Display test register - test mode off
  shiftAll(0x0b, 0x07); // Scan limit register - display digits 0 through 7
  shiftAll(0x0c, 0x01); // Shutdown register - normal operation
  shiftAll(0x0a, 0x0f); // Intensity register - max brightness
  shiftAll(0x09, 0x00); // Decode mode register - no decode

  // Initialize the grid with the R-pentomino pattern
  initializeGrid();

  // Clear the framebuffer before drawing
  clear();
}

// Function to initialize the grid with the R-pentomino at the center
void initializeGrid() {
  Serial.println(F("Initializing grid with the R-pentomino pattern"));
  // Clear the entire grid first
  for (int x = 0; x < GRID_WIDTH; x++) {
    for (int y = 0; y < GRID_HEIGHT; y++) {
      grid[x][y] = 0;  // All cells dead initially
    }
  }

  // Define the R-pentomino pattern
  int centerX = GRID_WIDTH / 2;
  int centerY = GRID_HEIGHT / 2;
  
  // Coordinates of the R-pentomino (relative to center)
  grid[centerX][centerY] = 1;
  grid[centerX + 1][centerY] = 1;
  grid[centerX - 1][centerY + 1] = 1;
  grid[centerX][centerY + 1] = 1;
  grid[centerX][centerY + 2] = 1;
}

// Function to count live neighbors for Game of Life
int countNeighbors(int x, int y) {
  int sum = 0;
  for (int i = -1; i <= 1; i++) {
    for (int j = -1; j <= 1; j++) {
      int col = (x + i + GRID_WIDTH) % GRID_WIDTH;
      int row = (y + j + GRID_HEIGHT) % GRID_HEIGHT;
      sum += grid[col][row];
    }
  }
  sum -= grid[x][y];  // Subtract the cell itself
  return sum;
}

// Function to update the grid according to Game of Life rules
void updateGrid() {
  for (int x = 0; x < GRID_WIDTH; x++) {
    for (int y = 0; y < GRID_HEIGHT; y++) {
      int neighbors = countNeighbors(x, y);

      // Apply the Game of Life rules
      if (grid[x][y] == 1) {
        // Cell dies if underpopulated or overpopulated
        if (neighbors < 2 || neighbors > 3) {
          nextGrid[x][y] = 0;
        } else {
          // Cell survives
          nextGrid[x][y] = 1;
        }
      } else {
        // Cell becomes alive if it has exactly 3 neighbors
        if (neighbors == 3) {
          nextGrid[x][y] = 1;
        } else {
          nextGrid[x][y] = 0;
        }
      }
    }
  }

  // Copy nextGrid to grid
  for (int x = 0; x < GRID_WIDTH; x++) {
    for (int y = 0; y < GRID_HEIGHT; y++) {
      grid[x][y] = nextGrid[x][y];
    }
  }
}

// Function to plot the grid on the LED matrix
void drawGrid() {
  clear();  // Clear the framebuffer before drawing
  
  for (int x = 0; x < GRID_WIDTH; x++) {
    for (int y = 0; y < GRID_HEIGHT; y++) {
      if (grid[x][y] == 1) {
        // If the cell is alive, plot the pixel
        safe_pixel(x, y, 1);
      }
    }
  }
  
  show();  // Update the LED matrix with the new framebuffer
}

// Main loop to draw the Game of Life and update generations
void loop() {
  drawGrid();    // Draw the current grid state on the LED matrix
  updateGrid();  // Update the grid according to the Game of Life rules
  delay(100);    // Add a delay for visual effect
}

// Function to set a pixel in the framebuffer
void set_pixel(uint8_t x, uint8_t y, uint8_t mode) {
  byte *addr = &fb[x / 8 + y * X_SEGMENTS];
  byte mask = 128 >> (x % 8);
  switch (mode) {
    case 0: // clear pixel
      *addr &= ~mask;
      break;
    case 1: // plot pixel
      *addr |= mask;
      break;
    case 2: // XOR pixel
      *addr ^= mask;
      break;
  }
}

// Safely plot a pixel, ensuring it is within the grid bounds
void safe_pixel(uint8_t x, uint8_t y, uint8_t mode) {
  if ((x >= X_SEGMENTS * 8) || (y >= Y_SEGMENTS * 8)) {
    return;
  }
  set_pixel(x, y, mode);
}

// Clear all LEDs in the framebuffer
void clear() {
  byte *addr = fb;
  for (byte i = 0; i < 8 * NUM_SEGMENTS; i++) {
    *addr++ = 0;
  }
}

// Send the framebuffer to the LED matrix
void show() {
  for (byte row = 0; row < 8; row++) {
    digitalWrite(CS, LOW);
    byte segment = NUM_SEGMENTS;
    while (segment--) {
      byte x = segment % X_SEGMENTS;
      byte y = segment / X_SEGMENTS * 8;
      byte addr = (row + y) * X_SEGMENTS;

      if (segment & X_SEGMENTS) { // odd rows of segments
        shiftOut(DIN, CLK, MSBFIRST, 8 - row);
        shiftOut(DIN, CLK, LSBFIRST, fb[addr + x]);
      } else { // even rows of segments
        shiftOut(DIN, CLK, MSBFIRST, 1 + row);
        shiftOut(DIN, CLK, MSBFIRST, fb[addr - x + X_SEGMENTS - 1]);
      }
    }
    digitalWrite(CS, HIGH);
  }
}