// https://wokwi.com/projects/413202725219174401

#include <Adafruit_MPU6050.h>
#include <Adafruit_Sensor.h>
#include <Wire.h>
#include <Servo.h>

Adafruit_MPU6050 mpu;
Servo myservo;


sensors_event_t event;

const unsigned int loopPeriod = 20;  // ms
unsigned int loopTimeStamp = 0;

int simServoPos = 90;
int simServoPosTarget = 90;
const unsigned int simServoSpeed = 270;  // degree / s

long trainingIteration = 0;
unsigned int trainingIterationStep = 0;
const unsigned int trainingStepsPerIteration = 2000 / loopPeriod; // 2 seconds

long trainingFitnessBestIteration = 0;
double trainingFitnessBest = 0x7FFFFFFFFFFFFFFFL;
double trainingFitnessCurrent = 0;


void servoSetPosition(int pos) {
  myservo.write(pos);
  simServoPosTarget = pos;
}

void simulationStep() {
  const int simServoSpeedStep = (simServoSpeed * loopPeriod) / 1000;
  // Simulate servo
  int servoDiff = simServoPosTarget - simServoPos;

  // Add some noise
  servoDiff += random(-2, 2);

  if (abs(servoDiff) > simServoSpeedStep) {
    if (servoDiff < 0) {
      servoDiff = -simServoSpeedStep;
    } else {
      servoDiff = simServoSpeedStep;
    }
  }
  simServoPos = simServoPos + servoDiff;
  /*Serial.print("Servo: [");
  Serial.print(simServoPosTarget);
  Serial.print("] [");
  Serial.print(simServoPos);
  Serial.print("] [");
  Serial.print(servoDiff);
  Serial.print("]\n");*/


  // Create acceleration values based on servo
  // 90 degree : z = 9.81
  const double factor = 3.14159265 / 180;
  event.acceleration.z = sin(simServoPos * factor) * 9.81;
  event.acceleration.x = cos(simServoPos * factor) * 9.81;
}

/**
 * Neuronal network design:
 *
 * Input layer: a_x    a_z
 *              / \ / \ / \
 * layer 0:    X   X   X   X
 *             |/ \|/ \|/ \|
 * layer 1:    X   X   X   X
 *             |/ \|/ \|/ \|
 * layer 2:    X   X   X   X
 *             |/ \|/ \|/ \|
 * layer 3:    X   X   X   X
 *              \   \ /   /   
 * Output layer:   servo
 */
union InputLayer {
  struct
  {
    signed char a_x;
    signed char a_z;
  } values;
  signed char raw[2];
};

union OutputLayer {
  struct
  {
    signed char servo;
  } values;
  signed char raw[1];
};

const unsigned int hiddenLayerCount = 4;
const unsigned int hiddenLayerSize = 4;
//const unsigned int hiddenLayerWeightCount = hiddenLayerSize * 3;
signed char weightsInput[sizeof((struct InputLayer){}.raw)][hiddenLayerSize];    // Random initialization!
signed char weightsHiddenLayer[hiddenLayerCount - 1][hiddenLayerSize][hiddenLayerSize];    // Random initialization!
signed char weightsOutput[sizeof((struct OutputLayer){}.raw)][hiddenLayerSize];  // Random initialization!

signed char weightsInputBackup[sizeof((struct InputLayer){}.raw)][hiddenLayerSize];
signed char weightsHiddenLayerBackup[hiddenLayerCount - 1][hiddenLayerSize][hiddenLayerSize];
signed char weightsOutputBackup[sizeof((struct OutputLayer){}.raw)][hiddenLayerSize];

static inline signed char calculateWeight(signed char x, signed char weight) {
  signed long temp = x;
  temp = (temp * weight) / 128;
  return temp;
}

static inline signed char neuralNetworkSaturatedCast(signed int i)
{
  if (i > 128)
  {
    return 128;
  }
  else if (i < -127)
  {
    return -127;
  }
  else
    return i;
}

struct OutputLayer neuralNetworkIterate(struct InputLayer inputValues) {
  signed char previousLayer[hiddenLayerSize];
  signed char currentLayer[hiddenLayerSize];

