diff --git a/torchqrnn/forget_mult.py b/torchqrnn/forget_mult.py index 6967850..9c07119 100644 --- a/torchqrnn/forget_mult.py +++ b/torchqrnn/forget_mult.py @@ -99,7 +99,7 @@ def __init__(self): def compile(self): if self.ptx is None: - program = Program(kernel.encode(), 'recurrent_forget_mult.cu'.encode()) + program = Program(kernel, 'recurrent_forget_mult.cu') GPUForgetMult.ptx = program.compile() if torch.cuda.current_device() not in GPUForgetMult.configured_gpus: