Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streaming refacto #353

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions img2dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""Img2dataset"""

from img2dataset.main import main
from img2dataset.main import download
from img2dataset.batch.main import main
from img2dataset.batch.main import download
from img2dataset.batch.preparer import preparer
from img2dataset.service.service import service
from img2dataset.service.launcher import launcher
Empty file added img2dataset/batch/__init__.py
Empty file.
28 changes: 20 additions & 8 deletions img2dataset/distributor.py → img2dataset/batch/distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,23 @@
from contextlib import contextmanager
from multiprocessing import get_context
from itertools import islice, chain
from ..core.downloader import download_shard

from tqdm import tqdm
from functools import partial
from pydantic import BaseModel


class DistributorOptions(BaseModel):
distributor: str = "multiprocessing"
subjob_size: int = 1000
max_shard_retry: int = 1
processes_count : int = 1


def rrun(d):
return download_shard(**d)

def retrier(runf, failed_shards, max_shard_retry):
# retry failed shards max_shard_retry times
for i in range(max_shard_retry):
Expand All @@ -20,20 +33,19 @@ def retrier(runf, failed_shards, max_shard_retry):
"still failed. You may restart the same command to retry again."
)


def multiprocessing_distributor(processes_count, downloader, reader, _, max_shard_retry):
def multiprocessing_distributor(processes_count, worker_config_generator, _, max_shard_retry):
"""Distribute the work to the processes using multiprocessing"""
ctx = get_context("spawn")
with ctx.Pool(processes_count, maxtasksperchild=5) as process_pool:

def run(gen):
failed_shards = []
for (status, row) in tqdm(process_pool.imap_unordered(downloader, gen)):
for (status, row) in tqdm(process_pool.imap_unordered(rrun, gen)):
if status is False:
failed_shards.append(row)
return failed_shards

failed_shards = run(reader)
failed_shards = run(worker_config_generator)

retrier(run, failed_shards, max_shard_retry)

Expand All @@ -42,7 +54,7 @@ def run(gen):
del process_pool


def pyspark_distributor(processes_count, downloader, reader, subjob_size, max_shard_retry):
def pyspark_distributor(processes_count, worker_config_generator, subjob_size, max_shard_retry):
"""Distribute the work to the processes using pyspark"""

with _spark_session(processes_count) as spark:
Expand All @@ -56,12 +68,12 @@ def run(gen):
failed_shards = []
for batch in batcher(gen, subjob_size):
rdd = spark.sparkContext.parallelize(batch, len(batch))
for (status, row) in rdd.map(downloader).collect():
for (status, row) in rdd.map(rrun).collect():
if status is False:
failed_shards.append(row)
return failed_shards

failed_shards = run(reader)
failed_shards = run(worker_config_generator)

retrier(run, failed_shards, max_shard_retry)

Expand All @@ -77,7 +89,7 @@ def _spark_session(processes_count: int):
if spark_major_version >= 3:
spark = SparkSession.getActiveSession()
else:
spark = pyspark.sql.SparkSession._instantiatedSession # pylint: disable=protected-access
spark = pyspark.sql.SparkSession._instantiatedSession # type: ignore # pylint: disable=protected-access

if spark is None:
print("No pyspark session found, creating a new one!")
Expand Down
168 changes: 168 additions & 0 deletions img2dataset/batch/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
"""Img2dataset"""

from typing import List, Optional
import fire
import logging
from ..core.logger import LoggerProcess
from ..core.reader import Reader
from .distributor import multiprocessing_distributor, pyspark_distributor
import fsspec
import sys
import signal
import os
import sys

import pydantic

logging.getLogger("exifread").setLevel(level=logging.CRITICAL)


def arguments_validator(params):
"""Validate the arguments"""
if params["save_additional_columns"] is not None:
save_additional_columns_set = set(params["save_additional_columns"])

forbidden_columns = set(
[
"key",
"caption",
"url",
"width",
"height",
"original_width",
"original_height",
"status",
"error_message",
"exif",
"md5",
]
)
intersection = save_additional_columns_set.intersection(forbidden_columns)
if intersection:
raise ValueError(
f"You cannot use in save_additional_columns the following columns: {intersection}."
+ "img2dataset reserves these columns for its own use. Please remove them from save_additional_columns."
)


from ..core.writer import WriterOptions
from ..core.downloader import DownloaderOptions
from ..core.resizer import ResizingOptions
from ..batch.distributor import DistributorOptions
from ..core.logger import LoggerOptions
from ..core.reader import ReaderOptions

MainOptions = pydantic.create_model("MainOptions", __base__=(
ResizingOptions,
WriterOptions,
DownloaderOptions,
DistributorOptions,
LoggerOptions,
ReaderOptions,))

DownloaderWorkerInMainOptions = pydantic.create_model("DownloaderWorkerOptions", __base__=(
WriterOptions,
DownloaderOptions,
ResizingOptions,))

# do kwargs + have a parsing thing for validating
def download(**kwargs):
"""Download is the main entry point of img2dataset, it uses multiple processes and download multiple files"""
opts = MainOptions(**kwargs)

def make_path_absolute(path):
fs, p = fsspec.core.url_to_fs(path)
if fs.protocol == "file":
return os.path.abspath(p)
return path

opts.output_folder = make_path_absolute(opts.output_folder)
opts.url_list = make_path_absolute(opts.url_list)
config_parameters = opts.dict()
arguments_validator(config_parameters)

logger_process = LoggerProcess(opts.output_folder, opts.enable_wandb, opts.wandb_project, config_parameters)

tmp_path = opts.output_folder + "/_tmp"
fs, tmp_dir = fsspec.core.url_to_fs(tmp_path)
if not fs.exists(tmp_dir):
fs.mkdir(tmp_dir)

def signal_handler(signal_arg, frame): # pylint: disable=unused-argument
try:
fs.rm(tmp_dir, recursive=True)
except Exception as _: # pylint: disable=broad-except
pass
logger_process.terminate()
sys.exit(0)

signal.signal(signal.SIGINT, signal_handler)

fs, output_path = fsspec.core.url_to_fs(opts.output_folder)

if not fs.exists(output_path):
fs.mkdir(output_path)
done_shards = set()
else:
if opts.incremental_mode == "incremental":
done_shards = set(int(x.split("/")[-1].split("_")[0]) for x in fs.glob(output_path + "/*.json"))
elif opts.incremental_mode == "overwrite":
fs.rm(output_path, recursive=True)
fs.mkdir(output_path)
done_shards = set()
else:
raise ValueError(f"Unknown incremental mode {opts.incremental_mode}")

logger_process.done_shards = done_shards
logger_process.start()

reader = Reader(
opts.url_list,
opts.input_format,
opts.url_col,
opts.caption_col,
opts.save_additional_columns,
opts.number_sample_per_shard,
done_shards,
tmp_path,
)

def worker_config_generator():
for (shard_id, input_file) in reader:
shard_name = "{shard_id:0{oom_shard_count}d}".format( # pylint: disable=consider-using-f-string
shard_id=shard_id, oom_shard_count=opts.oom_shard_count
)
output_file_prefix = f"{opts.output_folder}/{shard_name}"
param_keys = set(DownloaderWorkerInMainOptions.__fields__.keys())
param_keys.add("caption_col")
param_keys.add("input_format")
param_keys.add("save_additional_columns")
conf = {k: config_parameters[k] for k in param_keys}
conf["input_file"] = input_file
conf["output_file_prefix"] = output_file_prefix
yield conf

print("Starting the downloading of this file")
if opts.distributor == "multiprocessing":
distributor_fn = multiprocessing_distributor
elif opts.distributor == "pyspark":
distributor_fn = pyspark_distributor
else:
raise ValueError(f"Distributor {opts.distributor} not supported")

distributor_fn(
opts.processes_count,
worker_config_generator(),
opts.subjob_size,
opts.max_shard_retry,
)
logger_process.join()
fs.rm(tmp_dir, recursive=True)


def main():
fire.Fire(download)


if __name__ == "__main__":
main()
113 changes: 113 additions & 0 deletions img2dataset/batch/preparer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""Img2dataset"""

