Skip to content

Commit

Permalink
[GraphBolt] Optimize hetero sampling on CPU (#7360)
Browse files Browse the repository at this point in the history
  • Loading branch information
RamonZhou authored Apr 28, 2024
1 parent 9090a87 commit 658b208
Show file tree
Hide file tree
Showing 3 changed files with 448 additions and 64 deletions.
42 changes: 30 additions & 12 deletions graphbolt/include/graphbolt/fused_csc_sampling_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,13 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
private:
template <typename NumPickFn, typename PickFn>
c10::intrusive_ptr<FusedSampledSubgraph> SampleNeighborsImpl(
const torch::Tensor& seeds,
torch::optional<std::vector<int64_t>>& seed_offsets,
const std::vector<int64_t>& fanouts, bool return_eids,
NumPickFn num_pick_fn, PickFn pick_fn) const;

template <typename NumPickFn, typename PickFn>
c10::intrusive_ptr<FusedSampledSubgraph> TemporalSampleNeighborsImpl(
const torch::Tensor& nodes, bool return_eids, NumPickFn num_pick_fn,
PickFn pick_fn) const;

Expand Down Expand Up @@ -498,13 +505,14 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* @param offset The starting edge ID for the connected neighbors of the given
* node.
* @param num_neighbors The number of neighbors of this node.
*
* @return The pick number of the given node.
* @param num_picked_ptr The pointer of the tensor which stores the pick
* numbers.
*/
int64_t NumPick(
template <typename PickedNumType>
void NumPick(
int64_t fanout, bool replace,
const torch::optional<torch::Tensor>& probs_or_mask, int64_t offset,
int64_t num_neighbors);
int64_t num_neighbors, PickedNumType* num_picked_ptr);

int64_t TemporalNumPick(
torch::Tensor seed_timestamp, torch::Tensor csc_indics, int64_t fanout,
Expand All @@ -513,11 +521,13 @@ int64_t TemporalNumPick(
const torch::optional<torch::Tensor>& edge_timestamp, int64_t seed_offset,
int64_t offset, int64_t num_neighbors);

int64_t NumPickByEtype(
const std::vector<int64_t>& fanouts, bool replace,
template <typename PickedNumType>
void NumPickByEtype(
bool with_seed_offsets, const std::vector<int64_t>& fanouts, bool replace,
const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask, int64_t offset,
int64_t num_neighbors);
int64_t num_neighbors, PickedNumType* num_picked_ptr, int64_t seed_index,
const std::vector<int64_t>& etype_id_to_num_picked_offset);

int64_t TemporalNumPickByEtype(
torch::Tensor seed_timestamp, torch::Tensor csc_indices,
Expand Down Expand Up @@ -610,16 +620,24 @@ int64_t TemporalPick(
* probabilities associated with each neighboring edge of a node in the original
* graph. It must be a 1D floating-point tensor with the number of elements
* equal to the number of edges in the graph.
* @param picked_data_ptr The destination address where the picked neighbors
* @param picked_data_ptr The pointer of the tensor where the picked neighbors
* should be put. Enough memory space should be allocated in advance.
* @param seed_offset The offset(index) of the seed among the group of seeds
* which share the same node type.
* @param subgraph_indptr_ptr The pointer of the tensor which stores the indptr
* of the sampled subgraph.
* @param etype_id_to_num_picked_offset A vector storing the mappings from each
* etype_id to the offset of its pick numbers in the tensor.
*/
template <SamplerType S, typename PickedType>
int64_t PickByEtype(
int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts,
bool replace, const torch::TensorOptions& options,
const torch::Tensor& type_per_edge,
bool with_seed_offsets, int64_t offset, int64_t num_neighbors,
const std::vector<int64_t>& fanouts, bool replace,
const torch::TensorOptions& options, const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
PickedType* picked_data_ptr);
PickedType* picked_data_ptr, int64_t seed_offset,
PickedType* subgraph_indptr_ptr,
const std::vector<int64_t>& etype_id_to_num_picked_offset);

template <typename PickedType>
int64_t TemporalPickByEtype(
Expand Down
Loading

0 comments on commit 658b208

Please sign in to comment.