Graph: Implement a MessagePassing layer in Pytorch Geometric
How to implement a custom MessagePassing
layer in Pytorch Geometric (PyG) ?
Before you start, something you need to know.
special_arguments
: e.g.x_j
,x_i
,edge_index_j
,edge_index_i
aggregate
: scatter_add, scatter_mean, scatter_min, scatter_max- PyG
MessagePassing
framework only works fornode_graph
.
|
|
MessagePassing in PyTorch Geometric
Principal
Message passing graph neural networks can be described as
$$ \mathbf{x}_{i}^{(k)}=\gamma^{(k)} (\mathbf{x} _{i}^{(k-1)}, \square _{j \in \mathcal{N}(i)} \phi^{(k)}(\mathbf{x} _{i}^{(k-1)}, \mathbf{x} _{j}^{(k-1)}, \mathbf{e} _{i, j})) $$
- $x^{k-1}$: node features of node $i$ in layer ($k$−1)
- $e_{j,i} \in R^D$: (optional) edge features from node $j$ to node $i$
- $\square$: aggregation method (permutation invariant function). i.e., mean, sum, max
- $\gamma$, $\phi$: differentiable functions, such as MLP
In Pytorch Geometric, self.propagate
will do the following:
execute
self.message
, $\phi$: construct the message of node pairs(x_i, x_j)
execute
self.aggregate
, $\square$, aggregate message from neigbors. Internally, the aggregate works like this1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
from torch_scatter import scatter_add num_nodes = 4 embed_size = 5 src = torch.randint(0, num_nodes, (num_nodes, embed_size)) src_index = torch.tensor([0,0,0,1,1,2,3,3]) tmp = torch.index_select(src, 0, src_index) # shape [num_edges, embed_size ] print("input: ") print(tmp) target_index = torch.tensor([1,2,3,3,0,0,0,2]) aggr = scatter_add(tmp, target_index, 0) # shape [num_nodes, embed_size] # print("agg out:") print(aggr) # behind the sence, torch.scatter_add is used # repeat the edge_index index2 = target_index.expand((embed_size, target_index.size(0))).T # same result by using torch.scatter_add aggr2 = torch.zeros(num_nodes, embed_size, dtype=tmp.dtype).scatter_add(0, index2, tmp)
see torch_scatter
- execute
self.update
, $\gamma$.
- update embedding of
Node i
with aggregated message , $i \in \mathcal{V}$ - e.g. aggregated neighbor message and self message
Aggregate
Propogate
when propogate
is called, the excution as follow:
__check_input__(**kwargs)
: checkSparseTensor
or not__collect__(**kwargs)
: Construct the message ofnode i
, $i \in \mathcal{V}$- Take care the direction of
message
.
flow='source_to_target
: $j \rightarrow i$, that’s $(j, i) \in \mathcal{E}$flow='target_to_source
: $i \rightarrow j$, that’s $(i, j) \in \mathcal{E}$
- Take care the direction of
- construct message data with variable name suffixed with
_i
,_j
x_j
,x_i
with shape:[num_edges, embed_size]
- Even more, try
z_i
,z_j
if you’ve defined them inpropogate
.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# example code # src: node_attr # args: arugments defined in `message()`, e.g, x_j, x_i # 1. direction i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1) out={} # 2. construct message x_j, x_i. Both with shape [num_edge, embed_size] for arg in args: if arg.endswith("_i") or arg.endswith("_j"): dim = j if arg[-2:] == '_j' else i index = edge_index[dim] out[arg] = src.index_select(0, index) out['edge_index_i'] = edge_index[i] out['edge_index_j'] = edge_index[j] # return out
- construct message data with variable name suffixed with
- generate
edge_index_j
,edge_index_i
- generate
- return a dict
message(**kwargs)
:- arguments: the output of
__collect__
, and kwargs inpropogate
. e.gx_j
,edge_attr
,size
- arguments: the output of
- construct
node i
's messages by using variables suffixed with_i
,_j
.
- construct
- that’s why your see arugments with suffix
_i
,_j
- that’s why your see arugments with suffix
aggregate(**kwargs)
- arguments: the output of step 3:
message
, and kwargs inpropogate
- arguments: the output of step 3:
- aggreate method: mean, add, max, min
update(**kwargs)
- arguments: the output of step 4:
aggregate
, and kwargs inpropogate
- arguments: the output of step 4:
- update
Code snippets of MessagePassing
. See full source code here
|
|
Code Example: GCN
|
|