  // Calculate input layer
  //Serial.print("Input: ");
  for (int i = 0; i < hiddenLayerSize; i++) {
    signed int temp = 0;
    for (int input = 0; input < (sizeof((struct InputLayer){}.raw) / sizeof((struct InputLayer){}.raw[0])); input++) {
      temp += calculateWeight(inputValues.raw[input], weightsInput[input][i]);

      /*if (input != 0)
      {
        Serial.print(" + ");
      }
      Serial.print(inputValues.raw[input]);
      Serial.print(" * ");
      Serial.print(weightsInput[input][i]);*/
    }
    previousLayer[i] = neuralNetworkSaturatedCast(temp);
    //Serial.print(" = ");
    //Serial.print((int)previousLayer[i]);
    //Serial.print("; ");
  }
  //Serial.print("\n");

  // Calculate hidden layers
  for (int layer = 0; layer < (hiddenLayerCount - 1); layer++) {
    //Serial.print(layer);
    //Serial.print(": ");
    for (int i = 0; i < hiddenLayerSize; i++) {
      signed int temp = 0;
      for (int j = 0; j < hiddenLayerSize; j++) {
        temp += calculateWeight(previousLayer[j], weightsHiddenLayer[layer][i][j]);

        /*if (j != 0)
        {
          Serial.print(" + ");
        }
        Serial.print(previousLayer[j]);
        Serial.print(" * ");
        Serial.print(weightsHiddenLayer[layer][i][j]);*/
      }
      currentLayer[i] = neuralNetworkSaturatedCast(temp);
      //Serial.print(" = ");
      //Serial.print((int)currentLayer[i]);
      //Serial.print("; ");
    }

    // Move current layer into previous layer
    for (int i = 0; i < hiddenLayerSize; i++)
    {
      previousLayer[i] = currentLayer[i];
    }
    //Serial.print("\n");
  }

  // Calculate output layer
  struct OutputLayer outputValues;
  //Serial.print("Output: ");
  for (int output = 0; output < (sizeof((struct OutputLayer){}.raw) / sizeof((struct OutputLayer){}.raw[0])); output++) {
    signed int temp = 0;
    for (int i = 0; i < hiddenLayerSize; i++) {
      temp += calculateWeight(previousLayer[i], weightsOutput[output][i]);

      /*if (i != 0)
      {
        Serial.print(" + ");
      }
      Serial.print(previousLayer[i]);
      Serial.print(" * ");
      Serial.print(weightsOutput[output][i]);*/
    }

    outputValues.raw[output] = neuralNetworkSaturatedCast(temp);
    //Serial.print(" = ");
    //Serial.print((int)outputValues.raw[output]);
    //Serial.print("; ");
  }
  //Serial.print("\n");

  return outputValues;
}

void neuralNetworkInit() {
  // Fill with random numbers
  // Input layer
  for (int input = 0; input < (sizeof((struct InputLayer){}.raw) / sizeof((struct InputLayer){}.raw[0])); input++) {
    for (int i = 0; i < (hiddenLayerSize); i++) {
      weightsInput[input][i] = random(-127, 128);
    }
  }

  // Hidden layers
  for (int layer = 0; layer < (hiddenLayerCount - 1); layer++) {
    for (int i = 0; i < hiddenLayerSize; i++) {
      for (int j = 0; j < hiddenLayerSize; j++) {
        weightsHiddenLayer[layer][i][j] = random(-127, 128);
      }
    }
  }

  // Output layer
  for (int output = 0; output < (sizeof((struct OutputLayer){}.raw) / sizeof((struct OutputLayer){}.raw[0])); output++) {
    for (int i = 0; i < (hiddenLayerSize); i++) {
      weightsOutput[output][i] = random(-127, 128);
    }
  }
}

