Skip to content

Commit

Permalink
MLUtils and Flux v0.13 compatibility (#155)
Browse files Browse the repository at this point in the history
* MLUtils and Flux v0.13 compatibility

* bimp version

* cleanup

* cleanup

* import MLUtils

* using numobs, getobs
  • Loading branch information
CarloLucibello authored Apr 9, 2022
1 parent 2801b51 commit c7d0afe
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 153 deletions.
14 changes: 7 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "GraphNeuralNetworks"
uuid = "cffab07f-9bc2-4db1-8861-388f63bf7694"
authors = ["Carlo Lucibello and contributors"]
version = "0.3.15"
version = "0.4.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -12,8 +12,8 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
LearnBase = "7f8f8fb0-2700-5f03-b4bd-41f8cfc144b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
Expand All @@ -29,17 +29,17 @@ Adapt = "3"
CUDA = "3.3"
ChainRulesCore = "1"
DataStructures = "0.18"
Flux = "0.12.7"
Flux = "0.13"
Functors = "0.2"
Graphs = "1.4"
KrylovKit = "0.5"
LearnBase = "0.4, 0.5, 0.6"
MLUtils = "0.2.3"
MacroTools = "0.5"
NNlib = "0.7, 0.8"
NNlibCUDA = "0.1, 0.2"
NNlib = "0.8"
NNlibCUDA = "0.2"
NearestNeighbors = "0.4"
Reexport = "1"
StatsBase = "0.32, 0.33"
StatsBase = "0.33"
julia = "1.6"

[extras]
Expand Down
3 changes: 2 additions & 1 deletion src/GNNGraphs/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ import Flux
using Flux: batch
import NearestNeighbors
import NNlib
import LearnBase
import StatsBase
import KrylovKit
using ChainRulesCore
using LinearAlgebra, Random, Statistics
import MLUtils
using MLUtils: getobs, numobs

include("gnngraph.jl")
export GNNGraph,
Expand Down
17 changes: 5 additions & 12 deletions src/GNNGraphs/gnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,21 +223,14 @@ function Base.show(io::IO, g::GNNGraph)
end
end

### StatsBase/LearnBase compatibility
StatsBase.nobs(g::GNNGraph) = g.num_graphs
LearnBase.getobs(g::GNNGraph, i) = getgraph(g, i)

# Flux's Dataloader compatibility. Related PR https://github.com/FluxML/Flux.jl/pull/1683
Flux.Data._nobs(g::GNNGraph) = g.num_graphs
Flux.Data._getobs(g::GNNGraph, i) = getgraph(g, i)
MLUtils.numobs(g::GNNGraph) = g.num_graphs
MLUtils.getobs(g::GNNGraph, i) = getgraph(g, i)

# DataLoader compatibility passing a vector of graphs and
# effectively using `batch` as a collated function.
StatsBase.nobs(data::Vector{<:GNNGraph}) = length(data)
LearnBase.getobs(data::Vector{<:GNNGraph}, i::Int) = data[i]
LearnBase.getobs(data::Vector{<:GNNGraph}, i) = Flux.batch(data[i])
Flux.Data._nobs(g::Vector{<:GNNGraph}) = StatsBase.nobs(g)
Flux.Data._getobs(g::Vector{<:GNNGraph}, i) = LearnBase.getobs(g, i)
MLUtils.numobs(data::Vector{<:GNNGraph}) = length(data)
MLUtils.getobs(data::Vector{<:GNNGraph}, i::Int) = data[i]
MLUtils.getobs(data::Vector{<:GNNGraph}, i) = Flux.batch(data[i])


#########################
Expand Down
61 changes: 0 additions & 61 deletions src/GNNGraphs/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,64 +156,3 @@ binarize(x) = map(>(0), x)
@non_differentiable binarize(x...)
@non_differentiable edge_encoding(x...)
@non_differentiable edge_decoding(x...)



####################################
# FROM MLBASE.jl
# https://github.com/JuliaML/MLBase.jl/pull/1/files
# remove when package is registered
##############################################

numobs(A::AbstractArray{<:Any, N}) where {N} = size(A, N)

# 0-dim arrays
numobs(A::AbstractArray{<:Any, 0}) = 1

function getobs(A::AbstractArray{<:Any, N}, idx) where N
I = ntuple(_ -> :, N-1)
return A[I..., idx]
end

getobs(A::AbstractArray{<:Any, 0}, idx) = A[idx]

function getobs!(buffer::AbstractArray, A::AbstractArray{<:Any, N}, idx) where N
I = ntuple(_ -> :, N-1)
buffer .= A[I..., idx]
return buffer
end

# --------------------------------------------------------------------
# Tuples and NamedTuples

_check_numobs_error() =
throw(DimensionMismatch("All data containers must have the same number of observations."))

function _check_numobs(tup::Union{Tuple, NamedTuple})
length(tup) == 0 && return
n1 = numobs(tup[1])
for i=2:length(tup)
numobs(tup[i]) != n1 && _check_numobs_error()
end
end

function numobs(tup::Union{Tuple, NamedTuple})::Int
_check_numobs(tup)
return length(tup) == 0 ? 0 : numobs(tup[1])
end

function getobs(tup::Union{Tuple, NamedTuple}, indices)
_check_numobs(tup)
return map(x -> getobs(x, indices), tup)
end

function getobs!(buffers::Union{Tuple, NamedTuple},
tup::Union{Tuple, NamedTuple},
indices)
_check_numobs(tup)

return map(buffers, tup) do buffer, x
getobs!(buffer, x, indices)
end
end
#######################################################
33 changes: 0 additions & 33 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
@@ -1,33 +0,0 @@
## Deprecated in v0.2

function compute_message end
function update_node end
function update_edge end

compute_message(l, xi, xj, e) = compute_message(l, xi, xj)
update_node(l, x, m̄) =
update_edge(l, e, m) = e

function propagate(l::GNNLayer, g::GNNGraph, aggr, x, e=nothing)
@warn """
Passing a GNNLayer to propagate is deprecated,
you should pass the message function directly.
The new signature is `propagate(f, g, aggr; [xi, xj, e])`.
The functions `compute_message`, `update_node`,
and `update_edge` have been deprecated as well. Please
refer to the documentation.
"""
m = apply_edges((a...) -> compute_message(l, a...), g, x, x, e)
= aggregate_neighbors(g, aggr, m)
x = update_node(l, x, m̄)
e = update_edge(l, e, m)
return x, e
end

## Deprecated in v0.3

@deprecate copyxj(xi, xj, e) copy_xj(xi, xj, e)

@deprecate CGConv(nin::Int, ein::Int, out::Int, args...; kws...) CGConv((nin, ein) => out, args...; kws...)

14 changes: 7 additions & 7 deletions test/GNNGraphs/gnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@
@test_throws AssertionError rand_graph(10, 30, ndata=1, graph_type=GRAPH_T)
end

@testset "LearnBase and DataLoader compat" begin
@testset "MLUtils and DataLoader compat" begin
n, m, num_graphs = 10, 30, 50
X = rand(10, n)
E = rand(10, m)
Expand All @@ -255,18 +255,18 @@
g = Flux.batch(data)

@testset "batch then pass to dataloader" begin
@test LearnBase.getobs(g, 3) == getgraph(g, 3)
@test LearnBase.getobs(g, 3:5) == getgraph(g, 3:5)
@test StatsBase.nobs(g) == g.num_graphs
@test MLUtils.getobs(g, 3) == getgraph(g, 3)
@test MLUtils.getobs(g, 3:5) == getgraph(g, 3:5)
@test MLUtils.numobs(g) == g.num_graphs

d = Flux.Data.DataLoader(g, batchsize=2, shuffle=false)
@test first(d) == getgraph(g, 1:2)
end

@testset "pass to dataloader and collate" begin
@test LearnBase.getobs(data, 3) == getgraph(g, 3)
@test LearnBase.getobs(data, 3:5) == getgraph(g, 3:5)
@test StatsBase.nobs(data) == g.num_graphs
@test MLUtils.getobs(data, 3) == getgraph(g, 3)
@test MLUtils.getobs(data, 3:5) == getgraph(g, 3:5)
@test MLUtils.numobs(data) == g.num_graphs

d = Flux.Data.DataLoader(data, batchsize=2, shuffle=false)
@test first(d) == getgraph(g, 1:2)
Expand Down
31 changes: 0 additions & 31 deletions test/deprecations.jl
Original file line number Diff line number Diff line change
@@ -1,34 +1,3 @@
@testset "deprecations" begin
@testset "propagate" begin
struct GCN{A<:AbstractMatrix, B, F} <: GNNLayer
weight::A
bias::B
σ::F
end

Flux.@functor GCN # allow collecting params, gpu movement, etc...

function GCN(ch::Pair{Int,Int}, σ=identity)
in, out = ch
W = Flux.glorot_uniform(out, in)
b = zeros(Float32, out)
GCN(W, b, σ)
end

GraphNeuralNetworks.compute_message(l::GCN, xi, xj, e) = xj

function (l::GCN)(g::GNNGraph, x::AbstractMatrix{T}) where T
x, _ = propagate(l, g, +, x)
return l.σ.(l.weight * x .+ l.bias)
end

function new_forward(l, g, x)
x = propagate(copy_xj, g, +, xj=x)
return l.σ.(l.weight * x .+ l.bias)
end

g = GNNGraph(random_regular_graph(10, 4), ndata=randn(3, 10))
l = GCN(3 => 5, tanh)
@test l(g, g.ndata.x) new_forward(l, g, g.ndata.x)
end
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using CUDA
using Flux: gpu, @functor
using LinearAlgebra, Statistics, Random
using NNlib
using LearnBase
import MLUtils
import StatsBase
using SparseArrays
using Graphs
Expand Down

2 comments on commit c7d0afe

@CarloLucibello
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/58227

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.0 -m "<description of version>" c7d0afee487f72a762d03c5fafb145a6f49be23b
git push origin v0.4.0

Please sign in to comment.