#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Oct  4 21:20:18 2022

@author: saidurrahmanpavel
"""

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Oct  1 11:28:26 2022

@author: saidurrahmanpavel
"""

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Sep 18 13:58:52 2022

@author: saidurrahmanpavel
"""

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat May 14 14:52:33 2022

@author: saidurrahmanpavel
"""

# import os

# path = '/Users/saidurrahmanpavel/My Drive/Temple Material/ASP Lab/Condtional Entropy Minimization- Journal/ML'

# os.chdir(path)

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader,TensorDataset
from IPython import display
display.set_matplotlib_formats('svg')

import APBasicSimp as ap
import icassp_data_v2 as dt

def createModel():
    
    class model(nn.Module):
        def __init__(self):
            super().__init__()
            
            
            self.input = nn.Linear(182,500)
            self.fc1 = nn.Linear(500,500)

            self.fc2 = nn.Linear(500,500)

            self.out = nn.Linear(500,10*100)
            
        def forward(self,x):
            x = F.relu(self.input(x))
            

            x = F.relu(self.fc1(x))
            

            x = F.relu(self.fc2(x))

            
            x = self.out(x)
            
            return x
            
    net = model()
    #lossfun = nn.BCEWithLogitsLoss()
    #lossfun = nn.BCELoss()
    lossfun = nn.MSELoss()
    optimizer = torch.optim.Adam(net.parameters(),lr = .01)

            
    
    return net,lossfun,optimizer

def trainModel(train_loader,numepochs,samples,sigmas,grid,device):
    

    
    net,lossfun,optimizer = createModel()
    net.to(device)
    losses = np.zeros(numepochs)
    
    for epochi in range(numepochs):
        
        batchLoss = []
        
        for theta, X_data, posterior, label in train_loader:
            print(f'Epoch{epochi}, batch{len(batchLoss)}')
            
            
            
            X_data = X_data.to(device)
            posterior = posterior.to(device)
            label = label.to(device)
            batchsize = X_data.shape[0]
            
            R = torch.zeros((batchsize,10,10))
            for samp in range(samples): # Iterate through the samples in X
            
            
            # Computing A matrix as the NN output and do some postprocessing
                
                samp_normalized = samp/(samples-1)
                
                inp = torch.cat((posterior,samp_normalized*torch.ones(posterior.shape[0],1)),1)
                
                phi_tmp = net(inp)
                phi_tmp = phi_tmp.cpu()
                phi_shaped = phi_tmp.view(-1,10,100)
                
                phi_r =phi_shaped[:,:,:50]
                phi_i = phi_shaped[:,:,50:]
                phi = phi_r+1j*phi_i
                
                X = X_data[:,:,samp]
                X = X.view(-1,50,1)
                

                y = phi@X  
                
                

                

                

              # Updating posterior
              
                R = R+y@y.conj().mT
                if samp > 10:
                    posterior = ap.op_angle_tensor(grid,phi,50, R, 9)
                    
                
                
              


                print(f'Sample No: {samp}\n\n')
                
                    


            loss = lossfun(posterior,label)
 
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            
            batchLoss.append(loss.item())
            
        losses[epochi] = np.mean(batchLoss)
        print(f'Loss in epoch {epochi} is {losses[epochi]}')
        
    
    return losses,net,posterior,theta