# -*- coding: utf-8 -*-
"""Copy of previous_main.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1GpAop0kBJmQLiPJtnhTrXgrixNZlkm2P
"""

from google.colab import drive
drive.mount('/content/drive')

gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

from google.colab import drive
drive.mount('/content/drive')

import sys
sys.path.insert(0,'/content/drive/MyDrive/ICASSP2022')

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader,TensorDataset
from IPython import display
display.set_matplotlib_formats('svg')

import APBasicSimp as ap
import icassp_data_weval as dt
import icassp_ModelArch_weval_droppout as M
import t_loop as t

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

# Eandom seeding
seed = 10;
np.random.seed(seed)
torch.manual_seed(seed)


# Defining variables

grid = np.round(np.arange(-90,90.1)) # Angular grid
Nx =grid.shape[0]  # Number of angular grid

N_antenna = 50 # Number of Antenna
MM = 10 # Compressed dimensions
samples = 30


# For trainging data

obs = 10000 # Number of observations
theta = dt.remove_repeat(obs,9) # Generating data
theta = theta[:2000,:] # DOA's for training
batchsize = 64
#Extra----------------------
new_theta = np.array([[-8,-6,-4,-2,0,2,4,6,8],[-55,-48,-44,-20,8,20,31,41,45]])
theta[-2:,:] = new_theta
# ----------------------
SNR_db_train = np.random.randint(0,21,size = theta.shape[0])
sigmas_train = 10**(SNR_db_train/20)
sigmas2_train = sigmas_train**2
X = dt.array_received(theta,sigmas_train,samples) # Array received signa;
NUser = theta.shape[1] # Number of users
N_data  =theta.shape[0] # Number of training data
posterior = 1/Nx*np.ones((N_data,Nx)) # Uniform posterior
label = (1/NUser)*ap.label_gen(theta,1)[1]      
train_loader = dt.make_dataloader(theta,X,posterior,label,batchsize,True)


# Test Data

test_obs = 1000 # Number of observations
test_theta = dt.remove_repeat(test_obs,9) # Generating data
test_theta = test_theta[:500,:] # DOA's for training
SNR_db_test = np.random.randint(0,21,size = test_theta.shape[0])
sigmas_test = 10**(SNR_db_test/20)
sigmas2_test = sigmas_test**2
test_X = dt.array_received(test_theta, sigmas_test,samples) # Array received signa;
N_test  = test_theta.shape[0] # Number of training data
test_posterior = 1/Nx*np.ones((N_test,Nx)) # Uniform posterior
test_label = (1/NUser)*ap.label_gen(test_theta,1)[1]      
test_loader = dt.make_dataloader(test_theta,test_X,test_posterior,test_label,N_test,False)

initialLR = .1
stepsize = batchsize*len(train_loader)
L2lambda = 0
dropoutRate = 0

losses, net,train_post,train_theta,test_post,test_theta,test_losses,currentLR = M.trainModel(train_loader, test_loader,150, samples,grid,device,initialLR,True,batchsize,L2lambda,dropoutRate)
plt.plot(np.arange(len(losses)),losses)
plt.plot(np.arange(len(test_losses)),test_losses)
torch.save(net.state_dict(),'/content/drive/MyDrive/ICASSP2022/Saved Models/mixedsnr_150epoch.pt')
np.save('/content/drive/MyDrive/ICASSP2022/Saved Losses/losses_mixedsnr',losses,test_losses)

torch.save(net.state_dict(),'/content/drive/MyDrive/ICASSP2022/Saved Models/Model_1000data_150epochs_new.pt')

kk = train_post[20,:].cpu().detach().numpy()
plt.plot(grid,kk)

# Load model on cpu
net = M.createModel(initialLR,stepsize,L2lambda,dropoutRate)[0]
net.load_state_dict(torch.load('/content/drive/MyDrive/ICASSP2022/Saved Models/mixedsnr_150epoch.pt',map_location=torch.device('cpu')))
net.to(device)

net = M.createModel(initialLR,stepsize,L2lambda,dropoutRate)[0]
net.load_state_dict(torch.load('/content/drive/MyDrive/ICASSP2022/Saved Models/mixedsnr.pt'))

SNR_db = np.arange(0,21,2)
sigmas = 10**(SNR_db/20)
sigmas2 = sigmas**2

new_theta = np.array([-55,-48,-44,-20,8,20,31,41,45])
#new_theta = np.array([-8,-6,-4,-2,0,2,4,6,8])

nTrials = 300
#test_out_posterior = torch.zeros((len(sigmas),nTrials,Nx))
theta_o = np.zeros((len(sigmas),nTrials,9))
new_theta = new_theta.reshape(1,9)
for idx,sig in enumerate(sigmas): # Loop through different snr value
  
  for trial in range(nTrials): # For a particular SNR employ multiple trials
    
    test_X = dt.array_received(new_theta, sig,samples) # Array received signa;
    N_test  = new_theta.shape[0] # Number of training data

    test_posterior = 1/Nx*np.ones((N_test,Nx)) # Uniform posterior
    test_label = (1/NUser)*ap.label_gen(new_theta,1)[1]      
    test_loader = dt.make_dataloader(new_theta,test_X,test_posterior,test_label,N_test,False)

    net.eval()
    test_theta, test_X_data, test_posterior, test_label = next(iter(test_loader))
    test_X_data = test_X_data.to(device)
    test_posterior = test_posterior.to(device)
    test_batch = test_X_data.shape[0]
    R_test = R = torch.zeros((test_batch,10,10),dtype = torch.cfloat).to(device)
    test_out_posterior = t.net_loop(test_X_data,test_posterior,net,grid,samples,test_batch,R_test,device)
    kk = test_out_posterior.cpu().detach().numpy()
    sorted_ind = np.argsort(kk)
    theta_o[idx,trial,:] = np.sort(grid[sorted_ind[:,-9:]])

scipy.io.savemat('/content/drive/MyDrive/ICASSP2022/rmse_ml_data_case2_mixed_150ep.mat',{'theta_out':theta_o})

sigmas.size

import scipy.io

scipy.io.savemat('/content/drive/MyDrive/ICASSP2022/rmse_ml_data_case2_10db.mat',{'theta_out':theta_o})

a = np.random.rand(5,5)

scipy.io.savemat('/content/drive/MyDrive/ICASSP2022/bal.mat',{'var',})

kk = test_out_posterior.cpu().detach().numpy()
#kk[np.logical_and(kk<.02,kk>.001)] = 0
kkk = kk[3,:]
sorted_ind = np.argsort(kk)
theta_o = np.sort(grid[sorted_ind[:,-9:]])
plt.plot(grid,kkk)
print(theta_o)

theta_o[0,:,:]

sorted[:,-2:]

g1 = np.array([0,1,2])
print

g1[sorted[:-2]]

np.logical_and(a<5,a>1)

a

