Skip to content

Commit

Permalink
Bug fix in GNN summary network
Browse files Browse the repository at this point in the history
  • Loading branch information
sainsbmd authored and sainsbmd committed Oct 27, 2024
1 parent cc0aa2c commit 260ff8e
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 100 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NeuralEstimators"
uuid = "38f6df31-6b4a-4144-b2af-7ace2da57606"
authors = ["Matthew Sainsbury-Dale <[email protected]> and contributors"]
version = "0.1.1"
version = "0.2.0"

[deps]
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
Expand Down Expand Up @@ -39,9 +39,9 @@ NeuralEstimatorsPlotExt = ["AlgebraOfGraphics", "CairoMakie", "ColorSchemes"]
AlgebraOfGraphics = "0.8"
BSON = "0.3"
CSV = "0.10"
ColorSchemes = "2, 3"
CairoMakie = "0.12"
CUDA = "4, 5"
CairoMakie = "0.12"
ColorSchemes = "2, 3"
DataFrames = "1"
Distances = "0.10, 0.11"
Flux = "0.14"
Expand Down
6 changes: 3 additions & 3 deletions docs/src/API/architectures.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ NeighbourhoodVariogram

## Layers

In addition to the [built-in layers](https://fluxml.ai/Flux.jl/stable/reference/models/layers/) provided by Flux, the following layers may be used when constructing a neural-network architecture.
In addition to the [built-in layers](https://fluxml.ai/Flux.jl/stable/reference/models/layers/) provided by Flux, the following layers may be used when building a neural-network architecture.

```@docs
DensePositive
Expand All @@ -45,14 +45,14 @@ SpatialGraphConv
```


# Output activation functions
## Output layers

```@index
Order = [:type, :function]
Pages = ["activationfunctions.md"]
```

In addition to the [standard activation functions](https://fluxml.ai/Flux.jl/stable/models/activation/) provided by Flux, the following structs can be used at the end of an architecture to act as output activation functions that ensure valid estimates for certain models. **NB:** Although we refer to the following objects as "activation functions", they should be treated as layers that are included in the final stage of a Flux `Chain()`.
In addition to the [standard activation functions](https://fluxml.ai/Flux.jl/stable/models/activation/) provided by Flux, the following layers can be used at the end of an architecture to act as "output activation functions" that ensure valid estimates for certain models. Note that, although we may conceptualise the following structs as "output activation functions", they should be treated as separate layers included in the final stage of a Flux `Chain()`. In particular, they cannot be used as the activation function of a `Dense` layer.

```@docs
Compress
Expand Down
28 changes: 14 additions & 14 deletions docs/src/workflow/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ plot(assessment)

![Univariate Gaussian example: Estimates vs. truth](../assets/figures/univariate.png)

As an alternative form of uncertainty quantification, one may approximate a set of marginal posterior quantiles by training a second estimator under the quantile loss function, which allows one to generate approximate marginal posterior credible intervals. This is facilitated with [`IntervalEstimator`](@ref) which, by default, targets 95% central credible intervals:
As an alternative form of uncertainty quantification, one may approximate a set of marginal posterior quantiles by training a neural Bayes estimator under the quantile loss function, which allows one to generate approximate marginal posterior credible intervals. This is facilitated with [`IntervalEstimator`](@ref) which, by default, targets 95% central credible intervals:

```
q̂ = IntervalEstimator(architecture)
Expand Down Expand Up @@ -325,10 +325,10 @@ simulate(parameters::Parameters, m::Integer = 1) = simulate(parameters, range(m,
Next we construct an appropriate GNN architecture, as illustrated below. Here, our goal is to construct a point estimator, however any other kind of estimator (see [Estimators](@ref)) can be constructed by simply substituting the appropriate estimator class in the final line below:

```
# Spatial weight function constructed using 0-1 basis functions
# Spatial weight functions: continuous surrogates for 0-1 basis functions
h_max = 0.15 # maximum distance to consider
q = 10 # output dimension of the spatial weights
w = IndicatorWeights(h_max, q)
w = KernelWeights(h_max, q)
# Propagation module
propagation = GNNChain(
Expand All @@ -339,11 +339,11 @@ propagation = GNNChain(
# Readout module
readout = GlobalPool(mean)
# Global features
globalfeatures = SpatialGraphConv(1 => q, relu, w = w, w_out = q, glob = true)
# Summary network
ψ = GNNSummary(propagation, readout, globalfeatures)
ψ = GNNSummary(propagation, readout)
# Expert summary statistics, the empirical variogram
S = NeighbourhoodVariogram(h_max, q)
# Mapping module
ϕ = Chain(
Expand All @@ -353,7 +353,7 @@ globalfeatures = SpatialGraphConv(1 => q, relu, w = w, w_out = q, glob = true)
)
# DeepSet object
deepset = DeepSet(ψ, ϕ)
deepset = DeepSet(ψ, ϕ; S = S)
# Point estimator
θ̂ = PointEstimator(deepset)
Expand All @@ -363,10 +363,10 @@ Next, we train the estimator:

```
m = 1
K = 3000
K = 5000
θ_train = sample(K)
θ_val = sample(K÷5)
θ̂ = train(θ̂, θ_train, θ_val, simulate, m = m, epochs = 5)
θ̂ = train(θ̂, θ_train, θ_val, simulate, m = m, epochs = 20)
```

Then, we assess our trained estimator as before:
Expand All @@ -375,15 +375,15 @@ Then, we assess our trained estimator as before:
θ_test = sample(1000)
Z_test = simulate(θ_test, m)
assessment = assess(θ̂, θ_test, Z_test)
bias(assessment) # 0.001
rmse(assessment) # 0.037
risk(assessment) # 0.029
bias(assessment)
rmse(assessment)
risk(assessment)
plot(assessment)
```

![Estimates from a graph neural network (GNN) based neural Bayes estimator](../assets/figures/spatial.png)

Finally, once the estimator has been assessed and is deemed to be performant, it may be applied to observed data, with bootstrap-based uncertainty quantification facilitated by [`bootstrap`](@ref) and [`interval`](@ref). Below, we use simulated data as a substitute for observed data:
Finally, once the estimator has been assessed, it may be applied to observed data, with bootstrap-based uncertainty quantification facilitated by [`bootstrap`](@ref) and [`interval`](@ref). Below, we use simulated data as a substitute for observed data:

```
parameters = sample(1) # sample a single parameter vector
Expand Down
9 changes: 4 additions & 5 deletions src/Architectures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,6 @@ function (d::DeepSet)(tup::Tup) where {Tup <: Tuple{V₁, V₂}} where {V₁ <:
reduce(hcat, vec.(permutedims.(result)))
end

#TODO document summarystatistics()

# Fallback method to allow neural estimators to be called directly
summarystatistics(est, Z) = summarystatistics(est.deepset, Z)
# Single data set
Expand Down Expand Up @@ -260,7 +258,7 @@ function summarystatistics(d::DeepSet, Z::V) where {V <: AbstractVector{A}} wher

return t
else
# Array sizes differ, so therefor cannot stack together; use simple (and slower) broadcasting method (identical to general fallback method defined above)
# Array sizes differ, so therefore cannot stack together; use simple (and slower) broadcasting method (identical to general fallback method defined above)
return summarystatistics.(Ref(d), Z)
end
end
Expand All @@ -278,10 +276,11 @@ function summarystatistics(d::DeepSet, Z::V) where {V <: AbstractVector{G}} wher
# independent replicate), record the grouping of independent replicates
# so that they can be combined again later in the function
m = numberreplicates.(Z)
g = @ignore_derivatives Flux.batch(Z) # NB batch() causes array mutation, so do not attempt to compute derivatives through this call

# Propagation and readout
g = @ignore_derivatives Flux.batch(Z) # NB batch() causes array mutation, so do not attempt to compute derivatives through this call
R = d.ψ(g)
#R = stackarrays(d.ψ.(Z)) # Version that doesn't require us to use ignore_derivatives, but much slower (and haven't tested with m>1 independent replicates)

# Split R based on the original vector of data sets Z
if ndims(R) == 2
Expand Down Expand Up @@ -313,7 +312,7 @@ function summarystatistics(d::DeepSet, Z::V) where {V <: AbstractVector{G}} wher
return t
end

# TODO For graph data, currently not allowed to have data sets with variable number of independent replicates, since in this case we can't stack the three-dimensional arrays:
# TODO For graph data, currently not allowed to have data sets with variable numbers of independent replicates, since in this case we can't stack the three-dimensional arrays:
# θ = sample(2)
# g = simulate(θ, 5)
# g = Flux.batch(g)
Expand Down
128 changes: 58 additions & 70 deletions src/Graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,59 @@ end
Flux.trainable(l::IndicatorWeights) = ()


@doc raw"""
KernelWeights(h_max, n_bins::Integer)
(w::KernelWeights)(h::Matrix)
For spatial locations $\boldsymbol{s}$ and $\boldsymbol{u}$, creates a spatial weight function defined as
```math
\boldsymbol{w}(\boldsymbol{s}, \boldsymbol{u}) \equiv (\exp(-(h - \mu_k)^2 / (2\sigma_k^2)) : k = 1, \dots, K)',
```
where $h \equiv \|\boldsymbol{s} - \boldsymbol{u}\|$ is the spatial distance between $\boldsymbol{s}$ and $\boldsymbol{u}$, and ${\mu_k : k = 1, \dots, K}$ and ${\sigma_k : k = 1, \dots, K}$ are the means and standard deviations of the Gaussian kernels for each bin, covering the spatial distances between 0 and h_max.
# Examples
```
using NeuralEstimators
h_max = 1
n_bins = 10
w = KernelWeights(h_max, n_bins)
h = rand(1, 30) # distances between 30 pairs of spatial locations
w(h)
```
"""
struct KernelWeights
mu
sigma
end
function KernelWeights(h_max, n_bins::Integer)
h_cutoffs = range(0, stop=h_max, length=n_bins+1)
h_cutoffs = collect(h_cutoffs)
mu = [(h_cutoffs[i] + h_cutoffs[i+1]) / 2 for i in 1:n_bins] # midpoints of the intervals
sigma = [(h_cutoffs[i+1] - h_cutoffs[i]) / 4 for i in 1:n_bins] # std dev so that 95% of mass is within the bin
mu = Float32.(mu)
sigma = Float32.(sigma)
KernelWeights(mu, sigma)
end
function (l::KernelWeights)(h::M) where M <: AbstractMatrix{T} where T
mu = l.mu
sigma = l.sigma
N = [exp.(-(h .- mu[i:i]).^2 ./ (2 * sigma[i:i].^2)) for i in eachindex(mu)] # Gaussian kernel for each bin (NB avoid scalar indexing by i:i)
N = reduce(vcat, N)
Float32.(N)
end
@layer KernelWeights
Flux.trainable(l::KernelWeights) = ()


# ---- GraphConv ----

# 3D array version of GraphConv to allow the option to forego spatial information
"""
(l::GraphConv)(g::GNNGraph, x::A) where A <: AbstractArray{T, 3} where {T}
Given a graph with node features a three dimensional array of size `in` × m × n,
Given a graph with node features consisting of a three dimensional array of size `in` × m × n,
where n is the number of nodes in the graph, this method yields an array with
dimensions `out` × m × n.
Expand Down Expand Up @@ -283,14 +329,13 @@ although a custom choice for this function can be provided using the keyword arg
- `w = nothing`
- `w_width = 128`: (Only applicable if `w = nothing`) The width of the hidden layer in the MLP used to model $\boldsymbol{w}(\cdot, \cdot)$.
- `w_out = in`: (Only applicable if `w = nothing`) The output dimension of $\boldsymbol{w}(\cdot, \cdot)$.
- `glob = false`: If `true`, global features will be computed directly from the entire spatial graph. These features are of the form: $\boldsymbol{T} = \sum_{j=1}^n\sum_{j' \in \mathcal{N}(j)}\boldsymbol{w}^{(l)}(\|\boldsymbol{s}_{j'} - \boldsymbol{s}_j\|) \odot f^{(l)}(\boldsymbol{h}^{(l-1)}_{j}, \boldsymbol{h}^{(l-1)}_{j'})$. Note that these global features are no longer associated with a graph structure, and should therefore only be used in the final layer of a summary-statistics module.
# Examples
```
using NeuralEstimators, Flux, GraphNeuralNetworks
# Toy spatial data
m = 5 # number of replicates
m = 5 # number of independent replicates
d = 2 # spatial dimension
n = 250 # number of spatial locations
S = rand(n, d) # spatial locations
Expand All @@ -309,7 +354,6 @@ struct SpatialGraphConv{W<:AbstractMatrix, A, B,C, F} <: GNNLayer
w::A
f::C
g::F
glob::Bool
end
@layer SpatialGraphConv
WeightedGraphConv = SpatialGraphConv; export WeightedGraphConv # alias for backwards compatability
Expand All @@ -321,8 +365,7 @@ function SpatialGraphConv(
w = nothing,
f = nothing,
w_out::Union{Integer, Nothing} = nothing,
w_width::Integer = 128,
glob::Bool = false
w_width::Integer = 128
)

in, out = ch
Expand Down Expand Up @@ -359,16 +402,7 @@ function SpatialGraphConv(
# Bias vector
b = bias ? Flux.create_bias(Γ1, true, out) : false

SpatialGraphConv(Γ1, Γ2, b, w, f, g, glob)
end
function (l::SpatialGraphConv)(g::GNNGraph)
Z = :Z keys(g.ndata) ? g.ndata.Z : first(values(g.ndata))
h = l(g, Z)
if l.glob
@ignore_derivatives GNNGraph(g, gdata = (g.gdata..., R = h))
else
@ignore_derivatives GNNGraph(g, ndata = (g.ndata..., Z = h))
end
SpatialGraphConv(Γ1, Γ2, b, w, f, g)
end
function (l::SpatialGraphConv)(g::GNNGraph, x::M) where M <: AbstractMatrix{T} where {T}
l(g, reshape(x, size(x, 1), 1, size(x, 2)))
Expand All @@ -395,39 +429,17 @@ function (l::SpatialGraphConv)(g::GNNGraph, x::A) where A <: AbstractArray{T, 3}
# 3. Vector output with vector input features, in which case the dimensionalities must match
w = l.w(s)

if l.glob
= normalise_edges(g, w) # Sanity check: sum(w̃; dims = 2) # all close to one
else
= normalise_edge_neighbors(g, w) # Sanity check: aggregate_neighbors(g, +, w̃) # zeros and ones
end
= normalise_edge_neighbors(g, w) # Sanity check: aggregate_neighbors(g, +, w̃) # zeros and ones

# Coerce to three-dimensional array, repeated to match the number of independent replicates
= coerce3Darray(w̃, m)

# Compute spatially-weighted sum of input features over each neighbourhood
msg = apply_edges((l, xi, xj, w̃) ->.* l.f(xi, xj), g, l, x, x, w̃)
if l.glob
= reduce_edges(+, g, msg) # sum over all neighbourhoods in the graph
else
#TODO Need this to be a summation that ignores missing
= aggregate_neighbors(g, +, msg) # sum over each neighbourhood
end
#msg = apply_edges((l, xi, xj, w̃) -> w̃ .* l.f(xi, xj), g, l, x, x, w̃)
msg = apply_edges((xi, xj, w̃) ->.* l.f(xi, xj), g, x, x, w̃)
= aggregate_neighbors(g, +, msg) # sum over each neighbourhood individually

# Remove elements in which w summed to zero (i.e., deal with possible division by zero by omitting these terms from the convolution)
# (currently only do this for locally constructed summary statistics)
# if !l.glob
# w_sums = aggregate_neighbors(g, +, w)
# w_zero = w_sums .== 0
# w_zero = coerce3Darray(w_zero, m)
# h̄ = removedata(h̄, vec(w_zero))
# end

if l.glob
return
else
return l.g.(l.Γ1 x .+ l.Γ2 .+ l.b) # ⊠ is shorthand for batched_mul #NB any missingness will cause the feature vector to be entirely missing
#return [ismissing(a) ? missing : l.g(a) for a in x .+ h̄ .+ l.b]
end
l.g.(l.Γ1 x .+ l.Γ2 .+ l.b) # ⊠ is shorthand for batched_mul #NB any missingness will cause the feature vector to be entirely missing
end
function Base.show(io::IO, l::SpatialGraphConv)
in_channel = size(l.Γ1, ndims(l.Γ1))
Expand Down Expand Up @@ -471,20 +483,14 @@ Normalise the edge features `e` to sum to one over each node's neighborhood,
```
"""
function normalise_edge_neighbors(g::AbstractGNNGraph, e)
if g isa GNNHeteroGraph
for (key, value) in g.num_edges
@assert size(e)[end] == value
end
else
@assert size(e)[end] == g.num_edges
end
@assert size(e)[end] == g.num_edges
s, t = edge_index(g)
den = gather(scatter(+, e, t), t)
return e ./ (den .+ eps(eltype(e)))
end

@doc raw"""
GNNSummary(propagation, readout; globalfeatures = nothing)
GNNSummary(propagation, readout)
A graph neural network (GNN) module designed to serve as the summary network `ψ`
in the [`DeepSet`](@ref) representation when the data are graphical (e.g.,
Expand All @@ -496,13 +502,6 @@ a single hidden feature vector of fixed length (i.e., a vector of summary
statistics). The summary network is then defined as the composition of the
propagation and readout modules.
Optionally, one may also include a module that extracts features directly
from the graph, through the keyword argument `globalfeatures`. This module,
when applied to a `GNNGraph`, should return a matrix of features,
where the columns of the matrix correspond to the independent replicates
(e.g., a 5x10 matrix is expected for 5 hidden features for each of 10
independent replicates stored in the graph).
The data should be stored as a `GNNGraph` or `Vector{GNNGraph}`, where
each graph is associated with a single parameter vector. The graphs may contain
subgraphs corresponding to independent replicates.
Expand Down Expand Up @@ -544,12 +543,10 @@ g₃ = batch([g₁, g₂])
θ̂([g₁, g₂, g₃])
```
"""
struct GNNSummary{F, G, H}
struct GNNSummary{F, G}
propagation::F # propagation module
readout::G # readout module
globalfeatures::H
end
GNNSummary(propagation, readout; globalfeatures = nothing) = GNNSummary(propagation, readout, globalfeatures)
@layer GNNSummary
Base.show(io::IO, D::GNNSummary) = print(io, "\nThe propagation and readout modules of a graph neural network (GNN), with a total of $(nparams(D)) trainable parameters:\n\nPropagation module ($(nparams(D.propagation)) parameters): $(D.propagation)\n\nReadout module ($(nparams(D.readout)) parameters): $(D.readout)")

Expand All @@ -565,15 +562,6 @@ function (ψ::GNNSummary)(g::GNNGraph)
# ncols = number of independent replicates
R = ψ.readout(h, Z)

if !isnothing.globalfeatures)
R₂ = ψ.globalfeatures(g)
if isa(R₂, GNNGraph)
@assert length(R₂.gdata) > 0 "The `globalfeatures` field of a `GNNSummary` object must return either an array or a graph with a non-empty field `gdata`"
R₂ = first(values(R₂.gdata))
end
R = vcat(R, R₂)
end

# Reshape from three-dimensional array to matrix
R = reshape(R, size(R, 1), :) #NB not ideal to do this here, I think, makes the output of summarystatistics() quite confusing. (keep in mind the behaviour of summarystatistics on a vector of graphs and a single graph)

Expand Down
Loading

0 comments on commit 260ff8e

Please sign in to comment.