Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training on yet-another-retnet script #4

Open
Akbarable opened this issue Aug 19, 2023 · 3 comments
Open

Training on yet-another-retnet script #4

Akbarable opened this issue Aug 19, 2023 · 3 comments

Comments

@Akbarable
Copy link

Akbarable commented Aug 19, 2023

Hello Frank!

I love what you have created, and am having a great time going through and parsing through your implementation of the paper. It appears you have nailed the dilated attention calculation method.

Here are my observations so far:

  • I'm relatively new to this field, and am learning a lot of concepts on the go, so bear with me if I miss out finer details!
  • I'm trying to train the LM variant that you designed on a text dataset for language modelling. My goal is to ultimately test out how many tokens can my GPUs handle during training and inference (2x RTX 3090).
  • During training, it was taking around 8GB RAM to process 1024 tokens, scaling it up I think we can manage around 10k tokens within 50GBs of RAM consumption.
  • I was trying to use the training script from this https://github.com/fkodom/yet-another-retnet repo that you created, and while I got to clearing the shape mismatch and other issues, the loss, starts very low (around 0.0004) and then goes to NaN and the iterations stop.
  • I'm sharing how I did the training script with you, please let me know if you have any suggestions for me! I want to get a checkpoint that I can later use for inference. Thanks! :)
import os
from dataclasses import dataclass
from typing import (
    Any,
    Callable,
    Dict,
    Iterator,
    List,
    Literal,
    Optional,
    Sequence,
    Tuple,
    Union,
)

import tiktoken
import torch
from lightning import Fabric, seed_everything
from lightning.fabric.loggers import TensorBoardLogger
from torch import Tensor
from torch.utils.data import DataLoader
from tqdm import tqdm

#from yet_another_retnet.retnet import RetNet
from yet_another_retnet.utils.gutenberg import project_gutenberg_top_100_datapipe


torch.set_float32_matmul_precision("medium")
TOKENIZER = tiktoken.get_encoding("gpt2")
EVAL_PROMPT = "A Lannister always pays his debts."

def collate_fn(
    batch: List[str],
    max_length: int = 4096,
    device: Optional[Union[torch.device, str]] = None,
) -> Tuple[Tensor, Tensor]:
    x = torch.zeros(len(batch), max_length, device=device, dtype=torch.long)
    y = torch.zeros(len(batch), max_length, device=device, dtype=torch.long)

    for i, text in enumerate(batch):
        encoding = torch.as_tensor(
            TOKENIZER.encode(text), device=device, dtype=torch.long
        )
        seq_length = min(len(encoding) - 1, max_length)
        x[i, :seq_length] = encoding[:seq_length]
        y[i, :seq_length] = encoding[1 : seq_length + 1]

    return x, y

@dataclass
class TrainingState:
    fabric: Fabric
    model: LongNetLM
    optimizer: torch.optim.Optimizer
    callbacks: Sequence[Callable[["TrainingState", float], None]] = ()

    current_step: int = 0
    current_epoch: int = 0
    accumulate_grad_batches: int = 1
    monitor: str = "val_loss"
    monitor_mode: Literal["min", "max"] = "min"


@dataclass
class ModelCheckpoint:
    state_dict: Dict[str, Tensor]
    optimizer_state: Dict[str, Tensor]
    current_step: int
    current_epoch: int

    @classmethod
    def from_training_state(cls, state: TrainingState) -> "ModelCheckpoint":
        return cls(
            state_dict=state.model.state_dict(),
            optimizer_state=state.optimizer.state_dict(),
            current_step=state.current_step,
            current_epoch=state.current_epoch,
        )

    def to_dict(self) -> Dict[str, Any]:
        return {
            "state_dict": self.state_dict,
            "optimizer_state": self.optimizer_state,
            "current_step": self.current_step,
            "current_epoch": self.current_epoch,
        }

    def save(self, path: str) -> None:
        torch.save(self.to_dict(), path)

    @classmethod
    def load(cls, path: str) -> "ModelCheckpoint":
        checkpoint_dict = torch.load(path)
        return cls(**checkpoint_dict)

class CheckpointCallback:
    def __init__(
        self, save_dir: str, name: str = "checkpoint_epoch-{epoch:03d}.pt"
    ) -> None:
        self.save_dir = save_dir
        self.name = name
        self.best_path: Optional[str] = None
        self.best_loss: Optional[float] = None

    def __call__(self, state: TrainingState, loss: float) -> None:
        if self.best_loss is None:
            self.best_loss = loss

        fabric = state.fabric
        # 'local_rank == 0' means this only happens for the main process
        if fabric.local_rank == 0 and loss <= self.best_loss:
            checkpoint = ModelCheckpoint.from_training_state(state)
            self.best_loss = loss
            if self.best_path is not None:
                os.remove(self.best_path)
            self.best_path = os.path.join(
                self.save_dir, self.name.format(epoch=state.current_epoch)
            )
            torch.save(checkpoint, self.best_path)

        # All processes wait for main to finish saving the checkpoint.
        fabric.barrier()

