drive.py 3.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # URL helpers, see https://github.com/NVlabs/stylegan
  2. # ------------------------------------------------------------------------------------------
  3. import requests
  4. import html
  5. import hashlib
  6. import glob
  7. import os
  8. import io
  9. from typing import Any
  10. import re
  11. import uuid
  12. def is_url(obj: Any) -> bool:
  13. """Determine whether the given object is a valid URL string."""
  14. if not isinstance(obj, str) or not "://" in obj:
  15. return False
  16. try:
  17. res = requests.compat.urlparse(obj)
  18. if not res.scheme or not res.netloc or not "." in res.netloc:
  19. return False
  20. res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
  21. if not res.scheme or not res.netloc or not "." in res.netloc:
  22. return False
  23. except:
  24. return False
  25. return True
  26. def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_path: bool = False) -> Any:
  27. """Download the given URL and return a binary-mode file object to access the data."""
  28. assert is_url(url)
  29. assert num_attempts >= 1
  30. # Lookup from cache.
  31. url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
  32. if cache_dir is not None:
  33. cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
  34. if len(cache_files) == 1:
  35. if(return_path):
  36. return cache_files[0]
  37. else:
  38. return open(cache_files[0], "rb")
  39. # Download.
  40. url_name = None
  41. url_data = None
  42. with requests.Session() as session:
  43. if verbose:
  44. print("Downloading %s ..." % url, end="", flush=True)
  45. for attempts_left in reversed(range(num_attempts)):
  46. try:
  47. with session.get(url) as res:
  48. res.raise_for_status()
  49. if len(res.content) == 0:
  50. raise IOError("No data received")
  51. if len(res.content) < 8192:
  52. content_str = res.content.decode("utf-8")
  53. if "download_warning" in res.headers.get("Set-Cookie", ""):
  54. links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
  55. if len(links) == 1:
  56. url = requests.compat.urljoin(url, links[0])
  57. raise IOError("Google Drive virus checker nag")
  58. if "Google Drive - Quota exceeded" in content_str:
  59. raise IOError("Google Drive quota exceeded")
  60. match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
  61. url_name = match[1] if match else url
  62. url_data = res.content
  63. if verbose:
  64. print(" done")
  65. break
  66. except:
  67. if not attempts_left:
  68. if verbose:
  69. print(" failed")
  70. raise
  71. if verbose:
  72. print(".", end="", flush=True)
  73. # Save to cache.
  74. if cache_dir is not None:
  75. safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
  76. cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
  77. temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
  78. os.makedirs(cache_dir, exist_ok=True)
  79. with open(temp_file, "wb") as f:
  80. f.write(url_data)
  81. os.replace(temp_file, cache_file) # atomic
  82. if(return_path): return cache_file
  83. # Return data as file object.
  84. return io.BytesIO(url_data)