#include <Wire.h>
#include <sensor.h>
#include <FFT_signal.h>

#define SENSOR_ADDR 0x6B // IIS3DWB sensor I2C address
#define SAMPLE_RATE 1600.0 // Sampling rate in Hz (float)
#define FFT_SIZE 512 // FFT size
#define CLASSIFICATION_THRESHOLD_LOW 0.1
#define CLASSIFICATION_THRESHOLD_MEDIUM 0.4
#define CLASSIFICATION_THRESHOLD_HIGH 1.6
#define CLASSIFICATION_THRESHOLD_VERY_HIGH 6.3

IIS3DWB sensor;
arduinoFFT FFT = arduinoFFT();

void setup() {
  Wire.begin();
  Serial.begin(115200);
  
  if (!sensor.begin(Wire, SENSOR_ADDR)) {
    Serial.println("Failed to initialize IIS3DWB sensor!");
    while (1);
  }
  
  if (!sensor.setPowerMode(LOW_POWER)) {
    Serial.println("Failed to set power mode!");
    while (1);
  }

  if (!sensor.setFullScaleRange(FSR_2g)) {
    Serial.println("Failed to set full-scale range!");
    while (1);
  }

  if (!sensor.setODR(SAMPLE_RATE)) {
    Serial.println("Failed to set sampling rate!");
    while (1);
  }
}

void loop() {
  // Collect data
  int16_t xData[FFT_SIZE];
  int16_t yData[FFT_SIZE];
  int16_t zData[FFT_SIZE];
  
  for (int i = 0; i < FFT_SIZE; i++) {
    if (!sensor.getAccelerationRaw(xData[i], yData[i], zData[i])) {
      Serial.println("Failed to read sensor data!");
      while (1);
    }
    delayMicroseconds(1000000 / SAMPLE_RATE); // Delay based on sampling rate
  }

  // Compute FFT
  double xFFT[FFT_SIZE];
  double yFFT[FFT_SIZE];
  double zFFT[FFT_SIZE];
  
  FFT.Windowing(xData, FFT_SIZE, FFT_WIN_TYP_HAMMING, FFT_FORWARD);
  FFT.Compute(xData, FFT_SIZE, FFT_FORWARD);
  FFT.GetResults(xFFT);
  
  FFT.Windowing(yData, FFT_SIZE, FFT_WIN_TYP_HAMMING, FFT_FORWARD);
  FFT.Compute(yData, FFT_SIZE, FFT_FORWARD);
  FFT.GetResults(yFFT);
  
  FFT.Windowing(zData, FFT_SIZE, FFT_WIN_TYP_HAMMING, FFT_FORWARD);
  FFT.Compute(zData, FFT_SIZE, FFT_FORWARD);
  FFT.GetResults(zFFT);

  // Compute PSD
  double xPSD[FFT_SIZE / 2];
  double yPSD[FFT_SIZE / 2];
  double zPSD[FFT_SIZE / 2];
  
  for (int i = 0; i < FFT_SIZE / 2; i++) {
    xPSD[i] = (2.0 * pow(xFFT[i], 2)) / (FFT_SIZE * SAMPLE_RATE);
    yPSD[i] = (2.0 * pow(yFFT[i], 2)) / (FFT_SIZE * SAMPLE_RATE);
    zPSD[i] = (2.0 * pow(zFFT[i], 2)) / (FFT_SIZE * SAMPLE_RATE);
  }

  // Compute velocity RMS
  double velocityRMS = sqrt(xPSD[1] + yPSD[1] + zPSD[1]);

  // Classify velocity RMS based on ISO standards
  String classification;
  
  if (velocityRMS < CLASSIFICATION_THRESHOLD_LOW)
  classification = "Very low";
else if (velocityRMS < CLASSIFICATION_THRESHOLD_MEDIUM)
classification = "Low";
else if (velocityRMS < CLASSIFICATION_THRESHOLD_HIGH)
classification = "Medium";
else if (velocityRMS < CLASSIFICATION_THRESHOLD_VERY_HIGH)
classification = "High";
else
classification = "Very high";

// Print results
Serial.print("Velocity RMS: ");
Serial.print(velocityRMS, 4);
Serial.print(" mm/s (Classification: ");
Serial.println(classification);

delay(1000); // Delay between measurements
}