#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Sep 18 12:20:41 2022

@author: saidurrahmanpavel
"""

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader,TensorDataset

import APBasicSimp as ap

def remove_repeat(n_source):
    
    # np.random.seed(seed)
    # torch.manual_seed(seed)

    data = ap.data_gen(-90,90,10000,n_source) # Generat DOA set
    repeated_id = np.zeros(data.shape[0])
    for i in range(data.shape[0]):
        repeated_id[i] = len(np.unique(data[i,:]))
    repeated_id = repeated_id !=9
    dataset = np.delete(data,repeated_id,0)
    return dataset

def array_received(theta,sigmas):
    
    
    # Initializing a sample for X
    X = torch.zeros((theta.shape[0],50,100),dtype = torch.cfloat)
    
    # Computing X from theta for every data in a batchj
    for i in range(theta.shape[0]):
        theta_loop = torch.tensor(theta[i,:]).float()
        received_loop = ap.arr_received_tensor(50, theta_loop, 100,sigmas)
        X[i,:,:] = received_loop # Array received signal
    return X

def make_dataloader(theta,data,prior,label,batchsize):
    # Converting the data to tensor format
    thetaT = torch.tensor(theta).float()
    dataT = data.cfloat()
    labelT = torch.tensor(label).float()
    priorT = torch.tensor(prior).float()
    
    # Covert into pytorch dataset
    train_data = TensorDataset(thetaT,dataT,priorT,labelT)
    
    # Translet into dataloader object
    
    train_loader = DataLoader(train_data,batch_size = batchsize,shuffle = True,drop_last=True)
    return train_loader
    