/*
  www.aifes.ai
  https://github.com/Fraunhofer-IMS/AIfES_for_Arduino
  Copyright (C) 2020-2022  Fraunhofer Institute for Microelectronic Circuits and Systems.
  All rights reserved.

  AIfES is free software: you can redistribute it and/or modify
  it under the terms of the GNU Affero General Public License as published by
  the Free Software Foundation, either version 3 of the License, or
  (at your option) any later version.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU Affero General Public License for more details.

  You should have received a copy of the GNU Affero General Public License
  along with this program.  If not, see <https://www.gnu.org/licenses/>.
   
*/

//Serial keyword: "inference"

#include <aifes.h>                  // include the AIfES libary
#include "weights.h"

// ATTENTION!!!!!
// The array for the weights is in the "weights.h" file.
// The example uses the FlatWeights method. So all weights one after the other. Here is a tutorial:
// https://create.arduino.cc/projecthub/aifes_team/aifes-inference-tutorial-f44d96
// Add your weights in the "weights.h" file

// Network: 50-10-10-5

void setup() {
  Serial.begin(115200); //115200 baud rate (If necessary, change in the serial monitor)
  while (!Serial);

  delay(100);
  Serial.println(F("AIfES demo"));
  Serial.println(F("Type >inference< to start"));
}

void loop() {

  while(Serial.available() > 0 ){
    String str = Serial.readString();
    if(str.indexOf("inference") > -1){        // Keyword "inference"
      Serial.println(F("AIfES"));
      Serial.println();
      
      //Tensor for the input data
      float input_data[50] = {1.0f};                                        
      uint16_t input_shape[] = {1, 50};                                         
      aitensor_t input_tensor = AITENSOR_2D_F32(input_shape, input_data);      

      // ---------------------------------- Layer definition ---------------------------------------
    
      uint16_t input_layer_shape[] = {1, 50};
      ailayer_input_f32_t   input_layer    = AILAYER_INPUT_F32_A(2, input_layer_shape);
      ailayer_dense_f32_t   dense_layer_1  = AILAYER_DENSE_F32_A(10);
      ailayer_relu_f32_t    relu_layer_1   = AILAYER_RELU_F32_A();
      ailayer_dense_f32_t   dense_layer_2  = AILAYER_DENSE_F32_A(10);
      ailayer_relu_f32_t    relu_layer_2   = AILAYER_RELU_F32_A();
      ailayer_dense_f32_t   dense_layer_3  = AILAYER_DENSE_F32_A(5);
      ailayer_relu_f32_t    relu_layer_3   = AILAYER_RELU_F32_A();
    
      // --------------------------- Define the structure of the model ----------------------------
    
      aimodel_t model;  // AIfES model
      ailayer_t *x;     // Layer object from AIfES to connect the layers

      // Connect the layers to an AIfES model
      model.input_layer = ailayer_input_f32_default(&input_layer);
      x = ailayer_dense_f32_default(&dense_layer_1, model.input_layer);
      x = ailayer_relu_f32_default(&relu_layer_1, x);
      x = ailayer_dense_f32_default(&dense_layer_2, x);
      x = ailayer_relu_f32_default(&relu_layer_2, x);
      x = ailayer_dense_f32_default(&dense_layer_3, x);
      x = ailayer_relu_f32_default(&relu_layer_3, x);
      model.output_layer = x;
    
      aialgo_compile_model(&model); // Compile the AIfES model
      

      uint32_t parameter_memory_size = aialgo_sizeof_parameter_memory(&model);

      Serial.print(F("Flat weight memory [bytes]: "));
      Serial.println(parameter_memory_size);
      Serial.print(F(" "));

      if(parameter_memory_size != sizeof(FlatWeights))
      {
          Serial.println(F("Error: number of weights wrong!"));
          return;
      }

      aialgo_distribute_parameter_memory(&model, (void *) FlatWeights, parameter_memory_size);

      // ------------------------------------- Print the model structure ------------------------------------
      
      Serial.println(F("-------------- Model structure ---------------"));
      aialgo_print_model_structure(&model);
      Serial.println(F("----------------------------------------------\n"));
    
      // -------------------------------- Allocate and schedule the working memory for inference ---------
    
      // Allocate memory for result and temporal data
      uint32_t memory_size = aialgo_sizeof_inference_memory(&model);
      Serial.print(F("Required memory for intermediate results: "));
      Serial.print(memory_size);
      Serial.print(F(" bytes"));
      Serial.println();
      
      byte *memory_ptr = (byte *) malloc(memory_size);

      if(memory_ptr == NULL)
      {
          Serial.print(F("Not enough Memory"));
          for(;;) {}
      }
      
      // Here is an alternative if no "malloc" should be used
      // Do not forget to comment out the "free(memory_ptr);" at the end if you use this solution
      //byte memory_ptr[400];
    
      // Schedule the memory over the model
      aialgo_schedule_inference_memory(&model, memory_ptr, memory_size);
    
      // ------------------------------------- Run the inference ------------------------------------

      // Create an empty output tensor for the inference result
      uint16_t output_shape[] = {1, 5};
      float output_data[5] = {0.0f};                 // Empty data array of size output_shape
      aitensor_t output_tensor = AITENSOR_2D_F32(output_shape, output_data);
      
      aialgo_inference_model(&model, &input_tensor, &output_tensor); // Inference / forward pass
    
      // ------------------------------------- Print result ------------------------------------

      Serial.println(F(""));
      Serial.println(F("Results:"));
      Serial.println(output_data[0]);
      Serial.println(output_data[1]);
      Serial.println(output_data[2]);
      Serial.println(output_data[3]);
      Serial.println(output_data[4]);

      free(memory_ptr);

      Serial.println(F(""));
      Serial.println(F("Type >inference< to restart"));
    } else {
      Serial.println(F("unknown"));
    }
  }
}