Skip to content

Commit

Permalink
[GraphBolt] Hyperlink support in subgraph_sampler. (#7354)
Browse files Browse the repository at this point in the history
Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
yxy235 and Ubuntu authored Apr 28, 2024
1 parent ce37a93 commit 9090a87
Show file tree
Hide file tree
Showing 5 changed files with 576 additions and 15 deletions.
32 changes: 32 additions & 0 deletions python/dgl/graphbolt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"expand_indptr",
"CSCFormatBase",
"seed",
"seed_type_str_to_ntypes",
]

CANONICAL_ETYPE_DELIMITER = ":"
Expand Down Expand Up @@ -185,6 +186,37 @@ def etype_str_to_tuple(c_etype):
return ret


def seed_type_str_to_ntypes(seed_type, seed_size):
"""Convert seeds type to node types from string to list.
Examples
--------
1. node pairs
>>> seed_type = "user:like:item"
>>> seed_size = 2
>>> node_type = seed_type_str_to_ntypes(seed_type, seed_size)
>>> print(node_type)
["user", "item"]
2. hyperlink
>>> seed_type = "query:user:item"
>>> seed_size = 3
>>> node_type = seed_type_str_to_ntypes(seed_type, seed_size)
>>> print(node_type)
["query", "user", "item"]
"""
assert isinstance(
seed_type, str
), f"Passed-in seed type should be string, but got {type(seed_type)}"
ntypes = seed_type.split(CANONICAL_ETYPE_DELIMITER)
is_hyperlink = len(ntypes) == seed_size
if not is_hyperlink:
ntypes = ntypes[::2]
return ntypes


def apply_to(x, device):
"""Apply `to` function to object x only if it has `to`."""

Expand Down
40 changes: 40 additions & 0 deletions python/dgl/graphbolt/itemset.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,21 @@ class requires each input itemset to be iterable.
tensor([1, 1, 0, 0, 0]))
>>> item_set.names
('seeds', 'labels')
6. Tuple of iterables with different shape: hyperlink and labels.
>>> seeds = torch.arange(0, 10).reshape(-1, 5)
>>> labels = torch.tensor([1, 0])
>>> item_set = gb.ItemSet(
... (seeds, labels), names=("seeds", "lables"))
>>> list(item_set)
[(tensor([0, 1, 2, 3, 4]), tensor([1])),
(tensor([5, 6, 7, 8, 9]), tensor([0]))]
>>> item_set[:]
(tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]),
tensor([1, 0]))
>>> item_set.names
('seeds', 'labels')
"""

def __init__(
Expand Down Expand Up @@ -315,6 +330,31 @@ class ItemSetDict:
tensor([1, 1, 0]))}
>>> item_set.names
('seeds', 'labels')
4. Tuple of iterables with different shape: hyperlink and labels.
>>> first_seeds = torch.arange(0, 6).reshape(-1, 3)
>>> first_labels = torch.tensor([1, 0])
>>> second_seeds = torch.arange(0, 2).reshape(-1, 1)
>>> second_labels = torch.tensor([1, 0])
>>> item_set = gb.ItemSetDict({
... "query:user:item": gb.ItemSet(
... (first_seeds, first_labels),
... names=("seeds", "labels")),
... "user": gb.ItemSet(
... (second_seeds, second_labels),
... names=("seeds", "labels"))})
>>> list(item_set)
[{'query:user:item': (tensor([0, 1, 2]), tensor(1))},
{'query:user:item': (tensor([3, 4, 5]), tensor(0))},
{'user': (tensor([0]), tensor(1))},
{'user': (tensor([1]), tensor(0))}]
>>> item_set[:]
{'query:user:item': (tensor([[0, 1, 2], [3, 4, 5]]),
tensor([1, 0])),
'user': (tensor([[0], [1]]),tensor([1, 0]))}
>>> item_set.names
('seeds', 'labels')
"""

def __init__(self, itemsets: Dict[str, ItemSet]) -> None:
Expand Down
40 changes: 25 additions & 15 deletions python/dgl/graphbolt/subgraph_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from torch.utils.data import functional_datapipe

from .base import etype_str_to_tuple
from .base import seed_type_str_to_ntypes
from .internal import compact_temporal_nodes, unique_and_compact
from .minibatch_transformer import MiniBatchTransformer

