Skip to content

Commit

Permalink
ScalableReader changes
Browse files Browse the repository at this point in the history
  • Loading branch information
daviswer committed Feb 6, 2025
1 parent 934d37b commit 8d0cfd8
Show file tree
Hide file tree
Showing 2 changed files with 401 additions and 386 deletions.
121 changes: 59 additions & 62 deletions examples/ibm_rescaling/rescaling_demo.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
import argparse
import math
import os

import pyarrow as pa
import torch
from torch import distributed as dist

from torchdata.stateful_dataloader import StatefulDataLoader
from torchdata.stateful_dataloader.ibm_rescalable import (
DummyDataset,
ArrowHandler,
PreprocessDataset,
SamplingDataset,
ScalableShardDataset,
ScalableReader,
load_distributed_state_dict,
save_distributed_state_dict,
)

# This example script validates the rescaling behavior of the ibm rescalable distributed datasets.
# On first run, saves a distributed checkpoint to the desired location.
# On first run, creates a dummy dataset and saves a distributed checkpoint at the desired location.
# On subsequent runs, loads the checkpoint (possibly on a different world size / num workers)
# and verifies that previous data is not revisited, while upcoming data is.
# and verifies that all remaining data is covered by the time the epoch finishes.

# Example usage:
# torchrun [torchrun args] examples/ibm_rescaling/rescaling_demo.py --ckpt_path=~/ckpts/rescale_test --logical_shards=48 --num_workers=6
Expand All @@ -28,36 +28,50 @@
parser.add_argument(
"--logical_shards",
type=int,
default=96,
help="Total number of data partitions. (worldsize * n_workers) must divide this evenly.",
default=350,
help="Total number of data partitions. Must exceed (worldsize * n_workers) but not n_docs (1000).",
)
parser.add_argument("--num_workers", type=int, default=1, help="Number of dataloader workers per device")
parser.add_argument("--b_size", type=int, default=1, help="Number of data points per step per device")
parser.add_argument("--b_size", type=int, default=2, help="Number of data points per step per device")
parser.add_argument("--n_steps", type=int, default=50, help="Number of steps to take before saving. (n_steps * b_size * worldsize) cannot exceed number of items in epoch (3000)")
parser.add_argument("--seed", type=int, default=42)

args = parser.parse_args()


# Setup
rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
dist.init_process_group()
mesh = dist.device_mesh.init_device_mesh("cpu", [world_size])
placement = [dist.tensor.placement_types.Shard(0)]
subdatas = ["sub_dataset", "second_subdataset", "small_subdataset"]
[os.makedirs(os.path.join(args.ckpt_path, "data", subdata), exist_ok=True) for subdata in subdatas]

# Check input args
assert args.logical_shards >= world_size*args.num_workers, f"Logical shards {args.logical_shards} cannot be less than total workers {world_size*args.num_workers}"
assert args.logical_shards <= 1000, f"Logical shards {args.logical_shards} cannot exceed number of documents 1000"
assert args.n_steps*args.b_size*world_size < 3000, f"Number of items drawn before saving {args.n_steps*args.b_size*world_size} cannot exceed number of document chunks 3000."

# Build dataset
datapath = os.path.join(args.ckpt_path, "dataset")
if not os.path.exists(datapath):
os.mkdir(datapath)
schema = pa.schema([pa.field("tokens", pa.uint32())])
with pa.ipc.new_file(
os.path.join(datapath, "fileshard_1.arrow"), schema
) as writer:
for i in range(500):
out = list(range(i * 100, i * 100 + 100))
writer.write(pa.record_batch([out], schema=schema))

with pa.ipc.new_file(
os.path.join(datapath, "subfolder/fileshard_2.arrow"), schema
) as writer:
for i in range(500):
out = list(range(50000 + i * 100, 50000 + i * 100 + 100))
writer.write(pa.record_batch([out], schema=schema))

