import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def drop_path(x, drop_prob: float = 0., training: bool = False):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.

    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output

class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    是一种正则化手段，其效果是将深度学习模型中的多分支结构随机”删除“
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)


class LayerNorm(nn.Module):
    r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs
    with shape (batch_size, channels, height, width).
    """

    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape), requires_grad=True)
        self.bias = nn.Parameter(torch.zeros(normalized_shape), requires_grad=True)
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise ValueError(f"not support data format '{self.data_format}'")
        self.normalized_shape = (normalized_shape,)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            # [batch_size, channels, height, width]
            mean = x.mean(1, keepdim=True)
            var = (x - mean).pow(2).mean(1, keepdim=True)
            x = (x - mean) / torch.sqrt(var + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x

class Block(nn.Module):
    r""" ConvNeXt Block. There are two equivalent implementations:
    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
    (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
    We use (2) as we find it slightly faster in PyTorch

    Args:
        dim (int): 输入通道的数量。
        drop_rate (float): 随机深度率。默认值：0.0
        layer_scale_init_value (float): Layer Scale 的初始化值。默认值：1e-6。
    """
    def __init__(self, dim, drop_rate=0., layer_scale_init_value=1e-6):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)  # depthwise conv
        self.norm = LayerNorm(dim, eps=1e-6, data_format="channels_last")
        self.pwconv1 = nn.Linear(dim, 4 * dim)  # pointwise/1x1 convs, implemented with linear layers
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(4 * dim, dim)
        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim,)),
                                  requires_grad=True) if layer_scale_init_value > 0 else None
        self.drop_path = DropPath(drop_rate) if drop_rate > 0. else nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        shortcut = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1)  # [N, C, H, W] -> [N, H, W, C]
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x
        x = x.permute(0, 3, 1, 2)  # [N, H, W, C] -> [N, C, H, W]

        x = shortcut + self.drop_path(x)
        return x
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = DepthwiseSeparableConv(in_channels, out_channels, kernel_size, stride, padding)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = DepthwiseSeparableConv(out_channels, out_channels, kernel_size, stride, padding)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.eca=eca_layer(channel=out_channels)
    def forward(self, x):
        residual = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out_eca=self.eca(out)
        out=out*out_eca
        out += residual
        out = self.relu(out)
        return out
class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(DepthwiseSeparableConv, self).__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=in_channels)
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out

# ECANet 6.28添加
class eca_layer(nn.Module):
    """Constructs a ECA module.
    Args:
        channel: Number of channels of the input feature map
        k_size: Adaptive selection of kernel size
    """

    def __init__(self, channel, k_size=3):
        super(eca_layer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # x: input features with shape [b, c, h, w]
        b, c, h, w = x.size()

        # feature descriptor on the global spatial information
        y = self.avg_pool(x)

        # Two different branches of ECA module
        y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)

        # Multi-scale information fusion
        y = self.sigmoid(y)

        return x * y.expand_as(x)
class RED_CNN(nn.Module):
    def  __init__(self, out_ch=96):
        super(RED_CNN, self).__init__()

        self.convnext1 =Block(dim=out_ch)
        self.convnext2 = Block(dim=out_ch)
        self.convnext3 = Block(dim=out_ch)
        self.convnext4 = Block(dim=out_ch)
        self.convnext5 = Block(dim=out_ch)
        self.convnext6 = Block(dim=out_ch)
        self.convnext7 = Block(dim=out_ch)
        self.convnext8 = Block(dim=out_ch)
        self.convnext9 = Block(dim=out_ch)
        self.convnext10 = Block(dim=out_ch)
        self.convnext11 = Block(dim=out_ch)
        self.convnext12 = Block(dim=out_ch)



        self.layern = LayerNorm(normalized_shape=out_ch)

        self.resblock1 = ResidualBlock(out_ch, out_ch)
        self.resblock2 = ResidualBlock(out_ch, out_ch)
        self.resblock3 = ResidualBlock(out_ch, out_ch)
        self.resblock4 = ResidualBlock(out_ch, out_ch)
        self.resblock5 = ResidualBlock(out_ch, out_ch)
        self.resblock6 = ResidualBlock(out_ch, out_ch)
        self.resblock7 = ResidualBlock(out_ch, out_ch)
        self.resblock8 = ResidualBlock(out_ch, out_ch)
        self.resblock9 = ResidualBlock(out_ch, out_ch)
        self.resblock10 = ResidualBlock(out_ch, out_ch)

        self.conv1 = nn.Conv2d(1, out_ch, kernel_size=5, stride=1, padding=0)
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=5, stride=1, padding=0)
        self.conv3 = nn.Conv2d(out_ch, out_ch, kernel_size=5, stride=1, padding=0)
        self.conv4 = nn.Conv2d(out_ch, out_ch, kernel_size=5, stride=1, padding=0)
        self.conv5 = nn.Conv2d(out_ch, out_ch, kernel_size=5, stride=1, padding=0)

        self.tconv1 = nn.ConvTranspose2d(out_ch, out_ch, kernel_size=5, stride=1, padding=0)
        self.tconv2 = nn.ConvTranspose2d(out_ch, out_ch, kernel_size=5, stride=1, padding=0)
        self.tconv3 = nn.ConvTranspose2d(out_ch, out_ch, kernel_size=5, stride=1, padding=0)
        self.tconv4 = nn.ConvTranspose2d(out_ch, out_ch, kernel_size=5, stride=1, padding=0)
        self.tconv5 = nn.ConvTranspose2d(out_ch, 1, kernel_size=5, stride=1, padding=0)
        self.eca=eca_layer(channel=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        residual_1 = x  # 保存输入的原始数据作为第一个残差项。
        out = self.relu(self.conv1(x))
        out = self.convnext1(out)  ####  1
        eca_out = self.eca(out)
        out = out * eca_out
        out = self.conv2(out)
        # out=self.resblock2(out)
        out = self.convnext2(out)  #####  2
        eca_out = self.eca(out)
        out = out * eca_out
        residual_2 = out  # 保存当前特征图作为第二个残差项。
        out = self.conv3(out)
        out = self.convnext3(out)  #### 3
        eca_out = self.eca(out)
        out = out * eca_out
        out = self.conv4(out)
        out = self.convnext4(out)  ##### 4
        eca_out = self.eca(out)
        out = out * eca_out
        residual_3 = out  # 保存当前特征图作为第三个残差项
        # out = self.resblock5(out)
        out = self.conv5(out)
        out = self.convnext5(out)  #### 5
        eca_out = self.eca(out)
        out = out * eca_out
        # decoder
        out = self.tconv1(out)
        out = self.convnext6(out)  ##### A
        out += residual_3  # 将第三个残差项加到上采样后的特征图中。
        out = self.tconv2(out)
        out = self.convnext7(out)  #### B
        out = self.tconv3(out)
        out = self.convnext8(out)  #### C
        out += residual_2
        out = self.tconv4(out)
        out = self.convnext9(out)  ####D
        out = self.convnext10(out)  ##### E
        out = self.tconv5(out)
        out += residual_1
        # out = self.relu(out)
        out = self.relu(out)
        return out


