Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[release] cherry-pick from master and release for 2.2.1 #7388

Merged
merged 6 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/multigpu/graphbolt/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def parse_args():
"--gpu-cache-size",
type=int,
default=0,
help="The capacity of the GPU cache, the number of features to store.",
help="The capacity of the GPU cache in bytes.",
)
parser.add_argument(
"--dataset",
Expand Down
12 changes: 12 additions & 0 deletions examples/sampling/graphbolt/pyg/node_classification_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,12 @@ def parse_args():
help="Graph storage - feature storage - Train device: 'cpu' for CPU and RAM,"
" 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.",
)
parser.add_argument(
"--gpu-cache-size",
type=int,
default=0,
help="The capacity of the GPU cache in bytes.",
)
parser.add_argument(
"--sample-mode",
default="sample_neighbor",
Expand Down Expand Up @@ -403,6 +409,12 @@ def main():

num_classes = dataset.tasks[0].metadata["num_classes"]

if args.gpu_cache_size > 0 and args.feature_device != "cuda":
features._features[("node", None, "feat")] = gb.GPUCachedFeature(
features._features[("node", None, "feat")],
args.gpu_cache_size,
)

train_dataloader, valid_dataloader = (
create_dataloader(
graph=graph,
Expand Down
12 changes: 12 additions & 0 deletions python/dgl/graphbolt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@

import torch
from torch.torch_version import TorchVersion

if TorchVersion(torch.__version__) >= "2.3.0":
# [TODO][https://github.com/dmlc/dgl/issues/7387] Remove or refine below
# check.
# Due to https://github.com/dmlc/dgl/issues/7380, we need to check if dill
# is available before using it.
torch.utils.data.datapipes.utils.common.DILL_AVAILABLE = (
torch.utils._import_utils.dill_available()
)

# pylint: disable=wrong-import-position
from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe

Expand Down Expand Up @@ -342,6 +353,7 @@ class CSCFormatBase:
>>> print(csc_foramt_base)
... torch.tensor([1, 4, 2])
"""

indptr: torch.Tensor = None
indices: torch.Tensor = None

Expand Down
28 changes: 14 additions & 14 deletions python/dgl/graphbolt/feature_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,12 @@ def record_stream(tensor):

if self.node_feature_keys and input_nodes is not None:
if is_heterogeneous:
for type_name, feature_names in self.node_feature_keys.items():
nodes = input_nodes[type_name]
if nodes is None:
for type_name, nodes in input_nodes.items():
if type_name not in self.node_feature_keys or nodes is None:
continue
if nodes.is_cuda:
nodes.record_stream(torch.cuda.current_stream())
for feature_name in feature_names:
for feature_name in self.node_feature_keys[type_name]:
node_features[
(type_name, feature_name)
] = record_stream(
Expand Down Expand Up @@ -126,21 +125,22 @@ def record_stream(tensor):
if is_heterogeneous:
# Convert edge type to string.
original_edge_ids = {
etype_tuple_to_str(key)
if isinstance(key, tuple)
else key: value
(
etype_tuple_to_str(key)
if isinstance(key, tuple)
else key
): value
for key, value in original_edge_ids.items()
}
for (
type_name,
feature_names,
) in self.edge_feature_keys.items():
edges = original_edge_ids.get(type_name, None)
if edges is None:
for type_name, edges in original_edge_ids.items():
if (
type_name not in self.edge_feature_keys
or edges is None
):
continue
if edges.is_cuda:
edges.record_stream(torch.cuda.current_stream())
for feature_name in feature_names:
for feature_name in self.edge_feature_keys[type_name]:
edge_features[i][
(type_name, feature_name)
] = record_stream(
Expand Down
37 changes: 29 additions & 8 deletions python/dgl/graphbolt/impl/gpu_cached_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,22 @@
__all__ = ["GPUCachedFeature"]


def nbytes(tensor):
"""Returns the number of bytes to store the given tensor.

Needs to be defined only for torch versions less than 2.1. In torch >= 2.1,
we can simply use "tensor.nbytes".
"""
return tensor.numel() * tensor.element_size()


def num_cache_items(cache_capacity_in_bytes, single_item):
"""Returns the number of rows to be cached."""
item_bytes = nbytes(single_item)
# Round up so that we never get a size of 0, unless bytes is 0.
return (cache_capacity_in_bytes + item_bytes - 1) // item_bytes


class GPUCachedFeature(Feature):
r"""GPU cached feature wrapping a fallback feature.

Expand All @@ -17,8 +33,8 @@ class GPUCachedFeature(Feature):
----------
fallback_feature : Feature
The fallback feature.
cache_size : int
The capacity of the GPU cache, the number of features to store.
max_cache_size_in_bytes : int
The capacity of the GPU cache in bytes.

Examples
--------
Expand All @@ -42,16 +58,17 @@ class GPUCachedFeature(Feature):
torch.Size([5])
"""

def __init__(self, fallback_feature: Feature, cache_size: int):
def __init__(self, fallback_feature: Feature, max_cache_size_in_bytes: int):
super(GPUCachedFeature, self).__init__()
assert isinstance(fallback_feature, Feature), (
f"The fallback_feature must be an instance of Feature, but got "
f"{type(fallback_feature)}."
)
self._fallback_feature = fallback_feature
self.cache_size = cache_size
self.max_cache_size_in_bytes = max_cache_size_in_bytes
# Fetching the feature dimension from the underlying feature.
feat0 = fallback_feature.read(torch.tensor([0]))
cache_size = num_cache_items(max_cache_size_in_bytes, feat0)
self._feature = GPUCache((cache_size,) + feat0.shape[1:], feat0.dtype)

def read(self, ids: torch.Tensor = None):
Expand Down Expand Up @@ -104,11 +121,15 @@ def update(self, value: torch.Tensor, ids: torch.Tensor = None):
updated.
"""
if ids is None:
feat0 = value[:1]
self._fallback_feature.update(value)
size = min(self.cache_size, value.shape[0])
self._feature.replace(
torch.arange(0, size, device="cuda"),
value[:size].to("cuda"),
cache_size = min(
num_cache_items(self.max_cache_size_in_bytes, feat0),
value.shape[0],
)
self._feature = None # Destroy the existing cache first.
self._feature = GPUCache(
(cache_size,) + feat0.shape[1:], feat0.dtype
)
else:
self._fallback_feature.update(value, ids)
Expand Down
1 change: 1 addition & 0 deletions script/dgl_dev.yml.template
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,6 @@ dependencies:
- lintrunner
- jupyterlab
- ipywidgets
- expecttest
variables:
DGL_HOME: __DGL_HOME__
Original file line number Diff line number Diff line change
Expand Up @@ -1613,10 +1613,14 @@ def test_csc_sampling_graph_to_pinned_memory():
is_graph_pinned(graph)


@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize("is_pinned", [False, True])
@pytest.mark.parametrize("nodes", [None, True])
def test_sample_neighbors_homo(labor, is_pinned, nodes):
def test_sample_neighbors_homo(
indptr_dtype, indices_dtype, labor, is_pinned, nodes
):
if is_pinned and nodes is None:
pytest.skip("Optional nodes and is_pinned is not supported together.")
"""Original graph in COO:
Expand All @@ -1630,8 +1634,10 @@ def test_sample_neighbors_homo(labor, is_pinned, nodes):
pytest.skip("Pinning is not meaningful without a GPU.")
# Initialize data.
total_num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
indptr = torch.tensor([0, 3, 5, 7, 9, 12], dtype=indptr_dtype)
indices = torch.tensor(
[0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4], dtype=indices_dtype
)
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)

Expand All @@ -1642,7 +1648,7 @@ def test_sample_neighbors_homo(labor, is_pinned, nodes):

# Generate subgraph via sample neighbors.
if nodes:
nodes = torch.LongTensor([1, 3, 4]).to(F.ctx())
nodes = torch.tensor([1, 3, 4], dtype=indices_dtype).to(F.ctx())
elif F._default_context_str != "gpu":
pytest.skip("Optional nodes is supported only for the GPU.")
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
Expand All @@ -1662,8 +1668,10 @@ def test_sample_neighbors_homo(labor, is_pinned, nodes):
assert subgraph.original_edge_ids is None


@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_hetero(labor):
def test_sample_neighbors_hetero(indptr_dtype, indices_dtype, labor):
"""Original graph in COO:
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
Expand All @@ -1677,10 +1685,12 @@ def test_sample_neighbors_hetero(labor):
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
total_num_edges = 9
indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
node_type_offset = torch.LongTensor([0, 2, 5])
indptr = torch.tensor([0, 2, 4, 6, 7, 9], dtype=indptr_dtype)
indices = torch.tensor([2, 4, 2, 3, 0, 1, 1, 0, 1], dtype=indices_dtype)
type_per_edge = torch.tensor(
[1, 1, 1, 1, 0, 0, 0, 0, 0], dtype=indices_dtype
)
node_type_offset = torch.tensor([0, 2, 5], dtype=indices_dtype)
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)

Expand All @@ -1696,8 +1706,8 @@ def test_sample_neighbors_hetero(labor):

# Sample on both node types.
nodes = {
"n1": torch.tensor([0], device=F.ctx()),
"n2": torch.tensor([0], device=F.ctx()),
"n1": torch.tensor([0], dtype=indices_dtype, device=F.ctx()),
"n2": torch.tensor([0], dtype=indices_dtype, device=F.ctx()),
}
fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
Expand Down Expand Up @@ -1725,7 +1735,7 @@ def test_sample_neighbors_hetero(labor):
assert subgraph.original_edge_ids is None

# Sample on single node type.
nodes = {"n1": torch.tensor([0], device=F.ctx())}
nodes = {"n1": torch.tensor([0], dtype=indices_dtype, device=F.ctx())}
fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def test_gpu_cached_feature(dtype, cache_size_a, cache_size_b):
[[[1, 2], [3, 4]], [[4, 5], [6, 7]]], dtype=dtype, pin_memory=True
)

cache_size_a *= a[:1].element_size() * a[:1].numel()
cache_size_b *= b[:1].element_size() * b[:1].numel()

feat_store_a = gb.GPUCachedFeature(gb.TorchBasedFeature(a), cache_size_a)
feat_store_b = gb.GPUCachedFeature(gb.TorchBasedFeature(b), cache_size_b)

Expand Down Expand Up @@ -94,3 +97,7 @@ def test_gpu_cached_feature(dtype, cache_size_a, cache_size_b):
feat_store_a.read(),
torch.tensor([[2, 0, 1], [3, 5, 2]], dtype=dtype).to("cuda"),
)

# Test with different dimensionality
feat_store_a.update(b)
assert torch.equal(feat_store_a.read(), b.to("cuda"))
29 changes: 25 additions & 4 deletions tests/python/pytorch/graphbolt/test_feature_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,21 @@ def test_FeatureFetcher_hetero():
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler_dp = gb.NeighborSampler(item_sampler, graph, fanouts)
# "n3" is not in the sampled input nodes.
node_feature_keys = {"n1": ["a"], "n2": ["a"], "n3": ["a"]}
fetcher_dp = gb.FeatureFetcher(
sampler_dp, feature_store, {"n1": ["a"], "n2": ["a"]}
sampler_dp, feature_store, node_feature_keys=node_feature_keys
)

assert len(list(fetcher_dp)) == 3

# Do not fetch feature for "n1".
node_feature_keys = {"n2": ["a"]}
fetcher_dp = gb.FeatureFetcher(
sampler_dp, feature_store, node_feature_keys=node_feature_keys
)
for mini_batch in fetcher_dp:
assert ("n1", "a") not in mini_batch.node_features


def test_FeatureFetcher_with_edges_hetero():
a = torch.tensor([[random.randint(0, 10)] for _ in range(20)])
Expand Down Expand Up @@ -208,7 +217,11 @@ def add_node_and_edge_ids(minibatch):
return data

features = {}
keys = [("node", "n1", "a"), ("edge", "n1:e1:n2", "a")]
keys = [
("node", "n1", "a"),
("edge", "n1:e1:n2", "a"),
("edge", "n2:e2:n1", "a"),
]
features[keys[0]] = gb.TorchBasedFeature(a)
features[keys[1]] = gb.TorchBasedFeature(b)
feature_store = gb.BasicFeatureStore(features)
Expand All @@ -220,8 +233,15 @@ def add_node_and_edge_ids(minibatch):
)
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
converter_dp = Mapper(item_sampler_dp, add_node_and_edge_ids)
# "n3:e3:n3" is not in the sampled edges.
# Do not fetch feature for "n2:e2:n1".
node_feature_keys = {"n1": ["a"]}
edge_feature_keys = {"n1:e1:n2": ["a"], "n3:e3:n3": ["a"]}
fetcher_dp = gb.FeatureFetcher(
converter_dp, feature_store, {"n1": ["a"]}, {"n1:e1:n2": ["a"]}
converter_dp,
feature_store,
node_feature_keys=node_feature_keys,
edge_feature_keys=edge_feature_keys,
)

assert len(list(fetcher_dp)) == 5
Expand All @@ -230,3 +250,4 @@ def add_node_and_edge_ids(minibatch):
assert len(data.edge_features) == 3
for edge_feature in data.edge_features:
assert edge_feature[("n1:e1:n2", "a")].size(0) == 10
assert ("n2:e2:n1", "a") not in edge_feature
2 changes: 1 addition & 1 deletion third_party/cccl
Submodule cccl updated 9797 files
Loading