Skip to content

Commit

Permalink
Happy black?
Browse files Browse the repository at this point in the history
  • Loading branch information
theophilegervet committed Feb 4, 2024
1 parent 42337c8 commit 9d64245
Showing 1 changed file with 66 additions and 22 deletions.
88 changes: 66 additions & 22 deletions img2dataset/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def is_disallowed(headers, user_agent_token, disallowed_header_directives):
try:
uatoken_directives = values.split(":", 1)
directives = [x.strip().lower() for x in uatoken_directives[-1].split(",")]
ua_token = uatoken_directives[0].lower() if len(uatoken_directives) == 2 else None
ua_token = (
uatoken_directives[0].lower() if len(uatoken_directives) == 2 else None
)
if (ua_token is None or ua_token == user_agent_token) and any(
x in disallowed_header_directives for x in directives
):
Expand All @@ -35,15 +37,21 @@ def is_disallowed(headers, user_agent_token, disallowed_header_directives):
return False


def download_image(row, timeout, user_agent_token, disallowed_header_directives, ignore_ssl_certificate):
def download_image(
row, timeout, user_agent_token, disallowed_header_directives, ignore_ssl_certificate
):
"""Download an image with urllib"""
key, url = row
img_stream = None
user_agent_string = "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:72.0) Gecko/20100101 Firefox/72.0"
user_agent_string = (
"Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:72.0) Gecko/20100101 Firefox/72.0"
)
if user_agent_token:
user_agent_string += f" (compatible; {user_agent_token}; +https://github.com/rom1504/img2dataset)"
try:
request = urllib.request.Request(url, data=None, headers={"User-Agent": user_agent_string})
request = urllib.request.Request(
url, data=None, headers={"User-Agent": user_agent_string}
)
ctx = ssl.create_default_context()
if ignore_ssl_certificate:
ctx.check_hostname = False
Expand All @@ -69,12 +77,16 @@ def download_image_with_retry(
retries,
user_agent_token,
disallowed_header_directives,
ignore_ssl_certificate
ignore_ssl_certificate,
):
"""Download an image with urllib, retrying if it fails."""
for _ in range(retries + 1):
key, img_stream, err = download_image(
row, timeout, user_agent_token, disallowed_header_directives, ignore_ssl_certificate
row,
timeout,
user_agent_token,
disallowed_header_directives,
ignore_ssl_certificate,
)
if img_stream is not None:
return key, img_stream, err
Expand All @@ -84,8 +96,10 @@ def download_image_with_retry(
def compute_key(key, shard_id, oom_sample_per_shard, oom_shard_count):
true_key = (10**oom_sample_per_shard) * shard_id + key
key_format = oom_sample_per_shard + oom_shard_count
str_key = "{true_key:0{key_format}d}".format( # pylint: disable=consider-using-f-string
key_format=key_format, true_key=true_key
str_key = (
"{true_key:0{key_format}d}".format( # pylint: disable=consider-using-f-string
key_format=key_format, true_key=true_key
)
)
return str_key

Expand Down Expand Up @@ -128,11 +142,15 @@ def __init__(
self.verify_hash_type = verify_hash_type
self.encode_format = encode_format
self.retries = retries
self.user_agent_token = None if user_agent_token is None else user_agent_token.strip().lower()
self.user_agent_token = (
None if user_agent_token is None else user_agent_token.strip().lower()
)
self.disallowed_header_directives = (
None
if disallowed_header_directives is None
else {directive.strip().lower() for directive in disallowed_header_directives}
else {
directive.strip().lower() for directive in disallowed_header_directives
}
)
self.blurring_bbox_col = blurring_bbox_col
self.ignore_ssl_certificate = ignore_ssl_certificate
Expand Down Expand Up @@ -189,11 +207,19 @@ def download_shard(
failed_to_download = 0
failed_to_resize = 0
url_indice = self.column_list.index("url")
caption_indice = self.column_list.index("caption") if "caption" in self.column_list else None
caption_indice = (
self.column_list.index("caption") if "caption" in self.column_list else None
)
hash_indice = (
self.column_list.index(self.verify_hash_type) if self.verify_hash_type in self.column_list else None
self.column_list.index(self.verify_hash_type)
if self.verify_hash_type in self.column_list
else None
)
bbox_indice = (
self.column_list.index(self.blurring_bbox_col)
if self.blurring_bbox_col is not None
else None
)
bbox_indice = self.column_list.index(self.blurring_bbox_col) if self.blurring_bbox_col is not None else None
key_url_list = [(key, x[url_indice]) for key, x in shard_to_dl]

# this prevents an accumulation of more than twice the number of threads in sample ready to resize
Expand Down Expand Up @@ -231,7 +257,9 @@ def data_generator():
):
try:
_, sample_data = shard_to_dl[key]
str_key = compute_key(key, shard_id, oom_sample_per_shard, self.oom_shard_count)
str_key = compute_key(
key, shard_id, oom_sample_per_shard, self.oom_shard_count
)
meta = {
# Skip columns containing a the verification hash and only save the compute hash
**{
Expand Down Expand Up @@ -261,15 +289,19 @@ def data_generator():
sample_writer.write(
None,
str_key,
sample_data[caption_indice] if caption_indice is not None else None,
sample_data[caption_indice]
if caption_indice is not None
else None,
meta,
)
semaphore.release()
continue

if hash_indice is not None:
img_stream.seek(0)
test_hash = getattr(hashlib, self.verify_hash_type)(img_stream.read()).hexdigest()
test_hash = getattr(hashlib, self.verify_hash_type)(
img_stream.read()
).hexdigest()
if test_hash != sample_data[hash_indice]:
failed_to_download += 1
status = "failed_to_download"
Expand All @@ -279,7 +311,9 @@ def data_generator():
sample_writer.write(
None,
str_key,
sample_data[caption_indice] if caption_indice is not None else None,
sample_data[caption_indice]
if caption_indice is not None
else None,
meta,
)
img_stream.close()
Expand All @@ -288,7 +322,9 @@ def data_generator():
continue

img_stream.seek(0)
bbox_list = sample_data[bbox_indice] if bbox_indice is not None else None
bbox_list = (
sample_data[bbox_indice] if bbox_indice is not None else None
)
(
img,
width,
Expand All @@ -306,7 +342,9 @@ def data_generator():
sample_writer.write(
None,
str_key,
sample_data[caption_indice] if caption_indice is not None else None,
sample_data[caption_indice]
if caption_indice is not None
else None,
meta,
)
img_stream.close()
Expand All @@ -323,7 +361,9 @@ def data_generator():
exif = json.dumps(
{
k: str(v).strip()
for k, v in exifread.process_file(img_stream, details=False).items()
for k, v in exifread.process_file(
img_stream, details=False
).items()
if v is not None
}
)
Expand All @@ -333,7 +373,9 @@ def data_generator():

if self.compute_hash is not None:
img_stream.seek(0)
meta[self.compute_hash] = getattr(hashlib, self.compute_hash)(img_stream.read()).hexdigest()
meta[self.compute_hash] = getattr(hashlib, self.compute_hash)(
img_stream.read()
).hexdigest()

meta["status"] = status
meta["width"] = width
Expand All @@ -346,7 +388,9 @@ def data_generator():
sample_writer.write(
img,
str_key,
sample_data[caption_indice] if caption_indice is not None else None,
sample_data[caption_indice]
if caption_indice is not None
else None,
meta,
)
except Exception as err: # pylint: disable=broad-except
Expand Down

0 comments on commit 9d64245

Please sign in to comment.