from typing import List, Optional
import fire
import logging
from ..core.reader import Reader
import fsspec
import os

# is this needed? maybe only for slurm

logging.getLogger("exifread").setLevel(level=logging.CRITICAL)


def arguments_validator(params):
"""Validate the arguments"""
if params["save_additional_columns"] is not None:
save_additional_columns_set = set(params["save_additional_columns"])

forbidden_columns = set(
[
"key",
"caption",
"url",
"width",
"height",
"original_width",
"original_height",
"status",
"error_message",
"exif",
"md5",
]
)
intersection = save_additional_columns_set.intersection(forbidden_columns)
if intersection:
raise ValueError(
f"You cannot use in save_additional_columns the following columns: {intersection}."
+ "img2dataset reserves these columns for its own use. Please remove them from save_additional_columns."
)


def preparer(
url_list: str,
output_folder: str = "images",
input_format: str = "txt",
url_col: str = "url",
caption_col: Optional[str] = None,
number_sample_per_shard: int = 10000,
save_additional_columns: Optional[List[str]] = None,
incremental_mode: str = "incremental",
number: int = 1,
):
"""Prepare the dataset for downloading"""


# move all this to Reader

config_parameters = dict(locals())
arguments_validator(config_parameters)

def make_path_absolute(path):
fs, p = fsspec.core.url_to_fs(path)
if fs.protocol == "file":
return os.path.abspath(p)
return path

output_folder = make_path_absolute(output_folder)
url_list = make_path_absolute(url_list)


tmp_path = output_folder + "/_tmp"
fs, tmp_dir = fsspec.core.url_to_fs(tmp_path)
if not fs.exists(tmp_dir):
fs.mkdir(tmp_dir)

fs, output_path = fsspec.core.url_to_fs(output_folder)

if not fs.exists(output_path):
fs.mkdir(output_path)
done_shards = set()
else:
if incremental_mode == "incremental":
done_shards = set(int(x.split("/")[-1].split("_")[0]) for x in fs.glob(output_path + "/*.json"))
elif incremental_mode == "overwrite":
fs.rm(output_path, recursive=True)
fs.mkdir(output_path)
done_shards = set()
else:
raise ValueError(f"Unknown incremental mode {incremental_mode}")


reader = Reader(
url_list,
input_format,
url_col,
caption_col,
save_additional_columns,
number_sample_per_shard,
done_shards,
tmp_path,
)

reader.prepare(1)



def main():
fire.Fire(preparer)


if __name__ == "__main__":
main()
Empty file added img2dataset/core/__init__.py
Empty file.
Loading