// Import libraries

#include <SD.h>  // NR: SD card to read from file to simulate input

// #include "FlexiTimer2.h"
#include <util/atomic.h>

#include "ExpMovStats.h"
#include "PeakViaDev.h"
#include "Poisson.h"
#include "ThresEvent.h"
#include "FixedPeriodMicroscope.h"

// Define PINs
#define FLOW_INPUT_PIN 10
#define SCOPE_OUTPUT_PIN 3
#define SCOPE_IS_ON_PIN 4 // NR, relevant to showcase ON state of scope
#define A_OUTPUT_PIN 5
#define B_OUTPUT_PIN 6
#define C_OUTPUT_PIN 7

// Input variables
String FLOW_INPUT_FILE = "flow.txt";
// const int FLOW_NORM_FACTOR = 512;
const int FLOW_NORM_FACTOR = 2048;

int flow;
double norm_flow;
File text_file; // NR

// oStates Timer Settings
const int sample_per_sec = 1000; // 1ms
// const int sample_per_sec = 200; // 5ms
const float eval_per_sample = 1.0; // number of times to poll the oStates
const double ms_per_sample = 1000.0 / (double) sample_per_sec; // i.e. dt but in ms

// Sampling variables
uint32_t sample; // consider making global?
const uint32_t MAX_UINT = 2^32 - 1; // May not need anymore
double current_time;

// // Keep track of time in session
// // elapsedMillis session_time;  // type long unsigned int

// Objects and States
// const uint32_t SAMPLE_PER_RATE_UPDATE = 90000; // 90 sec
// const uint32_t SAMPLE_PER_RATE_UPDATE = 10000; 
const uint32_t SAMPLE_PER_RATE_UPDATE = 50000; 
const double LAMBDA_C = 1/4;
String stim_label = "";
double cum_rate_pre_A;

FixedPeriodMicroscope scope(
  10000, // tau_trial, ms
  5000, // tau_ITI, ms
  1000 // tau_delay, ms
);

ExpMovStats flow_ems(
  0.01 * ms_per_sample, // alpha
  100, // init_sample, in sample
  true // init_with_stats
);

PeakViaDev flow_peaks(
  ms_per_sample, // dt, in ms
  0.8, // theta, note: 1.0 is good for Fs = 5ms, 0.8 seems ok for Fs = 1ms
  5, // tau_delay, in ms
  120, // tau_refrac, in ms
  &flow_ems // references 
);

ExpMovStats sniff_rate(
  0.0005 * ms_per_sample, // alpha
  100, // init_sample, in sample
  false // init_with_stats
);

ThresEvent A_event(
  ms_per_sample, // dt, in ms
  3.5, // thres of value, unit Hz
  5000, // tau_refrac, in ms
  200, // tau_pre_refrac, in ms, smaller than tau_refrac
  8, // max_consec, in ms, use 0 to suppress this condition
  2000, // init_time, in ms, keep this small ~ 1-2 sec
  false, // whether event count is independent from other events
  &sniff_rate.states.mean // value source to threshold
);

Poisson B_event(
  ms_per_sample, // dt, in ms
  5000, // tau_refrac, in ms
  0.15 * 1000.0, // init_rate, in mHz
  5, // max_consec, in ms, use 0 to suppress this condition
  2000, // init_time, in ms,keep this small ~ 1-2 sec 
  false // whether event count/ is independent from other events
);

Poisson C_event(
  ms_per_sample, // dt, in ms
  5000, // tau_refrac, in ms
  0.10 * 1000.0, // init_rate, in mHz
  3, // max_consec, in ms, use 0 to suppress this condition
  2000, // init_time, in ms,keep this small ~ 1-2 sec 
  false // whether event count/ is independent from other events
);



