import numpy as np
import torch, os, cv2
import albumentations as aug
from albumentations import ReplayCompose
import matplotlib.pyplot as plt
import glob

from utils import generate_gaussian, pre_process, get_tip_point
from utils import pre_process, gaussian_1D_label


def single_scale_retinex(img, sigma):
    img = img.astype(np.float32) + 1.0  # avoid log(0)
    blur = cv2.GaussianBlur(img, (0, 0), sigma)
    retinex = np.log10(img) - np.log10(blur + 1.0)
    return retinex


def multi_scale_retinex(img, sigmas=(15, 80, 250)):
    retinex = np.zeros_like(img, dtype=np.float32)
    for sigma in sigmas:
        retinex += single_scale_retinex(img, sigma)
    retinex /= len(sigmas)
    return retinex


def simplest_color_balance(img, low_clip=1, high_clip=99):
    result = np.zeros_like(img)
    for i in range(3):
        low_val = np.percentile(img[:, :, i], low_clip)
        high_val = np.percentile(img[:, :, i], high_clip)
        channel = np.clip(img[:, :, i], low_val, high_val)
        result[:, :, i] = cv2.normalize(channel, None, 0, 255, cv2.NORM_MINMAX)
    return result


def apply_color_balance(image):
    img_float = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
    msr = multi_scale_retinex(img_float)

    # Normalize and color balance
    msr = np.clip(msr, -5, 5)  # keep range reasonable
    msr = ((msr - np.min(msr)) / (np.max(msr) - np.min(msr)) * 255).astype(np.uint8)
    msr = cv2.merge([msr[:, :, 0], msr[:, :, 1], msr[:, :, 2]])
    return simplest_color_balance(msr)


def log_transform(image_gray):
    """Apply logarithmic transform to a grayscale image."""
    # Convert to float for precision
    image_float = image_gray.astype(np.float32)

    # Normalize to [0,1]
    image_float /= 255.0

    # Apply log(1+x)
    log_image = np.log1p(image_float)

    # Rescale back to [0,255]
    log_image = cv2.normalize(log_image, None, 0, 255, cv2.NORM_MINMAX)
    log_image = np.uint8(log_image)

    return log_image


