Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalized Loop Checkpointing #135

Open
rlnsanz opened this issue Aug 27, 2024 · 0 comments
Open

Generalized Loop Checkpointing #135

rlnsanz opened this issue Aug 27, 2024 · 0 comments

Comments

@rlnsanz
Copy link
Collaborator

rlnsanz commented Aug 27, 2024

In PyTorch, the model training loop is doubly-nested: a loop traversing data batch by batch is nested inside an epoch loop. e.g.:

with flor.checkpointing(model=net, optimizer=optimizer):
    for epoch in flor.loop("epoch", range(num_epochs)):
        for data in flor.loop("step", trainloader):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            flor.log("loss", loss.item())
            optimizer.step()
        eval(net, testloader)

For such cases, it is recommended that both iterators (nested and outer) are wrapped by flor.loop, since the nested flor.loop can act as a Skip Block on replay, skipping the GPU-intensive computations and loading the state from checkpoint.

However, as FlorDB grows to span more applications than model training (e.g. data ingestion, featurization, feedback integrations, etc), we will encounter cases with a single main loop that may nevertheless require checkpointing across iterations. For example:

class Aggregator:
    def __init__(self):
        self.state = []

    def update(self, data_chunk):
        # Perform some processing and update the state
        processed_data = complex_transformation(data_chunk)
        self.state.append(processed_data)

    def get_state(self):
        return self.state

with flor.checkpointing(aggregator=Aggregator()) as chk_set:
    for chunk_id, data_chunk in flor.loop("chunk", enumerate(data_chunks)):
        flor.log("chunk_id", chunk_id)
        
        # Update the object with the new chunk
        chk_set.aggregator.update(data_chunk)

        # Log the state of the object for checkpointing
        flor.log("status", "complete")

    # Finalize the processing with the modified object
    final_result = validate_ingestion(chk_set.aggregator)

Thanks to @xllgit for identifying this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant