#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Sep 18 00:07:39 2022

@author: saidurrahmanpavel
"""

import os 

path = '/Users/saidurrahmanpavel/Library/CloudStorage/OneDrive-TempleUniversity/My Google Drive/Temple Material/ASP Lab/ICASSP 2022/ML3/icassp2/Regression/Simplified'
os.chdir(path)

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader,TensorDataset
from IPython import display
display.set_matplotlib_formats('svg')

import APBasicSimp as ap
import icassp_data_v2 as dt
import icassp_ModelArch_tmp6 as M

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)


# Defining some variableobs 
seed = 10;
np.random.seed(seed)
torch.manual_seed(seed)

obs = 10000 # Number of observations
grid = np.round(np.arange(-90,90.1))
Nx =grid.shape[0]  # Number of angular grid
SNR_db = 20
sigmas = 10**(SNR_db/20)
sigmas2 = sigmas**2
sigmas2_assume = sigmas2*100
N_antenna = 50
MM = 10
theta = dt.remove_repeat() # Generating data
theta = theta[:10,:]
X = dt.array_received(theta, sigmas)

NUser = theta.shape[1]
N_samp  =theta.shape[0]

posterior = 1/Nx*np.ones((N_samp,Nx))

#posterior = .5*np.ones((N_samp,Nx))



# Generating Label
label = (1/NUser)*ap.label_gen(theta,1)[1] 
#label = ap.label_gen(data,1)[1] 

      

train_loader = dt.make_dataloader(theta,X,posterior,label,5)

#losses,net = M.trainModel(train_loader,10,device)

net,lossfun,optimizer = M.createModel()
net.to(device)


losses, net, out, ang = M.trainModel(train_loader, 10, 30,sigmas,grid,device)



