Skip to content

Commit

Permalink
Update embedding_split_host_pt2_autograd_template.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
spcyppt authored Jan 31, 2025
1 parent 54cdaf7 commit 3bad375
Showing 1 changed file with 16 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@
#include "fbgemm_gpu/config/feature_gates.h"
#include "torch/csrc/autograd/record_function_ops.h"

{%- if has_vbe_support %}
#include "fbgemm_gpu/utils/pt2_autograd_utils.h"
{%- endif %}

using Tensor = at::Tensor;

using namespace fbgemm_gpu;
Expand Down Expand Up @@ -120,6 +116,9 @@ enum SSDTensor {
const c10::SymInt /*vbe_output_size*/,
const int64_t /*info_B_num_bits*/,
const int64_t /*info_B_mask_int64*/,
const Tensor& /*vbe_B_offsets_rank_per_feature*/,
const Tensor& /*vbe_output_offsets_feature_rank*/,
const c10::SymInt /*max_B*/,
{%- endif %}
{%- if is_gwd %}
const Tensor& /*prev_iter_dev*/,
Expand Down Expand Up @@ -164,6 +163,9 @@ enum SSDTensor {
vbe_output_size,
info_B_num_bits,
info_B_mask_int64,
vbe_B_offsets_rank_per_feature,
vbe_output_offsets_feature_rank,
max_B,
{%- endif %} {# /* if vbe */ #}
{%- if is_gwd %}
prev_iter_dev_,
Expand Down Expand Up @@ -239,6 +241,8 @@ enum SSDTensor {
const Tensor& /*B_offsets*/,
const Tensor& /*vbe_row_output_offsets*/,
const Tensor& /*vbe_b_t_map*/,
const Tensor& /*vbe_B_offsets_rank_per_feature*/,
const c10::SymInt /*max_B*/,
{%- endif %}
const bool /*use_uniq_cache_locations_bwd*/,
const bool /*use_homogeneous_placements*/,
Expand Down Expand Up @@ -297,6 +301,8 @@ enum SSDTensor {
B_offsets,
vbe_row_output_offsets,
vbe_b_t_map,
vbe_B_offsets_rank_per_feature,
max_B,
{%- endif %} {# /* if vbe */ #}
{%- if not dense %}
use_uniq_cache_locations_bwd,
Expand Down Expand Up @@ -926,19 +932,6 @@ static torch::autograd::variable_list backward(
// {{ fwd_mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cuda)
weights_dev = weights_dev.flatten();
{%- endif %}
{%- if vbe %}
// TODO: remove this once vbe_metadata for cpu is implemented
// MTIA kernel uses weights_host but follows CUDA implementation,
// so grad_output is already in a correct shape and must not be reshaped
// Reshaping on weights_host here causes MTIA kernel to fail.
// As a hotfix to unblock MTIA, we add condition check dimension so that reshpaing would skip on MTIA
// CUDA and MTIA vbe_b_t_map is size of {total_B} - should be 1 dim
// CPU vbe_b_t_map is B_offsets_rank_per_feature, so shape should be {num_features, batch_offsets}
// This will be removed totally once vbe_metadata for cpu is implemented
if (weights_host.numel() > 1 && vbe_b_t_map.dim() > 1){
grad_output = reshape_vbe_output(grad_output, B_offsets, vbe_b_t_map, D_offsets);
}
{%- endif %}

{%- set grad_indice_weights_op =
"{}_embedding_codegen_grad_indice_weights{}_pt2_wrapper".format(fwd_mdesc, vdesc)
Expand Down Expand Up @@ -968,7 +961,9 @@ static torch::autograd::variable_list backward(
const Tensor& /*vbe_row_output_offsets*/,
const Tensor& /*vbe_b_t_map*/,
const int64_t /*info_B_num_bits*/,
const int64_t /*info_B_mask_int64*/
const int64_t /*info_B_mask_int64*/,
const Tensor& /*vbe_B_offsets_rank_per_feature*/,
const c10::SymInt /*max_B*/
{%- else %}
const Tensor& /*feature_requires_grad*/
{%- endif %}
Expand Down Expand Up @@ -998,7 +993,9 @@ static torch::autograd::variable_list backward(
vbe_row_output_offsets,
vbe_b_t_map,
info_B_num_bits,
info_B_mask_int64
info_B_mask_int64,
vbe_B_offsets_rank_per_feature,
max_B
{%- else %}
feature_requires_grad
{%- endif %}
Expand Down

0 comments on commit 3bad375

Please sign in to comment.