diff --git a/petastorm/tests/test_unischema.py b/petastorm/tests/test_unischema.py index 5ac3750ba..3d97ee145 100644 --- a/petastorm/tests/test_unischema.py +++ b/petastorm/tests/test_unischema.py @@ -25,7 +25,9 @@ from petastorm.codecs import ScalarCodec, NdarrayCodec from petastorm.unischema import Unischema, UnischemaField, dict_to_spark_row, \ - insert_explicit_nulls, match_unischema_fields, _new_gt_255_compatible_namedtuple, _fullmatch + insert_explicit_nulls, match_unischema_fields, _new_gt_255_compatible_namedtuple, _fullmatch, encode_row + +from concurrent.futures import ThreadPoolExecutor try: from unittest import mock @@ -107,6 +109,28 @@ def test_as_spark_schema_unspecified_codec_type_unknown_scalar_type_raises(): TestSchema.as_spark_schema() +@pytest.mark.parametrize("pool_executor", [None, ThreadPoolExecutor(2)]) +def test_encode_row(pool_executor): + """Test various validations done on data types when converting a dictionary to a spark row""" + TestSchema = Unischema('TestSchema', [ + UnischemaField('string_field', np.string_, (), ScalarCodec(StringType()), False), + UnischemaField('int8_matrix', np.int8, (2, 2), NdarrayCodec(), False), + ]) + + row = {'string_field': 'abc', 'int8_matrix': np.asarray([[1, 2], [3, 4]], dtype=np.int8)} + encoded_row = encode_row(TestSchema, row, pool_executor) + assert set(row.keys()) == set(encoded_row) + assert isinstance(encoded_row['int8_matrix'], bytearray) + + extra_field_row = {'string_field': 'abc', 'int8_matrix': [[1, 2], [3, 4]], 'bogus': 'value'} + with pytest.raises(ValueError, match='.*not found.*bogus.*'): + encode_row(TestSchema, extra_field_row, pool_executor) + + extra_field_row = {'string_field': 'abc'} + with pytest.raises(ValueError, match='int8_matrix is not found'): + encode_row(TestSchema, extra_field_row, pool_executor) + + def test_dict_to_spark_row_field_validation_scalar_types(): """Test various validations done on data types when converting a dictionary to a spark row""" TestSchema = Unischema('TestSchema', [ diff --git a/petastorm/unischema.py b/petastorm/unischema.py index 65d8ff33b..e9e6827e9 100644 --- a/petastorm/unischema.py +++ b/petastorm/unischema.py @@ -353,7 +353,7 @@ def from_arrow_schema(cls, parquet_dataset, omit_unsupported_fields=False): return Unischema('inferred_schema', unischema_fields) -def dict_to_spark_row(unischema, row_dict): +def dict_to_spark_row(unischema, row_dict, pool_executor=None): """Converts a single row into a spark Row object. Verifies that the data confirms with unischema definition types and encodes the data using the codec specified @@ -363,44 +363,78 @@ def dict_to_spark_row(unischema, row_dict): :param unischema: an instance of Unischema object :param row_dict: a dictionary where the keys match name of fields in the unischema. + :param pool_executor: if not None, encoding of row fields will be performed using the pool_executor :return: a single pyspark.Row object """ # Lazy loading pyspark to avoid creating pyspark dependency on data reading code path # (currently works only with make_batch_reader) import pyspark + encoded_dict = encode_row(unischema, row_dict, pool_executor) + + field_list = list(unischema.fields.keys()) + # generate a value list which match the schema column order. + value_list = [encoded_dict[name] for name in field_list] + # create a row by value list + row = pyspark.Row(*value_list) + # set row fields + row.__fields__ = field_list + + return row + + +def encode_row(unischema, row_dict, pool_executor=None): + """Verifies that the data confirms with unischema definition types and encodes the data using the codec specified + by the unischema. + + :param unischema: an instance of Unischema object + :param row_dict: a dictionary where the keys match name of fields in the unischema. + :param pool_executor: if not None, encoding of row fields will be performed using the pool_executor + :return: a dictionary of encoded fields + """ + + # Lazy loading pyspark to avoid creating pyspark dependency on data reading code path + # (currently works only with make_batch_reader) assert isinstance(unischema, Unischema) # Add null fields. Be careful not to mutate the input dictionary - that would be an unexpected side effect copy_row_dict = copy.copy(row_dict) insert_explicit_nulls(unischema, copy_row_dict) - if set(copy_row_dict.keys()) != set(unischema.fields.keys()): - raise ValueError('Dictionary fields \n{}\n do not match schema fields \n{}'.format( - '\n'.join(sorted(copy_row_dict.keys())), '\n'.join(unischema.fields.keys()))) + input_field_names = set(copy_row_dict.keys()) + unischema_field_names = set(unischema.fields.keys()) + + unknown_field_names = input_field_names - unischema_field_names - encoded_dict = {} + if unknown_field_names: + raise ValueError('Following fields of row_dict are not found in ' + 'unischema: {}'.format(', '.join(sorted(unknown_field_names)))) + + encoded_dict = dict() + futures_dict = dict() for field_name, value in copy_row_dict.items(): schema_field = unischema.fields[field_name] if value is None: if not schema_field.nullable: raise ValueError('Field {} is not "nullable", but got passes a None value') if schema_field.codec: - encoded_dict[field_name] = schema_field.codec.encode(schema_field, value) if value is not None else None + if value is None: + encoded_dict[field_name] = None + else: + if pool_executor: + futures_dict[field_name] = pool_executor.submit(schema_field.codec.encode, schema_field, value) + else: + encoded_dict[field_name] = schema_field.codec.encode(schema_field, value) else: if isinstance(value, (np.generic,)): encoded_dict[field_name] = value.tolist() else: encoded_dict[field_name] = value - field_list = list(unischema.fields.keys()) - # generate a value list which match the schema column order. - value_list = [encoded_dict[name] for name in field_list] - # create a row by value list - row = pyspark.Row(*value_list) - # set row fields - row.__fields__ = field_list - return row + for k, v in futures_dict.items(): + encoded_dict[k] = v.result() + + return encoded_dict def insert_explicit_nulls(unischema, row_dict):