Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added num_workers to export files in parallel #923

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pydantic import BaseModel
from sqlalchemy.sql.functions import GenericFunction
from sqlalchemy.sql.sqltypes import NullType
from tqdm import tqdm

from datachain.dataset import DatasetRecord
from datachain.func import literal
Expand All @@ -32,7 +33,14 @@
from datachain.lib.convert.values_to_tuples import values_to_tuples
from datachain.lib.data_model import DataModel, DataType, DataValue, dict_to_data_model
from datachain.lib.dataset_info import DatasetInfo
from datachain.lib.file import ArrowRow, File, FileType, get_file_type
from datachain.lib.file import (
EXPORT_FILES_MAX_THREADS,
ArrowRow,
File,
FileExporter,
FileType,
get_file_type,
)
from datachain.lib.file import ExportPlacement as FileExportPlacement
from datachain.lib.listing import get_file_info, get_listing, list_bucket, ls
from datachain.lib.listing_info import ListingInfo
Expand Down Expand Up @@ -2500,8 +2508,10 @@ def to_storage(
placement: FileExportPlacement = "fullpath",
use_cache: bool = True,
link_type: Literal["copy", "symlink"] = "copy",
num_workers: Optional[int] = EXPORT_FILES_MAX_THREADS,
) -> None:
"""Export files from a specified signal to a directory.
"""Export files from a specified signal to a directory. Files can be
exported to a local or cloud directory.

