Graph: Implement a MessagePassing layer in Pytorch Geometric
How to implement a custom MessagePassing layer in Pytorch Geometric (PyG) ?
Before you start, something you need to know.
special_arguments: e.g.x_j,x_i,edge_index_j,edge_index_iaggregate: scatter_add, scatter_mean, scatter_min, scatter_max- PyG
MessagePassingframework only works fornode_graph.
| |
MessagePassing in PyTorch Geometric
Principal
Message passing graph neural networks can be described as
$$ \mathbf{x}_{i}^{(k)}=\gamma^{(k)} (\mathbf{x} _{i}^{(k-1)}, \square _{j \in \mathcal{N}(i)} \phi^{(k)}(\mathbf{x} _{i}^{(k-1)}, \mathbf{x} _{j}^{(k-1)}, \mathbf{e} _{i, j})) $$
- $x^{k-1}$: node features of node $i$ in layer ($k$−1)
- $e_{j,i} \in R^D$: (optional) edge features from node $j$ to node $i$
- $\square$: aggregation method (permutation invariant function). i.e., mean, sum, max
- $\gamma$, $\phi$: differentiable functions, such as MLP
In Pytorch Geometric, self.propagate will do the following:
execute
self.message, $\phi$: construct the message of node pairs(x_i, x_j)execute
self.aggregate, $\square$, aggregate message from neigbors. Internally, the aggregate works like this1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20from torch_scatter import scatter_add num_nodes = 4 embed_size = 5 src = torch.randint(0, num_nodes, (num_nodes, embed_size)) src_index = torch.tensor([0,0,0,1,1,2,3,3]) tmp = torch.index_select(src, 0, src_index) # shape [num_edges, embed_size ] print("input: ") print(tmp) target_index = torch.tensor([1,2,3,3,0,0,0,2]) aggr = scatter_add(tmp, target_index, 0) # shape [num_nodes, embed_size] # print("agg out:") print(aggr) # behind the sence, torch.scatter_add is used # repeat the edge_index index2 = target_index.expand((embed_size, target_index.size(0))).T # same result by using torch.scatter_add aggr2 = torch.zeros(num_nodes, embed_size, dtype=tmp.dtype).scatter_add(0, index2, tmp)
see torch_scatter
- execute
self.update, $\gamma$.
- update embedding of
Node iwith aggregated message , $i \in \mathcal{V}$ - e.g. aggregated neighbor message and self message
Aggregate

Propogate
when propogate is called, the excution as follow:
__check_input__(**kwargs): checkSparseTensoror not__collect__(**kwargs): Construct the message ofnode i, $i \in \mathcal{V}$- Take care the direction of
message.
flow='source_to_target: $j \rightarrow i$, that’s $(j, i) \in \mathcal{E}$flow='target_to_source: $i \rightarrow j$, that’s $(i, j) \in \mathcal{E}$
- Take care the direction of
- construct message data with variable name suffixed with
_i,_j
x_j,x_iwith shape:[num_edges, embed_size]- Even more, try
z_i,z_jif you’ve defined them inpropogate.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17# example code # src: node_attr # args: arugments defined in `message()`, e.g, x_j, x_i # 1. direction i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1) out={} # 2. construct message x_j, x_i. Both with shape [num_edge, embed_size] for arg in args: if arg.endswith("_i") or arg.endswith("_j"): dim = j if arg[-2:] == '_j' else i index = edge_index[dim] out[arg] = src.index_select(0, index) out['edge_index_i'] = edge_index[i] out['edge_index_j'] = edge_index[j] # return out- construct message data with variable name suffixed with
- generate
edge_index_j,edge_index_i
- generate
- return a dict
message(**kwargs):- arguments: the output of
__collect__, and kwargs inpropogate. e.gx_j,edge_attr,size
- arguments: the output of
- construct
node i's messages by using variables suffixed with_i,_j.
- construct
- that’s why your see arugments with suffix
_i,_j
- that’s why your see arugments with suffix
aggregate(**kwargs)- arguments: the output of step 3:
message, and kwargs inpropogate
- arguments: the output of step 3:
- aggreate method: mean, add, max, min
update(**kwargs)- arguments: the output of step 4:
aggregate, and kwargs inpropogate
- arguments: the output of step 4:
- update
Code snippets of MessagePassing. See full source code here
| |
Code Example: GCN
| |