import os
import torch
import wandb
import albumentations as aug
from tqdm import tqdm
from models import StereoHeatmapAngle
from dataset import Dataset
from utils import set_seeds, PolyLR, default_hyperparameters


new_hyperparameter = dict(
    model_name="StereoHeatmapAngle",
    epochs=26,
    lr=1e-4,
    batchsize=4,
    encoder_type="resnet50",
    feature_dim=256,
    dataset_path="../../DATA/DATA_07_05_PIN2/dataset",
    angle_step=1,
    smooth_sigma=3,
    img_size=256,
    seed=99,
)

# merge default_hyperparameters with new_hyperparameter (if present)
config = default_hyperparameters.copy()
for k, v in new_hyperparameter.items():
    if k in config:
        print(f"Overriding default hyperparameter {k} with {v}")
    config[k] = v

print("\nConfig:")
for k, v in config.items():
    print(f"\t{k}:   {v}")

wandb.init(config=config, project="project_name", entity="namehere", mode="online")
config = wandb.config


# set random seed
set_seeds(config.seed)


TRANSFORM = aug.Compose(
    [
        aug.RandomBrightnessContrast(contrast_limit=[0.2, 0.2], brightness_limit=[-0.2, 0.2]),
        aug.ChannelShuffle(p=0.5),
        aug.ColorJitter(p=0.5),
        aug.Affine(scale=(0.9, 1.1), translate_percent=(-0.1, 0.1), rotate=10),
    ],
    p=1,
)


