#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Sep 18 00:07:39 2022

@author: saidurrahmanpavel
"""

import os 

path = '/Users/saidurrahmanpavel/Library/CloudStorage/OneDrive-TempleUniversity/My Google Drive/Temple Material/ASP Lab/new_test'
os.chdir(path)

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader,TensorDataset
from IPython import display
display.set_matplotlib_formats('svg')

import APBasicSimp as ap
import icassp_data_weval as dt
import icassp_ModelArch_weval as M
import t_loop as t

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)


# Eandom seeding
seed = 10;
np.random.seed(seed)
torch.manual_seed(seed)


# Defining variables

grid = np.round(np.arange(-90,90.1)) # Angular grid
Nx =grid.shape[0]  # Number of angular grid
SNR_db = 20
sigmas = 10**(SNR_db/20)
sigmas2 = sigmas**2
N_antenna = 50 # Number of Antenna
MM = 10 # Compressed dimensions
samples = 30


# For trainging data

obs = 10000 # Number of observations
theta = dt.remove_repeat(obs,9) # Generating data
theta = theta[:1000,:] # DOA's for training
X = dt.array_received(theta, sigmas,samples) # Array received signa;
NUser = theta.shape[1] # Number of users
N_data  =theta.shape[0] # Number of training data
posterior = 1/Nx*np.ones((N_data,Nx)) # Uniform posterior
#posterior = .5*np.ones((N_samp,Nx))
# Generating Label
label = (1/NUser)*ap.label_gen(theta,1)[1] 
#label = ap.label_gen(data,1)[1]       
train_loader = dt.make_dataloader(theta,X,posterior,label,32)


# Test Data

test_obs = 20 # Number of observations
test_theta = dt.remove_repeat(test_obs,9) # Generating data
test_theta = test_theta[:5,:] # DOA's for training
test_X = dt.array_received(test_theta, sigmas,samples) # Array received signa;
N_test  = test_theta.shape[0] # Number of training data
test_posterior = 1/Nx*np.ones((N_test,Nx)) # Uniform posterior
#posterior = .5*np.ones((N_samp,Nx))
# Generating Label
test_label = (1/NUser)*ap.label_gen(test_theta,1)[1] 
#label = ap.label_gen(data,1)[1]       
test_loader = dt.make_dataloader(test_theta,test_X,test_posterior,test_label,N_test)


net,lossfun,optimizer = M.createModel()
net.to(device)


losses, net, out, ang = M.trainModel(train_loader, test_loader,10, samples,sigmas,grid,device)


# Evaluation
# net.eval()
# test_theta, test_X_data, test_posterior, test_label = next(iter(test_loader))
# test_X_data = test_X_data.to(device)
# test_posterior = test_posterior.to(device)
# test_batch = test_X_data.shape[0]
# R_test = torch.zeros((test_batch,10,10))
# test_out_posterior = t.net_loop(test_X_data,test_posterior,net,grid,samples,test_batch,R_test)


