import torch
from torch import nn
from encoder import Backbone


class StereoHeatmap(nn.Module):
    def __init__(self, feature_dim=256, resnet="resnet101"):
        super(StereoHeatmap, self).__init__()

        self.backbone = Backbone(resnet=resnet, feature_dim=feature_dim)

        # heatmap head
        self.conv_head1 = nn.Conv2d(feature_dim, feature_dim // 4, kernel_size=1, stride=1, padding=0)
        self.conv_head2 = nn.Conv2d(feature_dim // 4, 2, kernel_size=1, stride=1, padding=0)

    def forward_h(self, x):
        _, x1234 = self.backbone(x)
        z_h = self.conv_head1(x1234).relu()
        z_h = self.conv_head2(z_h).sigmoid()
        return z_h

    def forward(self, x_l, x_r):
        zh_l = self.forward_h(x_l)
        zh_r = self.forward_h(x_r)
        return zh_l, zh_r


class StereoAngle(nn.Module):
    def __init__(self, feature_dim=256, resnet="resnet101", angle_step=5):
        super(StereoAngle, self).__init__()

        self.backbone = Backbone(resnet=resnet, feature_dim=feature_dim, input_dim=4)

        # cls head
        self.linear1 = nn.Linear(feature_dim * 2, feature_dim)
        self.linear2 = nn.Linear(feature_dim, feature_dim // 4)
        self.out_cls = nn.Linear(feature_dim // 4, 360 // angle_step)

    def forward(self, x_l, x_r):
        # x_l = x_l[:, :3, :, :]
        # x_r = x_r[:, :3, :, :]

        zl, _ = self.backbone(x_l)
        zr, _ = self.backbone(x_r)

        z = torch.cat([zl, zr], dim=1)
        z = self.linear1(z).relu()
        z = self.linear2(z).relu()
        return self.out_cls(z)


class StereoHeatmapAngle(nn.Module):
    def __init__(self, feature_dim=256, resnet="resnet101", angle_step=5, input_dim=3):
        super(StereoHeatmapAngle, self).__init__()

        self.backbone = Backbone(resnet=resnet, feature_dim=feature_dim, input_dim=input_dim)

        # heatmap head
        self.conv_head_hetmaps = nn.Sequential(
            nn.Conv2d(feature_dim, feature_dim // 4, kernel_size=1, stride=1, padding=0),
            nn.ReLU(inplace=True),
            nn.Conv2d(feature_dim // 4, 2, kernel_size=1, stride=1, padding=0),
            nn.Sigmoid(),
        )

        # cls head
        self.linear_head_cls = nn.Sequential(
            nn.Linear(feature_dim * 2, feature_dim),
            nn.ReLU(inplace=True),
            nn.Linear(feature_dim, feature_dim // 4),
            nn.ReLU(inplace=True),
            nn.Linear(feature_dim // 4, 360 // angle_step),
        )

    def forward(self, x_l, x_r):

        # image encoding
        fcl, enc_l = self.backbone(x_l)
        fcr, enc_r = self.backbone(x_r)

        # heatmap head
        zh_l = self.conv_head_hetmaps(enc_l)
        zh_r = self.conv_head_hetmaps(enc_r)

        # angle head
        z = torch.cat([fcl, fcr], dim=1)
        z = self.linear_head_cls(z)

        return zh_l, zh_r, z


class StereoHeatmapAngleTestSingleView(nn.Module):
    def __init__(self, feature_dim=256, resnet="resnet101", angle_step=5, input_dim=3):
        super(StereoHeatmapAngleTestSingleView, self).__init__()

        self.backbone = Backbone(resnet=resnet, feature_dim=feature_dim, input_dim=input_dim)

        # heatmap head
        self.conv_head_hetmaps = nn.Sequential(
            nn.Conv2d(feature_dim, feature_dim // 4, kernel_size=1, stride=1, padding=0),
            nn.ReLU(inplace=True),
            nn.Conv2d(feature_dim // 4, 2, kernel_size=1, stride=1, padding=0),
            nn.Sigmoid(),
        )

        # cls head
        self.linear_head_cls = nn.Sequential(
            nn.Linear(feature_dim, feature_dim),
            nn.ReLU(inplace=True),
            nn.Linear(feature_dim, feature_dim // 4),
            nn.ReLU(inplace=True),
            nn.Linear(feature_dim // 4, 360 // angle_step),
        )

    def forward(self, x_l, x_r):

        # image encoding
        fcl, enc_l = self.backbone(x_l)
        _, enc_r = self.backbone(x_r)

        # heatmap head
        zh_l = self.conv_head_hetmaps(enc_l)
        zh_r = self.conv_head_hetmaps(enc_r)

        # angle head
        z = self.linear_head_cls(fcl)

        return zh_l, zh_r, z


class StereoHeatmapAngleTestSingleView2(nn.Module):
    def __init__(self, feature_dim=256, resnet="resnet101", angle_step=5, input_dim=3):
        super(StereoHeatmapAngleTestSingleView2, self).__init__()

        self.backbone_l = Backbone(resnet=resnet, feature_dim=feature_dim, input_dim=input_dim)
        self.backbone_r = Backbone(resnet=resnet, feature_dim=feature_dim, input_dim=input_dim)

        # heatmap head
        self.conv_head_hetmaps_l = nn.Sequential(
            nn.Conv2d(feature_dim, feature_dim // 4, kernel_size=1, stride=1, padding=0),
            nn.ReLU(inplace=True),
            nn.Conv2d(feature_dim // 4, 2, kernel_size=1, stride=1, padding=0),
            nn.Sigmoid(),
        )

        self.conv_head_hetmaps_r = nn.Sequential(
            nn.Conv2d(feature_dim, feature_dim // 4, kernel_size=1, stride=1, padding=0),
            nn.ReLU(inplace=True),
            nn.Conv2d(feature_dim // 4, 2, kernel_size=1, stride=1, padding=0),
            nn.Sigmoid(),
        )

        # cls head
        self.linear_head_cls = nn.Sequential(
            nn.Linear(feature_dim, feature_dim),
            nn.ReLU(inplace=True),
            nn.Linear(feature_dim, feature_dim // 4),
            nn.ReLU(inplace=True),
            nn.Linear(feature_dim // 4, 360 // angle_step),
        )

    def forward(self, x_l, x_r):

        # left
        fcl, enc_l = self.backbone_l(x_l)
        zh_l = self.conv_head_hetmaps_l(enc_l)

        # right
        zh_r = self.conv_head_hetmaps_r(self.backbone_r(x_r)[-1])

        # angle head
        z = self.linear_head_cls(fcl)

        return zh_l, zh_r, z


if __name__ == "__main__":

    B = 2
    H = 512
    W = 512
    C = 3

    if False:

        x_l = torch.randn(B, C, H, W)
        x_r = torch.randn(B, C, H, W)

        model = StereoHeatmap(resnet="resnet50", feature_dim=256)

        zh_l, zh_r = model(x_l, x_r)

        print(zh_l.shape)  # Should be [2, 2, 512, 512]
        print(zh_r.shape)  # Should be [2, 2, 512, 512]

    if False:

        x_l = torch.randn(B, C + 1, H, W)
        x_r = torch.randn(B, C + 1, H, W)

        model = StereoAngle(feature_dim=256, resnet="resnet50", angle_step=1)
        z = model(x_l, x_r)

        print(z.shape)  # Should be [2, 2, 512, 512]

    if False:

        x_l = torch.randn(B, C, H, W)
        x_r = torch.randn(B, C, H, W)

        model = StereoHeatmapAngle(feature_dim=256, resnet="resnet50", angle_step=1)
        zh_l, zh_r, z = model(x_l, x_r)

        print(zh_l.shape)  # Should be [2, 2, 512, 512]
        print(zh_r.shape)  # Should be [2, 2, 512, 512]
        print(z.shape)  # Should be [2, 2, 512, 512]

    if True:

        x_l = torch.randn(B, C, H, W)
        x_r = torch.randn(B, C, H, W)

        model = StereoHeatmapAngleTestSingleView2(feature_dim=256, resnet="resnet50", angle_step=1)
        zh_l, zh_r, z = model(x_l, x_r)

        print(zh_l.shape)
        print(zh_r.shape)
        print(z.shape)
