From f25e31f53314c3684cd772c86fa1d9b02fffed73 Mon Sep 17 00:00:00 2001 From: "a.khokhulin" Date: Mon, 11 Nov 2024 15:59:00 +0300 Subject: [PATCH 01/12] in progress --- turbo_alignment/cherry_picks/chat.py | 41 ++++++++++++++++++++++++- turbo_alignment/generators/base.py | 45 ++++++++++++++++++++-------- 2 files changed, 72 insertions(+), 14 deletions(-) diff --git a/turbo_alignment/cherry_picks/chat.py b/turbo_alignment/cherry_picks/chat.py index eec477e1..ff5c4700 100755 --- a/turbo_alignment/cherry_picks/chat.py +++ b/turbo_alignment/cherry_picks/chat.py @@ -1,3 +1,4 @@ +import math from typing import Iterable from accelerate import Accelerator @@ -41,16 +42,47 @@ def _get_dataset_metrics( tokenizer=tokenizer, transformers_settings=self._generator_transformers_settings, custom_generation_settings=self._custom_generation_settings, - accelerator=accelerator, + # accelerator=accelerator, return_logits=True, ) batch_size = self._generator_transformers_settings.num_return_sequences + if accelerator is not None: + len_records_batches = len(dataset) + world_size = accelerator.num_processes + rank_device = accelerator.process_index + window_size = math.ceil(len_records_batches / world_size) + + print(f"len records_batches", len_records_batches) + print("rank_device", rank_device) + print("world_size", world_size) + print(f"slice [{rank_device * window_size} : {rank_device * window_size + window_size}]") + + + dataset = dataset[rank_device * window_size : rank_device * window_size + window_size] + generations = generator.generate_from_dataset(dataset) + print(f"len dataset {len(dataset)}") + print(f"first 2 prompts: {[record['prompt'] for record in dataset][:2]}") + print(f"len generations {len(generations)}") + print(f"first 2 generations: {generations[:2]}") + prompts = [record['prompt'] for record in dataset] + + # len_records_batches = len(prompts) + # world_size = accelerator.num_processes + # rank_device = accelerator.process_index + # window_size = math.ceil(len_records_batches / world_size) + + # prompts = + + + string_answers = [[answer.content for answer in g.answers] for g in generations] + print(f"prompts: {len(prompts)}, {prompts[:2]}") + print(f"string_answers: {len(string_answers)}, {string_answers[:2]}") string_labels = [[g.messages[-1].content] * len(g.answers) for g in generations] flattened_answers = [answer for g in generations for answer in g.answers] @@ -66,6 +98,9 @@ def _get_dataset_metrics( if sft_model is not None: metrics_kwargs[KLType.SFT_MODEL] = get_logits(input_tokens_ids, answer_tokens_ids, sft_model) + print(f"len prompts {len(prompts)}") + print(f"batch_size {batch_size}") + print(f"prompt element_wise_scores {len([prompt for prompt in prompts for _ in range(batch_size)])}") metric_outputs = [ MetricResults( element_wise_scores=[ @@ -86,6 +121,10 @@ def _get_dataset_metrics( ] ), ] + print(f"prompts: {len(metric_outputs[0].element_wise_scores[0].values)}, {metric_outputs[0].element_wise_scores[0].values}") + print(f"labels: {len(metric_outputs[1].element_wise_scores[0].values)}, {metric_outputs[1].element_wise_scores[0].values}") + # print(f"metric_outputs: {metric_outputs}") + for metric in self._metrics: metric_results = metric.compute( diff --git a/turbo_alignment/generators/base.py b/turbo_alignment/generators/base.py index 5415daf5..e4d94ada 100755 --- a/turbo_alignment/generators/base.py +++ b/turbo_alignment/generators/base.py @@ -71,19 +71,38 @@ def generate_from_dataset(self, dataset: BaseDataset) -> list[InferenceOutputT]: ) ) else: - with self._accelerator.split_between_processes( - list(zip(records_batches, original_records_batches)), apply_padding=True - ) as accelerator_records: - generations = [ - self._generate_from_batch( - records_batch, - original_records_batch, - dataset.source.name, - ) - for records_batch, original_records_batch in accelerator_records - ][: len(records_batches)] - - return sum(generations, []) + # with self._accelerator.split_between_processes( + # list(zip(records_batches, original_records_batches)), apply_padding=True + # ) as accelerator_records: + len_records_batches = len(records_batches) + world_size = self._accelerator.num_processes + rank_device = self._accelerator.process_index + window_size = math.ceil(len_records_batches / world_size) + + print(f"len records_batches", len(records_batches)) + print("original_records_batches", len(original_records_batches)) + print("rank_device", rank_device) + print("world_size", world_size) + print(f"slice [{rank_device * window_size} : {rank_device * window_size + window_size}]") + + slice_records_batches = records_batches[rank_device * window_size : rank_device * window_size + window_size] + slice_original_records_batches = original_records_batches[rank_device * window_size : (rank_device + 1) * window_size] + + flattened_records_batches = [r for batch in slice_records_batches for r in batch] + flattened_original_records_batches = [r for batch in slice_original_records_batches for r in batch] + + generations = self._generate_from_batch( + flattened_records_batches, + flattened_original_records_batches, + dataset.source.name, + ) + # for records_batch, original_records_batch in accelerator_records + # ][: len(records_batches)] + # ] + # print(generations) + + # return sum(generations, []) + return generations class ChatGeneratorBase(BaseGenerator, Generic[DatasetRecordT, InferenceOutputT]): From 97e6aae7366e5c16d5667142af989e53afcbe473 Mon Sep 17 00:00:00 2001 From: "a.khokhulin" Date: Sat, 23 Nov 2024 23:55:13 +0300 Subject: [PATCH 02/12] works --- turbo_alignment/cherry_picks/chat.py | 46 ++++++---------------------- turbo_alignment/dataset/chat/chat.py | 20 +++++++++++- turbo_alignment/generators/base.py | 45 ++++++++------------------- 3 files changed, 42 insertions(+), 69 deletions(-) diff --git a/turbo_alignment/cherry_picks/chat.py b/turbo_alignment/cherry_picks/chat.py index ff5c4700..543ccafd 100755 --- a/turbo_alignment/cherry_picks/chat.py +++ b/turbo_alignment/cherry_picks/chat.py @@ -42,47 +42,22 @@ def _get_dataset_metrics( tokenizer=tokenizer, transformers_settings=self._generator_transformers_settings, custom_generation_settings=self._custom_generation_settings, - # accelerator=accelerator, return_logits=True, ) batch_size = self._generator_transformers_settings.num_return_sequences if accelerator is not None: - len_records_batches = len(dataset) - world_size = accelerator.num_processes - rank_device = accelerator.process_index - window_size = math.ceil(len_records_batches / world_size) - - print(f"len records_batches", len_records_batches) - print("rank_device", rank_device) - print("world_size", world_size) - print(f"slice [{rank_device * window_size} : {rank_device * window_size + window_size}]") - - - dataset = dataset[rank_device * window_size : rank_device * window_size + window_size] + dataset = self._get_sharded_dataset( + dataset=dataset, + accelerator=accelerator, + ) generations = generator.generate_from_dataset(dataset) - print(f"len dataset {len(dataset)}") - print(f"first 2 prompts: {[record['prompt'] for record in dataset][:2]}") - print(f"len generations {len(generations)}") - print(f"first 2 generations: {generations[:2]}") - prompts = [record['prompt'] for record in dataset] - # len_records_batches = len(prompts) - # world_size = accelerator.num_processes - # rank_device = accelerator.process_index - # window_size = math.ceil(len_records_batches / world_size) - - # prompts = - - - string_answers = [[answer.content for answer in g.answers] for g in generations] - print(f"prompts: {len(prompts)}, {prompts[:2]}") - print(f"string_answers: {len(string_answers)}, {string_answers[:2]}") string_labels = [[g.messages[-1].content] * len(g.answers) for g in generations] flattened_answers = [answer for g in generations for answer in g.answers] @@ -98,9 +73,6 @@ def _get_dataset_metrics( if sft_model is not None: metrics_kwargs[KLType.SFT_MODEL] = get_logits(input_tokens_ids, answer_tokens_ids, sft_model) - print(f"len prompts {len(prompts)}") - print(f"batch_size {batch_size}") - print(f"prompt element_wise_scores {len([prompt for prompt in prompts for _ in range(batch_size)])}") metric_outputs = [ MetricResults( element_wise_scores=[ @@ -121,10 +93,6 @@ def _get_dataset_metrics( ] ), ] - print(f"prompts: {len(metric_outputs[0].element_wise_scores[0].values)}, {metric_outputs[0].element_wise_scores[0].values}") - print(f"labels: {len(metric_outputs[1].element_wise_scores[0].values)}, {metric_outputs[1].element_wise_scores[0].values}") - # print(f"metric_outputs: {metric_outputs}") - for metric in self._metrics: metric_results = metric.compute( @@ -143,3 +111,9 @@ def _get_dataset_metrics( metric_outputs.extend(metric_results) return metric_outputs + + def _get_sharded_dataset(self, dataset: InferenceChatDataset, accelerator: Accelerator) -> InferenceChatDataset: + rank_device = accelerator.process_index + window_size = math.ceil(len(dataset) / accelerator.num_processes) + + return dataset.get_slice(rank_device * window_size, rank_device * window_size + window_size) diff --git a/turbo_alignment/dataset/chat/chat.py b/turbo_alignment/dataset/chat/chat.py index ebdede9e..f1ac6051 100755 --- a/turbo_alignment/dataset/chat/chat.py +++ b/turbo_alignment/dataset/chat/chat.py @@ -2,7 +2,7 @@ from abc import ABC from itertools import accumulate from pathlib import Path -from typing import Any, overload +from typing import Any, Self, overload import numpy as np import numpy.typing as npt @@ -361,3 +361,21 @@ def __init__( def convert_records(self, records: list[ChatDatasetRecord]) -> list[dict[str, Any] | None]: return self._encode(records, inference=True, random_cut=self._random_cut) + + def get_slice(self, start: int, end: int) -> Self: + new_instance = self.__class__( + source=self.source, + settings=self.settings, + tokenizer=self.tokenizer, + read=False, + random_cut=self._random_cut, + ) + + new_instance.records = self.records[start:end] + + new_instance.original_records_map = { + record['id']: self.get_original_record_by_id(record['id']) + for record in new_instance.records + } + + return new_instance diff --git a/turbo_alignment/generators/base.py b/turbo_alignment/generators/base.py index e4d94ada..5415daf5 100755 --- a/turbo_alignment/generators/base.py +++ b/turbo_alignment/generators/base.py @@ -71,38 +71,19 @@ def generate_from_dataset(self, dataset: BaseDataset) -> list[InferenceOutputT]: ) ) else: - # with self._accelerator.split_between_processes( - # list(zip(records_batches, original_records_batches)), apply_padding=True - # ) as accelerator_records: - len_records_batches = len(records_batches) - world_size = self._accelerator.num_processes - rank_device = self._accelerator.process_index - window_size = math.ceil(len_records_batches / world_size) - - print(f"len records_batches", len(records_batches)) - print("original_records_batches", len(original_records_batches)) - print("rank_device", rank_device) - print("world_size", world_size) - print(f"slice [{rank_device * window_size} : {rank_device * window_size + window_size}]") - - slice_records_batches = records_batches[rank_device * window_size : rank_device * window_size + window_size] - slice_original_records_batches = original_records_batches[rank_device * window_size : (rank_device + 1) * window_size] - - flattened_records_batches = [r for batch in slice_records_batches for r in batch] - flattened_original_records_batches = [r for batch in slice_original_records_batches for r in batch] - - generations = self._generate_from_batch( - flattened_records_batches, - flattened_original_records_batches, - dataset.source.name, - ) - # for records_batch, original_records_batch in accelerator_records - # ][: len(records_batches)] - # ] - # print(generations) - - # return sum(generations, []) - return generations + with self._accelerator.split_between_processes( + list(zip(records_batches, original_records_batches)), apply_padding=True + ) as accelerator_records: + generations = [ + self._generate_from_batch( + records_batch, + original_records_batch, + dataset.source.name, + ) + for records_batch, original_records_batch in accelerator_records + ][: len(records_batches)] + + return sum(generations, []) class ChatGeneratorBase(BaseGenerator, Generic[DatasetRecordT, InferenceOutputT]): From d2ed78191589c46df8efc9d601868429a67e538c Mon Sep 17 00:00:00 2001 From: "a.khokhulin" Date: Sat, 23 Nov 2024 23:58:03 +0300 Subject: [PATCH 03/12] del line --- turbo_alignment/cherry_picks/chat.py | 1 - 1 file changed, 1 deletion(-) diff --git a/turbo_alignment/cherry_picks/chat.py b/turbo_alignment/cherry_picks/chat.py index 543ccafd..d79d8af7 100755 --- a/turbo_alignment/cherry_picks/chat.py +++ b/turbo_alignment/cherry_picks/chat.py @@ -56,7 +56,6 @@ def _get_dataset_metrics( generations = generator.generate_from_dataset(dataset) prompts = [record['prompt'] for record in dataset] - string_answers = [[answer.content for answer in g.answers] for g in generations] string_labels = [[g.messages[-1].content] * len(g.answers) for g in generations] From 13fa7e40348c07f187f158fc2a9dfb15f3f9bac2 Mon Sep 17 00:00:00 2001 From: "a.khokhulin" Date: Sat, 23 Nov 2024 23:59:41 +0300 Subject: [PATCH 04/12] rename --- turbo_alignment/cherry_picks/chat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/turbo_alignment/cherry_picks/chat.py b/turbo_alignment/cherry_picks/chat.py index d79d8af7..e729538f 100755 --- a/turbo_alignment/cherry_picks/chat.py +++ b/turbo_alignment/cherry_picks/chat.py @@ -113,6 +113,6 @@ def _get_dataset_metrics( def _get_sharded_dataset(self, dataset: InferenceChatDataset, accelerator: Accelerator) -> InferenceChatDataset: rank_device = accelerator.process_index - window_size = math.ceil(len(dataset) / accelerator.num_processes) + slice_size = math.ceil(len(dataset) / accelerator.num_processes) - return dataset.get_slice(rank_device * window_size, rank_device * window_size + window_size) + return dataset.get_slice(rank_device * slice_size, rank_device * slice_size + slice_size) From 82fb83874656d02c77421fbe5a6303c4bbf8c01f Mon Sep 17 00:00:00 2001 From: "a.khokhulin" Date: Sun, 24 Nov 2024 00:13:40 +0300 Subject: [PATCH 05/12] self from typing_extensions --- turbo_alignment/dataset/chat/chat.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/turbo_alignment/dataset/chat/chat.py b/turbo_alignment/dataset/chat/chat.py index f1ac6051..5c5e09ab 100755 --- a/turbo_alignment/dataset/chat/chat.py +++ b/turbo_alignment/dataset/chat/chat.py @@ -2,7 +2,8 @@ from abc import ABC from itertools import accumulate from pathlib import Path -from typing import Any, Self, overload +from typing import Any, overload +from typing_extensions import Self import numpy as np import numpy.typing as npt From 32b5c169ea288060723c9fb80204ee03a946aae3 Mon Sep 17 00:00:00 2001 From: "a.khokhulin" Date: Sat, 23 Nov 2024 21:16:14 +0000 Subject: [PATCH 06/12] pretty --- turbo_alignment/dataset/chat/chat.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/turbo_alignment/dataset/chat/chat.py b/turbo_alignment/dataset/chat/chat.py index 5c5e09ab..dd19c1bc 100755 --- a/turbo_alignment/dataset/chat/chat.py +++ b/turbo_alignment/dataset/chat/chat.py @@ -3,12 +3,12 @@ from itertools import accumulate from pathlib import Path from typing import Any, overload -from typing_extensions import Self import numpy as np import numpy.typing as npt import torch from transformers import PreTrainedTokenizerBase +from typing_extensions import Self from turbo_alignment.common.data.io import read_jsonl from turbo_alignment.common.logging import get_project_logger @@ -375,8 +375,7 @@ def get_slice(self, start: int, end: int) -> Self: new_instance.records = self.records[start:end] new_instance.original_records_map = { - record['id']: self.get_original_record_by_id(record['id']) - for record in new_instance.records + record['id']: self.get_original_record_by_id(record['id']) for record in new_instance.records } return new_instance From b38e9d9429d90e4e5f8250b746bc2fbbb135201d Mon Sep 17 00:00:00 2001 From: "a.khokhulin" Date: Sun, 24 Nov 2024 00:38:47 +0300 Subject: [PATCH 07/12] dataset records separate --- turbo_alignment/dataset/chat/chat.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/turbo_alignment/dataset/chat/chat.py b/turbo_alignment/dataset/chat/chat.py index dd19c1bc..2cd23e14 100755 --- a/turbo_alignment/dataset/chat/chat.py +++ b/turbo_alignment/dataset/chat/chat.py @@ -372,10 +372,11 @@ def get_slice(self, start: int, end: int) -> Self: random_cut=self._random_cut, ) - new_instance.records = self.records[start:end] + dataset_records = [self[idx] for idx in range(len(self))] + new_instance.records = self.records[start:end] new_instance.original_records_map = { - record['id']: self.get_original_record_by_id(record['id']) for record in new_instance.records + record['id']: self.get_original_record_by_id(record['id']) for record in dataset_records } return new_instance From a8d8f10e25882e9a3d274e9f99a5f2eb7c483dd1 Mon Sep 17 00:00:00 2001 From: "a.khokhulin" Date: Sat, 23 Nov 2024 22:02:37 +0000 Subject: [PATCH 08/12] update --- turbo_alignment/dataset/chat/chat.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/turbo_alignment/dataset/chat/chat.py b/turbo_alignment/dataset/chat/chat.py index 2cd23e14..c3a76561 100755 --- a/turbo_alignment/dataset/chat/chat.py +++ b/turbo_alignment/dataset/chat/chat.py @@ -376,7 +376,8 @@ def get_slice(self, start: int, end: int) -> Self: new_instance.records = self.records[start:end] new_instance.original_records_map = { - record['id']: self.get_original_record_by_id(record['id']) for record in dataset_records + record['id']: self.get_original_record_by_id(record['id']) + for record in dataset_records } return new_instance From 6d0e565b078db42eeada6edf1fa5e6dbb059a521 Mon Sep 17 00:00:00 2001 From: "a.khokhulin" Date: Sun, 24 Nov 2024 01:03:10 +0300 Subject: [PATCH 09/12] linters fix --- turbo_alignment/dataset/chat/chat.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/turbo_alignment/dataset/chat/chat.py b/turbo_alignment/dataset/chat/chat.py index c3a76561..2cd23e14 100755 --- a/turbo_alignment/dataset/chat/chat.py +++ b/turbo_alignment/dataset/chat/chat.py @@ -376,8 +376,7 @@ def get_slice(self, start: int, end: int) -> Self: new_instance.records = self.records[start:end] new_instance.original_records_map = { - record['id']: self.get_original_record_by_id(record['id']) - for record in dataset_records + record['id']: self.get_original_record_by_id(record['id']) for record in dataset_records } return new_instance From 4d84b45fbb390dc73c99c7ea2088e3be260018ae Mon Sep 17 00:00:00 2001 From: "a.khokhulin" Date: Mon, 25 Nov 2024 15:16:39 +0300 Subject: [PATCH 10/12] expand to all cherry picks in progress --- turbo_alignment/cherry_picks/base.py | 9 +++++++++ turbo_alignment/cherry_picks/chat.py | 7 ------- turbo_alignment/cherry_picks/classification.py | 7 ++++++- turbo_alignment/cherry_picks/rag.py | 7 ++++++- turbo_alignment/cherry_picks/rm.py | 6 ++++++ turbo_alignment/dataset/base/base.py | 5 +++++ .../dataset/classification/classification.py | 17 +++++++++++++++++ .../pair_preferences/pair_preference.py | 18 ++++++++++++++++++ 8 files changed, 67 insertions(+), 9 deletions(-) diff --git a/turbo_alignment/cherry_picks/base.py b/turbo_alignment/cherry_picks/base.py index 14065346..995dd62d 100755 --- a/turbo_alignment/cherry_picks/base.py +++ b/turbo_alignment/cherry_picks/base.py @@ -1,6 +1,8 @@ from abc import abstractmethod +import math from typing import Generic, Iterable, TypeVar +from accelerate import Accelerator from transformers import ( PreTrainedModel, PreTrainedTokenizerBase, @@ -72,3 +74,10 @@ def on_evaluate( model.train() return dataset_metrics + + @staticmethod + def _get_sharded_dataset(dataset: InferenceDatasetT, accelerator: Accelerator) -> InferenceDatasetT: + rank_device = accelerator.process_index + slice_size = math.ceil(len(dataset) / accelerator.num_processes) + + return dataset.get_slice(rank_device * slice_size, rank_device * slice_size + slice_size) diff --git a/turbo_alignment/cherry_picks/chat.py b/turbo_alignment/cherry_picks/chat.py index e729538f..c6f4631b 100755 --- a/turbo_alignment/cherry_picks/chat.py +++ b/turbo_alignment/cherry_picks/chat.py @@ -1,4 +1,3 @@ -import math from typing import Iterable from accelerate import Accelerator @@ -110,9 +109,3 @@ def _get_dataset_metrics( metric_outputs.extend(metric_results) return metric_outputs - - def _get_sharded_dataset(self, dataset: InferenceChatDataset, accelerator: Accelerator) -> InferenceChatDataset: - rank_device = accelerator.process_index - slice_size = math.ceil(len(dataset) / accelerator.num_processes) - - return dataset.get_slice(rank_device * slice_size, rank_device * slice_size + slice_size) diff --git a/turbo_alignment/cherry_picks/classification.py b/turbo_alignment/cherry_picks/classification.py index b1d8bf16..cb2e0b1e 100755 --- a/turbo_alignment/cherry_picks/classification.py +++ b/turbo_alignment/cherry_picks/classification.py @@ -33,9 +33,14 @@ def _get_dataset_metrics( generator = ClassificationGenerator( model=model, tokenizer=tokenizer, - accelerator=accelerator, ) + if accelerator is not None: + dataset = self._get_sharded_dataset( + dataset=dataset, + accelerator=accelerator, + ) + generations = generator.generate_from_dataset(dataset) predictions = [record.predicted_label for record in generations] labels = [record['labels'] for record in dataset] diff --git a/turbo_alignment/cherry_picks/rag.py b/turbo_alignment/cherry_picks/rag.py index b1c2c4d1..53d332e6 100755 --- a/turbo_alignment/cherry_picks/rag.py +++ b/turbo_alignment/cherry_picks/rag.py @@ -22,9 +22,14 @@ def _get_dataset_metrics( tokenizer=tokenizer, transformers_settings=self._generator_transformers_settings, custom_generation_settings=self._custom_generation_settings, - accelerator=accelerator, ) + if accelerator is not None: + dataset = self._get_sharded_dataset( + dataset=dataset, + accelerator=accelerator, + ) + generations = generator.generate_from_dataset(dataset) prompts = [dataset[i]['prompt'] for i in range(len(dataset))] diff --git a/turbo_alignment/cherry_picks/rm.py b/turbo_alignment/cherry_picks/rm.py index 6cdcbcc4..829f2214 100755 --- a/turbo_alignment/cherry_picks/rm.py +++ b/turbo_alignment/cherry_picks/rm.py @@ -35,6 +35,12 @@ def _get_dataset_metrics( accelerator=accelerator, ) + if accelerator is not None: + dataset = self._get_sharded_dataset( + dataset=dataset, + accelerator=accelerator, + ) + generations = generator.generate_from_dataset(dataset) generations_w = [gen.reward_w for gen in generations] generations_l = [gen.reward_l for gen in generations] diff --git a/turbo_alignment/dataset/base/base.py b/turbo_alignment/dataset/base/base.py index 71e353e5..aac204fd 100755 --- a/turbo_alignment/dataset/base/base.py +++ b/turbo_alignment/dataset/base/base.py @@ -6,6 +6,7 @@ import torch from torch.utils.data import Dataset from transformers import PreTrainedTokenizerBase +from typing_extensions import Self from turbo_alignment.common.logging import get_project_logger from turbo_alignment.dataset.base.models import DatasetRecord @@ -102,6 +103,10 @@ def _read_records(records: list[dict]) -> list[RecordT]: def _read_records(records): ... + @abstractmethod + def get_slice(self, start: int, end: int) -> Self: + ... + class AlignmentDataset(BaseDataset, ABC, Generic[RecordT]): def __init__( diff --git a/turbo_alignment/dataset/classification/classification.py b/turbo_alignment/dataset/classification/classification.py index 5c6d9084..030bd69c 100755 --- a/turbo_alignment/dataset/classification/classification.py +++ b/turbo_alignment/dataset/classification/classification.py @@ -3,6 +3,7 @@ from typing import Any, overload from transformers import PreTrainedTokenizerBase +from typing_extensions import Self from turbo_alignment.common.data.io import read_jsonl from turbo_alignment.common.logging import get_project_logger @@ -81,6 +82,22 @@ def _read_records(records) -> list[ClassificationDatasetRecord]: return [ClassificationDatasetRecord(**record) for record in records] raise NotImplementedError + def get_slice(self, start: int, end: int) -> Self: + new_instance = self.__class__( + source=self.source, + settings=self.settings, + tokenizer=self.tokenizer, + ) + + dataset_records = [self[idx] for idx in range(len(self))] + + new_instance.records = self.records[start:end] + new_instance.original_records_map = { + record['id']: self.get_original_record_by_id(record['id']) for record in dataset_records + } + + return new_instance + @ClassificationDatasetTypeRegistry.register(DatasetStrategy.TRAIN) class TrainClassificationDataset(ClassificationDataset): diff --git a/turbo_alignment/dataset/pair_preferences/pair_preference.py b/turbo_alignment/dataset/pair_preferences/pair_preference.py index fc6ff70f..e163e4bd 100755 --- a/turbo_alignment/dataset/pair_preferences/pair_preference.py +++ b/turbo_alignment/dataset/pair_preferences/pair_preference.py @@ -2,6 +2,7 @@ from typing import Any, overload from transformers import PreTrainedTokenizerBase +from typing_extensions import Self from turbo_alignment.common.data.io import read_jsonl from turbo_alignment.common.logging import get_project_logger @@ -107,3 +108,20 @@ def _read_records(records) -> list[PairPreferenceRecord]: if isinstance(records, list): return [PairPreferenceRecord(**record) for record in records] raise NotImplementedError + + def get_slice(self, start: int, end: int) -> Self: + new_instance = self.__class__( + source=self.source, + settings=self.settings, + tokenizer=self.tokenizer, + read=False, + ) + + dataset_records = [self[idx] for idx in range(len(self))] + + new_instance.records = self.records[start:end] + new_instance.original_records_map = { + record['id']: self.get_original_record_by_id(record['id']) for record in dataset_records + } + + return new_instance \ No newline at end of file From d8d811c20f939ddd8c1380c686ab4621266b7929 Mon Sep 17 00:00:00 2001 From: "a.khokhulin" Date: Mon, 25 Nov 2024 16:11:50 +0000 Subject: [PATCH 11/12] refactor --- turbo_alignment/cherry_picks/base.py | 7 ++---- turbo_alignment/cherry_picks/chat.py | 8 +++++++ .../cherry_picks/classification.py | 20 ++++++++++++---- turbo_alignment/cherry_picks/multimodal.py | 2 ++ turbo_alignment/cherry_picks/rag.py | 9 +++++++ turbo_alignment/cherry_picks/rm.py | 8 +++++++ turbo_alignment/dataset/base/base.py | 5 ---- .../dataset/classification/classification.py | 24 +++++++++---------- .../pipelines/train/classification.py | 10 ++++---- 9 files changed, 63 insertions(+), 30 deletions(-) diff --git a/turbo_alignment/cherry_picks/base.py b/turbo_alignment/cherry_picks/base.py index 995dd62d..93b0ffdc 100755 --- a/turbo_alignment/cherry_picks/base.py +++ b/turbo_alignment/cherry_picks/base.py @@ -1,5 +1,4 @@ from abc import abstractmethod -import math from typing import Generic, Iterable, TypeVar from accelerate import Accelerator @@ -76,8 +75,6 @@ def on_evaluate( return dataset_metrics @staticmethod + @abstractmethod def _get_sharded_dataset(dataset: InferenceDatasetT, accelerator: Accelerator) -> InferenceDatasetT: - rank_device = accelerator.process_index - slice_size = math.ceil(len(dataset) / accelerator.num_processes) - - return dataset.get_slice(rank_device * slice_size, rank_device * slice_size + slice_size) + ... diff --git a/turbo_alignment/cherry_picks/chat.py b/turbo_alignment/cherry_picks/chat.py index c6f4631b..3229940f 100755 --- a/turbo_alignment/cherry_picks/chat.py +++ b/turbo_alignment/cherry_picks/chat.py @@ -1,3 +1,4 @@ +import math from typing import Iterable from accelerate import Accelerator @@ -109,3 +110,10 @@ def _get_dataset_metrics( metric_outputs.extend(metric_results) return metric_outputs + + @staticmethod + def _get_sharded_dataset(dataset: InferenceChatDataset, accelerator: Accelerator) -> InferenceChatDataset: + rank_device = accelerator.process_index + slice_size = math.ceil(len(dataset) / accelerator.num_processes) + + return dataset.get_slice(rank_device * slice_size, rank_device * slice_size + slice_size) diff --git a/turbo_alignment/cherry_picks/classification.py b/turbo_alignment/cherry_picks/classification.py index cb2e0b1e..e811b0e2 100755 --- a/turbo_alignment/cherry_picks/classification.py +++ b/turbo_alignment/cherry_picks/classification.py @@ -1,3 +1,4 @@ +import math from typing import Iterable from accelerate import Accelerator @@ -5,25 +6,27 @@ from turbo_alignment.cherry_picks.base import CherryPickCallbackBase from turbo_alignment.dataset.chat.conversation import Conversation -from turbo_alignment.dataset.classification.classification import ClassificationDataset +from turbo_alignment.dataset.classification.classification import ( + InferenceClassificationDataset, +) from turbo_alignment.generators.classification import ClassificationGenerator from turbo_alignment.metrics.metric import Metric from turbo_alignment.settings.cherry_pick import ClassificationCherryPickSettings from turbo_alignment.settings.metric import ElementWiseScores, MetricResults -class ClassificationCherryPickCallback(CherryPickCallbackBase[ClassificationDataset]): +class ClassificationCherryPickCallback(CherryPickCallbackBase[InferenceClassificationDataset]): def __init__( self, cherry_pick_settings: ClassificationCherryPickSettings, - datasets: Iterable[ClassificationDataset], + datasets: Iterable[InferenceClassificationDataset], metrics: list[Metric], ) -> None: super().__init__(cherry_pick_settings=cherry_pick_settings, datasets=datasets, metrics=metrics) def _get_dataset_metrics( self, - dataset: ClassificationDataset, + dataset: InferenceClassificationDataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase, **kwargs, @@ -69,3 +72,12 @@ def _get_dataset_metrics( ] return metric_outputs + + @staticmethod + def _get_sharded_dataset( + dataset: InferenceClassificationDataset, accelerator: Accelerator + ) -> InferenceClassificationDataset: + rank_device = accelerator.process_index + slice_size = math.ceil(len(dataset) / accelerator.num_processes) + + return dataset.get_slice(rank_device * slice_size, rank_device * slice_size + slice_size) diff --git a/turbo_alignment/cherry_picks/multimodal.py b/turbo_alignment/cherry_picks/multimodal.py index 80c60226..a4f94bdd 100755 --- a/turbo_alignment/cherry_picks/multimodal.py +++ b/turbo_alignment/cherry_picks/multimodal.py @@ -14,6 +14,8 @@ class MultimodalCherryPickCallback(CherryPickCallbackBase[InferenceMultimodalDataset]): + # pylint: disable=abstract-method + # TODO: add _get_sharded_dataset method def __init__( self, cherry_pick_settings: MultimodalCherryPickSettings, diff --git a/turbo_alignment/cherry_picks/rag.py b/turbo_alignment/cherry_picks/rag.py index 53d332e6..2c0cd276 100755 --- a/turbo_alignment/cherry_picks/rag.py +++ b/turbo_alignment/cherry_picks/rag.py @@ -1,3 +1,5 @@ +import math + from accelerate import Accelerator from transformers import PreTrainedModel, PreTrainedTokenizerBase @@ -62,3 +64,10 @@ def _get_dataset_metrics( metric_outputs.extend(metric_results) return metric_outputs + + @staticmethod + def _get_sharded_dataset(dataset, accelerator: Accelerator) -> InferenceChatDataset: + rank_device = accelerator.process_index + slice_size = math.ceil(len(dataset) / accelerator.num_processes) + + return dataset.get_slice(rank_device * slice_size, rank_device * slice_size + slice_size) diff --git a/turbo_alignment/cherry_picks/rm.py b/turbo_alignment/cherry_picks/rm.py index 829f2214..11e79f28 100755 --- a/turbo_alignment/cherry_picks/rm.py +++ b/turbo_alignment/cherry_picks/rm.py @@ -1,3 +1,4 @@ +import math from typing import Iterable from accelerate import Accelerator @@ -75,3 +76,10 @@ def _get_dataset_metrics( ] return metric_outputs + + @staticmethod + def _get_sharded_dataset(dataset: PairPreferenceDataset, accelerator: Accelerator) -> PairPreferenceDataset: + rank_device = accelerator.process_index + slice_size = math.ceil(len(dataset) / accelerator.num_processes) + + return dataset.get_slice(rank_device * slice_size, rank_device * slice_size + slice_size) diff --git a/turbo_alignment/dataset/base/base.py b/turbo_alignment/dataset/base/base.py index aac204fd..71e353e5 100755 --- a/turbo_alignment/dataset/base/base.py +++ b/turbo_alignment/dataset/base/base.py @@ -6,7 +6,6 @@ import torch from torch.utils.data import Dataset from transformers import PreTrainedTokenizerBase -from typing_extensions import Self from turbo_alignment.common.logging import get_project_logger from turbo_alignment.dataset.base.models import DatasetRecord @@ -103,10 +102,6 @@ def _read_records(records: list[dict]) -> list[RecordT]: def _read_records(records): ... - @abstractmethod - def get_slice(self, start: int, end: int) -> Self: - ... - class AlignmentDataset(BaseDataset, ABC, Generic[RecordT]): def __init__( diff --git a/turbo_alignment/dataset/classification/classification.py b/turbo_alignment/dataset/classification/classification.py index 030bd69c..77fac165 100755 --- a/turbo_alignment/dataset/classification/classification.py +++ b/turbo_alignment/dataset/classification/classification.py @@ -82,6 +82,18 @@ def _read_records(records) -> list[ClassificationDatasetRecord]: return [ClassificationDatasetRecord(**record) for record in records] raise NotImplementedError + +@ClassificationDatasetTypeRegistry.register(DatasetStrategy.TRAIN) +class TrainClassificationDataset(ClassificationDataset): + def convert_records(self, records: list[ClassificationDatasetRecord]) -> list[dict[str, Any] | None]: + return self._encode(records, inference=False) + + +@ClassificationDatasetTypeRegistry.register(DatasetStrategy.INFERENCE) +class InferenceClassificationDataset(ClassificationDataset): + def convert_records(self, records: list[ClassificationDatasetRecord]) -> list[dict[str, Any] | None]: + return self._encode(records, inference=True) + def get_slice(self, start: int, end: int) -> Self: new_instance = self.__class__( source=self.source, @@ -97,15 +109,3 @@ def get_slice(self, start: int, end: int) -> Self: } return new_instance - - -@ClassificationDatasetTypeRegistry.register(DatasetStrategy.TRAIN) -class TrainClassificationDataset(ClassificationDataset): - def convert_records(self, records: list[ClassificationDatasetRecord]) -> list[dict[str, Any] | None]: - return self._encode(records, inference=False) - - -@ClassificationDatasetTypeRegistry.register(DatasetStrategy.INFERENCE) -class InferenceClassificationDataset(ClassificationDataset): - def convert_records(self, records: list[ClassificationDatasetRecord]) -> list[dict[str, Any] | None]: - return self._encode(records, inference=True) diff --git a/turbo_alignment/pipelines/train/classification.py b/turbo_alignment/pipelines/train/classification.py index 1bc09fe1..5b07f929 100755 --- a/turbo_alignment/pipelines/train/classification.py +++ b/turbo_alignment/pipelines/train/classification.py @@ -7,7 +7,9 @@ from turbo_alignment.cherry_picks.classification import ClassificationCherryPickCallback from turbo_alignment.common.logging import get_project_logger from turbo_alignment.constants import TRAINER_LOGS_FOLDER -from turbo_alignment.dataset.classification.classification import ClassificationDataset +from turbo_alignment.dataset.classification.classification import ( + InferenceClassificationDataset, +) from turbo_alignment.dataset.loader import DatasetLoader from turbo_alignment.metrics.metric import Metric from turbo_alignment.metrics.registry import MetricSettingsRegistry @@ -47,9 +49,9 @@ def _get_cherry_pick_callback( ) -> ClassificationCherryPickCallback: cherry_pick_settings = experiment_settings.cherry_pick_settings - cherry_pick_datasets = DatasetLoader[ClassificationDataset](ClassificationDataset).load_datasets( - cherry_pick_settings.dataset_settings, tokenizer=tokenizer, strategy=DatasetStrategy.INFERENCE - ) + cherry_pick_datasets = DatasetLoader[InferenceClassificationDataset]( + InferenceClassificationDataset + ).load_datasets(cherry_pick_settings.dataset_settings, tokenizer=tokenizer, strategy=DatasetStrategy.INFERENCE) metrics = [ Metric.by_name(metric.type)(MetricSettingsRegistry.by_name(metric.type)(**metric.parameters)) From 0ba6c6aa27d56d78c9e4d85dd49c2c91c5e0a82f Mon Sep 17 00:00:00 2001 From: "a.khokhulin" Date: Mon, 25 Nov 2024 16:13:31 +0000 Subject: [PATCH 12/12] last line --- turbo_alignment/dataset/pair_preferences/pair_preference.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/turbo_alignment/dataset/pair_preferences/pair_preference.py b/turbo_alignment/dataset/pair_preferences/pair_preference.py index e163e4bd..bc2498eb 100755 --- a/turbo_alignment/dataset/pair_preferences/pair_preference.py +++ b/turbo_alignment/dataset/pair_preferences/pair_preference.py @@ -43,6 +43,7 @@ def __init__( read=False, ) super().__init__(source=source, settings=settings, tokenizer=tokenizer) + self.settings: PairPreferenceDatasetSettings = settings if read: self._read() @@ -124,4 +125,4 @@ def get_slice(self, start: int, end: int) -> Self: record['id']: self.get_original_record_by_id(record['id']) for record in dataset_records } - return new_instance \ No newline at end of file + return new_instance