Skip to content

Commit

Permalink
Better batching, handling of default model, simple flow to get respon…
Browse files Browse the repository at this point in the history
…ses (#6)

* feat: Add a simple flow to get granite model responses

Signed-off-by: shiv-sr <[email protected]>

* refactor: Split dataset into smaller batches for improved processing

This commit modifies the  class in  to split the dataset into smaller batches for improved processing. The code now uses start and end ids of the batch to split dataset and creates a batch on the fly rather than creating the batches apriori. This is necessary for large datasets.

Signed-off-by: shiv-sr <[email protected]>

* refactor: Update LLMBlock to use default model if model_id is not provided

This commit modifies the LLMBlock class in llmblock.py to use the default model if the model_id is not provided. It checks if the model_id is None and if so, retrieves the default model id from the client. This ensures that the LLMBlock can still function even if the model_id is not explicitly specified.

Signed-off-by: shiv-sr <[email protected]>

* refactor: Remove commented out code for dataset batching

The code changes in sdg.py remove the commented out code that was used for dataset batching. This code is no longer needed and can be safely removed.

Signed-off-by: shiv-sr <[email protected]>

---------

Signed-off-by: shiv-sr <[email protected]>
  • Loading branch information
Shivchander Sudalairaj authored and GitHub Enterprise committed Aug 14, 2024
1 parent 0fea5ec commit 3c1c491
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 14 deletions.
8 changes: 8 additions & 0 deletions src/instructlab/sdg/configs/skills/respond.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
system:
introduction:
principles:
examples:
generation: |
{question}
start_tags: [""]
end_tags: [""]
19 changes: 13 additions & 6 deletions src/instructlab/sdg/default_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,35 @@

# Local
from .filterblock import FilterByValueBlock
from .llmblock import LLMBlock, ConditionalLLMBlock
from .llmblock import ConditionalLLMBlock, LLMBlock
from .utilblocks import (
CombineColumnsBlock,
SamplePopulatorBlock,
SelectorBlock,
DuplicateColumns,
RenameColumns,
FlattenColumnsBlock,
RenameColumns,
SamplePopulatorBlock,
SelectorBlock,
SetToMajorityValue,
)

MODEL_FAMILY_MIXTRAL = "mixtral"
MODEL_FAMILY_MERLINITE = "merlinite"
MODEL_FAMILY_BLANK = "blank"

MODEL_FAMILY_IBM = "ibm"
MODEL_FAMILY_RHELAI = "rhelai"

_MODEL_PROMPT_MIXTRAL = "<s> [INST] {prompt} [/INST]"
_MODEL_PROMPT_MERLINITE = "<|system|>\nYou are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.\n<|user|>\n{prompt}\n<|assistant|>\n"
_MODEL_PROMPT_RHELAI = "<|system|>\nI am, Red Hat® Instruct Model based on Granite 7B, an AI language model developed by Red Hat and IBM Research, based on the Granite-7b-base language model. My primary function is to be a chat assistant.\n<|user|>\n{prompt}\n<|assistant|>\n"
_BLANK_PROMPT = "{prompt}"


_MODEL_PROMPTS = {
MODEL_FAMILY_MIXTRAL: _MODEL_PROMPT_MIXTRAL,
MODEL_FAMILY_MERLINITE: _MODEL_PROMPT_MERLINITE,
MODEL_FAMILY_BLANK: _BLANK_PROMPT,
MODEL_FAMILY_IBM: _MODEL_PROMPT_MERLINITE,
MODEL_FAMILY_RHELAI: _MODEL_PROMPT_RHELAI,
}


Expand Down Expand Up @@ -88,6 +92,9 @@ def get_flow_from_file(self, yaml_path: str) -> list:
with open(yaml_path, "r", encoding="utf-8") as yaml_file:
flow = yaml.safe_load(yaml_file)
for block in flow:
if "LLMBlock" in block["block_type"]:
block["block_config"]["client"] = self.client

block["block_type"] = BLOCK_TYPE_MAP[block["block_type"]]

if "config_path" in block["block_config"]:
Expand Down Expand Up @@ -121,7 +128,6 @@ def get_flow_from_file(self, yaml_path: str) -> list:
)

