Contents

Graph: Implement a MessagePassing layer in Pytorch Geometric

How to implement a custom MessagePassing layer in Pytorch Geometric (PyG) ?

Before you start, something you need to know.

  • special_arguments: e.g. x_j, x_i, edge_index_j, edge_index_i
  • aggregate: scatter_add, scatter_mean, scatter_min, scatter_max
  • PyG MessagePassing framework only works for node_graph.
1
2
3
4
5
x = ...           # Node features of shape [num_nodes, num_features]
edge_index = ...  # Edge indices of shape [2, num_edges]

x_j = x[edge_index[0]]  # Source node features [num_edges, num_features]
x_i = x[edge_index[1]]  # Target node features [num_edges, num_features]

MessagePassing in PyTorch Geometric

Principal

Message passing graph neural networks can be described as

$$ \mathbf{x}_{i}^{(k)}=\gamma^{(k)} (\mathbf{x} _{i}^{(k-1)}, \square _{j \in \mathcal{N}(i)} \phi^{(k)}(\mathbf{x} _{i}^{(k-1)}, \mathbf{x} _{j}^{(k-1)}, \mathbf{e} _{i, j})) $$

  • $x^{k-1}$: node features of node $i$ in layer ($k$−1)
  • $e_{j,i} \in R^D$: (optional) edge features from node $j$ to node $i$
  • $\square$: aggregation method (permutation invariant function). i.e., mean, sum, max
  • $\gamma$, $\phi$: differentiable functions, such as MLP

In Pytorch Geometric, self.propagate will do the following:

  1. execute self.message, $\phi$: construct the message of node pairs (x_i, x_j)

  2. execute self.aggregate, $\square$, aggregate message from neigbors. Internally, the aggregate works like this

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    
     from torch_scatter import scatter_add
     num_nodes = 4
     embed_size = 5
    
     src = torch.randint(0, num_nodes, (num_nodes, embed_size))
     src_index = torch.tensor([0,0,0,1,1,2,3,3])
     tmp = torch.index_select(src, 0, src_index) # shape [num_edges, embed_size ]
     print("input: ")
     print(tmp)
     target_index = torch.tensor([1,2,3,3,0,0,0,2])
     aggr = scatter_add(tmp, target_index, 0) # shape [num_nodes, embed_size]
     # 
     print("agg out:")
     print(aggr)
        
     # behind the sence, torch.scatter_add is used
     # repeat the edge_index
     index2 = target_index.expand((embed_size, target_index.size(0))).T
     # same result by using torch.scatter_add
     aggr2 = torch.zeros(num_nodes, embed_size, dtype=tmp.dtype).scatter_add(0, index2, tmp)
    

see torch_scatter

  1. execute self.update, $\gamma$.
  • update embedding of Node i with aggregated message , $i \in \mathcal{V}$
  • e.g. aggregated neighbor message and self message

Aggregate

/images/ml/PyG-MessagePassing.png

Propogate

when propogate is called, the excution as follow:

  1. __check_input__(**kwargs): check SparseTensor or not

  2. __collect__(**kwargs): Construct the message of node i, $i \in \mathcal{V}$

      1. Take care the direction of message.
      • flow='source_to_target: $j \rightarrow i$, that’s $(j, i) \in \mathcal{E}$
      • flow='target_to_source: $i \rightarrow j$, that’s $(i, j) \in \mathcal{E}$
      1. construct message data with variable name suffixed with _i, _j
      • x_j, x_i with shape: [num_edges, embed_size]
      • Even more, try z_i, z_j if you’ve defined them in propogate.
       1
       2
       3
       4
       5
       6
       7
       8
       9
      10
      11
      12
      13
      14
      15
      16
      17
      
         # example code
         # src: node_attr
         # args: arugments defined in `message()`, e.g, x_j, x_i 
      
         # 1. direction
         i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1)
         out={}
         # 2. construct message x_j, x_i. Both with shape [num_edge, embed_size]
         for arg in args:
             if arg.endswith("_i") or arg.endswith("_j"):
                 dim = j if arg[-2:] == '_j' else i
                 index = edge_index[dim]
                 out[arg] = src.index_select(0, index)   
      
         out['edge_index_i'] = edge_index[i]
         out['edge_index_j'] = edge_index[j]     
         # return out        
      
      1. generate edge_index_j, edge_index_i
      1. return a dict
  3. message(**kwargs):

      1. arguments: the output of __collect__, and kwargs in propogate. e.g x_j, edge_attr, size
      1. construct node i's messages by using variables suffixed with _i, _j.
      1. that’s why your see arugments with suffix _i, _j
  4. aggregate(**kwargs)

      1. arguments: the output of step 3: message, and kwargs in propogate
      1. aggreate method: mean, add, max, min
  5. update(**kwargs)

      1. arguments: the output of step 4: aggregate, and kwargs in propogate
      1. update

