Skip to content

Commit

Permalink
Multi-dim prompt tuning mask padding
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Jul 6, 2024
1 parent c8c5c12 commit 4bdc98b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 13 deletions.
2 changes: 2 additions & 0 deletions src/adapters/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ def forward(
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
attention_mask = prefix_attention_mask(attention_mask, [2, 3]) # type: ignore

if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
logger.warning_once(
Expand Down
29 changes: 16 additions & 13 deletions src/adapters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ def get_adapter_info(adapter_id: str, source: str = "ah") -> Optional[AdapterInf
raise ValueError("Please specify either 'ah' or 'hf' as source.")


def prefix_attention_mask(attention_mask, dim: int = 3, prefix_value: int = 0):
def prefix_attention_mask(attention_mask, dim: Union[int, List[int]] = 3, prefix_value: int = 0):
"""
Adds a prefix to an attention mask. The length of the prefix is determined by the `prefix_attention_mask_length`
attribute in the ForwardContext.
Expand All @@ -890,18 +890,21 @@ def prefix_attention_mask(attention_mask, dim: int = 3, prefix_value: int = 0):
and forward_context is not None
and getattr(forward_context, "prompt_tokens_length", None) is not None
):
# Create a tensor of ones with the desired shape
ones_shape = list(attention_mask.shape)
ones_shape[dim] = forward_context.prompt_tokens_length

prefix_attention_mask = torch.full(
ones_shape,
prefix_value,
dtype=attention_mask.dtype,
).to(attention_mask.device)

# Concatenate the prefix_attention_mask along the specified dimension
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=dim)
if isinstance(dim, int):
dim = [dim]
for d in dim:
# Create a tensor of ones with the desired shape
ones_shape = list(attention_mask.shape)
ones_shape[d] = forward_context.prompt_tokens_length

prefix_attention_mask = torch.full(
ones_shape,
prefix_value,
dtype=attention_mask.dtype,
).to(attention_mask.device)

# Concatenate the prefix_attention_mask along the specified dimension
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=d)

return attention_mask

Expand Down

0 comments on commit 4bdc98b

Please sign in to comment.