-
Notifications
You must be signed in to change notification settings - Fork 1.6k
This issue was moved to a discussion.
You can continue the conversation there. Go to discussion →
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
请问chatglm得generate方法是否支持embedding输入? #18
Comments
请问,你有任何新的想法吗?我在源文件中找到了‘PrefixEncoder’的类,似乎被用在了P-TuningV2里
使用在:
在这一过程中,官方是将这里的emmbedding作为past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None |
不确定,将跟算法同学进行讨论 |
I'm trying to use GCG with ChatGLM3. After I read the code carefully, I think generate() actually supports inputs_embeds, which may solve the issue.
the parameter So in fact, to use Not sure if my understanding is correct? And I find, when run Not sure about my understanding, thanks a lot in advance for your support! |
Following the code below does pass embedding as an input, but when using model.generate(), it will prompt an error:"You passed inputs = tokenizer(MutilTalk_Prompt,padding = 'max_length',max_length = 99)
tensor_input_ids = torch.tensor(inputs['input_ids']+[2])
tensor_input_ids = tensor_input_ids.cuda()
print(tensor_input_ids)
input_embeds = model.transformer.embedding(tensor_input_ids.unsqueeze(0))
outputs = model(input_ids=tensor_input_ids.unsqueeze(0),inputs_embeds=input_embeds)
logits_output = tokenizer.batch_decode(torch.argmax(outputs['logits'], -1).detach().cpu().numpy(), skip_special_tokens=True)
print(logits_output)
#error
outputs = model.generate(input_ids=tensor_input_ids.unsqueeze(0),inputs_embeds=input_embeds)
logits_output = tokenizer.batch_decode(torch.argmax(outputs['logits'], -1).detach().cpu().numpy(), skip_special_tokens=True)
print(logits_output) |
Oh, I get what u mean now, actually I do not use |
This issue was moved to a discussion.
You can continue the conversation there. Go to discussion →
我没看到具体generate方法代码,就先用prepare_inputs_for_generation分析。
如上图,llama的prepare_inputs_for_generation可以支持embedding输入,但是chatglm没有。
请问chatglm的generate方法是否不支持embedding输入?
如果理解错误,还望见谅。
@xunkai55 @davidlvxin @duzx16
The text was updated successfully, but these errors were encountered: