/*
 * This file is part of the Diwa library.
 * Copyright (c) 2024 Nathanne Isip
 *
 * See Diwa project repository at https://github.com/nthnn/Diwa
 */
#include "diwa.h"

void setup() {
    // Initialize serial communication with a baud rate of 115200
    Serial.begin(115200);

    #if defined(ARDUINO_ARCH_ESP32)
    // Check the ESP32 PSRAM to initialize
    if(psramFound() && !psramInit()) {
        Serial.println(F("Cannot initialize PSRAM."));
        while(true);
    }
    #endif

    // Define training input and output data for XOR operation
    double trainingInput[4][2] = {{0, 0}, {0, 1}, {1, 0}, {1, 1}};
    double trainingOutput[4][1] = {{1}, {0}, {0}, {1}};

    // Create an instance of the Diwa neural network with 2 input neurons,
    // 1 hidden layer with 3 neurons, and 1 output neuron
    Diwa network;
    if(network.initialize(2, 1, 3, 1) != NO_ERROR) {
        Serial.println(F("Something went wrong initializing neural network."));
        while(true);
    }

    // Train the network for 3000 epochs using the XOR training data
    Serial.println(F("Starting training..."));
    for(uint32_t epoch = 0; epoch <= 5000; epoch++) {
        // Train the network for each set of input and target output values
        network.train(6, trainingInput[0], trainingOutput[0]);
        network.train(6, trainingInput[1], trainingOutput[1]);
        network.train(6, trainingInput[2], trainingOutput[2]);
        network.train(6, trainingInput[3], trainingOutput[3]);

        // Show accuracy and loss on training for every 100th epoch
        if((epoch % 1000 == 0) || epoch == 5000) {
            double accuracy = 0.0, loss = 0.0;

            // Calculate accuracy and loss for each training sample
            for(uint8_t i = 0; i < 4; i++) {
                accuracy += network.calculateAccuracy(trainingInput[i], trainingOutput[i], 3);
                loss += network.calculateLoss(trainingInput[i], trainingOutput[i], 3);
            }

            // Average accuracy and loss over all samples
            accuracy /= 4, loss /= 4;

            // Print the accuracy and loss
            Serial.print(F("Epoch: "));
            Serial.print(epoch);
            Serial.print(F("\t| Accuracy: "));
            Serial.print(accuracy * 100);
            Serial.print(F("%\t| Loss: "));
            Serial.print(loss * 100);
            Serial.println(F("%"));
        }
    }
    Serial.println(F("Training done!\r\n"));

    // Perform inference on the trained network and print the results
    Serial.println(F("Testing inferences..."));
    for(uint8_t i = 0; i < 4; i++) {
        // Get the current input row
        double* row = trainingInput[i];

        // Perform inference using the trained network
        double* inferred = network.inference(row);

        // Print the result for the current input
        char str[100];
        sprintf(str, "\t[%g, %g]: %d (%g)\n", row[0], row[1], (inferred[0] >= 0.5), inferred[0]);
        Serial.print(str);
    }
}

void loop() {
    delay(10);
}