void setup() {
  // Serial.begin(9600);
  Serial.begin(19200);
  Serial.println("start");

  // delay(2000);

  // NR: Simulate flow sensor inputs
  if (!SD.begin(FLOW_INPUT_PIN)) {
    Serial.println("Card initialization failed!");
  }
  text_file = SD.open(FLOW_INPUT_FILE);

  // Define output pins
  // Microscope
  pinMode(SCOPE_OUTPUT_PIN, OUTPUT);
  digitalWrite(SCOPE_OUTPUT_PIN, LOW);

  pinMode(SCOPE_IS_ON_PIN, OUTPUT); // NR
  digitalWrite(SCOPE_IS_ON_PIN, LOW); // NR

  // Stimuli
  pinMode(A_OUTPUT_PIN, OUTPUT);
  digitalWrite(A_OUTPUT_PIN, LOW);

  pinMode(B_OUTPUT_PIN, OUTPUT);
  digitalWrite(B_OUTPUT_PIN, LOW);

  pinMode(C_OUTPUT_PIN, OUTPUT);
  digitalWrite(C_OUTPUT_PIN, LOW);

  // Initialize
  sample = 0;

  // Print header
  print_header();

  // Start Sample Loop Timer
  // FlexiTimer2::set(1, eval_per_sample / (sample_per_sec * 1), oStates);
  // FlexiTimer2::start();

}


// void loop() {}
// void oStates() {}

void loop() {
  // TODO: this can be shared across all state objects that need time as reference
  // TODO: also consider using `elapsedMilis` instead
  current_time = ms_per_sample * (double) sample;

  // Read flow sensor input
  read_flow_sensor();

  // Invert and normalize
  norm_flow = -((double)flow / FLOW_NORM_FACTOR - 0.5);

  // Update peak detector
  flow_peaks.step(norm_flow, sample);
  
  // Update moving statistics
  flow_ems.step(norm_flow, sample);
  
  // Update sniff rate
  sniff_rate.step(
    ((double) flow_peaks.states.peak * (double) sample_per_sec), 
    sample
  );

  // Process events A/B/C
  process_stimuli();

  // Process scope trigger
  process_scope();

  // NR: Simulate scope
  simulate_scope();

  // Report data on serial
  report_data();
  sample +=1 ;
}


void process_scope() {
  // Fixed period scope
  scope.step(current_time);

  // Trigger onset
  if (scope.states.output) {
    digitalWrite(SCOPE_OUTPUT_PIN, HIGH);
  } else {
    digitalWrite(SCOPE_OUTPUT_PIN, LOW);
  }

}


void simulate_scope() {
  // NR: Simulate scope ON state
  if (scope.states.is_on) {
    digitalWrite(SCOPE_IS_ON_PIN, HIGH);
  } else {
    digitalWrite(SCOPE_IS_ON_PIN, LOW);
  }
}

void read_flow_sensor() {
  // Emulate getting input
  if (!text_file.available()) {
    Serial.println("ENDED!");
    // exit(0);
    return;
  }
  flow = text_file.parseInt(); 
}

void update_last_stim_output() {
  A_event.aux.last_output_time = current_time;
  B_event.aux.last_output_time = current_time;
  C_event.aux.last_output_time = current_time;
}

void process_stimuli() {

  // Update target rate
  if (sample % SAMPLE_PER_RATE_UPDATE == 0 and sample > 1) {
    cum_rate_pre_A = (
      (double) 1000.0 * 
      A_event.states.cum_pre_count /
      (current_time / 1000.0)
    ); // convert to [mHz]
    B_event.set_target_rate(cum_rate_pre_A);
    C_event.set_target_rate(cum_rate_pre_A * LAMBDA_C);
  }

  // // TODO: consider defining a superclass/abstract
  // //        in order to do the following if-statements inside
  stim_label = "N";

  // Update event A ~ thresholded sniff rate
  A_event.step(sample);
  if (A_event.states.output) {
      // actual stim output
      digitalWrite(A_OUTPUT_PIN, HIGH);
      stim_label = "A";

      // reset consec count other stims
      B_event.reset_consec();
      C_event.reset_consec();

      // update time for refractory period
      update_last_stim_output();
  } else {
      digitalWrite(A_OUTPUT_PIN, LOW);
  }
  
  // Update event B ~ Poisson, matched A
  B_event.step(sample);
  if (B_event.states.output) {
      // actual stim output
      digitalWrite(B_OUTPUT_PIN, HIGH);
      stim_label = "B";

      // reset consec count other stims
      A_event.reset_consec();
      C_event.reset_consec();

      // update time for refractory period
      update_last_stim_output();
  } else {
      digitalWrite(B_OUTPUT_PIN, LOW);
  }
  
  // Update event C ~ Poisson, sham
  C_event.step(sample);
  if (C_event.states.output) {
      // actual stim output
      digitalWrite(C_OUTPUT_PIN, HIGH);
      stim_label = "C";

      // reset consec count other stims
      A_event.reset_consec();
      B_event.reset_consec();

      // update time for refractory period
      update_last_stim_output();
    } else {
      digitalWrite(C_OUTPUT_PIN, LOW);
  }
}


