Skip to content

Commit

Permalink
[fbsync] Fix download from google drive which was downloading empty f…
Browse files Browse the repository at this point in the history
…iles in some cases (#4109)

Reviewed By: NicolasHug

Differential Revision: D29369894

fbshipit-source-id: 52d175103eb77170963f8115dbee3f8eb373802d
  • Loading branch information
vincentqb authored and facebook-github-bot committed Jun 25, 2021
1 parent 5c12931 commit 449eb47
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions torchvision/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
import gzip
import re
import tarfile
from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple
from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator
from urllib.parse import urlparse
import zipfile
import lzma
import contextlib
import urllib
import urllib.request
import urllib.error
import pathlib
import itertools

import torch
from torch.utils.model_zoo import tqdm
Expand Down Expand Up @@ -184,11 +184,10 @@ def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
return files


def _quota_exceeded(response: "requests.models.Response") -> bool: # type: ignore[name-defined]
def _quota_exceeded(first_chunk: bytes) -> bool: # type: ignore[name-defined]
try:
start = next(response.iter_content(chunk_size=128, decode_unicode=True))
return isinstance(start, str) and "Google Drive - Quota exceeded" in start
except StopIteration:
return "Google Drive - Quota exceeded" in first_chunk.decode()
except UnicodeDecodeError:
return False


Expand Down Expand Up @@ -224,15 +223,25 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
params = {'id': file_id, 'confirm': token}
response = session.get(url, params=params, stream=True)

if _quota_exceeded(response):
# Ideally, one would use response.status_code to check for quota limits, but google drive is not consistent
# with their own API, refer https://github.com/pytorch/vision/issues/2992#issuecomment-730614517.
# Should this be fixed at some place in future, one could refactor the following to no longer rely on decoding
# the first_chunk of the payload
response_content_generator = response.iter_content(32768)
first_chunk = None
while not first_chunk: # filter out keep-alive new chunks
first_chunk = next(response_content_generator)

if _quota_exceeded(first_chunk):
msg = (
f"The daily quota of the file {filename} is exceeded and it "
f"can't be downloaded. This is a limitation of Google Drive "
f"and can only be overcome by trying again later."
)
raise RuntimeError(msg)

_save_response_content(response, fpath)
_save_response_content(itertools.chain((first_chunk, ), response_content_generator), fpath)
response.close()


def _get_confirm_token(response: "requests.models.Response") -> Optional[str]: # type: ignore[name-defined]
Expand All @@ -244,12 +253,13 @@ def _get_confirm_token(response: "requests.models.Response") -> Optional[str]:


def _save_response_content(
response: "requests.models.Response", destination: str, chunk_size: int = 32768, # type: ignore[name-defined]
response_gen: Iterator[bytes], destination: str, # type: ignore[name-defined]
) -> None:
with open(destination, "wb") as f:
pbar = tqdm(total=None)
progress = 0
for chunk in response.iter_content(chunk_size):

for chunk in response_gen:
if chunk: # filter out keep-alive new chunks
f.write(chunk)
progress += len(chunk)
Expand Down

0 comments on commit 449eb47

Please sign in to comment.