From 8b6f3c2ae46e8af2a9d7a68577aa9bfbdb11a065 Mon Sep 17 00:00:00 2001 From: Rhett Ying <85214957+Rhett-Ying@users.noreply.github.com> Date: Thu, 9 May 2024 16:28:10 +0800 Subject: [PATCH] [graphbolt] skip non-existent types in input_nodes (#7386) --- python/dgl/graphbolt/feature_fetcher.py | 28 +++++++++--------- .../pytorch/graphbolt/test_feature_fetcher.py | 29 ++++++++++++++++--- 2 files changed, 39 insertions(+), 18 deletions(-) diff --git a/python/dgl/graphbolt/feature_fetcher.py b/python/dgl/graphbolt/feature_fetcher.py index 01ff25af8c15..dc41e3883890 100644 --- a/python/dgl/graphbolt/feature_fetcher.py +++ b/python/dgl/graphbolt/feature_fetcher.py @@ -88,13 +88,12 @@ def record_stream(tensor): if self.node_feature_keys and input_nodes is not None: if is_heterogeneous: - for type_name, feature_names in self.node_feature_keys.items(): - nodes = input_nodes[type_name] - if nodes is None: + for type_name, nodes in input_nodes.items(): + if type_name not in self.node_feature_keys or nodes is None: continue if nodes.is_cuda: nodes.record_stream(torch.cuda.current_stream()) - for feature_name in feature_names: + for feature_name in self.node_feature_keys[type_name]: node_features[ (type_name, feature_name) ] = record_stream( @@ -126,21 +125,22 @@ def record_stream(tensor): if is_heterogeneous: # Convert edge type to string. original_edge_ids = { - etype_tuple_to_str(key) - if isinstance(key, tuple) - else key: value + ( + etype_tuple_to_str(key) + if isinstance(key, tuple) + else key + ): value for key, value in original_edge_ids.items() } - for ( - type_name, - feature_names, - ) in self.edge_feature_keys.items(): - edges = original_edge_ids.get(type_name, None) - if edges is None: + for type_name, edges in original_edge_ids.items(): + if ( + type_name not in self.edge_feature_keys + or edges is None + ): continue if edges.is_cuda: edges.record_stream(torch.cuda.current_stream()) - for feature_name in feature_names: + for feature_name in self.edge_feature_keys[type_name]: edge_features[i][ (type_name, feature_name) ] = record_stream( diff --git a/tests/python/pytorch/graphbolt/test_feature_fetcher.py b/tests/python/pytorch/graphbolt/test_feature_fetcher.py index b1944f06bc44..552a0bf5b055 100644 --- a/tests/python/pytorch/graphbolt/test_feature_fetcher.py +++ b/tests/python/pytorch/graphbolt/test_feature_fetcher.py @@ -160,12 +160,21 @@ def test_FeatureFetcher_hetero(): num_layer = 2 fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] sampler_dp = gb.NeighborSampler(item_sampler, graph, fanouts) + # "n3" is not in the sampled input nodes. + node_feature_keys = {"n1": ["a"], "n2": ["a"], "n3": ["a"]} fetcher_dp = gb.FeatureFetcher( - sampler_dp, feature_store, {"n1": ["a"], "n2": ["a"]} + sampler_dp, feature_store, node_feature_keys=node_feature_keys ) - assert len(list(fetcher_dp)) == 3 + # Do not fetch feature for "n1". + node_feature_keys = {"n2": ["a"]} + fetcher_dp = gb.FeatureFetcher( + sampler_dp, feature_store, node_feature_keys=node_feature_keys + ) + for mini_batch in fetcher_dp: + assert ("n1", "a") not in mini_batch.node_features + def test_FeatureFetcher_with_edges_hetero(): a = torch.tensor([[random.randint(0, 10)] for _ in range(20)]) @@ -208,7 +217,11 @@ def add_node_and_edge_ids(minibatch): return data features = {} - keys = [("node", "n1", "a"), ("edge", "n1:e1:n2", "a")] + keys = [ + ("node", "n1", "a"), + ("edge", "n1:e1:n2", "a"), + ("edge", "n2:e2:n1", "a"), + ] features[keys[0]] = gb.TorchBasedFeature(a) features[keys[1]] = gb.TorchBasedFeature(b) feature_store = gb.BasicFeatureStore(features) @@ -220,8 +233,15 @@ def add_node_and_edge_ids(minibatch): ) item_sampler_dp = gb.ItemSampler(itemset, batch_size=2) converter_dp = Mapper(item_sampler_dp, add_node_and_edge_ids) + # "n3:e3:n3" is not in the sampled edges. + # Do not fetch feature for "n2:e2:n1". + node_feature_keys = {"n1": ["a"]} + edge_feature_keys = {"n1:e1:n2": ["a"], "n3:e3:n3": ["a"]} fetcher_dp = gb.FeatureFetcher( - converter_dp, feature_store, {"n1": ["a"]}, {"n1:e1:n2": ["a"]} + converter_dp, + feature_store, + node_feature_keys=node_feature_keys, + edge_feature_keys=edge_feature_keys, ) assert len(list(fetcher_dp)) == 5 @@ -230,3 +250,4 @@ def add_node_and_edge_ids(minibatch): assert len(data.edge_features) == 3 for edge_feature in data.edge_features: assert edge_feature[("n1:e1:n2", "a")].size(0) == 10 + assert ("n2:e2:n1", "a") not in edge_feature