import os
from glob import glob
import numpy as np
from torch.utils.data import Dataset, DataLoader


class ct_dataset(Dataset):
    def __init__(self, mode, load_mode, saved_path, test_patient, patch_n=None, patch_size=None, transform=None):
        assert mode in ['train', 'test'], "mode is 'train' or 'test'"
        assert load_mode in [0,1], "load_mode is 0 or 1"

        input_path = sorted(glob(os.path.join(saved_path, '*_input.npy')))
        target_path = sorted(glob(os.path.join(saved_path, '*_target.npy')))
        self.load_mode = load_mode
        self.patch_n = patch_n
        self.patch_size = patch_size
        self.transform = transform

        if mode == 'train':
            # 当模式为 'train' 时，从数据集中排除测试患者数据

            input_ = [f for f in input_path if test_patient not in f]
            target_ = [f for f in target_path if test_patient not in f]
            # input_ = [f for f in input_path if not any(p in f for p in test_patient)]
            # target_ = [f for f in target_path if not any(p in f for p in test_patient)]
            if load_mode == 0: # batch data load
                self.input_ = input_
                self.target_ = target_
            else: # all data load
                self.input_ = [np.load(f) for f in input_]
                self.target_ = [np.load(f) for f in target_]
        else: # mode =='test'
            input_ = [f for f in input_path if test_patient in f]
            target_ = [f for f in target_path if test_patient in f]
            # input_ = [f for f in input_path if any(p in f for p in test_patient)]
            # target_ = [f for f in target_path if any(p in f for p in test_patient)]
            if load_mode == 0:
                self.input_ = input_
                self.target_ = target_
            else:
                self.input_ = [np.load(f) for f in input_]
                self.target_ = [np.load(f) for f in target_]

    def __len__(self):
        return len(self.target_)

    def __getitem__(self, idx):
        input_img, target_img = self.input_[idx], self.target_[idx]
        if self.load_mode == 0:
            input_img, target_img = np.load(input_img), np.load(target_img)

        if self.transform:
            input_img = self.transform(input_img)
            target_img = self.transform(target_img)

        if self.patch_size:
            input_patches, target_patches = get_patch(input_img,
                                                      target_img,
                                                      self.patch_n,
                                                      self.patch_size)
            return (input_patches, target_patches)
        else:
            return (input_img, target_img)


def get_patch(full_input_img, full_target_img, patch_n, patch_size):
    assert full_input_img.shape == full_target_img.shape
    patch_input_imgs = []
    patch_target_imgs = []
    h, w = full_input_img.shape
    new_h, new_w = patch_size, patch_size
    for _ in range(patch_n):
        top = np.random.randint(0, h-new_h)
        left = np.random.randint(0, w-new_w)
        patch_input_img = full_input_img[top:top+new_h, left:left+new_w]
        patch_target_img = full_target_img[top:top+new_h, left:left+new_w]
        patch_input_imgs.append(patch_input_img)
        patch_target_imgs.append(patch_target_img)
    return np.array(patch_input_imgs), np.array(patch_target_imgs)


def get_loader(mode='test', load_mode=0,
               saved_path=None, test_patient='L010',
               patch_n=None, patch_size=None,
               transform=None, batch_size=32, num_workers=6):
    dataset_ = ct_dataset(mode, load_mode, saved_path, test_patient, patch_n, patch_size, transform)
    data_loader = DataLoader(dataset=dataset_, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    return data_loader
if __name__ == '__main__':
    # 定义参数
    mode = 'test'  # 或者 'test'，根据你的需求
    load_mode = 0  # 或者 1
    saved_path = './newnpy_img10/'  # 替换为你的数据路径
    test_patient = 'L506' # 根据你的需求可能需要改变
    patch_n = 10  # 你希望的补丁数量
    patch_size = 64  # 你希望的补丁大小
    batch_size = 1  # 你希望的批量大小
    num_workers = 0  # 数据加载的工作进程数
    shuffle = False
    # 创建 data_loader
    data_loader = get_loader(mode=mode, load_mode=load_mode,
                             saved_path=saved_path, test_patient=test_patient,
                             patch_n=patch_n, patch_size=patch_size,
                              batch_size=batch_size, num_workers=num_workers)

    # 打印 data_loader 的值
    print(len(data_loader))
    print(f"DataLoader with length: {len(data_loader)}")
    print(f"Batch size: {data_loader.batch_size}")
    # 如果你想要查看 DataLoader 的其他属性，可以继续添加 print 语句