import torch
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt

from models import StereoHeatmapAngle
from dataset import Dataset
from utils import heatmap2tiptail


CKPT_TO_LOAD = (
    "path_here_to_your_checkpoint.pth"
)
PATH = "path_here_to_your_dataset"
FOLDERS = ["path_here_to_your_folder"]


if __name__ == "__main__":

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device {device}")

    checkpoint = torch.load(CKPT_TO_LOAD, map_location=torch.device("cpu"), weights_only=True)
    print(f"Checkpoint {os.path.basename(CKPT_TO_LOAD)} loaded !")

    angle_step = checkpoint.get("angle_step", 1)
    smooth_sigma = checkpoint.get("smooth_sigma", 2)
    img_size = checkpoint.get("img_size", 256)
    encoder_type = checkpoint.get("encoder_type", "resnet50")
    feature_dim = checkpoint.get("feature_dim", 256)

    print(f"Image size: {img_size}")
    print(f"Encoder type: {encoder_type}")
    print(f"Feature dim: {feature_dim}")

    model = StereoHeatmapAngle(resnet=encoder_type, angle_step=angle_step, feature_dim=feature_dim)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.to(device)
    model.eval()
    print("Model loaded !")

    out_dict = {}
    for pred_folder in FOLDERS:
        print(f"Processing folder {pred_folder}")

        dataset = Dataset(
            main_path=PATH,
            list_folders=[pred_folder],
            img_size=img_size,
            angle_step=angle_step,
            smooth_sigma=smooth_sigma,
            transform=None,
        )
        loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True)

        tmp_dict = {}
        for img_l, img_r, h_l, h_r, label in tqdm(loader):
            img_l = img_l.to(device)
            img_r = img_r.to(device)

            pred_hl, pred_hr, pred_angle = model(img_l, img_r)

            ############################

            point_ltip, point_ltail = heatmap2tiptail(pred_hl)
            pred_hl = pred_hl.squeeze().detach().cpu().numpy()

            point_rtip, point_rtail = heatmap2tiptail(pred_hr)
            pred_hr = pred_hr.squeeze().detach().cpu().numpy()

            pred = pred_angle.sigmoid().squeeze().detach().cpu().numpy()
            label = label.squeeze().detach().cpu().numpy()

            img_l = img_l.squeeze().detach().cpu().numpy().transpose(1, 2, 0)
            img_r = img_r.squeeze().detach().cpu().numpy().transpose(1, 2, 0)

            pred_hl = pred_hl[0] + pred_hl[1]
            pred_hr = pred_hr[0] + pred_hr[1]
            pred_hl = pred_hl / np.max(pred_hl)
            pred_hr = pred_hr / np.max(pred_hr)

            label_angle = np.argmax(label)
            pred_angle = np.argmax(pred)
            pred_val = np.max(pred)

            error = np.abs(label_angle - pred_angle) * angle_step
            if error > 180:
                error = 360 - error
            tmp_dict[label_angle] = (error, pred_val)

            #############################

            if True:
                fig, ax = plt.subplots(1, 3, figsize=(14, 7))
                ax[0].imshow(img_l)
                ax[0].imshow(pred_hl, alpha=0.5)
                ax[0].scatter(point_ltip[0], point_ltip[1], c="r", s=10)
                ax[0].scatter(point_ltail[0], point_ltail[1], c="r", s=10)
                ax[1].imshow(img_r)
                ax[1].imshow(pred_hr, alpha=0.5)
                ax[1].scatter(point_rtip[0], point_rtip[1], c="r", s=10)
                ax[1].scatter(point_rtail[0], point_rtail[1], c="r", s=10)
                ax[2].plot(pred, label="pred")
                ax[2].plot(label, label="label")
                ax[2].legend()
                plt.tight_layout()
                plt.show()

        out_dict[pred_folder] = tmp_dict

    fig, ax = plt.subplots(2, 1, figsize=(14, 7))
    bar_width = 0.2
    for it, folder in enumerate(out_dict.keys()):
        error_dict = out_dict[folder]

        values = np.array(list(error_dict.values()), dtype=np.float32)
        errors = values[:, 0]
        pred_vals = values[:, 1]

        angles = np.array(list(error_dict.keys()), dtype=np.float32)
        x_pos = angles - 0.5 + it * bar_width

        ax[0].bar(x_pos, errors, width=bar_width, label=folder)
        ax[1].bar(x_pos + bar_width, pred_vals, width=bar_width, label=f"{folder} pred")

    ax[0].legend()
    ax[1].legend()
    plt.xlabel("Angle")
    plt.ylabel("Error")
    plt.title(CKPT_TO_LOAD)
    plt.tight_layout()
    plt.show()
