import torch
from torch import nn
from resnet import resnet34, resnet50, resnet101


class DecoderConv(nn.Module):

    def __init__(self, dims, feature_dim, scales):
        super(DecoderConv, self).__init__()

        self.conv1 = nn.Conv2d(dims[0], feature_dim // 4, kernel_size=1, stride=1, padding=0)
        self.conv2 = nn.Conv2d(dims[1], feature_dim // 4, kernel_size=1, stride=1, padding=0)
        self.conv3 = nn.Conv2d(dims[2], feature_dim // 4, kernel_size=1, stride=1, padding=0)
        self.conv4 = nn.Conv2d(dims[3], feature_dim // 4, kernel_size=1, stride=1, padding=0)
        self.upsample1 = nn.Upsample(scale_factor=scales[0], mode="bilinear", align_corners=True)
        self.upsample2 = nn.Upsample(scale_factor=scales[1], mode="bilinear", align_corners=True)
        self.upsample3 = nn.Upsample(scale_factor=scales[2], mode="bilinear", align_corners=True)
        self.upsample4 = nn.Upsample(scale_factor=scales[3], mode="bilinear", align_corners=True)

    def forward(self, x1, x2, x3, x4):
        x1 = self.conv1(x1).relu()
        x2 = self.conv2(x2).relu()
        x3 = self.conv3(x3).relu()
        x4 = self.conv4(x4).relu()

        x1 = self.upsample1(x1)
        x2 = self.upsample2(x2)
        x3 = self.upsample3(x3)
        x4 = self.upsample4(x4)

        return torch.cat([x1, x2, x3, x4], dim=1)


class Backbone(nn.Module):

    def __init__(self, resnet="resnet101", feature_dim=256, input_dim=3):
        super(Backbone, self).__init__()
        if resnet == "resnet34":
            self.backbone = resnet34(
                pretrained=False,
                replace_stride_with_dilation=[False, False, False],
                num_classes=feature_dim,
                input_dim=input_dim,
            )
            dims = [64, 128, 256, 512]
            scales = [4, 8, 16, 32]
        elif resnet == "resnet50":
            self.backbone = resnet50(
                pretrained=False,
                replace_stride_with_dilation=[False, False, True],
                num_classes=feature_dim,
                input_dim=input_dim,
            )
            dims = [256, 512, 1024, 2048]
            scales = [4, 8, 16, 16]
        elif resnet == "resnet101":
            self.backbone = resnet101(
                pretrained=False,
                replace_stride_with_dilation=[False, False, True],
                num_classes=feature_dim,
                input_dim=input_dim,
            )
            dims = [256, 512, 1024, 2048]
            scales = [4, 8, 16, 16]
        else:
            raise ValueError("Invalid resnet type")

        self.decoder_conv = DecoderConv(dims, feature_dim, scales)

    def forward(self, x):
        x1, x2, x3, x4, xfc = self.backbone(x)
        x1234 = self.decoder_conv(x1, x2, x3, x4)
        return xfc, x1234


if __name__ == "__main__":

    x = torch.randn(2, 3, 512, 512)

    model = Backbone(resnet="resnet50", feature_dim=256, input_dim=3)
    xfc, x1234 = model(x)
    print(xfc.shape)
    print(x1234.shape)