def train_one_epoch(
    state: TrainingState,
    train_dataloader: DataLoader,
    val_dataloader: DataLoader,
    log_frequency: int = 25,
) -> None:
    state.current_epoch += 1
    fabric, model, optimizer = state.fabric, state.model, state.optimizer
    is_training = model.training
    model.train()

    with tqdm(desc=f"Ep: {state.current_epoch}") as progbar:
        train_loss, val_loss = 0.0, 0.0
        for x, y in train_dataloader:
            state.current_step += 1
            accumulating = state.current_step % state.accumulate_grad_batches != 0
            with fabric.no_backward_sync(model, enabled=accumulating):
                loss = model.forward(x)
                loss_value = loss.mean()
                fabric.backward(loss_value)

            if not accumulating:
                optimizer.step()
                optimizer.zero_grad()

            if state.current_step % log_frequency == 0:
                loss_scalar = loss.mean().item()  # Calculate the mean and extract the scalar value
                fabric.log("loss", loss_scalar, step=state.current_step)  # Log the scalar loss
                train_loss = loss_scalar  # Update the train_loss variable
                progbar.set_postfix_str(f"loss={train_loss:.4f}", refresh=False)
            progbar.update(1)

        model.eval()
        val_progbar = tqdm(desc="val", position=1, leave=False)
        for i, (x) in enumerate(val_dataloader):
            with torch.inference_mode():
                loss = model.forward(x)
            val_loss = (val_loss * i + loss.mean().item()) / (i + 1)

            if i % log_frequency == 0:
                val_progbar.set_postfix_str(f"val_loss={val_loss:.4f}", refresh=False)
            val_progbar.update(1)
            progbar.update(1)

        fabric.log("val_loss", val_loss, step=state.current_step)
        val_progbar.close()
        progbar.set_postfix_str(
            f"loss={train_loss:.4f}, val_loss={val_loss:.4f}", refresh=False
        )

        for callback in state.callbacks:
            callback(state, val_loss)

        model.train(mode=is_training)



def train(
    longnet: LongNetLM,
    train_dataloader: DataLoader,
    val_dataloader: DataLoader,
    accelerator: str = "auto",
    strategy: str = "auto",
    precision: Optional[str] = None,
    epochs: int = 10,
    lr: float = 3e-4,
    log_frequency: int = 25,
):
    if precision is None:
        if torch.cuda.is_available():
            # use bfloat16 if supported
            version, _ = torch.cuda.get_device_capability()
            precision = "bf16-mixed" if version >= 8 else "16-mixed"
        else:
            precision = "float32"

    logger = TensorBoardLogger(root_dir="./")
    fabric = Fabric(
        accelerator=accelerator,
        strategy=strategy,
        precision=precision,  # type: ignore
        loggers=[logger],
    )
    fabric.launch()
    print(f"Experiment version: {logger.version}")
    print("-" * 40)

    # Setup with fabric.
    optimizer = torch.optim.AdamW(longnet.parameters(), lr=lr)
    longnet, optimizer = fabric.setup(longnet, optimizer)
    train_dataloader, val_dataloader = fabric.setup_dataloaders(
        train_dataloader, val_dataloader
    )
    # Construct a training state and run the training loop.
    state = TrainingState(
        fabric=fabric,
        model=longnet,
        optimizer=optimizer,
        callbacks=[CheckpointCallback(save_dir=logger.log_dir)],
    )
    for _ in range(epochs):
        train_one_epoch(
            state=state,
            train_dataloader=train_dataloader,
            val_dataloader=val_dataloader,
            log_frequency=log_frequency,
        )

def generate(
    longnet: LongNet,
    prompt: str,
    prompt_chunk_size: Optional[int] = None,
    max_new_tokens: int = 4096,
    stop_tokens: Sequence[str] = (),
    top_k: int = 10,
    temperature: float = 1.0,
    seed: int = 42,
) -> Iterator[str]:
    seed_everything(seed)
    device = next(iter(longnet.parameters())).device
    is_training = longnet.training
    longnet.eval()

    # Tokenize the prompt and convert to a tensor.
    tokenized = TOKENIZER.encode(prompt)
    x = torch.as_tensor(tokenized, dtype=torch.long, device=device).unsqueeze_(0)

    if not prompt_chunk_size:
        prompt_chunk_size = x.size(1)

    prev_states: List[Optional[Tensor]] = [None] * longnet.num_layers
    start_idx: int = 0
    for start_idx in range(0, x.size(1), prompt_chunk_size):
        y, prev_states = longnet.forward(
            x, start_idx=start_idx, prev_states=prev_states
        )
        y = y[:, -1]

    # Generate tokens until we reach the maximum number of tokens or a stop token.
    for i in range(max_new_tokens):
        probs: Tensor = torch.softmax(y.squeeze() / max(temperature, 1e-8), dim=-1)
        # Get top-k tokens, renormalize their probabilities, and weighted sample.
        tokens: Tensor  # for mypy
        probs, tokens = probs.topk(k=top_k, dim=-1)
        probs /= probs.sum()

        # Take weighted random sample from the top-k tokens.
        sampled_idx: int = torch.multinomial(probs, num_samples=1).item()  # type: ignore
        token: int = tokens[sampled_idx].item()  # type: ignore
        tokenized.append(token)
        yield TOKENIZER.decode(tokenized)

        token_str: str = TOKENIZER.decode([token])
        if token_str in stop_tokens:
            break
        elif i < (max_new_tokens - 1):
            start_idx += 1
            x = torch.as_tensor([token], dtype=torch.long, device=device)
            y, prev_states = longnet.forward(
                x, start_idx, prev_states=prev_states
            )

    # Restore the model's original training state.
    longnet.train(mode=is_training)