Code snippets of MessagePassing. See full source code here

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80

import inspect
from inspect import Parameter

import torch
from torch import Tensor
from torch_sparse import SparseTensor

def __collect__(self, args, edge_index, size, kwargs):
    # 
    i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1)

    out = {}
    for arg in args:
        if arg[-2:] not in ['_i', '_j']:
            out[arg] = kwargs.get(arg, Parameter.empty)
        else:
            dim = 0 if arg[-2:] == '_j' else 1
            data = kwargs.get(arg[:-2], Parameter.empty)

            if isinstance(data, (tuple, list)):
                assert len(data) == 2
                if isinstance(data[1 - dim], Tensor):
                    self.__set_size__(size, 1 - dim, data[1 - dim])
                data = data[dim]

            if isinstance(data, Tensor):
                self.__set_size__(size, dim, data)
                data = self.__lift__(data, edge_index,
                                        j if arg[-2:] == '_j' else i)

            out[arg] = data

    if isinstance(edge_index, Tensor):
        out['adj_t'] = None
        out['edge_index'] = edge_index
        out['edge_index_i'] = edge_index[i]
        out['edge_index_j'] = edge_index[j]
        out['ptr'] = None
    elif isinstance(edge_index, SparseTensor):
        out['adj_t'] = edge_index
        out['edge_index'] = None
        out['edge_index_i'] = edge_index.storage.row()
        out['edge_index_j'] = edge_index.storage.col()
        out['ptr'] = edge_index.storage.rowptr()
        out['edge_weight'] = edge_index.storage.value()
        out['edge_attr'] = edge_index.storage.value()
        out['edge_type'] = edge_index.storage.value()

    out['index'] = out['edge_index_i']
    out['size'] = size
    out['size_i'] = size[1] or size[0]
    out['size_j'] = size[0] or size[1]
    out['dim_size'] = out['size_i']

    return out


def __set_size__(self, size: List[Optional[int]], dim: int, src: Tensor):
    the_size = size[dim]
    if the_size is None:
        size[dim] = src.size(self.node_dim)
    elif the_size != src.size(self.node_dim):
        raise ValueError(
            (f'Encountered tensor with size {src.size(self.node_dim)} in '
                f'dimension {self.node_dim}, but expected size {the_size}.'))

def __lift__(self, src, edge_index, dim):
    if isinstance(edge_index, Tensor):
        index = edge_index[dim]
        return src.index_select(self.node_dim, index)
    elif isinstance(edge_index, SparseTensor):
        if dim == 1:
            rowptr = edge_index.storage.rowptr()
            rowptr = expand_left(rowptr, dim=self.node_dim, dims=src.dim())
            return gather_csr(src, rowptr)
        elif dim == 0:
            col = edge_index.storage.col()
            return src.index_select(self.node_dim, col)
    raise ValueError

Code Example: GCN

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add') # "Add" aggregation.
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x has shape [num_nodes, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3-5: Start propagating messages.
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)

    def message(self, x_j, edge_index, size):
        # x_j has shape [num_edges, out_channels]

        # Step 3: Normalize node features.
        row, col = edge_index
        deg = degree(row, size[0], dtype=x_j.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        return norm.view(-1, 1) * x_j

    def update(self, aggr_out):
        # aggr_out has shape [num_nodes, out_channels]

        # Step 5: Return new node embeddings.
        return aggr_out