Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🍒 Fix cherrypicks #61

Merged
merged 12 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions turbo_alignment/cherry_picks/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import abstractmethod
from typing import Generic, Iterable, TypeVar

from accelerate import Accelerator
from transformers import (
PreTrainedModel,
PreTrainedTokenizerBase,
Expand Down Expand Up @@ -72,3 +73,8 @@ def on_evaluate(
model.train()

return dataset_metrics

@staticmethod
@abstractmethod
def _get_sharded_dataset(dataset: InferenceDatasetT, accelerator: Accelerator) -> InferenceDatasetT:
...
15 changes: 14 additions & 1 deletion turbo_alignment/cherry_picks/chat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from typing import Iterable

from accelerate import Accelerator
Expand Down Expand Up @@ -41,12 +42,17 @@ def _get_dataset_metrics(
tokenizer=tokenizer,
transformers_settings=self._generator_transformers_settings,
custom_generation_settings=self._custom_generation_settings,
accelerator=accelerator,
return_logits=True,
)

batch_size = self._generator_transformers_settings.num_return_sequences

if accelerator is not None:
dataset = self._get_sharded_dataset(
dataset=dataset,
accelerator=accelerator,
)

generations = generator.generate_from_dataset(dataset)

prompts = [record['prompt'] for record in dataset]
Expand Down Expand Up @@ -104,3 +110,10 @@ def _get_dataset_metrics(

metric_outputs.extend(metric_results)
return metric_outputs

@staticmethod
def _get_sharded_dataset(dataset: InferenceChatDataset, accelerator: Accelerator) -> InferenceChatDataset:
rank_device = accelerator.process_index
slice_size = math.ceil(len(dataset) / accelerator.num_processes)

return dataset.get_slice(rank_device * slice_size, rank_device * slice_size + slice_size)
27 changes: 22 additions & 5 deletions turbo_alignment/cherry_picks/classification.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,32 @@
import math
from typing import Iterable

from accelerate import Accelerator
from transformers import PreTrainedModel, PreTrainedTokenizerBase

from turbo_alignment.cherry_picks.base import CherryPickCallbackBase
from turbo_alignment.dataset.chat.conversation import Conversation
from turbo_alignment.dataset.classification.classification import ClassificationDataset
from turbo_alignment.dataset.classification.classification import (
InferenceClassificationDataset,
)
from turbo_alignment.generators.classification import ClassificationGenerator
from turbo_alignment.metrics.metric import Metric
from turbo_alignment.settings.cherry_pick import ClassificationCherryPickSettings
from turbo_alignment.settings.metric import ElementWiseScores, MetricResults


class ClassificationCherryPickCallback(CherryPickCallbackBase[ClassificationDataset]):
class ClassificationCherryPickCallback(CherryPickCallbackBase[InferenceClassificationDataset]):
def __init__(
self,
cherry_pick_settings: ClassificationCherryPickSettings,
datasets: Iterable[ClassificationDataset],
datasets: Iterable[InferenceClassificationDataset],
metrics: list[Metric],
) -> None:
super().__init__(cherry_pick_settings=cherry_pick_settings, datasets=datasets, metrics=metrics)

def _get_dataset_metrics(
self,
dataset: ClassificationDataset,
dataset: InferenceClassificationDataset,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
**kwargs,
Expand All @@ -33,9 +36,14 @@ def _get_dataset_metrics(
generator = ClassificationGenerator(
model=model,
tokenizer=tokenizer,
accelerator=accelerator,
)

if accelerator is not None:
dataset = self._get_sharded_dataset(
dataset=dataset,
accelerator=accelerator,
)

generations = generator.generate_from_dataset(dataset)
predictions = [record.predicted_label for record in generations]
labels = [record['labels'] for record in dataset]
Expand Down Expand Up @@ -64,3 +72,12 @@ def _get_dataset_metrics(
]

return metric_outputs

@staticmethod
def _get_sharded_dataset(
dataset: InferenceClassificationDataset, accelerator: Accelerator
) -> InferenceClassificationDataset:
rank_device = accelerator.process_index
slice_size = math.ceil(len(dataset) / accelerator.num_processes)

return dataset.get_slice(rank_device * slice_size, rank_device * slice_size + slice_size)
2 changes: 2 additions & 0 deletions turbo_alignment/cherry_picks/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@


class MultimodalCherryPickCallback(CherryPickCallbackBase[InferenceMultimodalDataset]):
# pylint: disable=abstract-method
# TODO: add _get_sharded_dataset method
def __init__(
self,
cherry_pick_settings: MultimodalCherryPickSettings,
Expand Down
16 changes: 15 additions & 1 deletion turbo_alignment/cherry_picks/rag.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math

from accelerate import Accelerator
from transformers import PreTrainedModel, PreTrainedTokenizerBase

Expand All @@ -22,9 +24,14 @@ def _get_dataset_metrics(
tokenizer=tokenizer,
transformers_settings=self._generator_transformers_settings,
custom_generation_settings=self._custom_generation_settings,
accelerator=accelerator,
)

if accelerator is not None:
dataset = self._get_sharded_dataset(
dataset=dataset,
accelerator=accelerator,
)

generations = generator.generate_from_dataset(dataset)

prompts = [dataset[i]['prompt'] for i in range(len(dataset))]
Expand Down Expand Up @@ -57,3 +64,10 @@ def _get_dataset_metrics(
metric_outputs.extend(metric_results)

return metric_outputs

@staticmethod
def _get_sharded_dataset(dataset, accelerator: Accelerator) -> InferenceChatDataset:
rank_device = accelerator.process_index
slice_size = math.ceil(len(dataset) / accelerator.num_processes)

return dataset.get_slice(rank_device * slice_size, rank_device * slice_size + slice_size)
14 changes: 14 additions & 0 deletions turbo_alignment/cherry_picks/rm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from typing import Iterable

from accelerate import Accelerator
Expand Down Expand Up @@ -35,6 +36,12 @@ def _get_dataset_metrics(
accelerator=accelerator,
)

if accelerator is not None:
dataset = self._get_sharded_dataset(
dataset=dataset,
accelerator=accelerator,
)

generations = generator.generate_from_dataset(dataset)
generations_w = [gen.reward_w for gen in generations]
generations_l = [gen.reward_l for gen in generations]
Expand Down Expand Up @@ -69,3 +76,10 @@ def _get_dataset_metrics(
]

return metric_outputs

@staticmethod
def _get_sharded_dataset(dataset: PairPreferenceDataset, accelerator: Accelerator) -> PairPreferenceDataset:
rank_device = accelerator.process_index
slice_size = math.ceil(len(dataset) / accelerator.num_processes)

return dataset.get_slice(rank_device * slice_size, rank_device * slice_size + slice_size)
19 changes: 19 additions & 0 deletions turbo_alignment/dataset/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy.typing as npt
import torch
from transformers import PreTrainedTokenizerBase
from typing_extensions import Self

from turbo_alignment.common.data.io import read_jsonl
from turbo_alignment.common.logging import get_project_logger
Expand Down Expand Up @@ -361,3 +362,21 @@ def __init__(

def convert_records(self, records: list[ChatDatasetRecord]) -> list[dict[str, Any] | None]:
return self._encode(records, inference=True, random_cut=self._random_cut)

def get_slice(self, start: int, end: int) -> Self:
new_instance = self.__class__(
source=self.source,
settings=self.settings,
tokenizer=self.tokenizer,
read=False,
random_cut=self._random_cut,
)

dataset_records = [self[idx] for idx in range(len(self))]

new_instance.records = self.records[start:end]
new_instance.original_records_map = {
record['id']: self.get_original_record_by_id(record['id']) for record in dataset_records
}

return new_instance
17 changes: 17 additions & 0 deletions turbo_alignment/dataset/classification/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, overload

from transformers import PreTrainedTokenizerBase
from typing_extensions import Self

from turbo_alignment.common.data.io import read_jsonl
from turbo_alignment.common.logging import get_project_logger
Expand Down Expand Up @@ -92,3 +93,19 @@ def convert_records(self, records: list[ClassificationDatasetRecord]) -> list[di
class InferenceClassificationDataset(ClassificationDataset):
def convert_records(self, records: list[ClassificationDatasetRecord]) -> list[dict[str, Any] | None]:
return self._encode(records, inference=True)

def get_slice(self, start: int, end: int) -> Self:
new_instance = self.__class__(
source=self.source,
settings=self.settings,
tokenizer=self.tokenizer,
)

dataset_records = [self[idx] for idx in range(len(self))]

new_instance.records = self.records[start:end]
new_instance.original_records_map = {
record['id']: self.get_original_record_by_id(record['id']) for record in dataset_records
}

return new_instance
19 changes: 19 additions & 0 deletions turbo_alignment/dataset/pair_preferences/pair_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, overload

from transformers import PreTrainedTokenizerBase
from typing_extensions import Self

from turbo_alignment.common.data.io import read_jsonl
from turbo_alignment.common.logging import get_project_logger
Expand Down Expand Up @@ -42,6 +43,7 @@ def __init__(
read=False,
)
super().__init__(source=source, settings=settings, tokenizer=tokenizer)
self.settings: PairPreferenceDatasetSettings = settings

if read:
self._read()
Expand Down Expand Up @@ -107,3 +109,20 @@ def _read_records(records) -> list[PairPreferenceRecord]:
if isinstance(records, list):
return [PairPreferenceRecord(**record) for record in records]
raise NotImplementedError

def get_slice(self, start: int, end: int) -> Self:
new_instance = self.__class__(
source=self.source,
settings=self.settings,
tokenizer=self.tokenizer,
read=False,
)

dataset_records = [self[idx] for idx in range(len(self))]

new_instance.records = self.records[start:end]
new_instance.original_records_map = {
record['id']: self.get_original_record_by_id(record['id']) for record in dataset_records
}

return new_instance
10 changes: 6 additions & 4 deletions turbo_alignment/pipelines/train/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from turbo_alignment.cherry_picks.classification import ClassificationCherryPickCallback
from turbo_alignment.common.logging import get_project_logger
from turbo_alignment.constants import TRAINER_LOGS_FOLDER
from turbo_alignment.dataset.classification.classification import ClassificationDataset
from turbo_alignment.dataset.classification.classification import (
InferenceClassificationDataset,
)
from turbo_alignment.dataset.loader import DatasetLoader
from turbo_alignment.metrics.metric import Metric
from turbo_alignment.metrics.registry import MetricSettingsRegistry
Expand Down Expand Up @@ -47,9 +49,9 @@ def _get_cherry_pick_callback(
) -> ClassificationCherryPickCallback:
cherry_pick_settings = experiment_settings.cherry_pick_settings

cherry_pick_datasets = DatasetLoader[ClassificationDataset](ClassificationDataset).load_datasets(
cherry_pick_settings.dataset_settings, tokenizer=tokenizer, strategy=DatasetStrategy.INFERENCE
)
cherry_pick_datasets = DatasetLoader[InferenceClassificationDataset](
InferenceClassificationDataset
).load_datasets(cherry_pick_settings.dataset_settings, tokenizer=tokenizer, strategy=DatasetStrategy.INFERENCE)

metrics = [
Metric.by_name(metric.type)(MetricSettingsRegistry.by_name(metric.type)(**metric.parameters))
Expand Down
Loading