import torch
import torch.nn as nn
import torch.nn.functional as F

# -------------------- 空间注意力 --------------------
class BalancedSpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
        self.sigmoid = nn.Sigmoid()
        self.temperature = nn.Parameter(torch.ones(1) * 0.5)

    def forward(self, x):
        avg_pool = torch.mean(x, dim=1, keepdim=True)
        max_pool, _ = torch.max(x, dim=1, keepdim=True)
        concat = torch.cat([avg_pool, max_pool], dim=1)
        attention = self.conv(concat)
        attention = self.sigmoid(attention / self.temperature.clamp(min=0.1))
        return x * attention

# -------------------- 通道注意力（ECA） --------------------
class ECAAttention(nn.Module):
    def __init__(self, k_size=3):
        super().__init__()
        self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        y = torch.mean(x, dim=(2, 3), keepdim=True)
        y = self.conv(y.squeeze(-1).transpose(-1, -2))
        y = self.sigmoid(y.transpose(-1, -2).unsqueeze(-1))
        return x * y.expand_as(x)

# -------------------- ASPP模块 --------------------
class ASPP(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.atrous_block1 = nn.Conv2d(in_ch, out_ch, kernel_size=1, padding=0, dilation=1)
        self.atrous_block6 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=6, dilation=6)
        self.atrous_block12 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=12, dilation=12)
        self.atrous_block18 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=18, dilation=18)
        self.conv1x1 = nn.Conv2d(out_ch * 4, out_ch, kernel_size=1)

    def forward(self, x):
        x1 = self.atrous_block1(x)
        x2 = self.atrous_block6(x)
        x3 = self.atrous_block12(x)
        x4 = self.atrous_block18(x)
        x_cat = torch.cat([x1, x2, x3, x4], dim=1)
        return self.conv1x1(x_cat)

# -------------------- 卷积 & 反卷积模块 --------------------
class Conv3x3_5x5(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv3 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(out_ch, out_ch, kernel_size=5, padding=2)

    def forward(self, x):
        x = self.conv3(x)
        x = self.conv5(x)
        return x

class TConv3x3_5x5(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.tconv3 = nn.ConvTranspose2d(in_ch, in_ch, kernel_size=3, padding=1)
        self.tconv5 = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=5, padding=2)

    def forward(self, x):
        x = self.tconv3(x)
        x = self.tconv5(x)
        return x

# -------------------- RDA-Net 模型主体 --------------------
class RDANET(nn.Module):
    def __init__(self, out_ch=96):
        super().__init__()
        self.conv1 = Conv3x3_5x5(1, out_ch)
        self.conv2 = Conv3x3_5x5(out_ch, out_ch)
        self.conv3 = Conv3x3_5x5(out_ch, out_ch)
        self.conv4 = Conv3x3_5x5(out_ch, out_ch)
        self.conv5 = Conv3x3_5x5(out_ch, out_ch)
        self.conv6 = Conv3x3_5x5(out_ch, out_ch)

        self.tconv1 = TConv3x3_5x5(out_ch, out_ch)
        self.tconv2 = TConv3x3_5x5(out_ch, out_ch)
        self.tconv3 = TConv3x3_5x5(out_ch, out_ch)
        self.tconv4 = TConv3x3_5x5(out_ch, out_ch)
        self.tconv5 = TConv3x3_5x5(out_ch, out_ch)
        self.tconv6 = TConv3x3_5x5(out_ch, 1)

        self.relu = nn.LeakyReLU(0.01)

        # 注意力模块
        self.sa2 = BalancedSpatialAttention()
        self.sa3 = BalancedSpatialAttention()
        self.sa4 = BalancedSpatialAttention()
        self.sa_d2 = BalancedSpatialAttention()
        self.sa_d3 = BalancedSpatialAttention()
        self.sa_d4 = BalancedSpatialAttention()

        self.eca2 = ECAAttention()
        self.eca3 = ECAAttention()
        self.eca4 = ECAAttention()
        self.eca_d2 = ECAAttention()
        self.eca_d3 = ECAAttention()
        self.eca_d4 = ECAAttention()

        self.adaptive_pool = nn.AdaptiveAvgPool2d((16, 16))
        self.aspp = ASPP(out_ch, out_ch)

    def forward(self, x):
        residual_1 = x

        out = self.relu(self.conv1(x))
        out = self.relu(self.conv2(out))
        out = self.eca2(out)
        out = self.relu(self.sa2(out))
        residual_2 = out

        out = self.relu(self.conv3(out))
        out = self.eca3(out)
        out = self.relu(self.sa3(out))
        residual_3 = out

        out = self.relu(self.conv4(out))
        out = self.eca4(out)
        out = self.relu(self.sa4(out))
        residual_4 = out

        out = self.relu(self.conv5(out))
        out = self.relu(self.conv6(out))

        out = self.adaptive_pool(out)
        out = self.aspp(out)

        out = self.tconv1(out)
        out = F.interpolate(out, size=residual_4.shape[2:])
        out += residual_4

        out = self.tconv2(out)
        out = self.eca_d2(out)
        out = self.relu(self.sa_d2(out))

        out = self.tconv3(out)
        out = self.eca_d3(out)
        out = self.relu(self.sa_d3(out))

        out = self.tconv4(out)
        out = self.eca_d4(out)
        out = self.relu(self.sa_d4(out))

        out = F.interpolate(out, size=residual_3.shape[2:])
        out += residual_3

        out = self.tconv5(out)
        out = F.interpolate(out, size=residual_2.shape[2:])
        out += residual_2

        out = self.tconv6(self.relu(out))
        out += residual_1
        out = self.relu(out)
        out = torch.clamp(out, 0., 1.)
        return out
