#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Oct  4 21:20:18 2022

@author: saidurrahmanpavel
"""

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Oct  1 11:28:26 2022

@author: saidurrahmanpavel
"""

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Sep 18 13:58:52 2022

@author: saidurrahmanpavel
"""

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat May 14 14:52:33 2022

@author: saidurrahmanpavel
"""

# import os

# path = '/Users/saidurrahmanpavel/My Drive/Temple Material/ASP Lab/Condtional Entropy Minimization- Journal/ML'

# os.chdir(path)

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader,TensorDataset
from IPython import display
display.set_matplotlib_formats('svg')

import APBasicSimp as ap
import icassp_data_v2 as dt
import t_loop as t

def createModel(initialLR,stepsize,L2lambda,dropoutRate):
    
    class model(nn.Module):
        def __init__(self,dropoutRate):
            super().__init__()
            
            
            self.input = nn.Linear(182,500)
            self.fc1 = nn.Linear(500,500)
            #self.bnorm1 = nn.BatchNorm1d(500)

            self.fc2 = nn.Linear(500,500)
            #self.bnorm2 = nn.BatchNorm1d(500)

            self.fc3 = nn.Linear(500,500)
            #self.bnorm3 = nn.BatchNorm1d(500)

            ### New Added ##
            self.fc4 = nn.Linear(500,500)
            #################

            self.out = nn.Linear(500,10*100)
            self.dr = dropoutRate
            
        def forward(self,x):
          
          x = F.relu(self.input(x))
          x = F.dropout(x,p=self.dr,training = self.training)
          

          #x = self.bnorm1(x)
          
          x = F.relu(self.fc1(x))
          x = F.dropout(x,p=self.dr,training = self.training)
          

          #x = self.bnorm2(x)
          
          x = F.relu(self.fc2(x))
          x = F.dropout(x,p=self.dr,training = self.training)

          #x = self.bnorm3(x)
          
          x = F.relu(self.fc3(x))
          x = F.dropout(x,p=self.dr,training = self.training)

         ### New Added ###
          x = F.relu(self.fc4(x))
          x = F.dropout(x,p=self.dr,training = self.training)
          ##################

          
          x = self.out(x)

  

            
          return x
            
    net = model(dropoutRate)
    #lossfun = nn.BCEWithLogitsLoss()
    #lossfun = nn.BCELoss()
    lossfun = nn.MSELoss()
    optimizer = torch.optim.Adam(net.parameters(),lr = initialLR,weight_decay=L2lambda)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size = stepsize,gamma=.5)

            
    
    return net,lossfun,optimizer,scheduler

def trainModel(train_loader,test_loader,numepochs,samples,grid,device,initialLR,toggleLR,batchsize,L2lambda,dropoutRate):
    

    stepsize = batchsize*len(train_loader)
    net,lossfun,optimizer,scheduler = createModel(initialLR,stepsize,L2lambda,dropoutRate)
    net.to(device)
    losses = np.zeros(numepochs)
    test_losses = np.zeros(numepochs)
    currentLR = []
    
    for epochi in range(numepochs):
        
        net.train()
        batchLoss = []
        
        for theta, X_data, posterior, label in train_loader:
            print('------------------')
            print(f'Epoch{epochi}, batch{len(batchLoss)}')
            print('------------------')
            X_data = X_data.to(device)
            posterior = posterior.to(device)
            label = label.to(device)

            #samples = X_data.shape[2]
            batchsize = X_data.shape[0]
            R = torch.zeros((batchsize,10,10),dtype = torch.cfloat).to(device)

            posterior = t.net_loop(X_data,posterior,net,grid,samples,batchsize,R,device)
            
            loss = lossfun(posterior,label)
 
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Learning rate scheduler
            if toggleLR:
              scheduler.step()


            
            batchLoss.append(loss.item())
            currentLR.append(scheduler.get_last_lr()[0])
            
        losses[epochi] = np.mean(batchLoss)
        print(f'Loss in epoch {epochi} is {losses[epochi]}')
        
        # testing
        net.eval()
        test_theta, test_X_data, test_posterior, test_label = next(iter(test_loader))
        test_X_data = test_X_data.to(device)
        test_posterior = test_posterior.to(device)
        test_label = test_label.to(device)
        test_batch = test_X_data.shape[0]
        R_test = R = torch.zeros((test_batch,10,10),dtype = torch.cfloat).to(device)
        test_out_posterior = t.net_loop(test_X_data,test_posterior,net,grid,samples,test_batch,R_test,device)
        test_losses[epochi] = lossfun(test_out_posterior,test_label)

        
        
        

        
        

        
        
    
    return losses,net,posterior,theta,test_out_posterior,test_theta,test_losses,currentLR