Skip to content

Commit

Permalink
Happy black?
Browse files Browse the repository at this point in the history
  • Loading branch information
theophilegervet committed Feb 4, 2024
1 parent 9d64245 commit 8192b76
Showing 1 changed file with 20 additions and 60 deletions.
80 changes: 20 additions & 60 deletions img2dataset/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ def is_disallowed(headers, user_agent_token, disallowed_header_directives):
try:
uatoken_directives = values.split(":", 1)
directives = [x.strip().lower() for x in uatoken_directives[-1].split(",")]
ua_token = (
uatoken_directives[0].lower() if len(uatoken_directives) == 2 else None
)
ua_token = uatoken_directives[0].lower() if len(uatoken_directives) == 2 else None
if (ua_token is None or ua_token == user_agent_token) and any(
x in disallowed_header_directives for x in directives
):
Expand All @@ -37,21 +35,15 @@ def is_disallowed(headers, user_agent_token, disallowed_header_directives):
return False


def download_image(
row, timeout, user_agent_token, disallowed_header_directives, ignore_ssl_certificate
):
def download_image(row, timeout, user_agent_token, disallowed_header_directives, ignore_ssl_certificate):
"""Download an image with urllib"""
key, url = row
img_stream = None
user_agent_string = (
"Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:72.0) Gecko/20100101 Firefox/72.0"
)
user_agent_string = "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:72.0) Gecko/20100101 Firefox/72.0"
if user_agent_token:
user_agent_string += f" (compatible; {user_agent_token}; +https://github.com/rom1504/img2dataset)"
try:
request = urllib.request.Request(
url, data=None, headers={"User-Agent": user_agent_string}
)
request = urllib.request.Request(url, data=None, headers={"User-Agent": user_agent_string})
ctx = ssl.create_default_context()
if ignore_ssl_certificate:
ctx.check_hostname = False
Expand Down Expand Up @@ -96,10 +88,8 @@ def download_image_with_retry(
def compute_key(key, shard_id, oom_sample_per_shard, oom_shard_count):
true_key = (10**oom_sample_per_shard) * shard_id + key
key_format = oom_sample_per_shard + oom_shard_count
str_key = (
"{true_key:0{key_format}d}".format( # pylint: disable=consider-using-f-string
key_format=key_format, true_key=true_key
)
str_key = "{true_key:0{key_format}d}".format( # pylint: disable=consider-using-f-string
key_format=key_format, true_key=true_key
)
return str_key

Expand Down Expand Up @@ -142,15 +132,11 @@ def __init__(
self.verify_hash_type = verify_hash_type
self.encode_format = encode_format
self.retries = retries
self.user_agent_token = (
None if user_agent_token is None else user_agent_token.strip().lower()
)
self.user_agent_token = None if user_agent_token is None else user_agent_token.strip().lower()
self.disallowed_header_directives = (
None
if disallowed_header_directives is None
else {
directive.strip().lower() for directive in disallowed_header_directives
}
else {directive.strip().lower() for directive in disallowed_header_directives}
)
self.blurring_bbox_col = blurring_bbox_col
self.ignore_ssl_certificate = ignore_ssl_certificate
Expand Down Expand Up @@ -207,19 +193,11 @@ def download_shard(
failed_to_download = 0
failed_to_resize = 0
url_indice = self.column_list.index("url")
caption_indice = (
self.column_list.index("caption") if "caption" in self.column_list else None
)
caption_indice = self.column_list.index("caption") if "caption" in self.column_list else None
hash_indice = (
self.column_list.index(self.verify_hash_type)
if self.verify_hash_type in self.column_list
else None
)
bbox_indice = (
self.column_list.index(self.blurring_bbox_col)
if self.blurring_bbox_col is not None
else None
self.column_list.index(self.verify_hash_type) if self.verify_hash_type in self.column_list else None
)
bbox_indice = self.column_list.index(self.blurring_bbox_col) if self.blurring_bbox_col is not None else None
key_url_list = [(key, x[url_indice]) for key, x in shard_to_dl]

# this prevents an accumulation of more than twice the number of threads in sample ready to resize
Expand Down Expand Up @@ -257,9 +235,7 @@ def data_generator():
):
try:
_, sample_data = shard_to_dl[key]
str_key = compute_key(
key, shard_id, oom_sample_per_shard, self.oom_shard_count
)
str_key = compute_key(key, shard_id, oom_sample_per_shard, self.oom_shard_count)
meta = {
# Skip columns containing a the verification hash and only save the compute hash
**{
Expand Down Expand Up @@ -289,19 +265,15 @@ def data_generator():
sample_writer.write(
None,
str_key,
sample_data[caption_indice]
if caption_indice is not None
else None,
sample_data[caption_indice] if caption_indice is not None else None,
meta,
)
semaphore.release()
continue

if hash_indice is not None:
img_stream.seek(0)
test_hash = getattr(hashlib, self.verify_hash_type)(
img_stream.read()
).hexdigest()
test_hash = getattr(hashlib, self.verify_hash_type)(img_stream.read()).hexdigest()
if test_hash != sample_data[hash_indice]:
failed_to_download += 1
status = "failed_to_download"
Expand All @@ -311,9 +283,7 @@ def data_generator():
sample_writer.write(
None,
str_key,
sample_data[caption_indice]
if caption_indice is not None
else None,
sample_data[caption_indice] if caption_indice is not None else None,
meta,
)
img_stream.close()
Expand All @@ -322,9 +292,7 @@ def data_generator():
continue

img_stream.seek(0)
bbox_list = (
sample_data[bbox_indice] if bbox_indice is not None else None
)
bbox_list = sample_data[bbox_indice] if bbox_indice is not None else None
(
img,
width,
Expand All @@ -342,9 +310,7 @@ def data_generator():
sample_writer.write(
None,
str_key,
sample_data[caption_indice]
if caption_indice is not None
else None,
sample_data[caption_indice] if caption_indice is not None else None,
meta,
)
img_stream.close()
Expand All @@ -361,9 +327,7 @@ def data_generator():
exif = json.dumps(
{
k: str(v).strip()
for k, v in exifread.process_file(
img_stream, details=False
).items()
for k, v in exifread.process_file(img_stream, details=False).items()
if v is not None
}
)
Expand All @@ -373,9 +337,7 @@ def data_generator():

if self.compute_hash is not None:
img_stream.seek(0)
meta[self.compute_hash] = getattr(hashlib, self.compute_hash)(
img_stream.read()
).hexdigest()
meta[self.compute_hash] = getattr(hashlib, self.compute_hash)(img_stream.read()).hexdigest()

meta["status"] = status
meta["width"] = width
Expand All @@ -388,9 +350,7 @@ def data_generator():
sample_writer.write(
img,
str_key,
sample_data[caption_indice]
if caption_indice is not None
else None,
sample_data[caption_indice] if caption_indice is not None else None,
meta,
)
except Exception as err: # pylint: disable=broad-except
Expand Down

0 comments on commit 8192b76

Please sign in to comment.