#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri May 13 20:46:22 2022

@author: saidurrahmanpavel
"""

import os

# path = '/Users/saidurrahmanpavel/My Drive/Temple Material/ASP Lab/Condtional Entropy Minimization- Journal/ML'

# os.chdir(path)
import numpy as np
import matplotlib.pyplot as plt
import torch
import math


def data_gen(l,h,N,n):
   
    
    # l = lower range
    # h = higher range
    # N = Number of observations
    #n = Number of angles in each observations
    # interval = discritizing resolution
    
    # Generating theta from an uniform distribution within the range of l to h degree
    # round it to 1 decimal point so that it looks like discritizing angular grid with 0.1  degree
    # interval
    
    theta = np.round(np.random.uniform(low = l, high = h, size=(N,n)),0) 
    return theta


def label_gen(theta,interval):
    
    # Generate labels for specific thetas
    # it will be in the formate 000000....10000...1000..
    # The grid where signal presents, label 1 ; otherwise labe 0
    
    theta_r = np.round(np.arange(-90,90.1,interval),1) # angular grid from -90 to 90 degree 
                                                # with.1 degree interva;
    
    label = np.zeros((theta.shape[0],theta_r.shape[0])) # Initialize labels with all 
                                                        # zeros
    
    
    for ii in range(theta.shape[0]):
        theta_i = theta[ii,:] # Grab a specific DOA observations
        
        for i in theta_i:
            l = np.where(theta_r==i)[0][0] 
            label[ii,l]=1 # make labels equal to 1 where DOA set match with the grid 
        
    return theta_r,label

def steer(n,d,phi):
    
    # Provide array manifold
    # n = numer of Antenna
    # d = antenna separation 
    # phi = doa's with size n x 1
    
    fac = np.pi/180
    out = np.exp(-1j*2*np.pi*np.arange(n).reshape(n,1)*d*np.sin(fac*phi));
    return out

def steer_tensor(n,d,phi):
    
    # Provide array manifold
    # Instead of retunring numpy array, it will return tensor. 
    # n = numer of Antenna
    # d = antenna separation 
    # phi = doa's with size n x 1
    


    pi = torch.tensor(math.pi)
    fac = pi/180
    out = torch.exp(-1j*2*pi*torch.arange(n).reshape(n,1)*d*torch.sin(fac*phi));
    return out


def arr_received(n_antenna,theta,T,snr):
    #n_antenna = Total number of antenna
    # theta = specific DOA scenario
    # T = Number of snapshots
    # snr = snr in db
    
    # This finction generate array received signal X. X= As+n
    
    sigma = 10**(snr/20) # SNR db to values
    k = len(theta) # Number of source 
    
    s = (np.random.randn(k,T)+ 1j*np.random.randn(k,T))*sigma
    
    A = steer(n_antenna,.5,theta) # Array manifold
    
    noise = np.random.randn(n_antenna,T)+1j*np.random.randn(n_antenna,T)
    X = A @ s + noise
    
    
    return X

def arr_received_tensor(n_antenna,theta,T,sigma):
    #n_antenna = Total number of antenna
    # theta = specific DOA scenario
    # T = Number of snapshots
    # snr = snr in db
    
    # This finction generate array received signal X. X= As+n
    

    

    k = len(theta) # Number of source 
    
    s = (torch.randn(k,T)+ 1j*torch.randn(k,T))*sigma
    
    A = steer_tensor(n_antenna,.5,theta) # Array manifold
    
    noise = torch.randn(n_antenna,T)+1j*torch.randn(n_antenna,T)
    X = A @ s + noise
    
    
    return X

def op_angle(low,high,interval,R,n):
    
    # This function generate cupon spectrum and extracts estimated angles
    # low = lower range of theta
    # high = higher tange of theta
    # interval = interval for discretizing grid
    # R = covariance matrix
    # n = number of sources
    
    theta_r = np.round(np.arange(low,high+interval,.1),1)
    theta_r = theta_r.reshape(1,theta_r.shape[0])
    p = np.zeros(theta_r.shape)
    n_antenna = R.shape[0]
    
    for index,ii in enumerate(theta_r[0]):
        aa = steer(n_antenna,.5,ii)
        p[:,index] = np.abs(1/(aa.conj().T@np.linalg.inv(R)@aa))
    
    p = p/np.sum(p) 
    sorted_ind = np.argsort(p)
    theta_o = theta_r[:,sorted_ind[:,-n:]]
    
    
    return theta_r[0],np.sort(theta_o)[0,0],p[0]

def op_angle_tensor(grid,Phi,n_antenna,R,n):
    
    # This function generate cupon spectrum and extracts estimated angles
    # Instead of returning numpy array it will return torch tensor
    
    # low = lower range of theta
    # high = higher tange of theta
    # interval = interval for discretizing grid
    #Phi = Compressive sampling matrix
    # n_antenna = Number of antenna
    # R = covariance matrix
    # n = number of sources
    
    theta_r = grid
    theta_r = theta_r.reshape(1,theta_r.shape[0])
    p = torch.zeros((Phi.shape[0],theta_r.shape[1]))

    R_inv = torch.inverse(R).cpu()
   # n_antenna = R.shape[0]
    
    for index,ii in enumerate(theta_r[0]):
        aa = steer_tensor(n_antenna,.5,ii)
        bb = Phi.cpu() @ aa
        # p[:,index] = torch.abs(torch.real((bb.conj().T@bb)/(bb.conj().T@R_inv@bb)))
        p[:,index] = torch.squeeze(torch.abs(((bb.conj().mT@bb)/(bb.conj().mT@R_inv@bb)).real))
    
    p = p/((torch.sum(p,axis=1)).view(-1,1)) 
    #q = q/torch.sum(q)
    #p = (p-torch.min(p))/(torch.max(p)-torch.min(p))
    # sorted_ind = torch.argsort(p)
    # theta_o = theta_r[:,sorted_ind[:,-n:]]
    
    
    return p


def update_posterior(Phi,y,prior,low,high,interval,N,sigmas,seed):
    
    # Based on the prior and the next measurement, this function will compute posterior
    # low = lower range of theta
    # high = higher tange of theta
    # interval = interval for discretizing grid
    # Phi: Compressed sampling matrix
    # prior: PMF of the prior information
    # N : Number of antenna
    
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    M = y.shape[1]
    
    theta_r = torch.round(torch.arange(low,high+interval,.1),decimals = 1)
    theta_r = theta_r.reshape(1,theta_r.shape[0])
    sigmas2 = sigmas**2
    
    for index,ii in enumerate(theta_r[0]):
        aa = steer_tensor(N,.5,ii,seed)
        SI_mat = sigmas2 * aa @ aa.conj().T + torch.eye(N)
        Cyy = Phi @ SI_mat @ Phi.conj().T
    
    
        f = torch.real(- M*torch.log(torch.tensor(torch.pi)) - torch.log(torch.det(Cyy)) - y.conj().T @ torch.inverse(Cyy) @y)
        #f = f.real
        
    posterior = (prior*f)/torch.sum(prior*f)
    
    return posterior
    




def label_processing(yHat,label_shape):
    
    # This function post process the output of the neural netwoek
    # The output of the neural netwrok is 2000
    # First make it 10 x 200 real valued matrix
    # Then make it 10 x 100 complex valued Matrix
    # This function finally return capon spectrum
    
    
    Ps = np.zeros((yHat.shape[0],label_shape)) # Initializing the capon spectrum
    
    for yi in range(yHat.shape[0]):
        ytmp = yHat[yi,:] # Grab a spcific yHat. 
        ytmp = ytmp.view(10,200) # Reshape it to 10 x 200
        ytmp_r = ytmp[:,:100] # Take 1st 100 columns which is corresponding to the real parts
        ytmp_i = ytmp[:,100:] #Take 1st 100 columns which is corresponding to the real parts
        ycomp = ytmp_r+1j*ytmp_i # Make complex valued in dimension 10*100
        Ry = ycomp @ ycomp.conj().T # Compute covariance matrix of dimension 10*10
        
        Ry = Ry.detach()
        theta_r,theta_o,p = op_angle(-90, 90, .1, Ry, 9) # Return capon spectrum 
        Ps[yi,:] = p
        
    return torch.tensor(Ps)
        

        