void neuralNetworkBackup()
{
  // Input layer
  for (int input = 0; input < (sizeof((struct InputLayer){}.raw) / sizeof((struct InputLayer){}.raw[0])); input++) {
    for (int i = 0; i < (hiddenLayerSize); i++) {
      weightsInputBackup[input][i] = weightsInput[input][i];
    }
  }

  // Hidden layers
  for (int layer = 0; layer < (hiddenLayerCount - 1); layer++) {
    for (int i = 0; i < hiddenLayerSize; i++) {
      for (int j = 0; j < hiddenLayerSize; j++) {
        weightsHiddenLayerBackup[layer][i][j] = weightsHiddenLayer[layer][i][j];
      }
    }
  }

  // Output layer
  for (int output = 0; output < (sizeof((struct OutputLayer){}.raw) / sizeof((struct OutputLayer){}.raw[0])); output++) {
    for (int i = 0; i < (hiddenLayerSize); i++) {
      weightsOutputBackup[output][i] = weightsOutput[output][i];
    }
  }
}

void neuralNetworkRestore()
{
  // Input layer
  for (int input = 0; input < (sizeof((struct InputLayer){}.raw) / sizeof((struct InputLayer){}.raw[0])); input++) {
    for (int i = 0; i < (hiddenLayerSize); i++) {
      weightsInput[input][i] = weightsInputBackup[input][i];
    }
  }

  // Hidden layers
  for (int layer = 0; layer < (hiddenLayerCount - 1); layer++) {
    for (int i = 0; i < hiddenLayerSize; i++) {
      for (int j = 0; j < hiddenLayerSize; j++) {
        weightsHiddenLayer[layer][i][j] = weightsHiddenLayerBackup[layer][i][j];
      }
    }
  }

  // Output layer
  for (int output = 0; output < (sizeof((struct OutputLayer){}.raw) / sizeof((struct OutputLayer){}.raw[0])); output++) {
    for (int i = 0; i < (hiddenLayerSize); i++) {
      weightsOutput[output][i] = weightsOutputBackup[output][i];
    }
  }
}

void neuralNetworkRandomize(unsigned char range) {
  // Fill with random numbers
  // Input layer
  for (int input = 0; input < (sizeof((struct InputLayer){}.raw) / sizeof((struct InputLayer){}.raw[0])); input++) {
    for (int i = 0; i < (hiddenLayerSize); i++) {
      signed int temp = weightsInput[input][i] + random(-range, range);
      weightsInput[input][i] = neuralNetworkSaturatedCast(temp);;
    }
  }

  // Hidden layers
  for (int layer = 0; layer < (hiddenLayerCount - 1); layer++) {
    for (int i = 0; i < hiddenLayerSize; i++) {
      for (int j = 0; j < hiddenLayerSize; j++) {
        signed int temp = weightsHiddenLayer[layer][i][j] + random(-range, range);
        weightsHiddenLayer[layer][i][j] = neuralNetworkSaturatedCast(temp);;
      }
    }
  }

  // Output layer
  for (int output = 0; output < (sizeof((struct OutputLayer){}.raw) / sizeof((struct OutputLayer){}.raw[0])); output++) {
    for (int i = 0; i < (hiddenLayerSize); i++) {
      signed int temp = weightsOutput[output][i] + random(-range, range);
      weightsOutput[output][i] = neuralNetworkSaturatedCast(temp);;
    }
  }
}

void neuralNetworkPrint() {
  // Input layer
  Serial.print("w_i[");
  for (int input = 0; input < (sizeof((struct InputLayer){}.raw) / sizeof((struct InputLayer){}.raw[0])); input++) {
    for (int i = 0; i < (hiddenLayerSize); i++) {
      Serial.print(weightsInput[input][i]);
      if (i < (hiddenLayerSize - 1)) {
        Serial.print(", ");
      }
    }
  }
  Serial.print("]\n");

  // Hidden layers
  Serial.print("w_h[\n");
  for (int layer = 0; layer < (hiddenLayerCount - 1); layer++) {
    Serial.print("  [");
    for (int i = 0; i < hiddenLayerSize; i++) {
      Serial.print("[");
      for (int j = 0; j < hiddenLayerSize; j++) {
        if (j != 0) {
          Serial.print(", ");
        }
        Serial.print(weightsHiddenLayer[layer][i][j]);
      }
      Serial.print("]");
    }
    Serial.print("]\n");
  }
  Serial.print("]\n");

  // Output layer
  Serial.print("w_o[");
  for (int output = 0; output < (sizeof((struct OutputLayer){}.raw) / sizeof((struct OutputLayer){}.raw[0])); output++) {
    for (int i = 0; i < (hiddenLayerSize); i++) {
      Serial.print(weightsOutput[output][i]);
      if (i < (hiddenLayerSize - 1)) {
        Serial.print(", ");
      }
    }
  }
  Serial.print("]\n");
}