def main(
    model_checkpoint: Optional[str] = None,
    accelerator: str = "auto",
    strategy: str = "auto",
    precision: Optional[str] = None,
    epochs: int = 10,
    batch_size: int = 16,
    lr: float = 3e-4,
    log_frequency: int = 25,
    seed: int = 42,
    eval_only: bool = False,
    eval_prompt: str = EVAL_PROMPT,
    eval_max_tokens: int = 1024,
):
    seed_everything(seed)
    # Create a (relatively small) model and dataloaders
    longnet = LongNetLM(
        num_tokens=TOKENIZER.n_vocab,
        d_model=768,
        nhead=12,
        num_encoder_layers=12,
        num_decoder_layers=12,
        dim_feedforward=3072,
        segment_lengths = [512,1024, 2048,4096],
        dilation_rates = [1, 2, 4, 6],
        dropout = 0.1,
        activation = F.relu,
        layer_norm_eps = 1e-5,
    )
    if model_checkpoint is not None:
        longnet.load_state_dict(ModelCheckpoint.load(model_checkpoint).state_dict)

    if not eval_only:
        train_dataloader = DataLoader(
            project_gutenberg_top_100_datapipe(
                split="train",
                chunk_size=4096,
                step_size=1024,
                shuffle=True,
                drop_last=True,
            ),
            batch_size=batch_size,
            collate_fn=collate_fn,
            drop_last=True,
        )
        val_dataloader = DataLoader(
            project_gutenberg_top_100_datapipe(
                split="val", chunk_size=4096, step_size=1024
            ),
            batch_size=batch_size,
            collate_fn=collate_fn,
        )

        train(
            longnet=longnet,
            train_dataloader=train_dataloader,
            val_dataloader=val_dataloader,
            accelerator=accelerator,
            strategy=strategy,
            precision=precision,
            epochs=epochs,
            lr=lr,
            log_frequency=log_frequency,
        )

    # Generate some text
    prev_output: str = ""
    for output in generate(longnet, eval_prompt, max_new_tokens=eval_max_tokens):
        # Return to the start of the line and print the output (no newline)
        print(output[len(prev_output) :], end="", flush=True)
        prev_output = output
    print()


# Define the default values or provide your desired values
default_model_checkpoint = None
default_accelerator = "auto"
default_strategy = "dp"
default_precision = None
default_epochs = 1
default_batch_size = 1
default_lr = 3e-4
default_log_frequency = 25
default_seed = 42
default_eval_only = False
default_eval_prompt = EVAL_PROMPT
default_eval_max_tokens = 1024

# Replace the argparse-related code
model_checkpoint = default_model_checkpoint
accelerator = default_accelerator
strategy = default_strategy
precision = default_precision
epochs = default_epochs
batch_size = default_batch_size
lr = default_lr
log_frequency = default_log_frequency
seed = default_seed
eval_only = default_eval_only
eval_prompt = default_eval_prompt
eval_max_tokens = default_eval_max_tokens

# Call the main function
main(
    model_checkpoint=model_checkpoint,
    accelerator=accelerator,
    strategy=strategy,
    precision=precision,
    epochs=epochs,
    batch_size=batch_size,
    lr=lr,
    log_frequency=log_frequency,
    seed=seed,
    eval_only=eval_only,
    eval_prompt=eval_prompt,
    eval_max_tokens=eval_max_tokens
)

@fkodom
Copy link
Owner

fkodom commented Aug 22, 2023

@Akbarable Letting you know I see this issue. 👀 Having a very busy week, so it may be a few days before I can dedicate time to this. Will update you as soon as I am able to.

@Akbarable
Copy link
Author

@Akbarable Letting you know I see this issue. 👀 Having a very busy week, so it may be a few days before I can dedicate time to this. Will update you as soon as I am able to.

Thanks for the response. No worries! 😁

@andrewcchu
Copy link

Bumping this as I'd also be interested in a training script -- tried the code @Akbarable posted here and saw the same behavior.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants