import os
import time
import cv2
import numpy as np
import matplotlib
import torchvision

matplotlib.use('Agg')
import matplotlib.pyplot as plt
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from prep import printProgressBar
from networks import RDANET
from measure import compute_measure


class ExampleDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        if not os.path.exists(root_dir):
            raise FileNotFoundError(f"路径 {root_dir} 不存在")
        self.image_files = [f for f in os.listdir(root_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
        if not self.image_files:
            raise ValueError("路径中未找到图像文件")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.image_files[idx])
        image = cv2.imread(img_name, cv2.IMREAD_GRAYSCALE)
        if image is None:
            raise ValueError(f"无法读取图像文件 {img_name}")
        image = torch.from_numpy(image).unsqueeze(0).float() / 255.0
        if self.transform:
            image = self.transform(image)
        return image, image


class SingleChannelVGG(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = torchvision.models.vgg16(pretrained=False).features[:9]
        vgg[0] = nn.Conv2d(1, 64, kernel_size=3, padding=1)
        nn.init.kaiming_normal_(vgg[0].weight, mode='fan_out', nonlinearity='relu')
        self.features = vgg
        for p in self.features.parameters():
            p.requires_grad = False

    def forward(self, x):
        return self.features(x)


class CTPerceptualLoss(nn.Module):
    def __init__(self, ct_mean=542.9525, ct_std=490.6945):
        super().__init__()
        self.vgg = SingleChannelVGG()
        self.criterion = nn.MSELoss()
        self.register_buffer('mean', torch.tensor([ct_mean]).view(1, 1, 1, 1))
        self.register_buffer('std', torch.tensor([ct_std]).view(1, 1, 1, 1))

    def forward(self, input, target):
        input_norm = (input - self.mean) / self.std
        target_norm = (target - self.mean) / self.std
        return self.criterion(self.vgg(input_norm), self.vgg(target_norm))


class Solver(object):
    def __init__(self, args, data_loader):
        self.device = torch.device(args.device or ('cuda' if torch.cuda.is_available() else 'cpu'))
        self.data_loader = data_loader

        self.norm_range_min = args.norm_range_min
        self.norm_range_max = args.norm_range_max
        self.trunc_min = args.trunc_min
        self.trunc_max = args.trunc_max

        self.save_path = args.save_path
        self.multi_gpu = args.multi_gpu

        self.num_epochs = args.num_epochs
        self.print_iters = args.print_iters
        self.decay_iters = args.decay_iters
        self.save_iters = args.save_iters
        self.test_iters = args.test_iters
        self.result_fig = args.result_fig

        self.patch_size = args.patch_size

        self.RDANET = RDANET()
        if self.multi_gpu and torch.cuda.device_count() > 1:
            print(f'Use {torch.cuda.device_count()} GPUs')
            self.RDANET = nn.DataParallel(self.RDANET)
        self.RDANET.to(self.device)

        self.lr = args.lr
        self.mse_loss = nn.MSELoss()   #均方根误差损失函数
        self.percep_loss = CTPerceptualLoss(ct_mean=542.9525, ct_std=490.6945).to(self.device) #定义感知损失函数，并移动到指定设备
        self.loss_weight = 0.1   #保存损失权重
        self.optimizer = optim.Adam(self.RDANET.parameters(), self.lr)  #定义 Adam 优化器
        self.loss_list = []   #初始化损失列表
        os.makedirs(self.save_path, exist_ok=True)   #创建模型保存路径
        os.makedirs(os.path.join(self.save_path, 'fig'), exist_ok=True)   #创建结果图形保存路径

    def save_model(self, iter_):
        torch.save(self.RDANET.state_dict(), os.path.join(self.save_path, f'RDANET_{iter_}iter.ckpt'))

    def load_model(self, iter_):
        path = os.path.join(self.save_path, f'RDANET_{iter_}iter.ckpt')
        if self.multi_gpu:
            state_d = OrderedDict()
            for k, v in torch.load(path).items():
                state_d[k[7:]] = v
            self.RDANET.load_state_dict(state_d)
        else:
            self.RDANET.load_state_dict(torch.load(path))

    def lr_decay(self):
        for param_group in self.optimizer.param_groups:
            param_group['lr'] *= 0.5

    def denormalize_(self, image):
        return image * (self.norm_range_max - self.norm_range_min) + self.norm_range_min

    def trunc(self, mat):
        mat[mat <= self.trunc_min] = self.trunc_min
        mat[mat >= self.trunc_max] = self.trunc_max
        return mat

    def save_fig(self, x, y, pred, fig_name, ori_result, pred_result):
        x, y, pred = x.numpy(), y.numpy(), pred.numpy()
        f, ax = plt.subplots(1, 3, figsize=(30, 10))
        ax[0].imshow(x, cmap='gray', vmin=self.trunc_min, vmax=self.trunc_max)
        ax[0].set_title('Quarter-dose', fontsize=30)
        ax[0].set_xlabel(f"PSNR: {ori_result[0]:.4f}\nSSIM: {ori_result[1]:.4f}\nRMSE: {ori_result[2]:.4f}", fontsize=20)
        ax[1].imshow(pred, cmap='gray', vmin=self.trunc_min, vmax=self.trunc_max)
        ax[1].set_title('Result', fontsize=30)
        ax[1].set_xlabel(f"PSNR: {pred_result[0]:.4f}\nSSIM: {pred_result[1]:.4f}\nRMSE: {pred_result[2]:.4f}", fontsize=20)
        ax[2].imshow(y, cmap='gray', vmin=self.trunc_min, vmax=self.trunc_max)
        ax[2].set_title('Full-dose', fontsize=30)
        f.savefig(os.path.join(self.save_path, 'fig', f'result_{fig_name}.png'))
        plt.close()

    def train(self):
        total_iters = 0
        start_time = time.time()
        for epoch in range(1, self.num_epochs):
            self.RDANET.train()
            for iter_, (x, y) in enumerate(self.data_loader):
                total_iters += 1
                # x = x.unsqueeze(0).float().to(self.device)
                # y = y.unsqueeze(0).float().to(self.device)
                x = x.float().to(self.device)
                y = y.float().to(self.device)

                if self.patch_size:
                    x = x.view(-1, 1, self.patch_size, self.patch_size)
                    y = y.view(-1, 1, self.patch_size, self.patch_size)

                pred = self.RDANET(x)
                loss = self.mse_loss(pred, y) + self.loss_weight * self.percep_loss(pred, y)  # 损失函数
                self.loss_list.append(loss.item())

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                if total_iters % self.print_iters == 0:
                    print(f"STEP [{total_iters}], EPOCH [{epoch}/{self.num_epochs}], ITER [{iter_+1}/{len(self.data_loader)}] \nLOSS: {loss.item():.8f}, TIME: {time.time() - start_time:.1f}s")

                if total_iters % self.decay_iters == 0:
                    self.lr_decay()
                if total_iters % self.save_iters == 0:
                    self.save_model(total_iters)
                    np.save(os.path.join(self.save_path, f'loss_{total_iters}_iter.npy'), np.array(self.loss_list))
                    print(f"[+] Saved loss curve at {total_iters} iters.")

        print(f"Total training time: {time.time() - start_time:.2f} seconds")

    def test(self):
        self.RDANET = RDANET().to(self.device)
        self.load_model(self.test_iters)
        start_time = time.time()

        ori_psnr_avg = ori_ssim_avg = ori_rmse_avg = 0
        pred_psnr_avg = pred_ssim_avg = pred_rmse_avg = 0

        print("data_loader长度:", len(self.data_loader))

        with torch.no_grad():
            for i, (x, y) in enumerate(self.data_loader):
                shape_ = x.shape[-1]
                x = x.unsqueeze(0).float().to(self.device)
                y = y.unsqueeze(0).float().to(self.device)

                pred = self.RDANET(x)
                x = self.trunc(self.denormalize_(x.view(shape_, shape_).cpu()))
                y = self.trunc(self.denormalize_(y.view(shape_, shape_).cpu()))
                pred = self.trunc(self.denormalize_(pred.view(shape_, shape_).cpu()))

                data_range = self.trunc_max - self.trunc_min
                ori_result, pred_result = compute_measure(x, y, pred, data_range)

                ori_psnr_avg += ori_result[0]
                ori_ssim_avg += ori_result[1]
                ori_rmse_avg += ori_result[2]
                pred_psnr_avg += pred_result[0]
                pred_ssim_avg += pred_result[1]
                pred_rmse_avg += pred_result[2]

                if self.result_fig:
                    self.save_fig(x, y, pred, i, ori_result, pred_result)

                printProgressBar(i, len(self.data_loader), prefix="Compute measurements ..", suffix='Complete', length=25)

            print('\nOriginal ===')
            print(f"PSNR avg: {ori_psnr_avg/len(self.data_loader):.4f}\nSSIM avg: {ori_ssim_avg/len(self.data_loader):.4f}\nRMSE avg: {ori_rmse_avg/len(self.data_loader):.4f}")
            print('\nPredictions ===')
            print(f"PSNR avg: {pred_psnr_avg/len(self.data_loader):.4f}\nSSIM avg: {pred_ssim_avg/len(self.data_loader):.4f}\nRMSE avg: {pred_rmse_avg/len(self.data_loader):.4f}")
            print(f"Total testing time: {time.time() - start_time:.2f} seconds")


if __name__ == "__main__":
    class Args:
        def __init__(self):
            self.data_path = './10AAPM-Mayo-CT-Challenge10/'
            self.save_path = './5.20.1save10/'
            self.result_fig = True
            self.norm_range_min = -1024.0
            self.norm_range_max = 3072.0
            self.trunc_min = -160.0
            self.trunc_max = 240.0
            self.patch_size = 64
            self.batch_size = 16
            self.num_epochs = 100
            self.print_iters = 20
            self.decay_iters = 3000
            self.save_iters = 1000
            self.test_iters = 14000
            self.lr = 1e-4
            self.device = None
            self.num_workers = 7
            self.multi_gpu = False

    args = Args()
    dataset = ExampleDataset(root_dir=args.data_path)
    data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)

    solver = Solver(args, data_loader)
    solver.train()