Mini-batch Sampling
Real world graphs can be very large with millions or even billions of nodes and edges. But the naive full-batch implementation of GNN cannot be feasible to these large-scale graphs.
Two frequently used methods are summarized here:
Other samplers in PyG
HGTLoaderGraphSAINTLoader
Overall, all heterogeneous graph loaders will produce a HeteroData object as output, holding a subset of the original data, and mainly differ in the way their sampling procedures works
Neighbor Sampling with Different Ratios
In PyG 2.0, NeighborLoader allows for mini-batch training of GNNs on large-scale graphs
where full-batch training is not feasible
1
2
3
4
5
6
7
8
9
10
11
12
| #import torch_geometric.transforms as T
from torch_geometric.loader import NeighborLoader
train_loader = NeighborLoader(
data,
# Sample 15 neighbors for each node and each edge type for 2 iterations:
num_neighbors= {key: [15] * 2 for key in data.edge_types}[15] * 2 # heterograph
# Use a batch size of 128 for sampling training nodes of type "paper":
batch_size=128,
input_nodes='paper', data['paper'].train_mask,
)
batch = next(iter(train_loader))
|
Please see the full deepsnap’s docs here
The content blew is almost the same as in colab notebooks. It’s just for easy and quick viewing in any devices.
1. Neighbor Sampling
sampling code using networkX as backend
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
| def sample_neighbors(nodes, G, ratio, all_nodes):
# This fuction takes a set of nodes, a NetworkX graph G and neighbor sampling ratio.
# It will return sampled neighbors (unioned with input nodes) and edges between
neighbors = set()
edges = []
for node in nodes:
neighbors_list = list(nx.neighbors(G, node))
# We only sample the (ratio * number of neighbors) neighbors
num = int(len(neighbors_list) * ratio)
if num > 0:
# Random shuffle the neighbors
random.shuffle(neighbors_list)
neighbors_list = neighbors_list[:num]
for neighbor in neighbors_list:
# Add neighbors
neighbors.add(neighbor)
edges.append((neighbor, node))
return neighbors, neighbors.union(all_nodes), edges
def nodes_to_tensor(nodes):
# This function transform a set of nodes to node index tensor
node_label_index = torch.tensor(list(nodes), dtype=torch.long)
return node_label_index
def edges_to_tensor(edges):
# This function transform a set of edges to edge index tensor
edge_index = torch.tensor(list(edges), dtype=torch.long)
edge_index = torch.cat([edge_index, torch.flip(edge_index, [1])], dim=0)
edge_index = edge_index.permute(1, 0)
return edge_index
def relable(nodes, labeled_nodes, edges_list):
# Relable the nodes, labeled_nodes and edges_list
relabled_edges_list = []
sorted_nodes = sorted(nodes)
node_mapping = {node : i for i, node in enumerate(sorted_nodes)}
for orig_edges in edges_list:
relabeled_edges = []
for edge in orig_edges:
relabeled_edges.append((node_mapping[edge[0]], node_mapping[edge[1]]))
relabled_edges_list.append(relabeled_edges)
relabeled_labeled_nodes = [node_mapping[node] for node in labeled_nodes]
relabeled_nodes = [node_mapping[node] for node in nodes]
return relabled_edges_list, relabeled_nodes, relabeled_labeled_nodes
def neighbor_sampling(graph, K=2, ratios=(0.1, 0.1, 0.1)):
# This function takes a DeepSNAP graph, K the number of GNN layers, and neighbor
# sampling ratios for each layer. This function returns relabeled node feature,
# edge indices and node_label_index
assert K + 1 == len(ratios)
labeled_nodes = graph.node_label_index.tolist()
random.shuffle(labeled_nodes)
num = int(len(labeled_nodes) * ratios[-1])
if num > 0:
labeled_nodes = labeled_nodes[:num]
nodes_list = [set(labeled_nodes)]
edges_list = []
all_nodes = labeled_nodes
for k in range(K):
# Get nodes and edges from the previous layer
nodes, all_nodes, edges = \
sample_neighbors(nodes_list[-1], graph.G, ratios[len(ratios) - k - 2], all_nodes)
nodes_list.append(nodes)
edges_list.append(edges)
# Reverse the lists
nodes_list.reverse()
edges_list.reverse()
relabled_edges_list, relabeled_all_nodes, relabeled_labeled_nodes = \
relable(all_nodes, labeled_nodes, edges_list)
node_index = nodes_to_tensor(relabeled_all_nodes)
# All node features that will be used
node_feature = graph.node_feature[node_index]
edge_indices = [edges_to_tensor(edges) for edges in relabled_edges_list]
node_label_index = nodes_to_tensor(relabeled_labeled_nodes)
log = "Sampled {} nodes, {} edges, {} labeled nodes"
print(log.format(node_feature.shape[0], edge_indices[0].shape[1] // 2, node_label_index.shape[0]))
return node_feature, edge_indices, node_label_index
|
Sampling with Clusters
Instead of the Neighbor Sampling, we can use another approach, subgraph (cluster) sampling, to scale up GNNs. This approach is proposed in Cluster-GCN (Chiang et al. (2019)).
see PyG’s torch_geometric.loader.ClusterLoader
1. Partition the Graph into Clusters
Three community detection / partition algorithms to partition the graph into different clusters:
To make the training more stable, we discard the cluster that has less than 10 nodes.
use networkx as backend
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
| # the package name on pip is python-louvain but it is imported as community in python
# pip install python-louvain
import community as community_louvain
def preprocess(G, node_label_index, method="louvain"):
graphs = []
labeled_nodes = set(node_label_index.tolist())
if method == "louvain":
community_mapping = community_louvain.best_partition(G, resolution=10)
communities = {}
for node in community_mapping:
comm = community_mapping[node]
if comm in communities:
communities[comm].add(node)
else:
communities[comm] = set([node])
communities = communities.values()
elif method == "bisection":
communities = nx.algorithms.community.kernighan_lin_bisection(G)
elif method == "greedy":
communities = nx.algorithms.community.greedy_modularity_communities(G)
for community in communities:
nodes = set(community)
subgraph = G.subgraph(nodes)
# Make sure each subgraph has more than 10 nodes
if subgraph.number_of_nodes() > 10:
node_mapping = {node : i for i, node in enumerate(subgraph.nodes())}
subgraph = nx.relabel_nodes(subgraph, node_mapping)
# Get the id of the training set labeled node in the new graph
train_label_index = []
for node in labeled_nodes:
if node in node_mapping:
# Append relabeled labeled node index
train_label_index.append(node_mapping[node])
# Make sure the subgraph contains at least one training set labeled node
if len(train_label_index) > 0:
dg = Graph(subgraph)
# Update node_label_index
dg.node_label_index = torch.tensor(train_label_index, dtype=torch.long)
graphs.append(dg)
return graphs
|