Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] iter_torch_batches very slow on video data #50128

Open
FredrikNoren opened this issue Jan 29, 2025 · 8 comments
Open

[Data] iter_torch_batches very slow on video data #50128

FredrikNoren opened this issue Jan 29, 2025 · 8 comments
Assignees
Labels
bug Something that is supposed to be working; but isn't data Ray Data-related issues P1 Issue that should be fixed within a few weeks

Comments

@FredrikNoren
Copy link

FredrikNoren commented Jan 29, 2025

What happened + What you expected to happen

I'm training a model on video data, and I noticed that 95% of the time it was training it was just waiting for the next batch (the time spent training an epoch was around 14sec, but 13 of those were spent just waiting for the next batch). So I did some digging, and here's the flame graph for CPU usage:

Image

This made me suspicious of the batching code, so I changed my batch_size from 16 to 1, and it took my training time from 14sec to 10sec, and the time it took to fetch a batch from 13 sec to 5 sec. I profiled it again and here's the new flame graph:

Image

Still not great, but a bit better. But I'm using torchcodec to load the video data which gives me a torch tensor back, so it feels a bit unnecessary for it to be transformed to numpy and then back to torch again. So my questions are:

  1. Can I somehow output torch tensors from my dataset transform, and make ray keep that format?
  2. Is there any way to customize how batches are created? I'd like to for example use pinned memory so it doesn't need to recreate memory all the time.

Other notes:

  • To make sure the problem isn't with video loading I'm just returning numpy.zero tensors instead during these benchmarks
  • Each row in the dataset is 240 x 3 x 128 x 128, so they're quite big

Versions / Dependencies

ray 2.41.0
torch 2.5.1

Reproduction script

None

Issue Severity

None

@FredrikNoren FredrikNoren added bug Something that is supposed to be working; but isn't triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Jan 29, 2025
@FredrikNoren FredrikNoren changed the title [Data] iter_torch_batches very slow video data [Data] iter_torch_batches very slow on video data Jan 29, 2025
@jcotant1 jcotant1 added the data Ray Data-related issues label Jan 30, 2025
@FredrikNoren
Copy link
Author

Alright, I've been able to optimize this a bit here today. Instead of using iter_torch_batches I've implemented my own version, which takes it from 14 sec for a batch of 16 items, to 6-7 seconds instead. The flame graph is now dominated by two items; item.pin_memory() and torch.cuda.synchronize(), which seems a lot more reasonable.

Image

Here's roughly what my code looks like:

def pyaarrow_chunks_to_torch2(chunks):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    pinned = []
    batch_size = 0
    for item in chunks.chunks:
        item = item.to_numpy()
        item = torch.from_numpy(item)
        item = item.pin_memory()
        pinned.append(item)
        batch_size += item.shape[0]
    shape = list(pinned[0].shape)
    shape[0] = batch_size
    tensor = torch.zeros(torch.Size(shape), device=device, dtype=torch.float32)
    offset = 0
    for item in pinned:
        tensor[offset:offset+item.shape[0]].copy_(item, non_blocking=True)
        offset += item.shape[0]
    return tensor

def collate_fn(batch):
        res = {
            "x": pyaarrow_chunks_to_torch2(batch["x"]),
            "y": pyaarrow_chunks_to_torch2(batch["y"])
        }
        torch.cuda.synchronize()
        return res

train_dataloader = train_data_shard.iter_batches(
    batch_size=config.batch_size, 
    prefetch_batches=4,
    _collate_fn=collate_fn,
    batch_format=None
)

My GPU is a lot less idle now as well. Not quite fully maxed out constantly but close.

@FredrikNoren
Copy link
Author

I found some relevant issues:

@FredrikNoren
Copy link
Author

I realized my entire dataset would fit in RAM, so I just circumvented the problem by loading everything into pinned memory and then using cuda streams I load the next four items onto the GPU. Here's my code if it helps anyone:

def iter_torch_batches_fast(dataset: ray.data.Dataset, batch_size: int):

    def pin_batch(batch):
        for key in batch:
            if isinstance(batch[key], torch.Tensor):
                batch[key] = batch[key].pin_memory()
        return batch
    
    pinned = [pin_batch(batch) for batch in dataset.iter_torch_batches(
        batch_size=batch_size, 
        prefetch_batches=4,
        device="cpu",
        dtypes=torch.float32,
    )]

    class PinnedDatasetStreamer():
        def __init__(self, pinned):
            self.pinned = iter(pinned)
            self.queue = []
            for _ in range(4):
                self.add_to_queue()
        def add_to_queue(self):
            device = ray.train.torch.get_device()
            try:
                batch = next(self.pinned)
                s = torch.cuda.Stream()
                with torch.cuda.stream(s):
                    b = {}
                    for key in batch:
                        if isinstance(batch[key], torch.Tensor):
                            b[key] = batch[key].to(device, non_blocking=True)
                        else:
                            b[key] = batch[key]
                    self.queue.append((b, s))
            except StopIteration:
                pass
        def __next__(self):
            if len(self.queue) == 0:
                raise StopIteration
            (batch, s) = self.queue.pop(0)
            self.add_to_queue()
            torch.cuda.current_stream().wait_stream(s)
            return batch

    class PinnedDataset():
        def __init__(self, pinned):
            self.pinned = pinned
        def __iter__(self):
            return PinnedDatasetStreamer(self.pinned)
    return PinnedDataset(pinned)

With this basically all time is spent in training (data loading is now 0.01 sec per epoch, training is 5.4 sec)

@gvspraveen gvspraveen removed the triage Needs triage (eg: priority, bug/not-bug, and owning component) label Jan 31, 2025
@gvspraveen gvspraveen added the P1 Issue that should be fixed within a few weeks label Jan 31, 2025
@raulchen
Copy link
Contributor

Hey @FredrikNoren , we've noticed this issue recently as well.
The plan is to offload the chunk combination from the train workers to previous map operators.
Also I'm wondering if increasing prefetch_batches helps in your case.

@FredrikNoren
Copy link
Author

@raulchen I tried tweaking the prefetch_batches first, but didn't see any performance improvements in my case. I also tried enabling actor_prefetcher_enabled but didn't see any gain either.

I think in my case there are multiple bottlenecks:

  1. Loading the video data from disk was slow (which was why I disabled that and just returned np.zero, do try to isolate the other parts)
  2. Then there's some kind of concat step which I think can be optimized away
  3. Then there's the convert to tensor step; this can be broken up into two parts: pinning memory and moving to the GPU. (This guide is great: https://pytorch.org/tutorials/intermediate/pinmem_nonblock.html). .to(device) does both internally. I think this can also be optimized on the ray side to run in the background/on a different cuda stream (at least I didn't see any code doing that right now but could be wrong).

@alexeykudinkin
Copy link
Contributor

alexeykudinkin commented Feb 3, 2025

Very insightful observations @FredrikNoren!

Can you share a bit more info about the shape of your data? For ex,

  • How many rows per batch you're using?
  • How you're fetching the data?
  • Are there are any transformations before training ingest?

Then there's some kind of concat step which I think can be optimized away

That step is happening when we convert from internal representation (Arrow) to Numpy -- by default NP requires contiguous slab of memory for its batch and hence the concatenation of the chunks produced by PyArrow. This step could obviously circumvented but is use-case dependent and could have performance repercussions.

@FredrikNoren
Copy link
Author

@alexeykudinkin Sure!

  • rows per batch: I'm using a batch size of 16, so my batches are [16, 240, 3, 128, 128] (i.e. [batch_size, frames, channels, width, height])
  • torchcodec is loading the data from "disk" right now (or rather AWS EFS). This is a map step in my Ray data pipeline.
  • Yes, the raw videos are much larger so I scale, crop and chunk them up to make them into 4s chunks

@raulchen
Copy link
Contributor

raulchen commented Feb 4, 2025

@FredrikNoren thanks for the insights!

  • Regarding the concat overheads, we are planning to fix it shortly.
  • Regarding moving data to GPU, I think your suggestion makes sense. We'll look into that. cc @justinvyu @srinathk10

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something that is supposed to be working; but isn't data Ray Data-related issues P1 Issue that should be fixed within a few weeks
Projects
None yet
Development

No branches or pull requests

6 participants