if "model_id" in block["block_config"]:
block["block_config"]["client"] = self.client
model_id = block["block_config"]["model_id"]
if "model_family" in block["block_config"]:
model_family = block["block_config"]["model_family"]
Expand Down Expand Up @@ -158,5 +164,6 @@ def get_flow_from_file(self, yaml_path: str) -> list:
"SynthSkillsFlow": "flows/synth_skills.yaml",
"SynthGroundedSkillsFlow": "flows/synth_grounded_skills.yaml",
"SynthKnowledgeFlow1.5": "flows/synth_knowledge1.5.yaml",
"GraniteResponsesFlow": "flows/granite_responses.yaml",
"AgenticImproveFlow": "flows/agentic_improve_skill.yaml",
}
10 changes: 10 additions & 0 deletions src/instructlab/sdg/flows/granite_responses.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
- block_type: LLMBlock
block_config:
block_name: gen_response
config_path: configs/skills/respond.yaml
model_family: rhelai
output_cols:
- gen_response
gen_kwargs:
temperature: 0
max_tokens: 8
10 changes: 8 additions & 2 deletions src/instructlab/sdg/llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ def __init__(
block_name,
config_path,
client,
model_id,
output_cols,
parser_kwargs={},
model_prompt="{prompt}",
model_id=None,
**batch_kwargs,
) -> None:
super().__init__(block_name)
Expand All @@ -54,7 +54,13 @@ def __init__(
)
self.prompt_template = self.prompt_struct.format(**self.block_config)
self.client = client
self.model = model_id
if model_id:
self.model = model_id
else:
# get the default model id from client
self.model = self.client.models.list().data[0].id

logger.info(f"Using model: {self.model}")
self.model_prompt = model_prompt
self.output_cols = output_cols
self.batch_params = batch_kwargs.get("batch_kwargs", {})
Expand Down
17 changes: 11 additions & 6 deletions src/instructlab/sdg/sdg.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ def _split_dataset(self, dataset: Dataset, batch_size: int) -> List[Dataset]:
"""Split the dataset into smaller batches."""
total_size = len(dataset)
num_batches = (total_size + batch_size - 1) // batch_size

batches = [
dataset.select(range(i * batch_size, min((i + 1) * batch_size, total_size)))
for i in range(num_batches)
(i * batch_size, min((i + 1) * batch_size, total_size))
for i in tqdm(range(num_batches))
]

return batches

def _get_missing_data(self, seed_data, generated_data):
Expand Down Expand Up @@ -69,8 +71,9 @@ def _save_intermediate_checkpoint(self, dataset, checkpoint_dir):
dataset.to_json(checkpoint_file, orient="records", lines=True)

@staticmethod
def _generate_data(pipelines, input_split, i=None):
def _generate_data(pipelines, input_split, ds, i=None):
logger.info(f"Processing split {i}")
input_split = ds.select(range(input_split[0], input_split[1]))
try:
for pipeline in pipelines:
input_split = pipeline.generate(input_split)
Expand Down Expand Up @@ -116,13 +119,13 @@ def generate(self, dataset: Dataset, checkpoint_dir=None) -> Dataset:
for pipeline in self.pipelines:
generated_dataset = pipeline.generate(seed_data)
return generated_dataset


logger.info("Splitting the dataset into smaller batches")
input_splits = (
self._split_dataset(seed_data, self.batch_size)
if self.batch_size
else [seed_data]
)

logger.info(
f"Generating dataset with {len(input_splits)} splits, batch size {self.batch_size}, and {self.num_workers} workers"
)
Expand All @@ -132,7 +135,9 @@ def generate(self, dataset: Dataset, checkpoint_dir=None) -> Dataset:

with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
futures = [
executor.submit(self._generate_data, self.pipelines, input_split, i)
executor.submit(
self._generate_data, self.pipelines, input_split, seed_data, i
)
for i, input_split in enumerate(input_splits)
]

Expand Down

0 comments on commit 3c1c491

Please sign in to comment.