diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index b224c3e70..164d6008b 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -40,10 +40,6 @@ #include "fbgemm_gpu/config/feature_gates.h" #include "torch/csrc/autograd/record_function_ops.h" -{%- if has_vbe_support %} -#include "fbgemm_gpu/utils/pt2_autograd_utils.h" -{%- endif %} - using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -120,6 +116,9 @@ enum SSDTensor { const c10::SymInt /*vbe_output_size*/, const int64_t /*info_B_num_bits*/, const int64_t /*info_B_mask_int64*/, + const Tensor& /*vbe_B_offsets_rank_per_feature*/, + const Tensor& /*vbe_output_offsets_feature_rank*/, + const c10::SymInt /*max_B*/, {%- endif %} {%- if is_gwd %} const Tensor& /*prev_iter_dev*/, @@ -164,6 +163,9 @@ enum SSDTensor { vbe_output_size, info_B_num_bits, info_B_mask_int64, + vbe_B_offsets_rank_per_feature, + vbe_output_offsets_feature_rank, + max_B, {%- endif %} {# /* if vbe */ #} {%- if is_gwd %} prev_iter_dev_, @@ -239,6 +241,8 @@ enum SSDTensor { const Tensor& /*B_offsets*/, const Tensor& /*vbe_row_output_offsets*/, const Tensor& /*vbe_b_t_map*/, + const Tensor& /*vbe_B_offsets_rank_per_feature*/, + const c10::SymInt /*max_B*/, {%- endif %} const bool /*use_uniq_cache_locations_bwd*/, const bool /*use_homogeneous_placements*/, @@ -297,6 +301,8 @@ enum SSDTensor { B_offsets, vbe_row_output_offsets, vbe_b_t_map, + vbe_B_offsets_rank_per_feature, + max_B, {%- endif %} {# /* if vbe */ #} {%- if not dense %} use_uniq_cache_locations_bwd, @@ -926,19 +932,6 @@ static torch::autograd::variable_list backward( // {{ fwd_mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cuda) weights_dev = weights_dev.flatten(); {%- endif %} - {%- if vbe %} - // TODO: remove this once vbe_metadata for cpu is implemented - // MTIA kernel uses weights_host but follows CUDA implementation, - // so grad_output is already in a correct shape and must not be reshaped - // Reshaping on weights_host here causes MTIA kernel to fail. - // As a hotfix to unblock MTIA, we add condition check dimension so that reshpaing would skip on MTIA - // CUDA and MTIA vbe_b_t_map is size of {total_B} - should be 1 dim - // CPU vbe_b_t_map is B_offsets_rank_per_feature, so shape should be {num_features, batch_offsets} - // This will be removed totally once vbe_metadata for cpu is implemented - if (weights_host.numel() > 1 && vbe_b_t_map.dim() > 1){ - grad_output = reshape_vbe_output(grad_output, B_offsets, vbe_b_t_map, D_offsets); - } - {%- endif %} {%- set grad_indice_weights_op = "{}_embedding_codegen_grad_indice_weights{}_pt2_wrapper".format(fwd_mdesc, vdesc) @@ -968,7 +961,9 @@ static torch::autograd::variable_list backward( const Tensor& /*vbe_row_output_offsets*/, const Tensor& /*vbe_b_t_map*/, const int64_t /*info_B_num_bits*/, - const int64_t /*info_B_mask_int64*/ + const int64_t /*info_B_mask_int64*/, + const Tensor& /*vbe_B_offsets_rank_per_feature*/, + const c10::SymInt /*max_B*/ {%- else %} const Tensor& /*feature_requires_grad*/ {%- endif %} @@ -998,7 +993,9 @@ static torch::autograd::variable_list backward( vbe_row_output_offsets, vbe_b_t_map, info_B_num_bits, - info_B_mask_int64 + info_B_mask_int64, + vbe_B_offsets_rank_per_feature, + max_B {%- else %} feature_requires_grad {%- endif %}