#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Oct  8 22:46:23 2022

@author: saidurrahmanpavel
"""

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader,TensorDataset

import APBasicSimp as ap

def net_loop(X_data,posterior,net,grid,samples,batchsize,R,device):

    

    for samp in range(samples): # Iterate through the samples in X
    
    
    # Computing A matrix as the NN output and do some postprocessing
        
        samp_normalized = (samp/(samples-1))*torch.ones(batchsize,1)
        samp_normalized = samp_normalized.to(device)
        
        inp = torch.cat((posterior,samp_normalized),1)
        
        phi_tmp = net(inp)
        
        phi_shaped = phi_tmp.view(-1,10,100)
        
        phi_r =phi_shaped[:,:,:50]
        phi_i = phi_shaped[:,:,50:]
        phi = phi_r+1j*phi_i
        
        X = X_data[:,:,samp]
        X = X.view(-1,50,1)
        

        y = phi@X  
        
      # Updating posterior
      
        R = R+y@y.conj().mT
        if samp > 10:
            posterior = ap.op_angle_tensor(grid,phi,50, R, 9)
            
        posterior = posterior.to(device)
        #print(f'Sample No: {samp}\n\n')
            
    return posterior