void print_header() {
  // TODO: print all or most configurations and parameters
  //    NOTE: something is weird with SD when attempting to print out configs, maybe memory issues ?
  // TODO: if possible, match with `report_data` automatically
  // print_header_debug();

  print_header_minimal();
}

void report_data() {
  // Use different `inspect_XXX` functions to tune parameters
  // as well as checking whether the outputs are correct

  // Peak detection from flow sensor
  // inspect_flow_peak();

  // Sniff rate calculation
  // inspect_sniff_rate();

  // Event A delivery
  // inspect_event_A();

  // Event B delivery
  // inspect_event_B();

  // Events A + B delivery
  // inspect_events_A_B();

  // Record data debug, i.e almost everything
  // record_data_debug();

  // Record minimal data
  record_data_minimal();

  Serial.println();
}

void print_header_minimal() {
  Serial.print("time[sec]"); Serial.print(",");
  Serial.print("scope_trial"); Serial.print(",");
  Serial.print("scope_is_on"); Serial.print(",");
  Serial.print("flow"); Serial.print(",");
  Serial.print("est_sniff_rate[Hz]"); Serial.print(",");
  Serial.print("stim_label"); Serial.print(",");
  Serial.println();
}

void record_data_minimal() {
  Serial.print(current_time / 1E3, 3); Serial.print(",");
  Serial.print(scope.states.trial); Serial.print(",");
  Serial.print(scope.states.is_on); Serial.print(",");
  Serial.print(norm_flow); Serial.print(",");
  Serial.print(sniff_rate.states.mean, 3); Serial.print(",");
  Serial.print(stim_label); Serial.print(",");
}


void print_header_debug() {
  Serial.print("scope.states.trial"); Serial.print(",");
  Serial.print("scope.states.is_on"); Serial.print(",");

  Serial.print("norm_flow"); Serial.print(",");
  Serial.print("flow_ems.states.mean"); Serial.print(",");
  Serial.print("flow_ems.states.std"); Serial.print(",");
  
  Serial.print("flow_peaks.aux.thres_cross"); Serial.print(",");
  Serial.print("flow_peaks.aux.sustained_cross[sec]"); Serial.print(",");
  Serial.print("flow_peaks.states.peak"); Serial.print(",");

  Serial.print("sniff_rate.states.mean[Hz]"); Serial.print(",");

  Serial.print("stim_label"); Serial.print(",");
  
  Serial.print("A_event.value[Hz]"); Serial.print(",");
  Serial.print("A_event.states.output"); Serial.print(",");
  Serial.print("A_event.states.cum_count"); Serial.print(",");
  Serial.print("A_event.states.cum_pre_count[K]"); Serial.print(",");

  Serial.print("B_event.states.output"); Serial.print(",");
  Serial.print("B_event.states.cum_count"); Serial.print(",");
  Serial.print("B_event.states.cum_pre_count[K]"); Serial.print(",");

  Serial.print("B_event.states.r[Hz]"); Serial.print(",");
  Serial.print("B_event.states.target_rate[Hz]"); Serial.print(",");

  Serial.print("C_event.states.output"); Serial.print(",");
  Serial.print("C_event.states.cum_count"); Serial.print(",");
  Serial.print("C_event.states.cum_pre_count[K]"); Serial.print(",");

  Serial.print("C_event.states.r[Hz]"); Serial.print(",");
  Serial.print("C_event.states.target_rate[Hz]"); Serial.print(",");
  Serial.println();
}

