From 5fedf6dffccf881345986142fc3afc586ae38fec Mon Sep 17 00:00:00 2001 From: raccoonliukai <903896015@qq.com> Date: Tue, 13 Aug 2019 15:47:48 +0800 Subject: [PATCH] fix(encoder): fix unused variable --- gnes/encoder/text/torch_transformers.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/gnes/encoder/text/torch_transformers.py b/gnes/encoder/text/torch_transformers.py index b83151ab..dc88c6e1 100644 --- a/gnes/encoder/text/torch_transformers.py +++ b/gnes/encoder/text/torch_transformers.py @@ -67,7 +67,7 @@ def encode(self, text: List[str], *args, **kwargs) -> np.ndarray: batch_data = np.zeros([batch_size, max_len], dtype=np.int64) # batch_mask = np.zeros([batch_size, max_len], dtype=np.float32) for i, ids in enumerate(tokens_ids): - batch_data[i, :tokens_lens[i]] = tokens_ids[i] + batch_data[i, :tokens_lens[i]] = ids # batch_mask[i, :tokens_lens[i]] = 1 # Convert inputs to PyTorch tensors @@ -85,8 +85,7 @@ def encode(self, text: List[str], *args, **kwargs) -> np.ndarray: with torch.no_grad(): out_tensor = self.model(tokens_tensor)[0] out_tensor = torch.mul(out_tensor, mask_tensor.unsqueeze(2)) - - if self.use_cuda: - output_tensor = output_tensor.cpu() + if self.use_cuda: + out_tensor = out_tensor.cpu() return out_tensor.numpy()