/**************************************
/ Arduino sketch using the Posit library to demonstrate Reinforcement Learning algorithms 
/ on an Arduino board.
/
/ The target is to evaluate the feasibility of these algorithms on a simple Arduino Uno, despite its limitations
/ in terms of code size (30kB) and RAM (2kB). It is expected that only simple usages could fit, but that would
/ demonstrate the efficiency of Posits !
/
/ Contextual bandits : maximizing reward by choosing best action given some variables
/ MDP : consider 32 states, up to 8 actions and next states (8x32=256 bytes max)
/       storing Q? V?
/ SARSA : 
/ 
/ 
/ ***************************************/

#define ES8 2 // Need for integer temps
#include <Posit.h> // My library made it to the Arduino ecosystem !

// First step : conceptual bandits. With help from ChatGPT

#define ARMS 8            // Number of arms 
#define CONTEXT 2          // Number of context variables
#define VARIATION 10       // Maximum variation in rewards
#define RUNS 6*24           // Number of runs
#define EPSILON 0.3        // Exploration probability (epsilon-greedy)

// Global variables and arrays
int counts[ARMS];                     // Count of selections for each arm
int secret_rewards[ARMS];             // The "secret" average rewards to find
int contexts[CONTEXT];                // Array to hold context variables
Posit8 context_weights[ARMS][CONTEXT]; // Weight matrix to link contexts to arms
Posit8 rewards[ARMS];                  // Cumulative rewards for each arm

int freeRam() { // Routine from https://docs.arduino.cc/learn/programming/memory-guide/
  extern int __heap_start,*__brkval;
  int v;
  return (int)&v - (__brkval ? (int) __brkval : (int)&__heap_start);  
}

void setup() {
  Serial.begin(9600);
  randomSeed(millis()+analogRead(0));      // Seed the random number generator

  Serial.print(F("Free RAM at setup() : ")); Serial.println(freeRam());

  Serial.print(F("Secret rewards : "));
  
  // Initialize rewards, counts, and weights
  for (int i = 0; i < ARMS; i++) {
    rewards[i] = 0;
    secret_rewards[i] = random(100);
    Serial.print(secret_rewards[i]);
	  Serial.print(i==ARMS-1 ? "\n" : ", ");
    counts[i] = 0;
    for (int j = 0; j < CONTEXT; j++) {
      context_weights[i][j] = random(1, 10) / 10.0; // Random weights between 0.1 and 1.0
    }
  }
  
  Serial.println("Starting Contextual Bandit with Epsilon-Greedy...\n");
  Serial.println("Run | Context | Arm | Reward | Averages");
}

void loop() {
  for (int run = 0; run < RUNS; run++) {
    // Generate random context
    for (int i = 0; i < CONTEXT; i++) {
      contexts[i] = random(0, 10); // Example: context values between 0 and 9
    }

    // Compute expected rewards based on context
    Posit8  expected_rewards[ARMS] = {0};
    for (int i = 0; i < ARMS; i++) {
      for (int j = 0; j < CONTEXT; j++) {
        expected_rewards[i] = expected_rewards[i] + context_weights[i][j] * contexts[j];
      }
      // Add average reward if the arm has been selected at least once
      if (counts[i] > 0) {
        expected_rewards[i] = expected_rewards[i] + rewards[i] / counts[i];
      }
    }

    // Select arm using epsilon-greedy policy
    int selected_arm;
    if (random(0, 100) < EPSILON * 100) {      
      selected_arm = random(0, ARMS); // Explore: select a random arm
    } else { // Exploit: select the arm with the highest expected reward
      selected_arm = 0;
      float max_reward = expected_rewards[0];
      for (int i = 1; i < ARMS; i++) {
        if (expected_rewards[i] > max_reward) {
          max_reward = expected_rewards[i];
          selected_arm = i;
        }
      }
    }

    // Calculate reward for the selected arm
    int reward = secret_rewards[selected_arm] + random(VARIATION) - VARIATION/2;
    rewards[selected_arm] += reward;
    counts[selected_arm]++;
    
    // Print the system's evolution
    //Serial.print("Run: ");
    Serial.print(run);
    Serial.print(" | "); //Serial.print(" | Context: ");
    for (int i = 0; i < CONTEXT; i++) {
      Serial.print(contexts[i]);
      if (i < CONTEXT - 1) Serial.print(", ");
    }
    Serial.print(" | "); //Serial.print(" | Selected Arm: ");
    Serial.print(selected_arm);
    Serial.print(" | "); //Serial.print(" | Reward: ");
    Serial.print(reward);
    Serial.print(" | "); //Serial.print(" | Average Rewards: ");
    for (int i = 0; i < ARMS; i++) {
      if (counts[i] > 0) {
        Serial.print(rewards[i] / counts[i], 2);
      } else {
        Serial.print("N/A");
      }
      if (i < ARMS - 1) Serial.print(", ");
    }
    Serial.println();
  }

   while (1); // End program after RUNS
}

/* // TODO adapt for HEATING
Posit8 wantedTemp=internalTemp=22 ; // The "context" variables
Posit8 externalTemp=15 ;
Posit16 generatedHeat=0 ; The result of heating
Posit16 energyUsed=0 ; the consumption of the system
uint8_t action=0 ; the number of the seleccted action 
Posit16

void actionNop () { // do nothing
  generatedHeat = energyUsed = 0;
  internalTemp = (internalTemp * 7 + externalTemp) /8; // averaging towards external temperature
}
void actionHeat () {
  generatedHeat = Posit16(21000) //
  energyUsed = Posit16(7000); // Heatpump
  internalTemp ++; // effect of heating
}

Posit16 calculateReward() {
  if (internalTemp < wantedTemp) return generatedHeat-energyUsed ;
  return -energyUsed; // If temperature high enough, no reward for result
} //*/

uno:A5.2
uno:A4.2
uno:AREF
uno:GND.1
uno:13
uno:12
uno:11
uno:10
uno:9
uno:8
uno:7
uno:6
uno:5
uno:4
uno:3
uno:2
uno:1
uno:0
uno:IOREF
uno:RESET
uno:3.3V
uno:5V
uno:GND.2
uno:GND.3
uno:VIN
uno:A0
uno:A1
uno:A2
uno:A3
uno:A4
uno:A5