Skip to content

Commit

Permalink
Only import mamba libraries if cuda available
Browse files Browse the repository at this point in the history
This should remove the cpu only inference errors.
  • Loading branch information
gkielian authored Feb 25, 2025
1 parent f01f5c6 commit 1d547e8
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions variations/attention_variations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from quantization.quant_utils import set_variant, create_activation_buffers
from variations.softmax_variations import softmax_dictionary
from variations.position_encoding_variations import QuantizedEmbedding, RotaryEmbedding, SymmetricalOverlapAngularPositions, FIRE
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update

from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
# Mamba related imports
if torch.cuda.is_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn

class CausalSelfAttention(nn.Module):
def __init__(self, config, fire_pos_enc=None):
Expand Down

0 comments on commit 1d547e8

Please sign in to comment.