void neuralNetworkStartTraining()
{
  neuralNetworkBackup();
  neuralNetworkRandomize(5);
}

bool neuralNetworkStopTraining(double fitnessBest, double fitnessCurrent)
{
  if (abs(fitnessBest) < abs(fitnessCurrent))
  {
    neuralNetworkRestore();
    return false;
  }
  return true;
}


void setup() {
  // put your setup code here, to run once:
  Serial.begin(115200);

  while (!mpu.begin()) {
    Serial.println("MPU6050 not connected!");
    delay(1000);
  }
  Serial.println("MPU6050 ready!");

  myservo.attach(2);
  Serial.println("Servo ready!");

  neuralNetworkInit();
  Serial.println("Neural network initialized!");

  neuralNetworkPrint();
  delay(2000);
}

void loop() {
  if (trainingIterationStep == 0)
  {
    if (trainingIteration != 0)
    {
      Serial.print("Completed training iteration ");
      Serial.print(trainingIteration);
      Serial.print(".\n");

      Serial.print("Current fitness [");
      Serial.print(trainingFitnessCurrent);
      Serial.print("], best fitness [");
      Serial.print(trainingFitnessBest);
      Serial.print("] of iteration ");
      Serial.print(trainingFitnessBestIteration);
      Serial.print(".");
      if (neuralNetworkStopTraining(trainingFitnessBest, trainingFitnessCurrent))
      {
        trainingFitnessBest = trainingFitnessCurrent;
        trainingFitnessBestIteration = trainingIteration;
        Serial.print("\n\n");
      }
      else
      {
        Serial.print(" Restoring neural network!\n\n");
      }
    }

    trainingIteration++;

    Serial.print("Starting training iteration ");
    Serial.print(trainingIteration);
    Serial.print(".\n");
    neuralNetworkStartTraining();
    trainingFitnessCurrent = 0;
  }

  trainingIterationStep = (trainingIterationStep + 1) % trainingStepsPerIteration;

  mpu.getAccelerometerSensor()->getEvent(&event);

  simulationStep();
  /*
  Serial.print("[");
  Serial.print(millis());
  Serial.print("] X: ");
  Serial.print(event.acceleration.x);
  Serial.print(", Y: ");
  Serial.print(event.acceleration.y);
  Serial.print(", Z: ");
  Serial.print(event.acceleration.z);

  Serial.println(" m/s^2");
*/

  // Calculate fitness of last step
  // Target 0.707 for x and z /~45°)
  double diff_x = (event.acceleration.x / 9.81) - 0.707;
  double diff_z = (event.acceleration.z / 9.81) - 0.707;
  trainingFitnessCurrent += (diff_x * diff_x) + (diff_z * diff_z);

  struct OutputLayer outputValues;
  struct InputLayer inputValues;
  inputValues.raw[0] = (event.acceleration.x * 128) / 9.81;
  inputValues.raw[1] = (event.acceleration.z * 128) / 9.81;
  outputValues = neuralNetworkIterate(inputValues);

  /*Serial.print("[");
  Serial.print(inputValues.raw[0]);
  Serial.print(", ");
  Serial.print(inputValues.raw[1]);
  Serial.print("] => [");
  Serial.print(outputValues.raw[0]);
  Serial.print("](");
  Serial.print(trainingFitnessCurrent);
  Serial.print(")\n");*/

  // Calculate servo value to set
  signed long setServo = outputValues.raw[0];
  setServo = ((setServo * 90) / 128) + 90;
  servoSetPosition(setServo);

  // Wait until next loop iteration
  loopTimeStamp += loopPeriod;
  while (millis() < loopTimeStamp) {
    // Wait
  }
}