Expand Down Expand Up @@ -93,7 +93,8 @@ def _seeds_preprocess(minibatch):
"""Preprocess `seeds` in a minibatch to construct `unique_seeds`,
`node_timestamp` and `compacted_seeds` for further sampling. It
optionally incorporates timestamps for temporal graphs, organizing and
compacting seeds based on their types and timestamps.
compacting seeds based on their types and timestamps. In heterogeneous
graph, `seeds` with same node type will be unqiued together.
Parameters
----------
Expand Down Expand Up @@ -121,7 +122,7 @@ def _seeds_preprocess(minibatch):
nodes_timestamp = None
if use_timestamp:
nodes_timestamp = defaultdict(list)
for etype, typed_seeds in seeds.items():
for seed_type, typed_seeds in seeds.items():
# When typed_seeds is a one-dimensional tensor, it represents
# seed nodes, which does not need to do unique and compact.
if typed_seeds.ndim == 1:
Expand All @@ -131,25 +132,27 @@ def _seeds_preprocess(minibatch):
else None
)
return seeds, nodes_timestamp, None
assert typed_seeds.ndim == 2 and typed_seeds.shape[1] == 2, (
"Only tensor with shape 1*N and N*2 is "
assert typed_seeds.ndim == 2, (
"Only tensor with shape 1*N and N*M is "
+ f"supported now, but got {typed_seeds.shape}."
)
ntypes = etype[:].split(":")[::2]
ntypes = seed_type_str_to_ntypes(
seed_type, typed_seeds.shape[1]
)
if use_timestamp:
negative_ratio = (
typed_seeds.shape[0]
// minibatch.timestamp[etype].shape[0]
// minibatch.timestamp[seed_type].shape[0]
- 1
)
neg_timestamp = minibatch.timestamp[
etype
seed_type
].repeat_interleave(negative_ratio)
for i, ntype in enumerate(ntypes):
nodes[ntype].append(typed_seeds[:, i])
if use_timestamp:
nodes_timestamp[ntype].append(
minibatch.timestamp[etype]
minibatch.timestamp[seed_type]
)
nodes_timestamp[ntype].append(neg_timestamp)
# Unique and compact the collected nodes.
Expand All @@ -164,11 +167,16 @@ def _seeds_preprocess(minibatch):
nodes_timestamp = None
compacted_seeds = {}
# Map back in same order as collect.
for etype, typed_seeds in seeds.items():
src_type, _, dst_type = etype_str_to_tuple(etype)
src = compacted[src_type].pop(0)
dst = compacted[dst_type].pop(0)
compacted_seeds[etype] = torch.cat((src, dst)).view(2, -1).T
for seed_type, typed_seeds in seeds.items():
ntypes = seed_type_str_to_ntypes(
seed_type, typed_seeds.shape[1]
)
compacted_seed = []
for ntype in ntypes:
compacted_seed.append(compacted[ntype].pop(0))
compacted_seeds[seed_type] = (
torch.cat(compacted_seed).view(len(ntypes), -1).T
)
else:
# When seeds is a one-dimensional tensor, it represents seed nodes,
# which does not need to do unique and compact.
Expand All @@ -193,7 +201,9 @@ def _seeds_preprocess(minibatch):
seeds_timestamp = torch.cat(
(minibatch.timestamp, neg_timestamp)
)
nodes_timestamp = [seeds_timestamp for _ in range(seeds.ndim)]
nodes_timestamp = [
seeds_timestamp for _ in range(seeds.shape[1])
]
# Unique and compact the collected nodes.
if use_timestamp:
(
Expand Down
25 changes: 25 additions & 0 deletions tests/python/pytorch/graphbolt/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,31 @@ def test_etype_str_to_tuple():
_ = gb.etype_str_to_tuple(c_etype_str)


def test_seed_type_str_to_ntypes():
"""Convert etype from string to tuple."""
# Test for node pairs.
seed_type_str = "user:like:item"
seed_size = 2
node_type = gb.seed_type_str_to_ntypes(seed_type_str, seed_size)
assert node_type == ["user", "item"]

# Test for node pairs.
seed_type_str = "user:item:user"
seed_size = 3
node_type = gb.seed_type_str_to_ntypes(seed_type_str, seed_size)
assert node_type == ["user", "item", "user"]

# Test for unexpected input: list.
seed_type_str = ["user", "item"]
with pytest.raises(
AssertionError,
match=re.escape(
"Passed-in seed type should be string, but got <class 'list'>"
),
):
_ = gb.seed_type_str_to_ntypes(seed_type_str, 2)


def test_isin():
elements = torch.tensor([2, 3, 5, 5, 20, 13, 11], device=F.ctx())
test_elements = torch.tensor([2, 5], device=F.ctx())
Expand Down
Loading

0 comments on commit 9090a87

Please sign in to comment.