import numpy as np
import torch
import math
import scipy.optimize as opt


def generate_gaussian(t, x, y, sigma=10):
    """
    Generates a 2D Gaussian point at location x,y in tensor t.

    x should be in range (-1, 1)

    sigma is the standard deviation of the generated 2D Gaussian.
    """

    h, w = t.shape

    # Heatmap pixel per output pixel
    mu_x = int(0.5 * (x + 1.0) * w)
    mu_y = int(0.5 * (y + 1.0) * h)

    tmp_size = sigma * 3

    # Top-left
    x1, y1 = int(mu_x - tmp_size), int(mu_y - tmp_size)

    # Bottom right
    x2, y2 = int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)
    if x1 >= w or y1 >= h or x2 < 0 or y2 < 0:
        return t

    size = 2 * tmp_size + 1
    tx = np.arange(0, size, 1, np.float32)
    ty = tx[:, np.newaxis]
    x0 = y0 = size // 2

    # The gaussian is not normalized, we want the center value to equal 1
    g = torch.tensor(np.exp(-((tx - x0) ** 2 + (ty - y0) ** 2) / (2 * sigma**2)))

    # Determine the bounds of the source gaussian
    g_x_min, g_x_max = max(0, -x1), min(x2, w) - x1
    g_y_min, g_y_max = max(0, -y1), min(y2, h) - y1

    # Image range
    img_x_min, img_x_max = max(0, x1), min(x2, w)
    img_y_min, img_y_max = max(0, y1), min(y2, h)

    t[img_y_min:img_y_max, img_x_min:img_x_max] = g[g_y_min:g_y_max, g_x_min:g_x_max]

    return t


