import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv




# 图卷积层
class GraphConvolutionalLayer(nn.Module):
    def __init__(self, in_channels, out_channels, layer_type='GCN'):
        super(GraphConvolutionalLayer, self).__init__()
        if layer_type == 'GCN':
            self.conv = GCNConv(in_channels, out_channels)
        elif layer_type == 'GAT':
            self.conv = GATConv(in_channels, out_channels, heads=1)  # 可以调整heads数量
        else:
            raise ValueError("Unsupported layer_type. Choose between 'GCN' and 'GAT'.")

    def forward(self, x, edge_index):
        return self.conv(x, edge_index)


# 特征提取器，加入图卷积层
class FeatureExtractorWithGNN(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_sizes, gnn_out_channels, gnn_layer_type='GCN', stride=1):
        super(FeatureExtractorWithGNN, self).__init__()
        self.convs = nn.ModuleList([
            nn.Sequential(
                nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=ks, stride=stride,
                          padding=ks // 2),
                nn.BatchNorm1d(out_channels),
                nn.ReLU(),
                nn.Conv1d(out_channels, out_channels, kernel_size=ks, stride=stride, padding=ks // 2),
                nn.BatchNorm1d(out_channels),
                nn.ReLU()
            )
            for ks in kernel_sizes
        ])
        self.fc = nn.Linear(out_channels * len(kernel_sizes), in_channels)

        self.gnn = GraphConvolutionalLayer(in_channels, gnn_out_channels, layer_type=gnn_layer_type)

    def forward(self, x, edge_index=None):
        x = x.permute(0, 2, 1)
        conv_outs = [conv(x) for conv in self.convs]

        min_len = min([out.shape[2] for out in conv_outs])
        conv_outs = [out[:, :, :min_len] for out in conv_outs]

        x = torch.cat(conv_outs, dim=1)
        x = x.permute(0, 2, 1)

        if edge_index is not None:
            x = self.gnn(x, edge_index)

        x = self.fc(x)
        return x


# 数据嵌入模块
class DataEmbedding_with_GNN(nn.Module):
    def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1, gnn_out_channels=64,
                 gnn_layer_type='GCN'):
        super(DataEmbedding_with_GNN, self).__init__()
        self.value_embedding = nn.Linear(c_in, d_model)
        self.dropout = nn.Dropout(p=dropout)
        self.feature_extractor = FeatureExtractorWithGNN(
            in_channels=c_in,
            out_channels=c_in,
            kernel_sizes=[3, 5, 7],
            gnn_out_channels=gnn_out_channels,
            gnn_layer_type=gnn_layer_type
        )

    def forward(self, x, x_mark, edge_index=None):
        x = x.permute(0, 2, 1)
        if x_mark is not None:
            x_mark = x_mark.permute(0, 2, 1)
            x = torch.cat([x, x_mark], 1)
            x = self.feature_extractor(x, edge_index)
        x = self.value_embedding(x)
        x = self.dropout(x)
        return x


# 局部注意力层
class LocalAttention(nn.Module):
    def __init__(self, window_size, scale=None, attention_dropout=0.1, output_attention=False):
        super(LocalAttention, self).__init__()
        self.window_size = window_size
        self.scale = scale
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)

    def forward(self, queries, keys, values, attn_mask=None):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape

        scale = self.scale or 1. / sqrt(E)
        Q = queries.view(B, L, H, -1)
        K = keys.view(B, S, H, -1)
        V = values.view(B, S, H, -1)

        outputs = []
        attentions = []

        for i in range(0, L, self.window_size):
            q = Q[:, i:i + self.window_size]
            k = K[:, i:i + self.window_size]
            v = V[:, i:i + self.window_size]

            QK = torch.matmul(q, k.transpose(-2, -1))
            QK = QK / scale

            if attn_mask is not None:
                QK = QK.masked_fill(attn_mask[:, :, i:i + self.window_size, i:i + self.window_size], float('-inf'))

            A = F.softmax(QK, dim=-1)
            A = self.dropout(A)

            output = torch.matmul(A, v)
            outputs.append(output)
            if self.output_attention:
                attentions.append(A)

        outputs = torch.cat(outputs, dim=1)
        if self.output_attention:
            attentions = torch.cat(attentions, dim=1)
            return outputs.contiguous(), attentions
        else:
            return outputs.contiguous(), None


