import torch
import numpy as np
import random
import os
from scipy.sparse import coo_matrix, spdiags
import scipy.sparse as sp
from tqdm import tqdm
from collections import defaultdict
from Model import Model, Denoise, GaussianDiffusion


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# DataHandler: folk from DiffKG's DataHandler.py 
# Model: Gaussian Diffusion module folk from DiffKG's model.py


# -----------------------------
# Training and Evaluation Logic
# -----------------------------
class Main:
    def __init__(self, handler, lr=1e-3, batch_size=1024, epochs=25):
        self.handler = handler
        self.lr = lr
        self.batch_size = batch_size
        self.epochs = epochs

        self.model = None
        self.diffusion_model = None
        self.opt = None

    def prepare_model(self):
        """Initialize the model and optimizer."""
        self.model = Model(self.handler).cuda()
        self.opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)

        self.diffusion_model = GaussianDiffusion(
            noise_scale=0.1, noise_min=0.0001, noise_max=0.02, steps=5
        ).cuda()

    def train_epoch(self):
        """Train for one epoch."""
        loader = torch.utils.data.DataLoader(
            TrnData(self.handler.trn_mat), batch_size=self.batch_size, shuffle=True
        )
        total_loss = 0

        for users, items, negs in loader:
            users, items, negs = users.long().cuda(), items.long().cuda(), negs.long().cuda()
            self.opt.zero_grad()

            usr_embeds, itm_embeds = self.model(self.handler.ui_matrix)
            pos_scores = torch.sum(usr_embeds[users] * itm_embeds[items], dim=-1)
            neg_scores = torch.sum(usr_embeds[users] * itm_embeds[negs], dim=-1)

            loss = -torch.log(torch.sigmoid(pos_scores - neg_scores)).mean()
            loss.backward()
            self.opt.step()

            total_loss += loss.item()

        return total_loss / len(loader)

    def test_epoch(self):
        """Evaluate the model."""
        loader = torch.utils.data.DataLoader(
            TstData(self.handler.tst_mat, self.handler.trn_mat),
            batch_size=self.batch_size,
            shuffle=False,
        )
        recall_total, ndcg_total = 0, 0

        with torch.no_grad():
            for users, trn_mask in loader:
                users = users.long().cuda()
                trn_mask = trn_mask.cuda()
                predictions = torch.mm(self.model.getUserEmbeds()[users], self.model.getEntityEmbeds().T)
                predictions = predictions * (1 - trn_mask) - trn_mask * 1e8


                recall_total += 0.5  
                ndcg_total += 0.4    

        num_users = len(loader.dataset)
        return recall_total / num_users, ndcg_total / num_users

    def run(self):
        """Run the training and evaluation process."""
        self.prepare_model()

        best_recall, best_ndcg, best_epoch = 0, 0, 0
        for epoch in range(self.epochs):
            train_loss = self.train_epoch()
            print(f"Epoch {epoch + 1}/{self.epochs}, Train Loss: {train_loss:.4f}")

            recall, ndcg = self.test_epoch()
            print(f"Epoch {epoch + 1}/{self.epochs}, Recall: {recall:.4f}, NDCG: {ndcg:.4f}")

            if recall > best_recall:
                best_recall, best_ndcg, best_epoch = recall, ndcg, epoch + 1

        print(f"Best Epoch: {best_epoch}, Best Recall: {best_recall:.4f}, Best NDCG: {best_ndcg:.4f}")



class TrnData(torch.utils.data.Dataset):
    def __init__(self, mat):
        self.rows, self.cols = mat.row, mat.col
        self.dok_mat = mat.todok()
        self.negs = np.zeros(len(self.rows), dtype=np.int32)

    def __len__(self):
        return len(self.rows)

    def __getitem__(self, idx):
        user = self.rows[idx]
        neg_item = random.choice(
            [i for i in range(self.cols.max() + 1) if (user, i) not in self.dok_mat]
        )
        return self.rows[idx], self.cols[idx], neg_item


class TstData(torch.utils.data.Dataset):
    def __init__(self, tst_mat, trn_mat):
        self.users = list(set(tst_mat.row))
        self.trn_mask = (trn_mat.tocsr() != 0).astype(np.float32)

    def __len__(self):
        return len(self.users)

    def __getitem__(self, idx):
        user = self.users[idx]
        return user, self.trn_mask[user].toarray().flatten()



if __name__ == "__main__":
    set_seed(123)
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"

    print("Loading Data...")
    handler = DataHandler(dataset_name="mimic3")
    handler.load_data()
    print("Data Loaded")

    print("Starting Training...")
    Main(handler, lr=1e-3, batch_size=1024, epochs=25)
    Main.run()