class Trainer:

    def __init__(self, config, checkpoint_dir):

        self.train_folders = ["ALL_DATA_train", "ALL_DATA_aug_train"]
        self.val_folders = ["ALL_DATA_val", "ALL_DATA_aug_val"]

        #########################

        model_name = config["model_name"]

        self.best_name = f"CP_{model_name}_{wandb.run.name}.pth"
        self.last_name = f"CP_{model_name}_LAST_{wandb.run.name}.pth"

        os.makedirs(checkpoint_dir, exist_ok=True)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print("Using device ", self.device)

        self.model = StereoHeatmapAngle(
            resnet=config["encoder_type"],
            feature_dim=config["feature_dim"],
            angle_step=config["angle_step"],
        )
        self.model.to(self.device)

        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=config["lr"])
        self.scheduler = PolyLR(self.optimizer, config["epochs"], power=0.97, min_lr=1e-9)
        self.criterion_cls = torch.nn.BCEWithLogitsLoss()
        self.criterion_mse = torch.nn.MSELoss()

        if config["ckpt_pre_train"] is not None:
            checkpoint = torch.load(
                os.path.join("checkpoints", config["ckpt_pre_train"]), map_location=torch.device("cpu")
            )
            self.model.load_state_dict(checkpoint["model_state_dict"], strict=False)
            current_epoch = 0
            global_step = 0
            lr = config["lr"]
            self.model.to(self.device)

        elif config["ckpt_resume"] is not None and os.path.isfile(config["ckpt_resume"]):
            checkpoint = torch.load(config["ckpt_resume"], map_location=torch.device("cpu"))
            self.model.load_state_dict(checkpoint["model_state_dict"])
            self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
            current_epoch = checkpoint["epoch"]
            global_step = checkpoint["step"]
            lr = self.optimizer.param_groups[0]["lr"]

            for state in self.optimizer.state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.cuda()
            self.model.to(self.device)
            print(
                f"""Resume training from: 
                current epoch:   {current_epoch}
                global step:     {global_step}
            """
            )
        else:
            print("""[!] Retrain""")
            current_epoch = 0
            global_step = 0
            lr = config["lr"]
            self.model.to(self.device)

        print("Starting training:")

        #############################
        self.lr = lr
        self.global_step = global_step
        self.current_epoch = current_epoch

        # Dataset ----------------------------------------------------------------------------------
        self.train_dataset = Dataset(
            main_path=config["dataset_path"],
            list_folders=self.train_folders,
            img_size=config["img_size"],
            angle_step=config["angle_step"],
            smooth_sigma=config["smooth_sigma"],
            transform=TRANSFORM,
        )
        self.val_dataset = Dataset(
            main_path=config["dataset_path"],
            list_folders=self.val_folders,
            img_size=config["img_size"],
            angle_step=config["angle_step"],
            smooth_sigma=config["smooth_sigma"],
            transform=TRANSFORM,
        )
        self.n_train = len(self.train_dataset)

        self.train_loader = torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=config["batchsize"],
            shuffle=True,
            num_workers=8,
            pin_memory=True,
            drop_last=True,
        )
        self.val_loader = torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=config["batchsize"],
            shuffle=False,
            num_workers=8,
            pin_memory=True,
            drop_last=True,
        )

    @torch.no_grad()
    def eval_net(self, loader):
        self.model.eval()
        val_loss = 0
        n_val = len(loader)
        with tqdm(total=n_val, desc="Validation round", unit="batch", leave=False) as pbar:
            for img_l, img_r, h_l, h_r, label in loader:

                img_l = img_l.to(device=self.device)
                img_r = img_r.to(device=self.device)
                h_l = h_l.to(device=self.device)
                h_r = h_r.to(device=self.device)
                label = label.to(device=self.device)

                pred_hl, pred_hr, pred_angle = self.model(img_l, img_r)
                loss_h = self.criterion_mse(pred_hl, h_l) + self.criterion_mse(pred_hr, h_r)
                loss_cls = self.criterion_cls(pred_angle, label)
                loss = loss_h + loss_cls

                val_loss += loss.item()

                pbar.update()

        self.model.train()
        return val_loss / n_val

    def train_net(self):

        ### TRAIN
        epochs_no_improve = 0
        min_loss, min_epoch_loss = 1000, 1000
        for epoch in range(self.current_epoch, config["epochs"]):

            # TRAIN
            self.model.train()
            epoch_loss = 0
            with tqdm(total=self.n_train, desc=f'Epoch {epoch + 1}/{config["epochs"]}', unit="img") as pbar:
                for img_l, img_r, h_l, h_r, label in self.train_loader:
                    # learning rate warmup
                    if config["warmup_steps"] > 0 and self.global_step <= config["warmup_steps"]:
                        self.lr = config["warmup_lr"] + (config["lr"] - config["warmup_lr"]) * float(
                            self.global_step / config["warmup_steps"]
                        )
                        for param_group in self.optimizer.param_groups:
                            param_group["lr"] = self.lr

                    self.optimizer.zero_grad()

                    img_l = img_l.to(device=self.device)
                    img_r = img_r.to(device=self.device)
                    h_l = h_l.to(device=self.device)
                    h_r = h_r.to(device=self.device)
                    label = label.to(device=self.device)

                    pred_hl, pred_hr, pred_angle = self.model(img_l, img_r)
                    loss_h = self.criterion_mse(pred_hl, h_l) + self.criterion_mse(pred_hr, h_r)
                    loss_cls = self.criterion_cls(pred_angle, label)

                    loss = loss_h + loss_cls
                    wandb.log(
                        {"train_loss": loss.item(), "train_loss_h": loss_h.item(), "train_loss_cls": loss_cls.item()},
                        step=self.global_step,
                    )

                    epoch_loss += loss.item()

                    ##########
                    loss.backward()
                    self.optimizer.step()

                    ##########
                    pbar.set_postfix(**{"loss ": loss.item()})
                    pbar.update(img_l.shape[0])
                    self.global_step += 1

                    # VALIDATION EVERY config["val_steps"]
                    if self.global_step % config["val_steps"] == 0:
                        print(f"Validation step {self.global_step}")
                        val_loss = self.eval_net(loader=self.val_loader)
                        wandb.log({"val_loss": val_loss}, step=self.global_step)
                        print(f"Evaluation loss: val {val_loss:.5f}")

                        # SAVE BEST CHECKPOINT
                        if val_loss < min_loss:
                            min_loss = val_loss
                            torch.save(self.get_state(), os.path.join(checkpoint_dir, self.best_name))
                            print(f"*** New min validation loss {val_loss:.5f}, checkpoint BEST saved!")

            epoch_loss /= len(self.train_loader)
            print(f"Train | loss: {epoch_loss:.5f}")

            # scheduler step every epoch
            self.scheduler.step()

            # EARLY STOPPING
            if epoch_loss < min_epoch_loss:
                epochs_no_improve = 0
                min_epoch_loss = epoch_loss
            else:
                epochs_no_improve += 1

            if epoch > config["early_stopping_min_epochs"] and epochs_no_improve == config["early_stopping_patience"]:
                print("Early Stopping!")
                break

    def get_state(self):
        state = dict(config).copy()
        state["epoch"] = self.current_epoch
        state["global_step"] = self.global_step
        # state["optimizer_state_dict"] = self.optimizer.state_dict()
        state["model_state_dict"] = self.model.state_dict()
        return state


if __name__ == "__main__":
    checkpoint_dir = "../checkpoints"
    trainer = Trainer(config, checkpoint_dir)
    trainer.train_net()