# 全局注意力层
class GlobalAttention(nn.Module):
    def __init__(self, scale=None, attention_dropout=0.1, output_attention=False):
        super(GlobalAttention, self).__init__()
        self.scale = scale
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)

    def forward(self, queries, keys, values, attn_mask=None):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape

        scale = self.scale or 1. / sqrt(E)
        Q = queries.view(B, L, H, -1)
        K = keys.view(B, S, H, -1)
        V = values.view(B, S, H, -1)

        QK = torch.matmul(Q, K.transpose(-2, -1))
        QK = QK / scale

        if attn_mask is not None:
            QK = QK.masked_fill(attn_mask[:, :, :, :S], float('-inf'))

        A = F.softmax(QK, dim=-1)
        A = self.dropout(A)

        output = torch.matmul(A, V)
        if self.output_attention:
            return output.contiguous(), A
        else:
            return output.contiguous(), None


# 结合注意力层
class CombinedAttentionLayer(nn.Module):
    def __init__(self, d_model, n_heads, window_size, scale=None, attention_dropout=0.1, output_attention=False):
        super(CombinedAttentionLayer, self).__init__()
        self.global_attention = GlobalAttention(scale=scale, attention_dropout=attention_dropout,
                                                output_attention=output_attention)
        self.local_attention = LocalAttention(window_size=window_size, scale=scale, attention_dropout=attention_dropout,
                                              output_attention=output_attention)

        self.query_projection = nn.Linear(d_model, d_model)
        self.key_projection = nn.Linear(d_model, d_model)
        self.value_projection = nn.Linear(d_model, d_model)
        self.out_projection = nn.Linear(d_model, d_model)
        self.n_heads = n_heads

    def forward(self, queries, keys, values, attn_mask=None):
        B, L, _ = queries.shape
        _, S, _ = keys.shape
        H = self.n_heads

        queries = self.query_projection(queries).view(B, L, H, -1)
        keys = self.key_projection(keys).view(B, S, H, -1)
        values = self.value_projection(values).view(B, S, H, -1)

        global_output, global_attn = self.global_attention(queries, keys, values, attn_mask)
        local_output, local_attn = self.local_attention(queries, keys, values, attn_mask)

        combined_output = global_output + local_output
        combined_output = combined_output.view(B, L, -1)
        output = self.out_projection(combined_output)

        if self.global_attention.output_attention or self.local_attention.output_attention:
            combined_attn = (global_attn, local_attn)
            return output, combined_attn
        else:
            return output, None


# 编码器层
class EncoderLayer(nn.Module):
    def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
        super(EncoderLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.attention = attention
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, x, attn_mask=None):
        attn_output, attn_weights = self.attention(x, x, x, attn_mask)
        x = x + self.dropout(attn_output)
        x = self.norm1(x)
        conv_output = self.activation(self.conv1(x.permute(0, 2, 1)))
        conv_output = self.conv2(conv_output).permute(0, 2, 1)
        x = x + self.dropout(conv_output)
        x = self.norm2(x)
        return x, attn_weights


# 编码器
class Encoder(nn.Module):
    def __init__(self, enc_layers, norm_layer):
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList(enc_layers)
        self.norm = norm_layer

    def forward(self, x, attn_mask=None):
        for layer in self.layers:
            x, attn_weights = layer(x, attn_mask)
        return self.norm(x), attn_weights