class Dataset(torch.utils.data.Dataset):
    def __init__(self, main_path, list_folders, img_size=512, angle_step=5, smooth_sigma=4, transform=None):
        self.main_path = main_path
        self.imgs_dir = "imgs"
        self.masks_dir = "masks"
        self.angle_dir = "angles"
        self.img_h = img_size
        self.img_w = img_size

        self.angle_step = angle_step
        self.angle_dim = 360 // self.angle_step
        self.smooth_sigma = smooth_sigma
        print("Angle step", self.angle_step, "Angle dim", self.angle_dim)

        self.files = []
        for folder in list_folders:
            img_path = os.path.join(self.main_path, folder, self.imgs_dir)
            img_files_left = glob.glob(os.path.join(img_path, "*_left.jpg"))
            # img_files_left.sort()
            self.files += img_files_left

        print(f"Creating dataset with {len(self.files)} examples")

        ##########
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def show_datasample(self, img_l, img_r, heatmap_l, heatmap_r, label):

        img_l = img_l.numpy().transpose((1, 2, 0))
        img_r = img_r.numpy().transpose((1, 2, 0))

        heatmap_l = heatmap_l.numpy().squeeze()
        heatmap_r = heatmap_r.numpy().squeeze()

        label = label.numpy().squeeze()
        label_argmax = np.argmax(label)
        angle = (label_argmax * (360 / self.angle_dim)) % 360

        fig, ax = plt.subplots(2, 3, figsize=(10, 5))
        ax[0, 0].imshow(img_l)
        ax[0, 0].imshow(heatmap_l[0], alpha=0.5)
        ax[1, 0].imshow(img_l)
        ax[1, 0].imshow(heatmap_l[1], alpha=0.5)
        ax[0, 0].set_title("Left Image")
        ax[0, 1].imshow(img_r)
        ax[0, 1].imshow(heatmap_r[0], alpha=0.5)
        ax[1, 1].imshow(img_r)
        ax[1, 1].imshow(heatmap_r[1], alpha=0.5)
        ax[0, 1].set_title("Right Image")
        ax[0, 2].plot(np.arange(0, self.angle_dim), label, label="label")
        ax[0, 2].set_title(f"Angle: {angle}")
        plt.tight_layout()
        plt.show()

    def __getitem__(self, i):

        img_file_name = self.files[i]
        mask_file_name = img_file_name.replace(self.imgs_dir, self.masks_dir).replace(".jpg", ".png")

        img_right_file_name = img_file_name.replace("_left.jpg", "_right.jpg")
        mask_right_file_name = img_right_file_name.replace(self.imgs_dir, self.masks_dir).replace(".jpg", ".png")

        # axial angle
        a_name = img_file_name.replace(self.imgs_dir, self.angle_dir).replace(".jpg", ".txt").replace("_left", "")
        angle = np.rad2deg(np.loadtxt(a_name))  # + 180

        img_l = np.array(cv2.cvtColor(cv2.imread(img_file_name, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB))
        mask_l = np.array(cv2.imread(mask_file_name, cv2.IMREAD_GRAYSCALE))
        img_r = np.array(cv2.cvtColor(cv2.imread(img_right_file_name, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB))
        mask_r = np.array(cv2.imread(mask_right_file_name, cv2.IMREAD_GRAYSCALE))

        img_l = cv2.resize(img_l, (self.img_w, self.img_h), interpolation=cv2.INTER_LINEAR)
        img_r = cv2.resize(img_r, (self.img_w, self.img_h), interpolation=cv2.INTER_LINEAR)
        mask_l = cv2.resize(mask_l, (self.img_w, self.img_h), interpolation=cv2.INTER_NEAREST)
        mask_r = cv2.resize(mask_r, (self.img_w, self.img_h), interpolation=cv2.INTER_NEAREST)

        # img_l = apply_color_balance(img_l)
        # img_r = apply_color_balance(img_r)

        if self.transform is not None:
            # Apply transforms to both images with ReplayCompose to replay the same transform
            replay_compose = ReplayCompose([self.transform])
            replay_l = replay_compose(image=img_l, mask=mask_l)
            img_l, mask_l = replay_l["image"], replay_l["mask"]
            replay_r = aug.ReplayCompose.replay(replay_l["replay"], image=img_r, mask=mask_r)
            img_r, mask_r = replay_r["image"], replay_r["mask"]

        # get tip
        tip_l, tail_l = get_tip_point(mask_l, self.img_h, self.img_w)
        tip_r, tail_r = get_tip_point(mask_r, self.img_h, self.img_w)

        # normalization
        tip_l_normalized = (np.array(tip_l) / np.array([self.img_w, self.img_h]) * 2) - 1
        tail_l_normalized = (np.array(tail_l) / np.array([self.img_w, self.img_h]) * 2) - 1
        tip_r_normalized = (np.array(tip_r) / np.array([self.img_w, self.img_h]) * 2) - 1
        tail_r_normalized = (np.array(tail_r) / np.array([self.img_w, self.img_h]) * 2) - 1

        # heatmap
        H_l = torch.zeros(2, self.img_h, self.img_w)
        H_l[0] = generate_gaussian(H_l[0], tip_l_normalized[0], tip_l_normalized[1], sigma=10)
        H_l[1] = generate_gaussian(H_l[1], tail_l_normalized[0], tail_l_normalized[1], sigma=10)

        H_r = torch.zeros(2, self.img_h, self.img_w)
        H_r[0] = generate_gaussian(H_r[0], tip_r_normalized[0], tip_r_normalized[1], sigma=10)
        H_r[1] = generate_gaussian(H_r[1], tail_r_normalized[0], tail_r_normalized[1], sigma=10)

        # HWC to CHW -> numpy to tensor
        img_l = torch.from_numpy(pre_process(img_l)).type(torch.FloatTensor)
        img_r = torch.from_numpy(pre_process(img_r)).type(torch.FloatTensor)

        # label angle
        angle_bin = np.round(angle / self.angle_step)
        angle_curve = gaussian_1D_label(angle_bin, self.angle_dim, sig=self.smooth_sigma)
        label_out = torch.tensor(angle_curve).type(torch.FloatTensor)

        return img_l, img_r, H_l, H_r, label_out


class DatasetGrayscale(torch.utils.data.Dataset):
    def __init__(self, main_path, list_folders, img_size=512, angle_step=5, smooth_sigma=4, transform=None):
        self.main_path = main_path
        self.imgs_dir = "imgs"
        self.masks_dir = "masks"
        self.angle_dir = "angles"
        self.img_h = img_size
        self.img_w = img_size

        self.angle_step = angle_step
        self.angle_dim = 360 // self.angle_step
        self.smooth_sigma = smooth_sigma
        print("Angle step", self.angle_step, "Angle dim", self.angle_dim)

        self.files = []
        for folder in list_folders:
            img_path = os.path.join(self.main_path, folder, self.imgs_dir)
            img_files_left = glob.glob(os.path.join(img_path, "*_left.jpg"))
            # img_files_left.sort()
            self.files += img_files_left

        print(f"Creating dataset with {len(self.files)} examples")

        ##########
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def show_datasample(self, img_l, img_r, heatmap_l, heatmap_r, label):

        img_l = img_l.numpy().transpose((1, 2, 0))
        img_r = img_r.numpy().transpose((1, 2, 0))

        heatmap_l = heatmap_l.numpy().squeeze()
        heatmap_r = heatmap_r.numpy().squeeze()

        label = label.numpy().squeeze()
        label_argmax = np.argmax(label)
        angle = (label_argmax * (360 / self.angle_dim)) % 360

        fig, ax = plt.subplots(2, 3, figsize=(10, 5))
        ax[0, 0].imshow(img_l, cmap="gray")
        ax[0, 0].imshow(heatmap_l[0], alpha=0.5)
        ax[1, 0].imshow(img_l, cmap="gray")
        ax[1, 0].imshow(heatmap_l[1], alpha=0.5)
        ax[0, 0].set_title("Left Image")
        ax[0, 1].imshow(img_r, cmap="gray")
        ax[0, 1].imshow(heatmap_r[0], alpha=0.5)
        ax[1, 1].imshow(img_r, cmap="gray")
        ax[1, 1].imshow(heatmap_r[1], alpha=0.5)
        ax[0, 1].set_title("Right Image")
        ax[0, 2].plot(np.arange(0, self.angle_dim), label, label="label")
        ax[0, 2].set_title(f"Angle: {angle}")
        plt.tight_layout()
        plt.show()

    def __getitem__(self, i):

        img_file_name = self.files[i]
        mask_file_name = img_file_name.replace(self.imgs_dir, self.masks_dir).replace(".jpg", ".png")

        img_right_file_name = img_file_name.replace("_left.jpg", "_right.jpg")
        mask_right_file_name = img_right_file_name.replace(self.imgs_dir, self.masks_dir).replace(".jpg", ".png")

        # axial angle
        a_name = img_file_name.replace(self.imgs_dir, self.angle_dir).replace(".jpg", ".txt").replace("_left", "")
        angle = np.rad2deg(np.loadtxt(a_name))  # + 180

        img_l = np.array(cv2.imread(img_file_name, cv2.IMREAD_GRAYSCALE))
        mask_l = np.array(cv2.imread(mask_file_name, cv2.IMREAD_GRAYSCALE))
        img_r = np.array(cv2.imread(img_right_file_name, cv2.IMREAD_GRAYSCALE))
        mask_r = np.array(cv2.imread(mask_right_file_name, cv2.IMREAD_GRAYSCALE))

        img_l = cv2.resize(img_l, (self.img_w, self.img_h), interpolation=cv2.INTER_LINEAR)
        img_r = cv2.resize(img_r, (self.img_w, self.img_h), interpolation=cv2.INTER_LINEAR)
        mask_l = cv2.resize(mask_l, (self.img_w, self.img_h), interpolation=cv2.INTER_NEAREST)
        mask_r = cv2.resize(mask_r, (self.img_w, self.img_h), interpolation=cv2.INTER_NEAREST)

        # img_l = log_transform(img_l)
        # img_r = log_transform(img_r)

        if self.transform is not None:
            # Apply transforms to both images with ReplayCompose to replay the same transform
            replay_compose = ReplayCompose([self.transform])
            replay_l = replay_compose(image=img_l, mask=mask_l)
            img_l, mask_l = replay_l["image"], replay_l["mask"]
            replay_r = aug.ReplayCompose.replay(replay_l["replay"], image=img_r, mask=mask_r)
            img_r, mask_r = replay_r["image"], replay_r["mask"]

        # get tip
        tip_l, tail_l = get_tip_point(mask_l, self.img_h, self.img_w)
        tip_r, tail_r = get_tip_point(mask_r, self.img_h, self.img_w)

        # normalization
        tip_l_normalized = (np.array(tip_l) / np.array([self.img_w, self.img_h]) * 2) - 1
        tail_l_normalized = (np.array(tail_l) / np.array([self.img_w, self.img_h]) * 2) - 1
        tip_r_normalized = (np.array(tip_r) / np.array([self.img_w, self.img_h]) * 2) - 1
        tail_r_normalized = (np.array(tail_r) / np.array([self.img_w, self.img_h]) * 2) - 1

        # heatmap
        H_l = torch.zeros(2, self.img_h, self.img_w)
        H_l[0] = generate_gaussian(H_l[0], tip_l_normalized[0], tip_l_normalized[1], sigma=10)
        H_l[1] = generate_gaussian(H_l[1], tail_l_normalized[0], tail_l_normalized[1], sigma=10)

        H_r = torch.zeros(2, self.img_h, self.img_w)
        H_r[0] = generate_gaussian(H_r[0], tip_r_normalized[0], tip_r_normalized[1], sigma=10)
        H_r[1] = generate_gaussian(H_r[1], tail_r_normalized[0], tail_r_normalized[1], sigma=10)

        # HWC to CHW -> numpy to tensor
        img_l = torch.from_numpy(pre_process(img_l)).type(torch.FloatTensor)
        img_r = torch.from_numpy(pre_process(img_r)).type(torch.FloatTensor)

        # label angle
        angle_bin = np.round(angle / self.angle_step)
        angle_curve = gaussian_1D_label(angle_bin, self.angle_dim, sig=self.smooth_sigma)
        label_out = torch.tensor(angle_curve).type(torch.FloatTensor)

        return img_l, img_r, H_l, H_r, label_out


if __name__ == "__main__":

    main_path = "/home/lar/dev25/ELVEZ/DATA/DATA_APRIL_09_INTERPOLATION/"

    list_folders = ["ALL_val"]

    dataset = DatasetGrayscale(
        main_path=main_path, list_folders=list_folders, img_size=512, angle_step=1, smooth_sigma=4, transform=None
    )

    for img_l, img_r, H_l, H_r, label in dataset:
        print("img shape", img_l.shape, img_r.shape)
        print("H shape", H_l.shape, H_r.shape)
        print("label shape", label.shape)

        angle_val = np.argmax(label.numpy())
        print("Angle value", angle_val)
        dataset.show_datasample(img_l, img_r, H_l, H_r, label)
