diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index 0a8c4e290..6faa98a4e 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -36,9 +36,9 @@ export AGNNConv, NNConv, ResGatedGraphConv, SAGEConv, - SGConv + SGConv, # TAGConv, - # TransformerConv + TransformerConv include("layers/temporalconv.jl") export TGCN, @@ -49,4 +49,4 @@ export TGCN, EvolveGCNO end #module - \ No newline at end of file + diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index f0b51066b..0e97fecfa 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -845,6 +845,124 @@ function Base.show(io::IO, l::ResGatedGraphConv) print(io, ")") end +@concrete struct TransformerConv <: GNNContainerLayer{(:W1, :W2, :W3, :W4, :W5, :W6, :FF, :BN1, :BN2)} + in_dims::NTuple{2, Int} + out_dims::Int + heads::Int + add_self_loops::Bool + concat::Bool + skip_connection::Bool + sqrt_out::Float32 + W1 + W2 + W3 + W4 + W5 + W6 + FF + BN1 + BN2 +end + +function TransformerConv(ch::Pair{Int, Int}, args...; kws...) + return TransformerConv((ch[1], 0) => ch[2], args...; kws...) +end + +function TransformerConv(ch::Pair{NTuple{2, Int}, Int}; + heads::Int = 1, + concat::Bool = true, + init_weight = glorot_uniform, + init_bias = zeros32, + add_self_loops::Bool = false, + bias_qkv = true, + bias_root::Bool = true, + root_weight::Bool = true, + gating::Bool = false, + skip_connection::Bool = false, + batch_norm::Bool = false, + ff_channels::Int = 0) + (in, ein), out = ch + + if add_self_loops + @assert iszero(ein) "Using edge features and setting add_self_loops=true at the same time is not yet supported." + end + + if skip_connection + @assert in == (concat ? out * heads : out) "In-channels must correspond to out-channels * heads (or just out_channels if concat=false) if skip_connection is used" + end + + W1 = root_weight ? Dense(in => out * (concat ? heads : 1); use_bias=bias_root, init_weight, init_bias) : nothing + W2 = Dense(in => out * heads; use_bias=bias_qkv, init_weight, init_bias) + W3 = Dense(in => out * heads; use_bias=bias_qkv, init_weight, init_bias) + W4 = Dense(in => out * heads; use_bias=bias_qkv, init_weight, init_bias) + out_mha = out * (concat ? heads : 1) + W5 = gating ? Dense(3 * out_mha => 1, sigmoid; use_bias=false, init_weight, init_bias) : nothing + W6 = ein > 0 ? Dense(ein => out * heads; use_bias=bias_qkv, init_weight, init_bias) : nothing + FF = ff_channels > 0 ? + Chain(Dense(out_mha => ff_channels, relu; init_weight, init_bias), + Dense(ff_channels => out_mha; init_weight, init_bias)) : nothing + BN1 = batch_norm ? BatchNorm(out_mha) : nothing + BN2 = (batch_norm && ff_channels > 0) ? BatchNorm(out_mha) : nothing + + return TransformerConv((in, ein), out, heads, add_self_loops, concat, + skip_connection, Float32(√out), W1, W2, W3, W4, W5, W6, FF, BN1, BN2) +end + +LuxCore.outputsize(l::TransformerConv) = (l.concat ? l.out_dims * l.heads : l.out_dims,) + +function (l::TransformerConv)(g, x, ps, st) + return l(g, x, nothing, ps, st) +end + +function (l::TransformerConv)(g, x, e, ps, st) + W1 = l.W1 === nothing ? nothing : + StatefulLuxLayer{true}(l.W1, ps.W1, _getstate(st, :W1)) + W2 = StatefulLuxLayer{true}(l.W2, ps.W2, _getstate(st, :W2)) + W3 = StatefulLuxLayer{true}(l.W3, ps.W3, _getstate(st, :W3)) + W4 = StatefulLuxLayer{true}(l.W4, ps.W4, _getstate(st, :W4)) + W5 = l.W5 === nothing ? nothing : + StatefulLuxLayer{true}(l.W5, ps.W5, _getstate(st, :W5)) + W6 = l.W6 === nothing ? nothing : + StatefulLuxLayer{true}(l.W6, ps.W6, _getstate(st, :W6)) + FF = l.FF === nothing ? nothing : + StatefulLuxLayer{true}(l.FF, ps.FF, _getstate(st, :FF)) + BN1 = l.BN1 === nothing ? nothing : + StatefulLuxLayer{true}(l.BN1, ps.BN1, _getstate(st, :BN1)) + BN2 = l.BN2 === nothing ? nothing : + StatefulLuxLayer{true}(l.BN2, ps.BN2, _getstate(st, :BN2)) + m = (; W1, W2, W3, W4, W5, W6, FF, BN1, BN2, l.sqrt_out, + l.heads, l.concat, l.skip_connection, l.add_self_loops, l.in_dims, l.out_dims) + return GNNlib.transformer_conv(m, g, x, e), st +end + +function LuxCore.parameterlength(l::TransformerConv) + n = parameterlength(l.W2) + parameterlength(l.W3) + parameterlength(l.W4) + n += l.W1 === nothing ? 0 : parameterlength(l.W1) + n += l.W5 === nothing ? 0 : parameterlength(l.W5) + n += l.W6 === nothing ? 0 : parameterlength(l.W6) + n += l.FF === nothing ? 0 : parameterlength(l.FF) + n += l.BN1 === nothing ? 0 : parameterlength(l.BN1) + n += l.BN2 === nothing ? 0 : parameterlength(l.BN2) + return n +end + +function LuxCore.statelength(l::TransformerConv) + n = statelength(l.W2) + statelength(l.W3) + statelength(l.W4) + n += l.W1 === nothing ? 0 : statelength(l.W1) + n += l.W5 === nothing ? 0 : statelength(l.W5) + n += l.W6 === nothing ? 0 : statelength(l.W6) + n += l.FF === nothing ? 0 : statelength(l.FF) + n += l.BN1 === nothing ? 0 : statelength(l.BN1) + n += l.BN2 === nothing ? 0 : statelength(l.BN2) + return n +end + +function Base.show(io::IO, l::TransformerConv) + (in, ein), out = (l.in_dims, l.out_dims) + print(io, "TransformerConv(($in, $ein) => $out, heads=$(l.heads))") +end + + @concrete struct SAGEConv <: GNNLayer in_dims::Int out_dims::Int @@ -891,4 +1009,4 @@ function (l::SAGEConv)(g, x, ps, st) m = (; ps.weight, bias = _getbias(ps), l.σ, l.aggr) return GNNlib.sage_conv(m, g, x), st -end \ No newline at end of file +end diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index 4f871b64e..162d0aafd 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -5,6 +5,18 @@ out_dims = 5 x = randn(rng, Float32, in_dims, 10) + @testset "TransformerConv" begin + x = randn(rng, Float32, 6, 10) + ein = 2 + e = randn(rng, Float32, ein, g.num_edges) + + l = TransformerConv((6, ein) => 8, heads = 2, gating = true, bias_qkv = true) + test_lux_layer(rng, l, g, x, outputsize = (16,), e = e, container = true) + + # l = TransformerConv((16, ein) => 16, heads = 2, concat = false, skip_connection = true) + # test_lux_layer(rng, l, g, x, outputsize = (16,), e = e, container = true) + end + @testset "GCNConv" begin l = GCNConv(in_dims => out_dims, tanh) test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index e310fa81c..9644d63aa 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -559,7 +559,7 @@ function transformer_conv(l, g::GNNGraph, x::AbstractMatrix, e::Union{AbstractM g = add_self_loops(g) end - out = l.channels[2] + out = l.out_dims heads = l.heads W1x = !isnothing(l.W1) ? l.W1(x) : nothing W2x = reshape(l.W2(x), out, heads, :)