def heatmap2argmax(heatmap):
    N, C, H, W = heatmap.shape
    index = heatmap.view(N, C, 1, -1).argmax(dim=-1)
    pts = torch.cat([index % W, index // W], dim=2)
    return pts


def fit_gaussian(heatmap, sigma=10):
    def gaussian_2d_func(xy, x0, y0, sigma):
        x, y = xy
        return np.exp(-((x - x0) ** 2 / (2 * sigma**2) + (y - y0) ** 2 / (2 * sigma**2)))

    pts = heatmap2argmax(heatmap).detach().cpu().numpy().squeeze()
    tip = pts[0]
    tail = pts[1]

    size = 50
    crop_tip = heatmap[0, 0, tip[1] - size : tip[1] + size, tip[0] - size : tip[0] + size]
    crop_tail = heatmap[0, 1, tail[1] - size : tail[1] + size, tail[0] - size : tail[0] + size]

    X, Y = np.meshgrid(np.array(range(0, size * 2, 1)), np.array(range(0, size * 2, 1)))
    x_data = X.ravel()
    y_data = Y.ravel()
    z_data_tip = crop_tip.detach().cpu().numpy().squeeze().ravel()
    z_data_tail = crop_tail.detach().cpu().numpy().squeeze().ravel()

    # Fit Gaussian function
    params_tip, _ = opt.curve_fit(gaussian_2d_func, (x_data, y_data), z_data_tip, p0=(size, size, sigma))
    params_tail, _ = opt.curve_fit(gaussian_2d_func, (x_data, y_data), z_data_tail, p0=(size, size, sigma))

    # scale back
    tip = np.array([params_tip[0], params_tip[1]]) + np.array([tip[0] - size, tip[1] - size])
    tail = np.array([params_tail[0], params_tail[1]]) + np.array([tail[0] - size, tail[1] - size])

    return tip, tail


def heatmap2tiptail(heatmap):
    pts = heatmap2argmax(heatmap).detach().cpu().numpy().squeeze()
    tip = pts[0]
    tail = pts[1]
    return tip, tail


def gaussian_1D_label(label, num_class, u=0, sig=4.0):
    label = int(label)
    x = np.array(range(math.floor(-num_class / 2), math.ceil(num_class / 2), 1))
    y_sig = np.exp(-((x - u) ** 2) / (2 * sig**2))
    return np.concatenate(
        [y_sig[math.ceil(num_class / 2) - label :], y_sig[: math.ceil(num_class / 2) - label]], axis=0
    )


def gaussian_1D_label_tensor(label, num_class=360, sig=4.0):
    label_deg = (label * 180 / torch.pi).view(1, -1)
    x = torch.arange(num_class, dtype=torch.float32, device=label.device).view(-1, 1)
    d = torch.minimum((x - label_deg) % num_class, (label_deg - x) % num_class)
    y = torch.exp(-(d**2) / (2 * sig**2))
    return y + 1e-10  # to avoid log(0) with KL divergence


def get_tip_point(pred, img_h, img_w):

    mask = pred > 127
    center = np.array([np.mean(np.where(mask)[1]), np.mean(np.where(mask)[0])])

    # Step 2: Extract the coordinates of non-zero pixels (object pixels)
    coords = np.column_stack(np.where(mask > 0))  # (row, col) format

    # Step 3: Calculate the mean of the coordinates
    mean = np.mean(coords, axis=0)

    # Step 4: Center the coordinates by subtracting the mean
    centered_coords = coords - mean

    # Step 5: Compute the covariance matrix
    cov_matrix = np.cov(centered_coords, rowvar=False)

    # Step 6: Perform eigenvalue decomposition
    eigenvalues, eigenvectors = np.linalg.eig(cov_matrix)

    # Step 7: Identify the eigenvector with the largest eigenvalue (main axis)
    main_axis_index = np.argmax(eigenvalues)
    main_axis = eigenvectors[:, main_axis_index]
    main_axis /= np.linalg.norm(main_axis)
    main_axis = main_axis[::-1]  # flip the main axis

    ########################
    # get tip point
    ########################
    mask_values_up = []
    for i in range(0, 300):
        pixel = center + main_axis * i
        pixel = pixel.astype(int)
        if pixel[0] >= img_h or pixel[1] >= img_w or pixel[0] < 0 or pixel[1] < 0:
            break
        mask_values_up.append(pred[pixel[1], pixel[0]])
    if 0 not in mask_values_up:
        first_zero = len(mask_values_up)
    else:
        first_zero = mask_values_up.index(0)

    mask_values_down = []
    for i in range(0, 300):
        pixel = center - main_axis * i
        pixel = pixel.astype(int)
        if pixel[0] >= img_h or pixel[1] >= img_w or pixel[0] < 0 or pixel[1] < 0:
            break

        mask_values_down.append(int(pred[pixel[1], pixel[0]]))
    if 0 not in mask_values_down:
        first_zero_down = len(mask_values_down)
    else:
        first_zero_down = mask_values_down.index(0)

    tip_up = center + main_axis * first_zero
    tip_down = center - main_axis * first_zero_down

    return tip_down, tip_up


def retrieve_gt_angle(label_cos_sin):
    l = label_cos_sin.squeeze()
    gt_angle = torch.atan2(l[1], l[0]).detach().cpu().numpy().squeeze()
    gt_angle = (gt_angle + 2 * np.pi) % (2 * np.pi)  # between 0 and 2pi
    return gt_angle


def retrieve_gt_angle_numpy(label_cos_sin):
    gt_angle = np.arctan2(label_cos_sin[1], label_cos_sin[0])
    gt_angle = (gt_angle + 2 * np.pi) % (2 * np.pi)  # between 0 and 2pi
    return gt_angle


def compute_angle_error(pred_angle, gt_angle):
    error_rad = (pred_angle - gt_angle + np.pi) % (2 * np.pi) - np.pi
    return error_rad


def pre_process(img):
    if len(img.shape) == 2:
        img = np.expand_dims(img, axis=2)
    img = img.transpose((2, 0, 1))
    if img.max() > 1:
        img = img / 255
    return img


if __name__ == "__main__":

    angle = 180

    label_1 = gaussian_1D_label(angle, num_class=360, sig=1.0)
    label_2 = gaussian_1D_label(angle, num_class=360, sig=2.0)
    label_4 = gaussian_1D_label(angle, num_class=360, sig=4.0)

    import matplotlib.pyplot as plt

    fig = plt.figure()
    plt.plot(label_4, "o", label="sigma=4")
    plt.plot(label_2, "o", label="sigma=2")
    plt.plot(label_1, "o", label="sigma=1")
    plt.xlim(160, 200)
    plt.tight_layout()
    plt.show()
