#include <MD_MAX72xx.h>

#define MAX_DEVICES 4
#define CLK_PIN     13
#define DATA_PIN    11
#define CS_PIN      10
#define VERT_PIN    A0
#define HORZ_PIN    A1
#define SEL_PIN     2

MD_MAX72XX mx = MD_MAX72XX(MD_MAX72XX::PAROLA_HW, CS_PIN, MAX_DEVICES);

const int maxX = 15;
const int maxY = 15;
int direction = 4; // 1 = up, 2 = right, 3 = down, 4 = left

// Snake logic
#define MAX_SEGMENTS 256
int snakex[MAX_SEGMENTS];
int snakey[MAX_SEGMENTS];
int age[MAX_SEGMENTS];
int length = 4;
int headIndex = 0;

// Apple position
int appleX = 0;
int appleY = 0;

// Snake head position
int x = 15;
int y = 8;

void setup() {
  spawnApple();
  mx.begin();
  mx.control(MD_MAX72XX::INTENSITY, MAX_INTENSITY / 2);
  mx.clear();
  Serial.begin(9600);

  pinMode(VERT_PIN, INPUT);
  pinMode(HORZ_PIN, INPUT);
  pinMode(SEL_PIN, INPUT_PULLUP);
}

void loop() {
  int newX = x;
  int newY = y;
  parseJoystick(newX, newY);

  bool isDead = false;

  

  // Self collision
  for (int i = 0; i < MAX_SEGMENTS; i++) {
    if (age[i] > 0 && snakex[i] == newX && snakey[i] == newY) {
      isDead = true;
      break;
    }
  }

  if (isDead) {
    mx.clear(); // Clear the screen immediately on game over
    mx.update();
    while(true)
    {
      delay(10000);
    }
  }

  x = constrain(newX, 0, 15);
  y = constrain(newY, 0, 15);

  updateSnake(x, y);
  drawVirtual16x16(appleX, appleY, true);

  for (int i = 0; i < MAX_SEGMENTS; i++) {
    if (age[i] > 0 && age[i] <= length) {
      drawVirtual16x16(snakex[i], snakey[i], true);
    }
  }

  mx.update();
  delay(250);
}

void parseJoystick(int &x, int &y) {
  int vert = analogRead(VERT_PIN);
  int horz = analogRead(HORZ_PIN);
  Serial.println(direction);

  if (vert > 700 && direction != 3) {
    y = max(0, y - 1);
    direction = 1;
  } else if (vert < 300 && direction != 1) {
    y = min(15, y + 1);
    direction = 3;
  } else if (horz > 700 && direction != 4) {
    x = min(15, x + 1);
    direction = 2;
  } else if (horz < 300 && direction != 2) {
    x = max(0, x - 1);
    direction = 4;
  } else {
    switch (direction) {
      case 1: y = max(0, y - 1); break;
      case 2: x = min(15, x + 1); break;
      case 3: y = min(15, y + 1); break;
      case 4: x = max(0, x - 1); break;
    }
  }
}

void spawnApple() {
  bool valid = false;
  while (!valid) {
    valid = true;
    appleX = random(0, 16);
    appleY = random(0, 16);

    for (int i = 0; i < MAX_SEGMENTS; i++) {
      if (age[i] > 0 && snakex[i] == appleX && snakey[i] == appleY) {
        valid = false;
        break;
      }
    }
  }
}

void updateSnake(int newX, int newY) {
  snakex[headIndex] = newX;
  snakey[headIndex] = newY;
  age[headIndex] = 1;

  if (newX == appleX && newY == appleY) {
    length++;
    spawnApple();
  }

  for (int i = 0; i < MAX_SEGMENTS; i++) {
    if (i != headIndex && age[i] > 0) {
      age[i]++;
    }
    if (age[i] > length) {
      drawVirtual16x16(snakex[i], snakey[i], false);
      age[i] = 0;
    }
  }

  headIndex++;
  if (headIndex >= MAX_SEGMENTS) headIndex = 0;
}

void drawVirtual16x16(int x, int y, bool state) {
  int physicalCol = x;
  int physicalRow = (y < 8) ? y : y - 8;
  if (y >= 8) physicalCol += 16;
  mx.setPoint(physicalRow, physicalCol, state);
}

void flashMatrixOnce() {
  for (int x = 0; x < 16; x++) {
    for (int y = 0; y < 16; y++) {
      drawVirtual16x16(x, y, true);
    }
  }
  mx.update();
  delay(200);
  mx.clear();
  mx.update();
}