void record_data_debug() {
  Serial.print(scope.states.trial); Serial.print(",");
  Serial.print(scope.states.is_on); Serial.print(",");

  Serial.print(norm_flow, 3); Serial.print(",");
  Serial.print(flow_ems.states.mean, 3); Serial.print(",");
  Serial.print(flow_ems.states.std, 3); Serial.print(",");

  Serial.print((double) flow_peaks.aux.thres_cross); Serial.print(",");
  Serial.print(flow_peaks.aux.sustained_cross / 1000.000, 3); Serial.print(",");
  Serial.print((double) flow_peaks.states.peak); Serial.print(",");

  Serial.print(sniff_rate.states.mean, 3); Serial.print(",");

  Serial.print(stim_label); Serial.print(",");
  
  Serial.print(*A_event.value); Serial.print(",");
  Serial.print((double) A_event.states.output); Serial.print(",");
  Serial.print((double) A_event.states.cum_count); Serial.print(",");
  Serial.print((double) A_event.states.cum_pre_count / 1000.000, 3); Serial.print(",");

  Serial.print((double) B_event.states.output); Serial.print(",");
  Serial.print((double) B_event.states.cum_count); Serial.print(",");
  Serial.print((double) B_event.states.cum_pre_count / 1000.000, 3); Serial.print(",");

  Serial.print((double) B_event.states.r / 1000.000, 6); Serial.print(",");
  Serial.print((double) B_event.states.target_rate / 1000.000, 3); Serial.print(",");

  Serial.print((double) C_event.states.output); Serial.print(",");
  Serial.print((double) C_event.states.cum_count); Serial.print(",");
  Serial.print((double) C_event.states.cum_pre_count / 1000.000, 3); Serial.print(",");

  Serial.print((double) C_event.states.r / 1000.000, 6); Serial.print(",");
  Serial.print((double) C_event.states.target_rate / 1000.000, 3); Serial.print(",");
}



/* 
  Printing options for each step, i.e. to troubleshoot or tune parameters for each step
  Note: The following are not always updated 
*/

// void inspect_flow_peak() {
//   // Serial.print(flow); Serial.print(",");
//   Serial.print(norm_flow); Serial.print(",");
//   Serial.print(flow_ems.states.mean); Serial.print(",");
//   Serial.print(flow_ems.states.std); Serial.print(",");
//   Serial.print((double) flow_peaks.aux.thres_cross / 5); Serial.print(",");
//   Serial.print(flow_peaks.aux.sustained_cross / 100); Serial.print(",");
//   Serial.print((double) flow_peaks.states.peak / 2); Serial.print(",");
// }

// void inspect_sniff_rate() {
//   Serial.print((double) flow_peaks.states.peak / 10); Serial.print(",");
//   Serial.print(sniff_rate.states.mean); Serial.print(",");
// }

// void inspect_event_A() {
//   Serial.print(sniff_rate.states.mean); Serial.print(",");
//   // Serial.print(*A_event.value); Serial.print(","); // should be exactly same as above
//   Serial.print((double) A_event.states.output*5); Serial.print(",");
//   // Serial.print(A_event.states.cum_count); Serial.print(",");
// }

// void inspect_event_B() {
//   Serial.print((double) B_event.states.r); Serial.print(",");
//   Serial.print((double) B_event.states.target_rate); Serial.print(",");
//   Serial.print((double) B_event.states.output); Serial.print(",");
//   Serial.print((double) B_event.states.cum_count / 10); Serial.print(",");
//   Serial.print((double) B_event.states.cum_pre_count / 10); Serial.print(",");
// }

// void inspect_events_B_C() {
//   Serial.print((double) B_event.states.r); Serial.print(",");
//   Serial.print((double) B_event.states.target_rate); Serial.print(",");
//   Serial.print((double) B_event.states.output); Serial.print(",");
//   Serial.print((double) B_event.states.cum_count / 10); Serial.print(",");
//   Serial.print((double) B_event.states.cum_pre_count / 10);
//   Serial.print((double) C_event.states.r); Serial.print(",");
//   Serial.print((double) C_event.states.target_rate); Serial.print(",");
//   Serial.print((double) C_event.states.output); Serial.print(",");
//   Serial.print((double) C_event.states.cum_count / 10); Serial.print(",");
//   Serial.print((double) C_event.states.cum_pre_count / 10); Serial.print(",");

// }

// void inspect_events_A_B() {
//   Serial.print(stim_label); Serial.print(",");
//   Serial.print(*A_event.value); Serial.print(",");
//   Serial.print((double) A_event.states.output); Serial.print(",");
//   Serial.print((double) A_event.states.cum_count / 10); Serial.print(",");
//   Serial.print((double) A_event.states.cum_pre_count / 10); Serial.print(",");
//   Serial.print((double) B_event.states.output); Serial.print(",");
//   Serial.print((double) B_event.states.cum_count / 10); Serial.print(",");
//   Serial.print((double) B_event.states.cum_pre_count / 10); Serial.print(",");
// }