Skip to content

Commit

Permalink
Merge pull request #143 from markmc/dtype-default
Browse files Browse the repository at this point in the history
filterblock: add default_value for use with convert_dtype
  • Loading branch information
russellb authored Jul 16, 2024
2 parents eb78549 + 2a04cbe commit c083f98
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 47 deletions.
96 changes: 56 additions & 40 deletions src/instructlab/sdg/filterblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,10 @@ class FilterByValueBlockError(Exception):

def _get_operator_func(op):
if not op in dir(operator):
raise FilterByValueBlockError("Unknown FilterByValueBlock operation '{op}'")
raise FilterByValueBlockError(f"Unknown FilterByValueBlock operation '{op}'")
return getattr(operator, op)


def _get_convert_dtype(convert_dtype):
if not convert_dtype:
return None

type_mapping = {
"int": int,
"float": float,
"bool": bool,
}

if not convert_dtype in type_mapping:
raise FilterByValueBlockError(
"Unknown FilterByValueBlock convert_dtype '{convert_dtype}'"
)

return type_mapping[convert_dtype]


# Note - this is not a method on the class below in order to avoid
# serializing the object itself when multi-processing is used.
# In particular, SSLContext - embedded in the OpenAI client object -
Expand All @@ -51,24 +33,52 @@ def _filter_by_values(samples, column, op, values, num_proc=1):
)


def _map_dtype(samples, column, dtype, num_proc=1):
def convert_column(sample):
class DTypeConverter:
def __init__(self, dtype, default_value=None):
self.dtype = dtype
self.default_value = default_value

def __call__(self, value):
if self.dtype is None:
return value
try:
sample[column] = dtype(sample[column])
return self.dtype(value)
except ValueError as e:
logger.error(
"Error converting dtype: %s, filling with None to be filtered later", e
logger.debug(
f"Error converting to {self.dtype}: {e}, filling with {self.default_value}"
)
return self.default_value

@classmethod
def get(cls, dtype, default_value):
if not dtype:
return DTypeConverter(None, None)

type_mapping = {
"int": (int, 0),
"float": (float, 0.0),
"bool": (bool, False),
}
if not dtype in type_mapping:
raise FilterByValueBlockError(
f"Unknown FilterByValueBlock convert_dtype '{dtype}'"
)
sample[column] = None
return sample

# FIXME: it appears multiprocessing map has issues with
# None columns. If we pass num_proc>1 here and the error
# case is triggered above, we get:
# ValueError: The features can't be aligned ...
# because the column is still considered a string not
# the new dtype.
num_proc = 1
if default_value is None:
return DTypeConverter(*type_mapping[dtype])

dtype = type_mapping[dtype][0]
return DTypeConverter(dtype, dtype(default_value))


# Note - this is not a method on the class below in order to avoid
# serializing the object itself when multi-processing is used.
# In particular, SSLContext - embedded in the OpenAI client object -
# cannot be pickled.
def _map_dtype(samples, column, dtype, num_proc=1):
def convert_column(sample):
sample[column] = dtype(sample[column])
return sample

return samples.map(convert_column, num_proc=num_proc)

Expand All @@ -83,6 +93,7 @@ def __init__(
filter_value,
operation,
convert_dtype=None,
default_value=None,
) -> None:
"""
Initializes a new instance of the FilterByValueBlock class.
Expand All @@ -100,6 +111,8 @@ def __init__(
- convert_dtype (string, optional): the name of a Python type to convert
the column values to. Supported values are "int", "float", and "bool".
Defaults to None.
- default_value (string, optional): a default value that should be used
if convert_dtype fails. Defaults to 0 for int/float, and False for bool.
Returns:
None
Expand Down Expand Up @@ -137,9 +150,9 @@ def __init__(
In that case, the operation will be applied to each value in the list. The result is
considered True if the operation is True for any of the values in the list.
Example: FilterByValueBlock(ctx, "filter_by_age", "age", [30, 35], "eq", "int")
Example: FilterByValueBlock(ctx, "filter_by_age", "age", [30, 35], "eq", "int", "0")
- This block will filter the dataset to only include rows where the
"age" column is equal to 30 or 35.
"age" column is equal to 30 or 35. Non-integer values will be treated like zero.
Example: FilterByValueBlock(ctx, "filter_by_city", "city", ["boston", "charleston", "dublin", "new york"], "eq")
- This block will filter the dataset to only include rows where the
Expand All @@ -153,14 +166,17 @@ def __init__(
self.value = filter_value if isinstance(filter_value, list) else [filter_value]
self.column_name = filter_column
self.operation = _get_operator_func(operation)
self.convert_dtype = _get_convert_dtype(convert_dtype)
if self.convert_dtype:
self.value = [self.convert_dtype(value) for value in self.value]
self.dtype = DTypeConverter.get(convert_dtype, default_value)
if self.dtype:
self.value = [self.dtype(value) for value in self.value]

def generate(self, samples) -> Dataset:
if self.convert_dtype:
if self.dtype:
samples = _map_dtype(
samples, self.column_name, self.convert_dtype, self.ctx.num_procs
samples,
self.column_name,
self.dtype,
self.ctx.num_procs,
)

return _filter_by_values(
Expand Down
3 changes: 3 additions & 0 deletions src/instructlab/sdg/pipelines/schema/v1.json
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@
"type": "string",
"enum": ["float", "int", "bool"]
},
"default_value": {
"type": "string"
},
"filter_column": {
"type": "string"
},
Expand Down
10 changes: 3 additions & 7 deletions tests/test_filterblock.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Standard
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock
import operator
import unittest

Expand Down Expand Up @@ -39,16 +39,12 @@ def setUp(self):
features=Features({"age": Value("string")}),
)

@patch("instructlab.sdg.filterblock.logger")
def test_generate_mixed_types(self, mock_logger):
def test_generate_mixed_types(self):
filtered_dataset = self.block.generate(self.dataset)
self.assertEqual(len(filtered_dataset), 1)
self.assertEqual(filtered_dataset["age"], [30])
mock_logger.error.assert_called()

@patch("instructlab.sdg.filterblock.logger")
def test_generate_mixed_types_multi_value(self, mock_logger):
def test_generate_mixed_types_multi_value(self):
filtered_dataset = self.block_with_list.generate(self.dataset)
self.assertEqual(len(filtered_dataset), 2)
self.assertEqual(filtered_dataset["age"], [30, 35])
mock_logger.error.assert_called()
1 change: 1 addition & 0 deletions tests/test_importblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def test_generate(self):
filter_value: 40
operation: le
convert_dtype: int
default_value: 1000
- name: import_child
type: ImportBlock
config:
Expand Down

0 comments on commit c083f98

Please sign in to comment.