Skip to content

Commit

Permalink
feat: [FC-0074] add support for annotated python dicts as avro map ty…
Browse files Browse the repository at this point in the history
…pe (#433)

Enable Python dicts to be mapped to Avro Map type for schema generation, expanding support for event payloads. Unlike the previous approach of mapping dicts to records, this method prevents conflicts with data attributes and avoids errors when dictionary contents (not type) are unknown.
  • Loading branch information
mariajgrimaldi authored Jan 30, 2025
1 parent dbfe2e1 commit c3d7aca
Show file tree
Hide file tree
Showing 9 changed files with 208 additions and 36 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ Change Log
Unreleased
__________

[9.16.0] - 2025-01-30
---------------------

Added
~~~~~

* Added support for annotated Python dictionaries as Avro Map type.

[9.15.2] - 2025-01-16
---------------------
Expand Down
2 changes: 1 addition & 1 deletion openedx_events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
more information about the project.
"""

__version__ = "9.15.2"
__version__ = "9.16.0"
17 changes: 14 additions & 3 deletions openedx_events/event_bus/avro/deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,27 @@ def _deserialized_avro_record_dict_to_object(data: dict, data_type, deserializer
elif data_type in PYTHON_TYPE_TO_AVRO_MAPPING:
return data
elif data_type_origin == list:
# returns types of list contents
# if data_type == List[int], arg_data_type = (int,)
# Returns types of list contents.
# Example: if data_type == List[int], arg_data_type = (int,)
arg_data_type = get_args(data_type)
if not arg_data_type:
raise TypeError(
"List without annotation type is not supported. The argument should be a type, for eg., List[int]"
)
# check whether list items type is in basic types.
# Check whether list items type is in basic types.
if arg_data_type[0] in SIMPLE_PYTHON_TYPE_TO_AVRO_MAPPING:
return data
elif data_type_origin == dict:
# Returns types of dict contents.
# Example: if data_type == Dict[str, int], arg_data_type = (str, int)
arg_data_type = get_args(data_type)
if not arg_data_type:
raise TypeError(
"Dict without annotation type is not supported. The argument should be a type, for eg., Dict[str, int]"
)
# Check whether dict items type is in basic types.
if all(arg in SIMPLE_PYTHON_TYPE_TO_AVRO_MAPPING for arg in arg_data_type):
return data
elif hasattr(data_type, "__attrs_attrs__"):
transformed = {}
for attribute in data_type.__attrs_attrs__:
Expand Down
21 changes: 18 additions & 3 deletions openedx_events/event_bus/avro/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,14 @@ def _create_avro_field_definition(data_key, data_type, previously_seen_types,
field["type"] = field_type
# Case 2: data_type is a simple type that can be converted directly to an Avro type
elif data_type in PYTHON_TYPE_TO_AVRO_MAPPING:
if PYTHON_TYPE_TO_AVRO_MAPPING[data_type] in ["record", "array"]:
if PYTHON_TYPE_TO_AVRO_MAPPING[data_type] in ["map", "array"]:
# pylint: disable-next=broad-exception-raised
raise Exception("Unable to generate Avro schema for dict or array fields without annotation types.")
avro_type = PYTHON_TYPE_TO_AVRO_MAPPING[data_type]
field["type"] = avro_type
elif data_type_origin == list:
# returns types of list contents
# if data_type == List[int], arg_data_type = (int,)
# Returns types of list contents.
# Example: if data_type == List[int], arg_data_type = (int,)
arg_data_type = get_args(data_type)
if not arg_data_type:
raise TypeError(
Expand All @@ -89,6 +89,21 @@ def _create_avro_field_definition(data_key, data_type, previously_seen_types,
f" {set(SIMPLE_PYTHON_TYPE_TO_AVRO_MAPPING.keys())}"
)
field["type"] = {"type": PYTHON_TYPE_TO_AVRO_MAPPING[data_type_origin], "items": avro_type}
elif data_type_origin == dict:
# Returns types of dict contents.
# Example: if data_type == Dict[str, int], arg_data_type = (str, int)
arg_data_type = get_args(data_type)
if not arg_data_type:
raise TypeError(
"Dict without annotation type is not supported. The argument should be a type, for eg., Dict[str, int]"
)
avro_type = SIMPLE_PYTHON_TYPE_TO_AVRO_MAPPING.get(arg_data_type[1])
if avro_type is None:
raise TypeError(
"Only following types are supported for dict arguments:"
f" {set(SIMPLE_PYTHON_TYPE_TO_AVRO_MAPPING.keys())}"
)
field["type"] = {"type": PYTHON_TYPE_TO_AVRO_MAPPING[data_type_origin], "values": avro_type}
# Case 3: data_type is an attrs class
elif hasattr(data_type, "__attrs_attrs__"):
# Inner Attrs Class
Expand Down
24 changes: 23 additions & 1 deletion openedx_events/event_bus/avro/tests/test_avro.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import io
import os
from datetime import datetime
from typing import List
from typing import List, Union
from unittest import TestCase
from uuid import UUID, uuid4

Expand Down Expand Up @@ -43,6 +43,7 @@ def generate_test_data_for_schema(schema): # pragma: no cover
'string': "default",
'double': 1.0,
'null': None,
'map': {'key': 'value'},
}

def get_default_value_or_raise(schema_field_type):
Expand Down Expand Up @@ -71,6 +72,9 @@ def get_default_value_or_raise(schema_field_type):
elif sub_field_type == "record":
# if we're dealing with a record, recurse into the record
data_dict.update({key: generate_test_data_for_schema(field_type)})
elif sub_field_type == "map":
# if we're dealing with a map, "values" will be the type of values in the map
data_dict.update({key: {"key": get_default_value_or_raise(field_type["values"])}})
else:
raise Exception(f"Unsupported type {field_type}") # pylint: disable=broad-exception-raised

Expand Down Expand Up @@ -112,6 +116,24 @@ def generate_test_event_data_for_data_type(data_type): # pragma: no cover
datetime: datetime.now(),
CCXLocator: CCXLocator(org='edx', course='DemoX', run='Demo_course', ccx='1'),
UUID: uuid4(),
dict[str, str]: {'key': 'value'},
dict[str, int]: {'key': 1},
dict[str, float]: {'key': 1.0},
dict[str, bool]: {'key': True},
dict[str, CourseKey]: {'key': CourseKey.from_string("course-v1:edX+DemoX.1+2014")},
dict[str, UsageKey]: {'key': UsageKey.from_string(
"block-v1:edx+DemoX+Demo_course+type@video+block@UaEBjyMjcLW65gaTXggB93WmvoxGAJa0JeHRrDThk",
)},
dict[str, LibraryLocatorV2]: {'key': LibraryLocatorV2.from_string('lib:MITx:reallyhardproblems')},
dict[str, LibraryUsageLocatorV2]: {
'key': LibraryUsageLocatorV2.from_string('lb:MITx:reallyhardproblems:problem:problem1'),
},
dict[str, List[int]]: {'key': [1, 2, 3]},
dict[str, List[str]]: {'key': ["hi", "there"]},
dict[str, dict[str, str]]: {'key': {'key': 'value'}},
dict[str, dict[str, int]]: {'key': {'key': 1}},
dict[str, Union[str, int]]: {'key': 'value'},
dict[str, Union[str, int, float]]: {'key': 1.0},
}
data_dict = {}
for attribute in data_type.__attrs_attrs__:
Expand Down
136 changes: 111 additions & 25 deletions openedx_events/event_bus/avro/tests/test_deserializer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
"""Tests for avro.deserializer"""
import json
from datetime import datetime
from typing import List
from typing import Dict, List
from unittest import TestCase

import ddt
from opaque_keys.edx.keys import CourseKey, UsageKey
from opaque_keys.edx.locator import LibraryLocatorV2, LibraryUsageLocatorV2

from openedx_events.event_bus.avro.deserializer import AvroSignalDeserializer, deserialize_bytes_to_event_data
from openedx_events.event_bus.avro.tests.test_utilities import (
ComplexAttrs,
EventData,
NestedAttrsWithDefaults,
NestedNonAttrs,
Expand All @@ -23,43 +25,74 @@
from openedx_events.tests.utils import FreezeSignalCacheMixin


@ddt.ddt
class TestAvroSignalDeserializerCache(TestCase, FreezeSignalCacheMixin):
"""Test AvroSignalDeserializer"""

def setUp(self) -> None:
super().setUp()
self.maxDiff = None

def test_schema_string(self):
@ddt.data(
(
SimpleAttrs,
{
"name": "CloudEvent",
"type": "record",
"doc": "Avro Event Format for CloudEvents created with openedx_events/schema",
"namespace": "simple.signal",
"fields": [
{
"name": "data",
"type": {
"name": "SimpleAttrs",
"type": "record",
"fields": [
{"name": "boolean_field", "type": "boolean"},
{"name": "int_field", "type": "long"},
{"name": "float_field", "type": "double"},
{"name": "bytes_field", "type": "bytes"},
{"name": "string_field", "type": "string"},
],
},
},
],
}
),
(
ComplexAttrs,
{
"name": "CloudEvent",
"type": "record",
"doc": "Avro Event Format for CloudEvents created with openedx_events/schema",
"namespace": "simple.signal",
"fields": [
{
"name": "data",
"type": {
"name": "ComplexAttrs",
"type": "record",
"fields": [
{"name": "list_field", "type": {"type": "array", "items": "long"}},
{"name": "dict_field", "type": {"type": "map", "values": "long"}},
],
},
},
],
}
)
)
@ddt.unpack
def test_schema_string(self, data_cls, expected_schema):
"""
Test JSON round-trip; schema creation is tested more fully in test_schema.py.
"""
SIGNAL = create_simple_signal({
"data": SimpleAttrs
"data": data_cls
})

actual_schema = json.loads(AvroSignalDeserializer(SIGNAL).schema_string())
expected_schema = {
'name': 'CloudEvent',
'type': 'record',
'doc': 'Avro Event Format for CloudEvents created with openedx_events/schema',
'namespace': 'simple.signal',
'fields': [
{
'name': 'data',
'type': {
'name': 'SimpleAttrs',
'type': 'record',
'fields': [
{'name': 'boolean_field', 'type': 'boolean'},
{'name': 'int_field', 'type': 'long'},
{'name': 'float_field', 'type': 'double'},
{'name': 'bytes_field', 'type': 'bytes'},
{'name': 'string_field', 'type': 'string'},
]
}
}
]
}

assert actual_schema == expected_schema

def test_convert_dict_to_event_data(self):
Expand Down Expand Up @@ -233,6 +266,59 @@ def test_deserialization_of_list_without_annotation(self):
with self.assertRaises(TypeError):
deserializer.from_dict(initial_dict)

def test_deserialization_of_dict_with_annotation(self):
"""
Check that deserialization works as expected when dict data is annotated.
"""
DICT_SIGNAL = create_simple_signal({"dict_input": Dict[str, int]})
initial_dict = {"dict_input": {"key1": 1, "key2": 3}}

deserializer = AvroSignalDeserializer(DICT_SIGNAL)
event_data = deserializer.from_dict(initial_dict)
expected_event_data = {"key1": 1, "key2": 3}
test_data = event_data["dict_input"]

self.assertIsInstance(test_data, dict)
self.assertEqual(test_data, expected_event_data)

def test_deserialization_of_dict_without_annotation(self):
"""
Check that deserialization raises error when dict data is not annotated.
Create dummy signal to bypass schema check while initializing deserializer. Then,
update signal with incomplete type info to test whether correct exceptions are raised while deserializing data.
"""
SIGNAL = create_simple_signal({"dict_input": Dict[str, int]})
DICT_SIGNAL = create_simple_signal({"dict_input": Dict})
initial_dict = {"dict_input": {"key1": 1, "key2": 3}}

deserializer = AvroSignalDeserializer(SIGNAL)
deserializer.signal = DICT_SIGNAL

with self.assertRaises(TypeError):
deserializer.from_dict(initial_dict)

def test_deserialization_of_dict_with_complex_types_fails(self):
SIGNAL = create_simple_signal({"dict_input": Dict[str, list]})
with self.assertRaises(TypeError):
AvroSignalDeserializer(SIGNAL)
initial_dict = {"dict_input": {"key1": [1, 3], "key2": [4, 5]}}
# create dummy signal to bypass schema check while initializing deserializer
# This allows us to test whether correct exceptions are raised while deserializing data
DUMMY_SIGNAL = create_simple_signal({"dict_input": Dict[str, int]})
deserializer = AvroSignalDeserializer(DUMMY_SIGNAL)
# Update signal with incorrect type info
deserializer.signal = SIGNAL
with self.assertRaises(TypeError):
deserializer.from_dict(initial_dict)

def test_deserialization_of_dicts_with_keys_of_complex_types_fails(self):
SIGNAL = create_simple_signal({"dict_input": Dict[CourseKey, int]})
deserializer = AvroSignalDeserializer(SIGNAL)
initial_dict = {"dict_input": {CourseKey.from_string("course-v1:edX+DemoX.1+2014"): 1}}
with self.assertRaises(TypeError):
deserializer.from_dict(initial_dict)

def test_deserialization_of_nested_list_fails(self):
"""
Check that deserialization raises error when nested list data is passed.
Expand Down
28 changes: 26 additions & 2 deletions openedx_events/event_bus/avro/tests/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Tests for event_bus.avro.schema module
"""
from typing import List
from typing import Dict, List
from unittest import TestCase

from openedx_events.event_bus.avro.schema import schema_from_signal
Expand Down Expand Up @@ -245,8 +245,9 @@ class UnextendedClass:

def test_throw_exception_to_list_or_dict_types_without_annotation(self):
LIST_SIGNAL = create_simple_signal({"list_input": list})
DICT_SIGNAL = create_simple_signal({"list_input": dict})
DICT_SIGNAL = create_simple_signal({"dict_input": dict})
LIST_WITHOUT_ANNOTATION_SIGNAL = create_simple_signal({"list_input": List})
DICT_WITHOUT_ANNOTATION_SIGNAL = create_simple_signal({"dict_input": Dict})
with self.assertRaises(Exception):
schema_from_signal(LIST_SIGNAL)

Expand All @@ -256,6 +257,14 @@ def test_throw_exception_to_list_or_dict_types_without_annotation(self):
with self.assertRaises(TypeError):
schema_from_signal(LIST_WITHOUT_ANNOTATION_SIGNAL)

with self.assertRaises(TypeError):
schema_from_signal(DICT_WITHOUT_ANNOTATION_SIGNAL)

def test_throw_exception_invalid_dict_annotation(self):
INVALID_DICT_SIGNAL = create_simple_signal({"dict_input": Dict[str, NestedAttrsWithDefaults]})
with self.assertRaises(TypeError):
schema_from_signal(INVALID_DICT_SIGNAL)

def test_list_with_annotation_works(self):
LIST_SIGNAL = create_simple_signal({"list_input": List[int]})
expected_dict = {
Expand All @@ -270,3 +279,18 @@ def test_list_with_annotation_works(self):
}
schema = schema_from_signal(LIST_SIGNAL)
self.assertDictEqual(schema, expected_dict)

def test_dict_with_annotation_works(self):
DICT_SIGNAL = create_simple_signal({"dict_input": Dict[str, int]})
expected_dict = {
'name': 'CloudEvent',
'type': 'record',
'doc': 'Avro Event Format for CloudEvents created with openedx_events/schema',
'namespace': 'simple.signal',
'fields': [{
'name': 'dict_input',
'type': {'type': 'map', 'values': 'long'},
}],
}
schema = schema_from_signal(DICT_SIGNAL)
self.assertDictEqual(schema, expected_dict)
7 changes: 7 additions & 0 deletions openedx_events/event_bus/avro/tests/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ class SimpleAttrs:
string_field: str


@attr.s(auto_attribs=True)
class ComplexAttrs:
"""Class with all complex type fields"""
list_field: list[int]
dict_field: dict[str, int]


@attr.s(auto_attribs=True)
class SubTestData0:
"""Subclass for testing nested attrs"""
Expand Down
2 changes: 1 addition & 1 deletion openedx_events/event_bus/avro/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@
PYTHON_TYPE_TO_AVRO_MAPPING = {
**SIMPLE_PYTHON_TYPE_TO_AVRO_MAPPING,
None: "null",
dict: "record",
dict: "map",
list: "array",
}

0 comments on commit c3d7aca

Please sign in to comment.