diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py index 401032b577..c93814bade 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py @@ -599,6 +599,7 @@ def forward( offsets, self.bounds_check_mode_int, self.bounds_check_warning, + per_sample_weights, ) self.step += 1 if len(self.timesteps_prefetched) == 0: @@ -1972,6 +1973,7 @@ def forward( offsets, self.bounds_check_mode_int, self.bounds_check_warning, + per_sample_weights, ) # Note: CPU and CUDA ops use the same interface to facilitate JIT IR # generation for CUDA/CPU. For CPU op, we don't need weights_uvm and