Skip to content

Commit

Permalink
disable efficient sdp for old caps gpus (#258)
Browse files Browse the repository at this point in the history
  • Loading branch information
mitya52 committed Jan 3, 2024
1 parent c2896a2 commit 3f87614
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 0 deletions.
1 change: 1 addition & 0 deletions self_hosting_machinery/finetune/modelling/flash_sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def _forward(

if torch.cuda.get_device_capability() < (8, 0):
model.force_low_gpu_mem_mode = True
torch.backends.cuda.enable_mem_efficient_sdp(False)
logging.warning("Flash attention is not supported on gpus with cuda capability < 8")
return

Expand Down
2 changes: 2 additions & 0 deletions self_hosting_machinery/inference/inference_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,8 @@ def infer(self, request: Dict[str, Any], upload_proxy: UploadProxy, upload_proxy
if request_id in upload_proxy.check_cancelled():
scratchpad.finish_reason = "cancelled"
return
if torch.cuda.get_device_capability() < (8, 0):
torch.backends.cuda.enable_mem_efficient_sdp(False)
with torch.inference_mode():
stopping_criteria = StoppingCriteriaList([
CancellationStoppingCriteria(scratchpad, request_id, upload_proxy),
Expand Down

0 comments on commit 3f87614

Please sign in to comment.