Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
fix(encoder): fix unused variable
Browse files Browse the repository at this point in the history
  • Loading branch information
raccoonliukai committed Aug 13, 2019
1 parent 732f2e6 commit 5fedf6d
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions gnes/encoder/text/torch_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def encode(self, text: List[str], *args, **kwargs) -> np.ndarray:
batch_data = np.zeros([batch_size, max_len], dtype=np.int64)
# batch_mask = np.zeros([batch_size, max_len], dtype=np.float32)
for i, ids in enumerate(tokens_ids):
batch_data[i, :tokens_lens[i]] = tokens_ids[i]
batch_data[i, :tokens_lens[i]] = ids
# batch_mask[i, :tokens_lens[i]] = 1

# Convert inputs to PyTorch tensors
Expand All @@ -85,8 +85,7 @@ def encode(self, text: List[str], *args, **kwargs) -> np.ndarray:
with torch.no_grad():
out_tensor = self.model(tokens_tensor)[0]
out_tensor = torch.mul(out_tensor, mask_tensor.unsqueeze(2))

if self.use_cuda:
output_tensor = output_tensor.cpu()
if self.use_cuda:
out_tensor = out_tensor.cpu()

return out_tensor.numpy()

0 comments on commit 5fedf6d

Please sign in to comment.