Args:
output: Path to the target directory for exporting files.
Expand All @@ -2511,6 +2521,8 @@ def to_storage(
use_cache: If `True`, cache the files before exporting.
link_type: Method to use for exporting files.
Falls back to `'copy'` if symlinking fails.
num_workers : number of workers to use for exporting files.
By default it uses 5 workers.

Example:
Cross cloud transfer
Expand All @@ -2525,8 +2537,22 @@ def to_storage(
):
raise ValueError("Files with the same name found")

for file in self.collect(signal):
file.export(output, placement, use_cache, link_type=link_type) # type: ignore[union-attr]
progress_bar = tqdm(
desc=f"Exporting files to {output}: ",
unit=" files",
unit_scale=True,
unit_divisor=10,
total=self.count(),
leave=False,
)
file_exporter = FileExporter(
output,
placement,
use_cache,
link_type,
max_threads=num_workers or 1,
)
file_exporter.run(self.collect(signal), progress_bar)

def shuffle(self) -> "Self":
"""Shuffle the rows of the chain deterministically."""
Expand Down
30 changes: 30 additions & 0 deletions src/datachain/lib/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from datachain.client.fileslice import FileSlice
from datachain.lib.data_model import DataModel
from datachain.lib.utils import DataChainError
from datachain.nodes_thread_pool import NodesThreadPool
from datachain.sql.types import JSON, Boolean, DateTime, Int, String
from datachain.utils import TIME_ZERO

Expand All @@ -43,6 +44,35 @@
ExportPlacement = Literal["filename", "etag", "fullpath", "checksum"]

FileType = Literal["binary", "text", "image", "video"]
EXPORT_FILES_MAX_THREADS = 5


class FileExporter(NodesThreadPool):
"""Class that does file exporting concurrently with thread pool"""

def __init__(
self,
output: str,
placement: ExportPlacement,
use_cache: bool,
link_type: Literal["copy", "symlink"],
max_threads: int = EXPORT_FILES_MAX_THREADS,
):
super().__init__(max_threads)
self.output = output
self.placement = placement
self.use_cache = use_cache
self.link_type = link_type

def done_task(self, done):
for task in done:
task.result()

def do_task(self, file):
file.export(
self.output, self.placement, self.use_cache, link_type=self.link_type
)
self.increase_counter(1)


class VFileError(DataChainError):
Expand Down
43 changes: 32 additions & 11 deletions src/datachain/nodes_thread_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,44 +57,65 @@
self._max_threads = max_threads
self._thread_counter = 0
self._thread_lock = threading.Lock()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

separate: probably can be done w/o a lock if this only about +1. locks imposes some performance loss. There should be something like atomicint or something

self.tasks = set()
self.canceled = False
self.th_pool = None

def run(
self,
chunk_gen,
progress_bar=None,
):
results = []
with concurrent.futures.ThreadPoolExecutor(self._max_threads) as th_pool:
tasks = set()
self.th_pool = concurrent.futures.ThreadPoolExecutor(self._max_threads)
try:
self._thread_counter = 0
for chunk in chunk_gen:
while len(tasks) >= self._max_threads:
if self.canceled:
break

Check warning on line 75 in src/datachain/nodes_thread_pool.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/nodes_thread_pool.py#L75

Added line #L75 was not covered by tests
while len(self.tasks) >= self._max_threads:
done, _ = concurrent.futures.wait(
tasks, timeout=1, return_when="FIRST_COMPLETED"
self.tasks, timeout=1, return_when="FIRST_COMPLETED"
)
self.done_task(done)

tasks = tasks - done
self.tasks = self.tasks - done
self.update_progress_bar(progress_bar)

tasks.add(th_pool.submit(self.do_task, chunk))
self.tasks.add(self.th_pool.submit(self.do_task, chunk))
self.update_progress_bar(progress_bar)

while tasks:
while self.tasks:
if self.canceled:
break

Check warning on line 90 in src/datachain/nodes_thread_pool.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/nodes_thread_pool.py#L90

Added line #L90 was not covered by tests
done, _ = concurrent.futures.wait(
tasks, timeout=1, return_when="FIRST_COMPLETED"
self.tasks, timeout=1, return_when="FIRST_COMPLETED"
)
task_results = self.done_task(done)
if task_results:
results.extend(task_results)

tasks = tasks - done
self.tasks = self.tasks - done
self.update_progress_bar(progress_bar)

th_pool.shutdown()
except:
self.cancel_all()
raise
else:
self.th_pool.shutdown()

return results

def cancel_all(self):
self.cancel = True
# Canceling tasks just in case any of them is scheduled to run.
# Note that running tasks cannot be canceled, instead we will wait for
# them to finish when shutting down thread loop executor by calling
# shutdown() method.
for task in self.tasks:
task.cancel()
if self.th_pool:
self.th_pool.shutdown() # this will wait for running tasks to finish
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so, what does it mean for the cancellation? (potentially taking long time?)

@skshetry how does cancellation for prefetch work? is it more or less immediate / with some finite timeout?

Copy link
Member

@skshetry skshetry Feb 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@skshetry how does cancellation for prefetch work? is it more or less immediate / with some finite timeout?

We schedule a (worker) task to run in a separate thread (not the main thread).
When we receive KeyboardInterrupt or any exception on the main thread, we cancel the task which raises asyncio.CancelledError inside the worker task. This gets raised "at the next opportunity", i.e on some await point.

The cancellation should be nearly immediate for this discussion. However, technically, it depends on what is currently running in the asyncio event loop, such as a synchronous function (in which case it will block unfortunately until completion).


def update_progress_bar(self, progress_bar):
if progress_bar is not None:
with self._thread_lock:
Expand Down
19 changes: 16 additions & 3 deletions tests/func/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,21 +301,34 @@ def test_read_file(cloud_test_catalog, use_cache):
@pytest.mark.parametrize("use_map", [True, False])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("file_type", ["", "binary", "text"])
@pytest.mark.parametrize("num_workers", [0, 2])
@pytest.mark.parametrize("cloud_type", ["file"], indirect=True)
def test_to_storage(
tmp_dir, cloud_test_catalog, test_session, placement, use_map, use_cache, file_type
tmp_dir,
cloud_test_catalog,
test_session,
placement,
use_map,
use_cache,
file_type,
num_workers,
):
ctc = cloud_test_catalog
df = DataChain.from_storage(ctc.src_uri, type=file_type, session=test_session)
if use_map:
df.to_storage(tmp_dir / "output", placement=placement, use_cache=use_cache)
df.to_storage(
tmp_dir / "output",
placement=placement,
use_cache=use_cache,
num_workers=num_workers,
)
df.map(
res=lambda file: file.export(
tmp_dir / "output", placement=placement, use_cache=use_cache
)
).exec()
else:
df.to_storage(tmp_dir / "output", placement=placement)
df.to_storage(tmp_dir / "output", placement=placement, num_workers=num_workers)

expected = {
"description": "Cats and Dogs",
Expand Down
Loading