Skip to content

Commit

Permalink
Do not serialize non-default fields by default (#1452)
Browse files Browse the repository at this point in the history
Added a configuration entry (enabled by default) that serializes only
the modified fields in an SDFG. This leads to a reduction in size.

Merging this PR is contingent on updating the SDFG renderer to use the
defaults/metadata for properties.

---------

Co-authored-by: Philipp Schaad <[email protected]>
  • Loading branch information
tbennun and phschaad authored Dec 4, 2023
1 parent b0cd25b commit 6374843
Show file tree
Hide file tree
Showing 11 changed files with 68 additions and 25 deletions.
14 changes: 11 additions & 3 deletions dace/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def generate_code(sdfg, validate=True) -> List[CodeObject]:

if Config.get_bool('testing', 'serialization'):
from dace.sdfg import SDFG
import difflib
import filecmp
import shutil
import tempfile
Expand All @@ -174,9 +175,16 @@ def generate_code(sdfg, validate=True) -> List[CodeObject]:
sdfg2.save(f'{tmp_dir}/test2.sdfg', hash=False)
print('Testing SDFG serialization...')
if not filecmp.cmp(f'{tmp_dir}/test.sdfg', f'{tmp_dir}/test2.sdfg'):
shutil.move(f"{tmp_dir}/test.sdfg", "test.sdfg")
shutil.move(f"{tmp_dir}/test2.sdfg", "test2.sdfg")
raise RuntimeError('SDFG serialization failed - files do not match')
with open(f'{tmp_dir}/test.sdfg', 'r') as f1:
with open(f'{tmp_dir}/test2.sdfg', 'r') as f2:
diff = difflib.unified_diff(f1.readlines(),
f2.readlines(),
fromfile='test.sdfg (first save)',
tofile='test2.sdfg (after roundtrip)')
diff = ''.join(diff)
shutil.move(f'{tmp_dir}/test.sdfg', 'test.sdfg')
shutil.move(f'{tmp_dir}/test2.sdfg', 'test2.sdfg')
raise RuntimeError(f'SDFG serialization failed - files do not match:\n{diff}')

# Convert any loop constructs with hierarchical loop regions into simple 1-level state machine loops.
# TODO (later): Adapt codegen to deal with hierarchical CFGs instead.
Expand Down
4 changes: 2 additions & 2 deletions dace/codegen/targets/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1914,7 +1914,7 @@ def _generate_ConsumeEntry(
'size_t')

# Take quiescence condition into account
if node.consume.condition.code is not None:
if node.consume.condition is not None:
condition_string = "[&]() { return %s; }, " % cppunparse.cppunparse(node.consume.condition.code, False)
else:
condition_string = ""
Expand All @@ -1933,7 +1933,7 @@ def _generate_ConsumeEntry(
"{num_pes}, {condition}"
"[&](int {pe_index}, {element_or_chunk}) {{".format(
chunksz=node.consume.chunksize,
cond="" if node.consume.condition.code is None else "_cond",
cond="" if node.consume.condition is None else "_cond",
condition=condition_string,
stream_in=input_stream.data, # TODO: stream arrays
element_or_chunk=chunk,
Expand Down
8 changes: 8 additions & 0 deletions dace/config_schema.yml
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,14 @@ required:
When an exception is raised in a deserialization process (e.g., due to missing library node),
by default a warning is issued. If this setting is True, the exception will be raised as-is.
serialize_all_fields:
type: bool
default: false
title: Serialize all unmodified fields in SDFG files
description: >
If False (default), saving an SDFG keeps only the modified non-default properties. If True,
saves all fields.
#############################################
# DaCe library settings

Expand Down
10 changes: 10 additions & 0 deletions dace/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,14 @@ def as_string(self, code):
else:
self.code = code

def __eq__(self, other):
if isinstance(other, str) or other is None:
return self.as_string == other
elif isinstance(other, CodeBlock):
return self.as_string == other.as_string and self.language == other.language
else:
return super().__eq__(other)

def to_json(self):
# Two roundtrips to avoid issues in AST parsing/unparsing of negative
# numbers, i.e., "(-1)" becomes "(- 1)"
Expand Down Expand Up @@ -1382,6 +1390,8 @@ def to_json(self, obj):
def from_json(obj, context=None):
if obj is None:
return None
elif isinstance(obj, typeclass):
return obj
elif isinstance(obj, str):
return TypeClassProperty.from_string(obj)
elif isinstance(obj, dict):
Expand Down
8 changes: 5 additions & 3 deletions dace/sdfg/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,8 +1005,10 @@ def __str__(self):
@property
def free_symbols(self) -> Set[str]:
dyn_inputs = set(c for c in self.in_connectors if not c.startswith('IN_'))
return ((set(self._consume.num_pes.free_symbols)
| set(self._consume.condition.get_free_symbols())) - dyn_inputs)
result = set(self._consume.num_pes.free_symbols)
if self._consume.condition is not None:
result |= set(self._consume.condition.get_free_symbols())
return result - dyn_inputs

def new_symbols(self, sdfg, state, symbols) -> Dict[str, dtypes.typeclass]:
from dace.codegen.tools.type_inference import infer_expr_type
Expand Down Expand Up @@ -1094,7 +1096,7 @@ class Consume(object):
label = Property(dtype=str, desc="Name of the consume node")
pe_index = Property(dtype=str, desc="Processing element identifier")
num_pes = SymbolicProperty(desc="Number of processing elements", default=1)
condition = CodeProperty(desc="Quiescence condition", allow_none=True)
condition = CodeProperty(desc="Quiescence condition", allow_none=True, default=None)
schedule = EnumProperty(dtype=dtypes.ScheduleType, desc="Consume schedule", default=dtypes.ScheduleType.Default)
chunksize = Property(dtype=int, desc="Maximal size of elements to consume at a time", default=1)
debuginfo = DebugInfoProperty()
Expand Down
10 changes: 8 additions & 2 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,8 @@ def to_json(self, hash=False):
tmp = super().to_json()

# Ensure properties are serialized correctly
tmp['attributes']['constants_prop'] = json.loads(dace.serialize.dumps(tmp['attributes']['constants_prop']))
if 'constants_prop' in tmp['attributes']:
tmp['attributes']['constants_prop'] = json.loads(dace.serialize.dumps(tmp['attributes']['constants_prop']))

tmp['sdfg_list_id'] = int(self.sdfg_id)
tmp['start_state'] = self._start_block
Expand All @@ -604,8 +605,13 @@ def from_json(cls, json_obj, context_info=None):
nodes = json_obj['nodes']
edges = json_obj['edges']

if 'constants_prop' in attrs:
constants_prop = dace.serialize.loads(dace.serialize.dumps(attrs['constants_prop']))
else:
constants_prop = None

ret = SDFG(name=attrs['name'],
constants=dace.serialize.loads(dace.serialize.dumps(attrs['constants_prop'])),
constants=constants_prop,
parent=context_info['sdfg'])

dace.serialize.set_properties_from_json(ret,
Expand Down
4 changes: 3 additions & 1 deletion dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1644,7 +1644,9 @@ def add_consume(self,
pe_tuple = (elements[0], SymbolicProperty.from_string(elements[1]))

debuginfo = _getdebuginfo(debuginfo or self._default_lineinfo)
consume = nd.Consume(name, pe_tuple, CodeBlock(condition, language), schedule, chunksize, debuginfo=debuginfo)
if condition is not None:
condition = CodeBlock(condition, language)
consume = nd.Consume(name, pe_tuple, condition, schedule, chunksize, debuginfo=debuginfo)
entry = nd.ConsumeEntry(consume)
exit = nd.ConsumeExit(consume)

Expand Down
3 changes: 3 additions & 0 deletions dace/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,11 @@ def dump(*args, **kwargs):


def all_properties_to_json(object_with_properties):
save_all_fields = config.Config.get_bool('testing', 'serialize_all_fields')
retdict = {}
for x, v in object_with_properties.properties():
if not save_all_fields and v == x.default: # Skip default fields
continue
if x.optional and not x.optional_condition(object_with_properties):
continue
retdict[x.attr_name] = x.to_json(v)
Expand Down
16 changes: 9 additions & 7 deletions dace/transformation/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,12 +391,13 @@ def from_json(json_obj: Dict[str, Any], context: Dict[str, Any] = None) -> 'Patt
if ext.__name__ == json_obj['transformation'])

# Recreate subgraph
expr = xform.expressions()[json_obj['expr_index']]
subgraph = {expr.node(int(k)): int(v) for k, v in json_obj['_subgraph'].items()}
expr = xform.expressions()[json_obj.get('expr_index', 0)]
subgraph = {expr.node(int(k)): int(v) for k, v in json_obj.get('_subgraph', {}).items()}

# Reconstruct transformation
ret = xform()
ret.setup_match(None, json_obj['sdfg_id'], json_obj['state_id'], subgraph, json_obj['expr_index'])
ret.setup_match(None, json_obj.get('sdfg_id', 0), json_obj.get('state_id', 0), subgraph,
json_obj.get('expr_index', 0))
context = context or {}
context['transformation'] = ret
serialize.set_properties_from_json(ret, json_obj, context=context, ignore_properties={'transformation', 'type'})
Expand Down Expand Up @@ -652,12 +653,13 @@ def from_json(json_obj: Dict[str, Any], context: Dict[str, Any] = None) -> 'Expa
xform = pydoc.locate(json_obj['classpath'])

# Recreate subgraph
expr = xform.expressions()[json_obj['expr_index']]
subgraph = {expr.node(int(k)): int(v) for k, v in json_obj['_subgraph'].items()}
expr = xform.expressions()[json_obj.get('expr_index', 0)]
subgraph = {expr.node(int(k)): int(v) for k, v in json_obj.get('_subgraph', {}).items()}

# Reconstruct transformation
ret = xform()
ret.setup_match(None, json_obj['sdfg_id'], json_obj['state_id'], subgraph, json_obj['expr_index'])
ret.setup_match(None, json_obj.get('sdfg_id', 0), json_obj.get('state_id', 0), subgraph,
json_obj.get('expr_index', 0))
context = context or {}
context['transformation'] = ret
serialize.set_properties_from_json(ret,
Expand Down Expand Up @@ -864,7 +866,7 @@ def from_json(json_obj: Dict[str, Any], context: Dict[str, Any] = None) -> 'Subg

# Reconstruct transformation
ret = xform()
ret.setup_match(json_obj['subgraph'], json_obj['sdfg_id'], json_obj['state_id'])
ret.setup_match(json_obj.get('subgraph', {}), json_obj.get('sdfg_id', 0), json_obj.get('state_id', 0))
context = context or {}
context['transformation'] = ret
serialize.set_properties_from_json(ret, json_obj, context=context, ignore_properties={'transformation', 'type'})
Expand Down
10 changes: 6 additions & 4 deletions tests/openmp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,7 @@ def test_omp_props():
break

mapnode.schedule = dtypes.ScheduleType.CPU_Multicore
json = sdfg.to_json()
assert (key_exists(json, 'omp_num_threads'))
assert (key_exists(json, 'omp_schedule'))
assert (key_exists(json, 'omp_chunk_size'))

code = sdfg.generate_code()[0].clean_code
assert ("#pragma omp parallel for" in code)

Expand All @@ -73,6 +70,11 @@ def test_omp_props():
code = sdfg.generate_code()[0].clean_code
assert ("#pragma omp parallel for schedule(guided, 5) num_threads(10)" in code)

json = sdfg.to_json()
assert (key_exists(json, 'omp_num_threads'))
assert (key_exists(json, 'omp_schedule'))
assert (key_exists(json, 'omp_chunk_size'))


def test_omp_parallel():

Expand Down
6 changes: 3 additions & 3 deletions tests/transformations/local_storage_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_in_local_storage_implicit():

# Check array was set correctly
serialized = sdfg.transformation_hist[0].to_json()
assert serialized["array"] == None
assert "array" not in serialized or serialized["array"] is None


def test_out_local_storage_explicit():
Expand Down Expand Up @@ -217,7 +217,7 @@ def test_out_local_storage_implicit():

# Check array was set correctly
serialized = sdfg.transformation_hist[0].to_json()
assert serialized["array"] == None
assert "array" not in serialized or serialized["array"] is None


@dace.program
Expand Down Expand Up @@ -250,8 +250,8 @@ def test_uneven(self):


if __name__ == '__main__':
unittest.main()
test_in_local_storage_explicit()
test_in_local_storage_implicit()
test_out_local_storage_explicit()
test_out_local_storage_implicit()
unittest.main()

0 comments on commit 6374843

Please sign in to comment.