Mini-batch Sampling
Real world graphs can be very large with millions or even billions of nodes and edges. But the naive full-batch implementation of GNN cannot be feasible to these large-scale graphs.
Two frequently used methods are summarized here:
Other samplers in PyG
HGTLoader
GraphSAINTLoader
Overall, all heterogeneous graph loaders will produce a HeteroData object as output, holding a subset of the original data, and mainly differ in the way their sampling procedures works
Neighbor Sampling with Different Ratios
In PyG 2.0, NeighborLoader
allows for mini-batch training of GNNs on large-scale graphs
where full-batch training is not feasible
1
2
3
4
5
6
7
8
9
10
11
12
| #import torch_geometric.transforms as T
from torch_geometric.loader import NeighborLoader
train_loader = NeighborLoader(
data,
# Sample 15 neighbors for each node and each edge type for 2 iterations:
num_neighbors= {key: [15] * 2 for key in data.edge_types}[15] * 2 # heterograph
# Use a batch size of 128 for sampling training nodes of type "paper":
batch_size=128,
input_nodes='paper', data['paper'].train_mask,
)
batch = next(iter(train_loader))
|
Please see the full deepsnap’s docs here
The content blew is almost the same as in colab notebooks
. It’s just for easy and quick viewing in any devices.
1. Neighbor Sampling
sampling code using networkX as backend
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
| def sample_neighbors(nodes, G, ratio, all_nodes):
# This fuction takes a set of nodes, a NetworkX graph G and neighbor sampling ratio.
# It will return sampled neighbors (unioned with input nodes) and edges between
neighbors = set()
edges = []
for node in nodes:
neighbors_list = list(nx.neighbors(G, node))
# We only sample the (ratio * number of neighbors) neighbors
num = int(len(neighbors_list) * ratio)
if num > 0:
# Random shuffle the neighbors
random.shuffle(neighbors_list)
neighbors_list = neighbors_list[:num]
for neighbor in neighbors_list:
# Add neighbors
neighbors.add(neighbor)
edges.append((neighbor, node))
return neighbors, neighbors.union(all_nodes), edges
def nodes_to_tensor(nodes):
# This function transform a set of nodes to node index tensor
node_label_index = torch.tensor(list(nodes), dtype=torch.long)
return node_label_index
def edges_to_tensor(edges):
# This function transform a set of edges to edge index tensor
edge_index = torch.tensor(list(edges), dtype=torch.long)
edge_index = torch.cat([edge_index, torch.flip(edge_index, [1])], dim=0)
edge_index = edge_index.permute(1, 0)
return edge_index
def relable(nodes, labeled_nodes, edges_list):
# Relable the nodes, labeled_nodes and edges_list
relabled_edges_list = []
sorted_nodes = sorted(nodes)
node_mapping = {node : i for i, node in enumerate(sorted_nodes)}
for orig_edges in edges_list:
relabeled_edges = []
for edge in orig_edges:
relabeled_edges.append((node_mapping[edge[0]], node_mapping[edge[1]]))
relabled_edges_list.append(relabeled_edges)
relabeled_labeled_nodes = [node_mapping[node] for node in labeled_nodes]
relabeled_nodes = [node_mapping[node] for node in nodes]
return relabled_edges_list, relabeled_nodes, relabeled_labeled_nodes
def neighbor_sampling(graph, K=2, ratios=(0.1, 0.1, 0.1)):
# This function takes a DeepSNAP graph, K the number of GNN layers, and neighbor
# sampling ratios for each layer. This function returns relabeled node feature,
# edge indices and node_label_index
assert K + 1 == len(ratios)
labeled_nodes = graph.node_label_index.tolist()
random.shuffle(labeled_nodes)
num = int(len(labeled_nodes) * ratios[-1])
if num > 0:
labeled_nodes = labeled_nodes[:num]
nodes_list = [set(labeled_nodes)]
edges_list = []
all_nodes = labeled_nodes
for k in range(K):
# Get nodes and edges from the previous layer
nodes, all_nodes, edges = \
sample_neighbors(nodes_list[-1], graph.G, ratios[len(ratios) - k - 2], all_nodes)
nodes_list.append(nodes)
edges_list.append(edges)
# Reverse the lists
nodes_list.reverse()
edges_list.reverse()
relabled_edges_list, relabeled_all_nodes, relabeled_labeled_nodes = \
relable(all_nodes, labeled_nodes, edges_list)
node_index = nodes_to_tensor(relabeled_all_nodes)
# All node features that will be used
node_feature = graph.node_feature[node_index]
edge_indices = [edges_to_tensor(edges) for edges in relabled_edges_list]
node_label_index = nodes_to_tensor(relabeled_labeled_nodes)
log = "Sampled {} nodes, {} edges, {} labeled nodes"
print(log.format(node_feature.shape[0], edge_indices[0].shape[1] // 2, node_label_index.shape[0]))
return node_feature, edge_indices, node_label_index
|
Sampling with Clusters
Instead of the Neighbor Sampling, we can use another approach, subgraph (cluster) sampling, to scale up GNNs. This approach is proposed in Cluster-GCN (Chiang et al. (2019)).
see PyG’s torch_geometric.loader.ClusterLoader
1. Partition the Graph into Clusters
Three community detection / partition algorithms to partition the graph into different clusters:
To make the training more stable, we discard the cluster that has less than 10 nodes.
use networkx
as backend
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
| # the package name on pip is python-louvain but it is imported as community in python
# pip install python-louvain
import community as community_louvain
def preprocess(G, node_label_index, method="louvain"):
graphs = []
labeled_nodes = set(node_label_index.tolist())
if method == "louvain":
community_mapping = community_louvain.best_partition(G, resolution=10)
communities = {}
for node in community_mapping:
comm = community_mapping[node]
if comm in communities:
communities[comm].add(node)
else:
communities[comm] = set([node])
communities = communities.values()
elif method == "bisection":
communities = nx.algorithms.community.kernighan_lin_bisection(G)
elif method == "greedy":
communities = nx.algorithms.community.greedy_modularity_communities(G)
for community in communities:
nodes = set(community)
subgraph = G.subgraph(nodes)
# Make sure each subgraph has more than 10 nodes
if subgraph.number_of_nodes() > 10:
node_mapping = {node : i for i, node in enumerate(subgraph.nodes())}
subgraph = nx.relabel_nodes(subgraph, node_mapping)
# Get the id of the training set labeled node in the new graph
train_label_index = []
for node in labeled_nodes:
if node in node_mapping:
# Append relabeled labeled node index
train_label_index.append(node_mapping[node])
# Make sure the subgraph contains at least one training set labeled node
if len(train_label_index) > 0:
dg = Graph(subgraph)
# Update node_label_index
dg.node_label_index = torch.tensor(train_label_index, dtype=torch.long)
graphs.append(dg)
return graphs
|