# Build dataloader
data = DummyDataset(os.path.join(args.ckpt_path, "data"), rank, world_size, delimiter_token=-1, seed=args.seed)
# Pretend that we're sampling over multiple sub-datasets
data = SamplingDataset(
os.path.join(args.ckpt_path, "data"),
data,
delimiter_token=-1,
datasets=subdatas,
weights=[12, 17, 5],
)
# Apply rescalability layer
data = ScalableShardDataset(data, n_logical_shards=args.logical_shards)
data = ScalableReader(datapath, rank, world_size, ArrowHandler, -1, seed=args.seed, max_chunksize=30, n_logical_shards=args.logical_shards)
# Statelessly convert all outputs to tensors
data = PreprocessDataset(data, torch.tensor)
# Wrap in StatefulDataLoader
Expand All @@ -69,16 +83,16 @@
os.makedirs(ckpt_path, exist_ok=True)
# Iterate, assemble values to exclude
if rank == 0:
print("No existing checkpoint. Processing 100 steps.")
print(f"No existing checkpoint. Processing {args.n_steps} steps.")

avoid = []
for i, inp in enumerate(data):
if i == 100:
if i == args.n_steps:
if rank == 0:
print("Iteration complete!")
save_distributed_state_dict(data, ckpt_path, mesh)
break
avoid.append(inp)
avoid.append(inp[:,0])
avoid = torch.cat(avoid)
# Get all vals onto each rank
avoid = dist.tensor.DTensor.from_local(
Expand All @@ -87,63 +101,46 @@
placement,
).full_tensor()

# Continue, assemble values to include
load_distributed_state_dict(data, ckpt_path, mesh)
if rank == 0:
print("DCP state loaded!")

include = []
for i, inp in enumerate(data):
if i == 10:
break
include.append(inp)
include = torch.cat(include)
if rank == 0:
print("Iteration round 2 complete!")
# Get all vals onto each rank
include = dist.tensor.DTensor.from_local(include, mesh, placement).full_tensor()

if rank == 0:
torch.save(avoid, os.path.join(args.ckpt_path, "avoid.pth"))
torch.save(include, os.path.join(args.ckpt_path, "include.pth"))
print(
"Generation complete! Please rerun (with different world size / workers if desired) to complete the check."
)

# If checkpoint does exist, load and take 100 steps.
# Ensure avoid values are avoided, and all include values are included.
# If checkpoint does exist, load and finish epoch.
# Ensure all expected values are covered once epoch concludes.
else:
if rank == 0:
print("Checkpoint detected!")
load_distributed_state_dict(data, ckpt_path, mesh)
avoid = torch.load(os.path.join(args.ckpt_path, "avoid.pth")).tolist()

# Finish out epoch (extra 2*ceil(ndocs/nshards) steps to account for worst-case uneven finishing times)
vals = []
n_steps = (
math.ceil((3000 - len(avoid)) / (world_size * args.num_workers))
+ 2 * math.ceil(1000/args.logical_shards)
)
for i, inp in enumerate(data):
if i == 100:
if i == n_steps:
break
vals.append(inp)
vals = torch.cat(vals)
# Get all vals onto each rank
vals = dist.tensor.DTensor.from_local(vals, mesh, placement).full_tensor()

# Perform avoid/include checks on rank 0 only
# Perform data coverage check on rank 0 only
if rank == 0:
avoid = torch.load(os.path.join(args.ckpt_path, "avoid.pth"))
include = torch.load(os.path.join(args.ckpt_path, "include.pth"))

def _in(v, m):
# Returns whether vector v is a row of matrix m (both tensors)
return m.sub(v[None]).abs().sum(1).sign().prod().bool().logical_not().item()

# Avoid check
for i, x in enumerate(avoid.split(1)):
assert not _in(x[0], vals), i
print("Check passed: seen data was not revisited!")

# Include check
for i, x in enumerate(include.split(1)):
assert _in(x[0], vals), i
print("Check passed: upcoming data appears as expected!")
# Invert avoid to get expected vals
expect = []
for i in range(1000):
for offset in [0,40,80]:
if i*100+offset not in avoid:
expect.append(i*100+offset)

for x in expect:
assert x in vals, x
print("Check passed: upcoming data is covered as expected!")

dist.barrier()
dist.destroy_process_group()
Loading

0 comments on commit 8d0cfd8

Please sign in to comment.