#include <avr/wdt.h>
const int rewardPin = 13; //Pin for the reward (for CS+)
const int rewardPin2 = 7; //Pin for the reward (for CS-)
const int lickPin = A0; //Pin to detect licks
const int laserPin = 8; //Pin connected to laser (for opto or photometry(?))
const int cameraPin = 12; //Pin for camera
const int cuePin = 11; //Pin for speaker
const float ref_voltage = 5.0; //Reference voltage for arduino
int currentTrial = 0; //reset current trial and ITI variables to 0
float ITI = 0;
int cue_time; //Duration of cues
int cueplus_frequency; //CS+ frequency (in kHz)
int cueminus_frequency;//CS- frequency (in kHz)
int trace_time; //End of cue to reward start duration
int reward_magnitude; //Time reward solenoid should be open. Should be determined by calibration
int reward_period; //Period including and immediately after reward delivery for stimulation
int num_trials; //Number trials
int num_cueplus_trials; //Number CS+ trials. Rest of trials default to CS-
float ITI_min; //Minimum value for ITI (in ms)
float ITI_max; //Maximum value for ITI (in ms)
int pulse_on; //Pulse on duration for laser
int pulse_off; //Pulse off duration for laser. For example, a pulse_on of 10 and pulse_off of 40 gives a 20hz stim.
bool stim_csplus_only; //Stim only on CS+ trials
bool stim_csminus_only; //Stim only on CS- trials
bool stim_at_cue; //Stim during cue period
bool stim_at_reward; //Stim during reward period
bool stim_at_trace; //Stim during cue-reward interval
bool laser_on; //Is stimulation going to happen on this experiment
bool trial_by_trial; //Only choose half of all possible stimulations in which to stimulate. Ex: stim_at_cue + trial_by_trial stims on half of all trials at stim_at_cue.
bool play_csminus; //Whether to actually play the CS-
bool stopProgram = false; //Control buttons to start/stop the experiment
bool startSession = false;
volatile bool lickDetected = false; //Changing boolean - I can't remember if this does anything
int* trial_type = nullptr; // Pointer to dynamically allocated array
int* trial_type_unordered = nullptr;
bool* stim_on_trial = nullptr; // Pointer to dynamically allocated array
bool licked;
void setup() {
pinMode(rewardPin, OUTPUT); // Solenoid output, TDT: PC1
pinMode(cameraPin, OUTPUT); // Camera output
pinMode(cuePin, OUTPUT); // Tone output, TDT: PC3
pinMode(rewardPin2, OUTPUT);
pinMode(laserPin, OUTPUT); // Laser output
pinMode(lickPin,INPUT); // Lick Pin
Serial.begin(9600);
delay(1000);
}
void readSerialCommands() {
//Function to read in serial commands from Python
while (Serial.available() > 0) {
String command = Serial.readStringUntil('\n');
Serial.print("Received Command: ");
Serial.println(command); // Debugging print
if (command == "P1") {
//Prime/empty solenoid 1
digitalWrite(rewardPin, HIGH);
Serial.println("Priming started for Solenoid 1.");
} else if (command == "S1") {
digitalWrite(rewardPin, LOW);
Serial.println("Priming stopped for Solenoid 1.");
} else if (command == "P2") {
//Prime/empty solenoid 2
digitalWrite(rewardPin2, HIGH);
Serial.println("Priming started for Solenoid 1.");
} else if (command == "S2") {
digitalWrite(rewardPin2, LOW);
Serial.println("Priming stopped for Solenoid 1.");
} else if (command == "M1") {
//Test manual reward delivery from solenoids
giveReward(rewardPin);
Serial.println("Manual reward delivered to Solenoid 1.");
} else if (command == "M2") {
giveReward(rewardPin2);
Serial.println("Manual reward delivered to Solenoid 2.");
} else if (command == "R") {
// Read and set parameters
cue_time = Serial.parseInt();
cueplus_frequency = Serial.parseInt();
cueminus_frequency = Serial.parseInt();
trace_time = Serial.parseInt();
reward_magnitude = Serial.parseInt();
reward_period = Serial.parseInt();
num_trials = Serial.parseInt();
num_cueplus_trials = Serial.parseInt();
ITI_min = Serial.parseInt();
ITI_max = Serial.parseInt();
pulse_on = Serial.parseInt();
pulse_off = Serial.parseInt();
stim_csplus_only = Serial.parseInt();
stim_csminus_only = Serial.parseInt();
stim_at_cue = Serial.parseInt();
stim_at_reward = Serial.parseInt();
stim_at_trace = Serial.parseInt();
laser_on = Serial.parseInt();
trial_by_trial = Serial.parseInt();
play_csminus = Serial.parseInt();
Serial.println("Parameters received.");
currentTrial = 0; // Start trials from the beginning
digitalWrite(cameraPin,HIGH);
delay(10);
digitalWrite(cameraPin,LOW);
delay(50);
} else if (command == "T") {
//Terminate program
Serial.println("Received stop command");
stopProgram = true;
startSession = false;
Serial.println("Stopping Program");
Serial.println("End of Trials.");
} else if (command == "L") {
//Turn on laser to test it
digitalWrite(laserPin, HIGH);
delay(1000);
digitalWrite(laserPin, LOW);
} else if (command == "Q1") {
Serial.print("Testing CSplus tone at ");
Serial.println(cueplus_frequency);
//Test CS+
tone(cuePin, cueplus_frequency);
delay(2000);
noTone(cuePin);
delay(3000);
} else if (command == "Q2") {
//Test CS-
tone(cuePin, cueminus_frequency);
delay(2000);
noTone(cuePin);
delay(3000);
} else if (command == "S") {
//Start experiment
Serial.println("Starting Experiment");
startSession = true;
stopProgram = false;
}else if (command == "V11") {
//For calibration - V + solenoid number + open time to test.
int number_reps = Serial.parseInt();
int time_dispensed = Serial.parseInt();
Serial.println(number_reps);
Serial.println(time_dispensed);
for (int r = 0; r < number_reps; r++) {
Serial.println(number_reps);
digitalWrite(rewardPin,HIGH);
delay(time_dispensed);
digitalWrite(rewardPin,LOW);
delay(1000);
}
}else if (command == "V12") {
int number_reps = Serial.parseInt();
int time_dispensed = Serial.parseInt();
for (int r = 0; r < number_reps; r++) {
digitalWrite(rewardPin,HIGH);
delay(time_dispensed);
digitalWrite(rewardPin,LOW);
delay(1000);
}
}else if (command == "V13") {
int number_reps = Serial.parseInt();
int time_dispensed = Serial.parseInt();
for (int r = 0; r < number_reps; r++) {
digitalWrite(rewardPin,HIGH);
delay(time_dispensed);
digitalWrite(rewardPin,LOW);
delay(1000);
}
}else if (command == "V21") {
int number_reps = Serial.parseInt();
int time_dispensed = Serial.parseInt();
for (int r = 0; r < number_reps; r++) {
digitalWrite(rewardPin2,HIGH);
delay(time_dispensed);
digitalWrite(rewardPin2,LOW);
delay(1000);
}
}else if (command == "V22") {
int number_reps = Serial.parseInt();
int time_dispensed = Serial.parseInt();
for (int r = 0; r < number_reps; r++) {
digitalWrite(rewardPin2,HIGH);
delay(time_dispensed);
digitalWrite(rewardPin2,LOW);
delay(1000);
}
}else if (command == "V23") {
int number_reps = Serial.parseInt();
int time_dispensed = Serial.parseInt();
for (int r = 0; r < number_reps; r++) {
digitalWrite(rewardPin2,HIGH);
delay(time_dispensed);
digitalWrite(rewardPin2,LOW);
delay(1000);
}
} else {
Serial.println("Unknown command.");
}
delay(100); // Small delay to handle serial input properly
}
}
void loop() {
//Read in serial commands, then start experiment if signalled to do so
readSerialCommands();
send_voltage_vals();
int laser_frequency = 1000 / (pulse_on + pulse_off);
stop_program(stopProgram);
if(stopProgram) return;
while (startSession) {
readSerialCommands();
stop_program(stopProgram);
if(stopProgram) return;
//Order and shuffle CS+ and CS- trials
trial_type_unordered = new int[num_trials];
trial_type = new int[num_trials];
for (int i = 0; i < num_trials; i++) {
trial_type_unordered[i] = false;
}
for (int i = 0; i < num_cueplus_trials; i++) {
trial_type_unordered[i] = true;
}
bool cue_flagged = false;
shuffleArray(trial_type_unordered, trial_type, num_trials);
float ITI_vals[num_trials];
generate_exp_rand_numbers(ITI_min,ITI_max,num_trials,ITI_vals);
digitalWrite(cameraPin,HIGH);
delay(10);
digitalWrite(cameraPin,LOW);
//Loop through each trial, go through ITI, cue period, trace period, and reward period
for (int trial = 0; trial < num_trials; trial++) {
readSerialCommands();
send_voltage_vals();
stop_program(stopProgram);
if(stopProgram) return;
currentTrial = trial;
cue_flagged = false;
bool curr_trial_type = trial_type[trial];
if(trial == 0) {
ITI = ITI_vals[trial];
} else if (trial > 0) {
ITI = ITI_vals[trial]-reward_period;
}
Serial.println(ITI);
Serial.print("Trial: ");
Serial.print(currentTrial + 1);
Serial.print(", ITI: ");
Serial.println(ITI / 1000); // Convert ITI to seconds for readability
stim_on_trial = new bool[num_trials];
configure_stimulation(stim_on_trial, num_cueplus_trials, num_trials, trial_type, trial_by_trial, stim_csplus_only, stim_csminus_only, laser_on);
unsigned long ITI_starttime = millis();
while (millis() - ITI_starttime < ITI) {
readSerialCommands();
send_voltage_vals();
stop_program(stopProgram);
if(stopProgram) return;
monitorLicks();
}
unsigned long cue_starttime = millis();
while (millis() - cue_starttime < cue_time) {
readSerialCommands();
send_voltage_vals();
stop_program(stopProgram);
if(stopProgram) return;
monitorLicks();
if (trial_type[currentTrial]) {
tone(cuePin, cueplus_frequency);
if (!cue_flagged) {
Serial.print("CODE:4,TIME:");
Serial.println(millis());
cue_flagged = true;
}
} else if (!trial_type[currentTrial]) {
if (play_csminus) {
tone(cuePin, cueminus_frequency);
}
if (!cue_flagged) {
Serial.print("CODE:5,TIME:");
Serial.println(millis());
cue_flagged = true;
}
}
if (stim_on_trial[currentTrial] && stim_at_cue) {
laser(cue_time);
logStimulation(laser_frequency);
}
}
noTone(cuePin);
digitalWrite(cuePin, LOW);
unsigned long trace_starttime = millis();
while (millis() - trace_starttime < trace_time) {
readSerialCommands();
send_voltage_vals();
stop_program(stopProgram);
if(stopProgram) return;
monitorLicks();
if (stim_on_trial[currentTrial] && stim_at_trace) {
laser(trace_time);
logStimulation(laser_frequency);
}
}
deliver_reward(curr_trial_type, rewardPin, rewardPin2, reward_magnitude);
unsigned long reward_start_time = millis();
while (millis() - reward_start_time < reward_period) {
readSerialCommands();
send_voltage_vals();
stop_program(stopProgram);
if(stopProgram) return;
monitorLicks();
if (stim_on_trial[currentTrial] && stim_at_reward) {
laser(reward_period);
logStimulation(laser_frequency);
}
}
}
Serial.println("End of Trials.");
stopProgram = true;
stop_program(stopProgram);
}
}
void monitorLicks() {
//Look for licks and print signal to serial port
static int prevLickState = LOW; // Initialize previous state to LOW
int lickState = digitalRead(lickPin);
licked = (lickState == HIGH && prevLickState == LOW);
prevLickState = lickState;
if (licked) {
// Print lick code here if needed
Serial.print("CODE:2,TIME:");
Serial.println(millis());
delay(20);
} else {
delay(20);
}
}
void configure_stimulation(bool stim_on_trial[], int num_cueplus_trials,int num_trials,const int trial_type[],bool trial_by_trial,bool stim_csplus_only,bool stim_csminus_only,bool laser_on) {
//Determine which trials stimulation occurs on, based on parameters
if (laser_on) {
if(!trial_by_trial) {
for (int t = 0; t < num_trials; t++) {
stim_on_trial[t] = 1;
}
if(stim_csplus_only) {
for (int t = 0; t < num_trials; t++) {
if (trial_type) {
stim_on_trial[t] = 1;
} else {
stim_on_trial[t] = 0;
}
}
}else if (stim_csminus_only) {
for (int t = 0; t < num_trials; t++) {
if (!trial_type) {
stim_on_trial[t] = 1;
} else {
stim_on_trial[t] = 0;
}
}
} else {
for (int t=0;t<num_trials;t++) {
stim_on_trial[t] = 1;
}
}
}else if (trial_by_trial) {
int num_CSplus_indices = num_cueplus_trials;
int num_CSminus_indices = num_trials - num_cueplus_trials;
int CSplus_indices[num_CSplus_indices];
int CSminus_indices[num_CSminus_indices];
int CSplus_count = 0;
int CSminus_count = 0;
for (int t=0; t<num_trials;t++) {
if(trial_type[t]==1) {
CSplus_indices[CSplus_count] = t;
CSplus_count++;
}else {
CSminus_indices[CSminus_count] = t;
CSminus_count++;
}
}
int shuffledCSplus_indices[num_CSplus_indices];
int shuffledCSminus_indices[num_CSminus_indices];
shuffleArray(CSplus_indices,num_CSplus_indices,shuffledCSplus_indices);
shuffleArray(CSminus_indices,num_CSminus_indices,shuffledCSminus_indices);
int num_CSplus_stim_tbt = num_CSplus_indices/2;
int num_CSminus_stim_tbt = num_CSminus_indices/2;
int total_tbt = num_trials/2;
if (num_CSplus_stim_tbt + num_CSminus_stim_tbt < total_tbt) {
if(random(0,2)==0) {
num_CSplus_stim_tbt++;
}else if(random(0,2)==1) {
num_CSminus_stim_tbt++;
}
}
for (int t = 0; t < num_trials; t++) {
stim_on_trial[t] = 0;
}
if (stim_csplus_only) {
for (int t=0;t<num_CSplus_stim_tbt;t++) {
int trial_to_stim = shuffledCSplus_indices[t];
stim_on_trial[trial_to_stim] = 1;
}
} else if (stim_csminus_only) {
for (int t=0;t<num_CSplus_stim_tbt;t++) {
int trial_to_stim = shuffledCSminus_indices[t];
stim_on_trial[trial_to_stim] = 1;
}
} else {
for (int t=0;t<total_tbt;t++) {
if (isinarray(t,shuffledCSplus_indices,num_CSplus_stim_tbt) || isinarray(t,shuffledCSminus_indices,num_CSplus_stim_tbt)) {
stim_on_trial[t] = 1;
}
}
}
}
}
}
bool isinarray(int value, const int arr[],int size) {
for (int i = 0; i < size; i++) {
if (arr[i] == value) {
return true;
}
}
return false;
}
void laser(unsigned long duration) {
//Turn on laser
unsigned long startTime = millis(); // Record the start time
unsigned long currentTime = startTime;
if(laser_on) {
while (currentTime - startTime < duration) {
currentTime = millis(); // Update the current time
if (pulse_on != 0 && pulse_off != 0) {
unsigned long pulseEndTime = currentTime + pulse_on;
// Turn on the laser and wait until the pulse length is over
digitalWrite(laserPin, HIGH);
while (millis() < pulseEndTime);
// Turn off the laser and wait until the pulse off duration is over
unsigned long pulseOffEndTime = millis() + pulse_off;
digitalWrite(laserPin, LOW);
while (millis() < pulseOffEndTime);
}
}
}
}
void deliver_reward(bool trial_type, int rewardPin, int rewardPin2,int magnitude_reward) {
//For reward delivery in trials. Not the same as manual reward delivery
if (trial_type) {
digitalWrite(rewardPin, HIGH);
delay(magnitude_reward);
digitalWrite(rewardPin, LOW);
Serial.print("CODE:1,TIME:");
delay(10);
} else {
// digitalWrite(rewardPin2,HIGH);
delay(magnitude_reward);
// digitalWrite(rewardPin,LOW);
Serial.print("CODE:6,TIME:");
delay(10);
}
Serial.println(millis());
}
void shuffleArray(const int originalArray[], int shuffledArray[], int size) {
bool indicesUsed[size] = {false}; // Track whether an index is already used
for (int i = 0; i < size; i++) {
int randomIndex;
do {
randomIndex = random(0, size); // Generate a random index
} while (indicesUsed[randomIndex]); // Ensure it's unused
indicesUsed[randomIndex] = true; // Mark index as used
shuffledArray[i] = originalArray[randomIndex]; // Assign value to shuffled array
}
}
void logStimulation(int laser_frequency) {
//Send code through Serial port for stimulation
Serial.print("Delivering stimulation at frequency ");
Serial.println(laser_frequency);
Serial.print("CODE:3,TIME:");
Serial.println(millis());
}
void giveReward(int solenoidPin) {
//Send reward.
digitalWrite(solenoidPin, HIGH);
delay(30); // Keep the solenoid open for 1 second
digitalWrite(solenoidPin, LOW);
Serial.println("Manual reward delivered.");
}
void send_voltage_vals() {
int analogval = analogRead(lickPin);
float voltage = analogval * (ref_voltage / 1023.0);
Serial.print("VOLTAGE: ");
Serial.println(voltage,3);
delay(10);
}
void stop_program(bool stopProgram) {
if(stopProgram) {
startSession = false;
digitalWrite(laserPin, LOW);
digitalWrite(rewardPin, LOW);
digitalWrite(rewardPin2, LOW);
digitalWrite(cuePin,LOW);
// ✅ Free dynamically allocated memory
delete[] trial_type;
trial_type = nullptr;
delete[] stim_on_trial;
stim_on_trial = nullptr;
delete[] trial_type_unordered;
trial_type_unordered = nullptr;
// ✅ Reset variables
licked = false;
currentTrial = 0;
ITI = 0;
// ✅ Delay before reset (Optional: Adjust timing)
delay(500);
// ✅ Arduino Soft Reset Methods (Choose One)
readSerialCommands();
}
}
void generate_exp_rand_numbers(float ITI_min, float ITI_max, int num_trials, float ITI_vals[]) {
float ITI_mean = (ITI_min + ITI_max) / 2;
int num_numbers_generated = 0;
while (num_numbers_generated < num_trials) {
float x;
do {
float U = random(1, 10001) / 10000.0; // Generate new random U in (0,1]
x = -ITI_mean * log(U); // Exponential sampling
} while (x < ITI_min || x > ITI_max); // Reject if out of bounds
ITI_vals[num_numbers_generated] = round(x); // Store valid ITI value
num_numbers_generated++;
}
}