Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Getting statistics about filtered examples #7389

Closed
jonathanasdf opened this issue Feb 10, 2025 · 2 comments
Closed

Getting statistics about filtered examples #7389

jonathanasdf opened this issue Feb 10, 2025 · 2 comments

Comments

@jonathanasdf
Copy link

@lhoestq wondering if the team has thought about this and if there are any recommendations?

Currently when processing datasets some examples are bound to get filtered out, whether it's due to bad format, or length is too long, or any other custom filters that might be getting applied. Let's just focus on the filter by length for now, since that would be something that gets applied dynamically for each training run. Say we want to show a graph in W&B with the running total of the number of filtered examples so far.

What would be a good way to go about hooking this up? Because the map/filter operations happen before the DataLoader batches are created, at training time if we're just grabbing batches from the DataLoader then we won't know how many things have been filtered already. But there's not really a good way to include a 'num_filtered' key into the dataset itself either because dataset map/filter process examples independently and don't have a way to track a running sum.

The only approach I can kind of think of is having a 'is_filtered' key in the dataset, and then creating a custom batcher/collator that reads that and tracks the metric?

@lhoestq
Copy link
Member

lhoestq commented Feb 11, 2025

You can actually track a running sum in map() or filter() :)

num_filtered = 0

def f(x):
    global num_filtered
    condition = len(x["text"]) < 1000
    if not condition:
        num_filtered += 1
    return condition

ds = ds.filter(f)
print(num_filtered)

and if you want to use multiprocessing, make sure to use a variable that is shared across processes

from multiprocess import Manager

manager = Manager()
num_filtered = manager.Value('i', 0)

def f(x):
    global num_filtered
    condition = len(x["text"]) < 1000
    if not condition:
        num_filtered.value += 1
    return condition

ds = ds.filter(f, num_proc=4)
print(num_filtered.value)

PS: datasets uses multiprocess instead of the multiprocessing package to support lambda functions in map() and filter()

@jonathanasdf
Copy link
Author

Oh that's great to know!

I guess this value would not be exactly synced with the batch in cases of pre-fetch and shuffle buffers and so on, but that's probably fine. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants