Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix _quota_exceeded check and gdrive download #4109

Merged
merged 3 commits into from
Jun 24, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions torchvision/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
import gzip
import re
import tarfile
from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple
from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator
from urllib.parse import urlparse
import zipfile
import lzma
import contextlib
import urllib
import urllib.request
import urllib.error
import pathlib
import itertools

import torch
from torch.utils.model_zoo import tqdm
Expand Down Expand Up @@ -184,11 +184,10 @@ def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
return files


def _quota_exceeded(response: "requests.models.Response") -> bool: # type: ignore[name-defined]
def _quota_exceeded(first_chunk: bytes) -> bool: # type: ignore[name-defined]
try:
start = next(response.iter_content(chunk_size=128, decode_unicode=True))
return isinstance(start, str) and "Google Drive - Quota exceeded" in start
except StopIteration:
return "Google Drive - Quota exceeded" in first_chunk.decode()
except UnicodeDecodeError:
return False


Expand Down Expand Up @@ -224,15 +223,25 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
params = {'id': file_id, 'confirm': token}
response = session.get(url, params=params, stream=True)

if _quota_exceeded(response):
# Ideally, one would use response.status_code to check for quota limits, but google drive is not consistent
# with their own API, refer https://github.com/pytorch/vision/issues/2992#issuecomment-730614517.
# Should this be fixed at some place in future, one could refactor the following to no longer rely on decoding
# the first_chunk of the payload
response_content_generator = response.iter_content(32768)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it matter that the chunksize used to be chunk_size=128 for this specific check and now it's 32768?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I tested it back then and 128 was enough to get to the important part of the message. But since we use the chunk anyway, any reasonable chunk size should be ok.

Copy link
Contributor Author

@ORippler ORippler Jun 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I could debug manually, google drive either returns:

  1. A html stating that the quota is exceeded. We currently parse the contents of this html in _quota_exceeded, and yes, size of 128 is sufficient to get to the phrase/part we check against here. Larger chunk sizes didn't hurt in my debugging.
  2. The desired payload, e.g. a tarball of a dataset (which cannot be decoded and parsed).

The reason parsing the html is worse than using the response.status_code is that if google decides to change the formatting of the returned html or returns different htmls for edge cases we are currently unaware of our check will fail, we will write the html to disk again, and subsequently fail to unpack it in the next step, leaving the users with a cryptic error message.

We could of course be more strict here by rejecting all responses that have a payload which's first chunk can be fully decoded to strings, but the best long term solution would probably be for google to fix/adhere to their API.

first_chunk = None
while not first_chunk: # filter out keep-alive new chunks
first_chunk = next(response_content_generator)

if _quota_exceeded(first_chunk):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this doesn't come from this PR but I would suggest to remove the _quota_exceeded function and just inline it here. We'll be able to use raise RuntimeError(msg) from decode_error which will provide cleaner tracebacks.

Copy link
Contributor Author

@ORippler ORippler Jun 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I can surely inline the function I don't get what you mean with raise from UnicodeDecodeError.
We use the UnicodeDecodeError as the passing condition (i.e. UnicodeDecodeError will only happen if we have a valid payload).

Did I misunderstand you here ? Please clarify (sorry I am new to Exception Chaining).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh you're right, it wouldn't make sense then!

msg = (
f"The daily quota of the file {filename} is exceeded and it "
f"can't be downloaded. This is a limitation of Google Drive "
f"and can only be overcome by trying again later."
)
raise RuntimeError(msg)

_save_response_content(response, fpath)
_save_response_content(itertools.chain((first_chunk, ), response_content_generator), fpath)
response.close()


def _get_confirm_token(response: "requests.models.Response") -> Optional[str]: # type: ignore[name-defined]
Expand All @@ -244,12 +253,13 @@ def _get_confirm_token(response: "requests.models.Response") -> Optional[str]:


def _save_response_content(
response: "requests.models.Response", destination: str, chunk_size: int = 32768, # type: ignore[name-defined]
response_gen: Iterator[bytes], destination: str, # type: ignore[name-defined]
) -> None:
with open(destination, "wb") as f:
pbar = tqdm(total=None)
progress = 0
for chunk in response.iter_content(chunk_size):

for chunk in response_gen:
if chunk: # filter out keep-alive new chunks
f.write(chunk)
progress += len(chunk)
Expand Down