# 主模型
class Model(nn.Module):
    def __init__(self, configs):
        super(Model, self).__init__()
        self.pred_len = configs.pred_len
        self.output_attention = configs.output_attention
        self.patch_len = configs.patch_len
        self.stride = configs.stride
        self.patch_size_list = configs.patch_size_list

        self.enc_embedding = DataEmbedding_with_GNN(
            configs.seq_len, configs.d_model, configs.embed, configs.freq, configs.dropout,
            gnn_out_channels=configs.gnn_out_channels, gnn_layer_type=configs.gnn_layer_type
        )

        self.encoder = Encoder(
            [
                EncoderLayer(
                    CombinedAttentionLayer(
                        configs.d_model,
                        configs.n_heads,
                        window_size=5,
                        scale=configs.factor,
                        attention_dropout=configs.dropout,
                        output_attention=configs.output_attention
                    ),
                    configs.d_model,
                    configs.d_ff,
                    dropout=configs.dropout,
                    activation=configs.activation
                )
                for _ in range(configs.e_layers)
            ],
            norm_layer=torch.nn.LayerNorm(configs.d_model)
        )

        self.projector = nn.Linear(configs.d_model, configs.pred_len, bias=True)

    def process_single_variate(self, x_enc_var, x_mark_enc, dec_inp, batch_y_mark, edge_index=None):
        # 输入形状

        # 通过嵌入层
        enc_out = self.enc_embedding(x_enc_var, x_mark_enc, edge_index=edge_index)

        # 通过编码器
        enc_out, attns = self.encoder(enc_out, attn_mask=None)

        # 通过投影层并 permute
        dec_out = self.projector(enc_out).permute(0, 2, 1)[:, :, 0]

        return dec_out

    def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec, edge_index=None):
        B, L, D = x_enc.shape  # B: batch size, L: length of sequence, D: number of features
        # print(f"x_enc shape: {x_enc.shape}")

        outputs = []
        for i in range(D):
            x_enc_var = x_enc[:, :, i].unsqueeze(-1)

            dec_out_var = self.process_single_variate(x_enc_var, x_mark_enc, x_dec, x_mark_dec, edge_index)

            outputs.append(dec_out_var)

        dec_out = torch.stack(outputs, dim=-1)

        return dec_out

    def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, edge_index=None):

        dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec, edge_index=edge_index)

        dec_out = dec_out[:, -self.pred_len:, :]

        return dec_out


# 编码器
class Encoder(nn.Module):
    def __init__(self, enc_layers, norm_layer):
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList(enc_layers)
        self.norm = norm_layer

    def forward(self, x, attn_mask=None):
        for layer in self.layers:
            x, attn_weights = layer(x, attn_mask)
        return self.norm(x), attn_weights


# 主模型
class Model(nn.Module):
    def __init__(self, configs):
        super(Model, self).__init__()
        self.pred_len = configs.pred_len
        self.output_attention = configs.output_attention
        self.patch_len = configs.patch_len
        self.stride = configs.stride
        self.patch_size_list = configs.patch_size_list

        self.enc_embedding = DataEmbedding_with_GNN(
            configs.seq_len, configs.d_model, configs.embed, configs.freq, configs.dropout,
            gnn_out_channels=configs.gnn_out_channels, gnn_layer_type=configs.gnn_layer_type
        )

        self.encoder = Encoder(
            [
                EncoderLayer(
                    CombinedAttentionLayer(
                        configs.d_model,
                        configs.n_heads,
                        window_size=5,
                        scale=configs.factor,
                        attention_dropout=configs.dropout,
                        output_attention=configs.output_attention
                    ),
                    configs.d_model,
                    configs.d_ff,
                    dropout=configs.dropout,
                    activation=configs.activation
                )
                for _ in range(configs.e_layers)
            ],
            norm_layer=torch.nn.LayerNorm(configs.d_model)
        )

        self.projector = nn.Linear(configs.d_model, configs.pred_len, bias=True)

    def process_single_variate(self, x_enc_var, x_mark_enc, dec_inp, batch_y_mark, edge_index=None):
        # 输入形状

        # 通过嵌入层
        enc_out = self.enc_embedding(x_enc_var, x_mark_enc, edge_index=edge_index)

        # 通过编码器
        enc_out, attns = self.encoder(enc_out, attn_mask=None)

        # 通过投影层并 permute
        dec_out = self.projector(enc_out).permute(0, 2, 1)[:, :, 0]

        return dec_out

    def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec, edge_index=None):
        B, L, D = x_enc.shape  # B: batch size, L: length of sequence, D: number of features
        # print(f"x_enc shape: {x_enc.shape}")

        outputs = []
        for i in range(D):
            x_enc_var = x_enc[:, :, i].unsqueeze(-1)

            dec_out_var = self.process_single_variate(x_enc_var, x_mark_enc, x_dec, x_mark_dec, edge_index)

            outputs.append(dec_out_var)

        dec_out = torch.stack(outputs, dim=-1)

        return dec_out

    def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, edge_index=None):

        dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec, edge_index=edge_index)

        dec_out = dec_out[:, -self.pred_len:, :]

        return dec_out
