From c1df7f885f9496784268eccfb3d8df77aec54b3b Mon Sep 17 00:00:00 2001 From: "Gal Cohen (galco)" Date: Tue, 20 Aug 2024 15:50:13 +0300 Subject: [PATCH] fix: jamba cache fails to use torch.nn.module (#32894) Co-authored-by: Gal Cohen --- src/transformers/models/jamba/modeling_jamba.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 5449c1fb97d4..c6e8d425459f 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -210,6 +210,7 @@ class HybridMambaAttentionDynamicCache(DynamicCache): """ def __init__(self, config, batch_size, dtype=torch.float16, device=None): + super().__init__() self.dtype = dtype self.layers_block_type = config.layers_block_type self.has_previous_state = False # only used by mamba