Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added new parameter 'compute_key' #390

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ with each number being the position in the list. The subfolders avoids having to

If **captions** are provided, they will be saved as 0.txt, 1.txt, ...

If **compute_key** is provided, the filename keys ({key}.jpg) will be computed from this function instead of the default.

This can then easily be fed into machine learning training or any other use case.

Also .json files named 0.json, 1.json,... are saved with these keys:
Expand Down Expand Up @@ -177,11 +179,43 @@ This module exposes a single function `download` which takes the same arguments
* **max_shard_retry** Number of time to retry failed shards at the end (default *1*)
* **user_agent_token** Additional identifying token that will be added to the User-Agent header sent with HTTP requests to download images; for example: "img2downloader". (default *None*)
* **disallowed_header_directives** List of X-Robots-Tags header directives that, if present in HTTP response when downloading an image, will cause the image to be excluded from the output dataset. To ignore x-robots-tags, pass '[]'. (default '["noai", "noimageai", "noindex", "noimageindex"]')
* **compute_key** A function reference to override the default function used to compute the key for a given sample set. If set to None, img2dataset create keys as a combination of its shard number and count within the shard e.g. 100001.jpg, 100001.txt. (default *None*)

## Incremental mode

If a first download got interrupted for any reason, you can run again with --incremental "incremental" (this is the default) and using the same output folder , the same number_sample_per_shard and the same input urls, and img2dataset will complete the download.

## Compute key paramater function

To override the default method in calculating a samples key, pass a function reference to the compute_key parameter. This function will be passed 6 parameters, and it should return a single string, *unique across the entire dataset* being downloaded. The parameters it will be passed are:
- key: the index of the sample in the shard
- shard_id: the shard id this sample belongs too
- oom_sample_per_shard: the number of samples per shard
- oom_shard_count: the total number of shards
- additional_columns: a dictionary containing any additional columns specified in initial parameters for this sample

As an example you can reconstruct the default method for computing a key:
```python
def compute_key(key, shard_id, oom_sample_per_shard, oom_shard_count, additional_columns):
true_key = (10**oom_sample_per_shard) * shard_id + key
key_format = oom_sample_per_shard + oom_shard_count
str_key = "{true_key:0{key_format}d}".format(
key_format=key_format, true_key=true_key
)
return str_key
```
Alternatively, if your dataset had some additional which you specified, one of which was a uid across the dataset, you could simply do the following:
```python
def compute_key(key, shard_id, oom_sample_per_shard, oom_shard_count, additional_columns):
return str(additional_columns['uid'])
```
Which would change the output to be:
* output_folder
* 00000
* *your_uid_0*.jpg
* *your_uid_0*.txt
* *your_uid_1*.jpg
...
## Output format choice

Img2dataset support several formats. There are trade off for which to choose:
Expand Down
16 changes: 13 additions & 3 deletions img2dataset/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(
user_agent_token,
disallowed_header_directives,
blurring_bbox_col=None,
compute_key_override=None,
) -> None:
self.sample_writer_class = sample_writer_class
self.resizer = resizer
Expand All @@ -119,6 +120,7 @@ def __init__(
else {directive.strip().lower() for directive in disallowed_header_directives}
)
self.blurring_bbox_col = blurring_bbox_col
self.compute_key_override = compute_key_override

def __call__(
self,
Expand Down Expand Up @@ -213,14 +215,22 @@ def data_generator():
):
try:
_, sample_data = shard_to_dl[key]
str_key = compute_key(key, shard_id, oom_sample_per_shard, self.oom_shard_count)
meta = {
additional_columns = {
# Skip columns containing a the verification hash and only save the compute hash
**{
self.column_list[i]: sample_data[i]
for i in range(len(self.column_list))
if (hash_indice is None or i != hash_indice)
},
}
}
if self.compute_key_override is None:
str_key = compute_key(key, shard_id, oom_sample_per_shard, self.oom_shard_count)
else:
str_key = self.compute_key_override(
key, shard_id, oom_sample_per_shard, self.oom_shard_count, additional_columns
)
meta = {
**additional_columns,
"key": str_key,
"status": None,
"error_message": error_message,
Expand Down
4 changes: 3 additions & 1 deletion img2dataset/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Img2dataset"""

from typing import List, Optional
from typing import List, Optional, Callable
import fire
import logging
from .logger import LoggerProcess
Expand Down Expand Up @@ -108,6 +108,7 @@ def download(
max_shard_retry: int = 1,
user_agent_token: Optional[str] = None,
disallowed_header_directives: Optional[List[str]] = None,
compute_key: Optional[Callable] = None,
):
"""Download is the main entry point of img2dataset, it uses multiple processes and download multiple files"""
if disallowed_header_directives is None:
Expand Down Expand Up @@ -247,6 +248,7 @@ def signal_handler(signal_arg, frame): # pylint: disable=unused-argument
user_agent_token=user_agent_token,
disallowed_header_directives=disallowed_header_directives,
blurring_bbox_col=bbox_col,
compute_key_override=compute_key,
)

print("Starting the downloading of this file")
Expand Down
Loading