diff --git a/reccmp/isledecomp/compare/core.py b/reccmp/isledecomp/compare/core.py index 0acba9cb..d51db0e6 100644 --- a/reccmp/isledecomp/compare/core.py +++ b/reccmp/isledecomp/compare/core.py @@ -144,6 +144,8 @@ def _load_cvdump(self): # In the rare case we have duplicate symbols for an address, ignore them. dataset = {} + batch = self._db.batch() + for sym in self.cvdump_analysis.nodes: # Skip nodes where we have almost no information. # These probably came from SECTION CONTRIBUTIONS. @@ -215,15 +217,15 @@ def _load_cvdump(self): except UnicodeDecodeError: pass - dataset[addr] = { - "type": sym.node_type, - "name": sym.name(), - "symbol": sym.decorated_name, - "size": sym.size(), - } + batch.set_recomp( + addr, + type=sym.node_type, + name=sym.name(), + symbol=sym.decorated_name, + size=sym.size(), + ) - # Convert dict of dicts (keyed by addr) to list of dicts (that contains the addr) - self._db.bulk_recomp_insert(dataset.items()) + batch.commit() for (section, offset), ( filename, @@ -233,7 +235,9 @@ def _load_cvdump(self): self._lines_db.add_line(filename, line_no, addr) # The _entry symbol is referenced in the PE header so we get this match for free. - self._db.set_function_pair(self.orig_bin.entry, self.recomp_bin.entry) + with self._db.batch() as batch: + batch.set_recomp(self.recomp_bin.entry, type=EntityType.FUNCTION) + batch.match(self.orig_bin.entry, self.recomp_bin.entry) def _load_markers(self): codefiles = list(walk_source_dir(self.code_dir)) @@ -311,8 +315,8 @@ def _match_array_elements(self): Note that there is no recursion, so an array of arrays would not be handled entirely. This step is necessary e.g. for `0x100f0a20` (LegoRacers.cpp). """ - dataset: dict[int, dict[str, str]] = {} - orig_by_recomp: dict[int, int] = {} + seen_recomp = set() + batch = self._db.batch() # Helper function def _add_match_in_array( @@ -320,12 +324,13 @@ def _add_match_in_array( ): # pylint: disable=unused-argument # TODO: Previously used scalar_type_pointer(type_id) to set whether this is a pointer - if recomp_addr in dataset: + if recomp_addr in seen_recomp: return - dataset[recomp_addr] = {"name": name} + seen_recomp.add(recomp_addr) + batch.set_recomp(recomp_addr, name=name) if orig_addr < max_orig: - orig_by_recomp[recomp_addr] = orig_addr + batch.match(orig_addr, recomp_addr) # Indexed by recomp addr. Need to preload this data because it is not stored alongside the db rows. cvdump_lookup = {x.addr: x for x in self.cvdump_analysis.nodes} @@ -386,17 +391,7 @@ def _add_match_in_array( upper_bound, ) - # Upsert here to update the starting address of variables already in the db. - self._db.bulk_recomp_insert( - ((addr, {"name": values["name"]}) for addr, values in dataset.items()), - upsert=True, - ) - self._db.bulk_match( - ( - (orig_addr, recomp_addr) - for recomp_addr, orig_addr in orig_by_recomp.items() - ) - ) + batch.commit() def _find_original_strings(self): """Go to the original binary and look for the specified string constants @@ -434,31 +429,33 @@ def is_real_string(s: str) -> bool: # We could try to match the string addrs if there is only one in orig and recomp. # When we sanitize the asm, the result is the same regardless. if self.orig_bin.is_debug: - for addr, string in self.orig_bin.iter_string("latin1"): - if is_real_string(string): - self._db.set_orig_symbol( - addr, type=EntityType.STRING, name=string, size=len(string) - ) + with self._db.batch() as batch: + for addr, string in self.orig_bin.iter_string("latin1"): + if is_real_string(string): + batch.insert_orig( + addr, type=EntityType.STRING, name=string, size=len(string) + ) - for addr, string in self.recomp_bin.iter_string("latin1"): - if is_real_string(string): - self._db.set_recomp_symbol( - addr, type=EntityType.STRING, name=string, size=len(string) - ) + for addr, string in self.recomp_bin.iter_string("latin1"): + if is_real_string(string): + batch.insert_recomp( + addr, type=EntityType.STRING, name=string, size=len(string) + ) def _find_float_const(self): """Add floating point constants in each binary to the database. We are not matching anything right now because these values are not deduped like strings.""" - for addr, size, float_value in find_float_consts(self.orig_bin): - self._db.set_orig_symbol( - addr, type=EntityType.FLOAT, name=str(float_value), size=size - ) + with self._db.batch() as batch: + for addr, size, float_value in find_float_consts(self.orig_bin): + batch.insert_orig( + addr, type=EntityType.FLOAT, name=str(float_value), size=size + ) - for addr, size, float_value in find_float_consts(self.recomp_bin): - self._db.set_recomp_symbol( - addr, type=EntityType.FLOAT, name=str(float_value), size=size - ) + for addr, size, float_value in find_float_consts(self.recomp_bin): + batch.insert_recomp( + addr, type=EntityType.FLOAT, name=str(float_value), size=size + ) def _match_imports(self): """We can match imported functions based on the DLL name and diff --git a/reccmp/isledecomp/compare/db.py b/reccmp/isledecomp/compare/db.py index f35d78f0..bd268586 100644 --- a/reccmp/isledecomp/compare/db.py +++ b/reccmp/isledecomp/compare/db.py @@ -141,6 +141,89 @@ def matched_entity_factory(_, row: object) -> ReccmpMatch: logger = logging.getLogger(__name__) +class EntityBatch: + base: "EntityDb" + + # To be inserted only if the address is unused + _orig_insert: dict[int, dict[str, Any]] + _recomp_insert: dict[int, dict[str, Any]] + + # To be upserted + _orig: dict[int, dict[str, Any]] + _recomp: dict[int, dict[str, Any]] + + # Matches + _orig_to_recomp: dict[int, int] + _recomp_to_orig: dict[int, int] + + def __init__(self, backref: "EntityDb") -> None: + self.base = backref + self._orig_insert = {} + self._recomp_insert = {} + self._orig = {} + self._recomp = {} + self._orig_to_recomp = {} + self._recomp_to_orig = {} + + def reset(self): + """Clear all pending changes""" + self._orig_insert.clear() + self._recomp_insert.clear() + self._orig.clear() + self._recomp.clear() + self._orig_to_recomp.clear() + self._recomp_to_orig.clear() + + def insert_orig(self, addr: int, **kwargs): + self._orig_insert.setdefault(addr, {}).update(kwargs) + + def insert_recomp(self, addr: int, **kwargs): + self._recomp_insert.setdefault(addr, {}).update(kwargs) + + def set_orig(self, addr: int, **kwargs): + self._orig.setdefault(addr, {}).update(kwargs) + + def set_recomp(self, addr: int, **kwargs): + self._recomp.setdefault(addr, {}).update(kwargs) + + def match(self, orig: int, recomp: int): + # Integrity check: orig and recomp addr must be used only once + if (used_orig := self._recomp_to_orig.pop(recomp, None)) is not None: + self._orig_to_recomp.pop(used_orig, None) + + self._orig_to_recomp[orig] = recomp + self._recomp_to_orig[recomp] = orig + + def commit(self): + # SQL transaction + with self.base.sql: + if self._orig_insert: + self.base.bulk_orig_insert(self._orig_insert.items()) + + if self._recomp_insert: + self.base.bulk_recomp_insert(self._recomp_insert.items()) + + if self._orig: + self.base.bulk_orig_insert(self._orig.items(), upsert=True) + + if self._recomp: + self.base.bulk_recomp_insert(self._recomp.items(), upsert=True) + + if self._orig_to_recomp: + self.base.bulk_match(self._orig_to_recomp.items()) + + self.reset() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + if exc_type is not None: + self.reset() + else: + self.commit() + + class EntityDb: # pylint: disable=too-many-public-methods def __init__(self): @@ -152,6 +235,9 @@ def __init__(self): def sql(self) -> sqlite3.Connection: return self._sql + def batch(self) -> EntityBatch: + return EntityBatch(self) + def set_orig_symbol(self, addr: int, **kwargs): self.bulk_orig_insert(iter([(addr, kwargs)])) @@ -191,10 +277,28 @@ def bulk_recomp_insert( ) def bulk_match(self, pairs: Iterable[tuple[int, int]]): - """Expects iterable of `(orig_addr, recomp_addr)`.""" - self._sql.executemany( - "UPDATE or ignore entities SET orig_addr = ? WHERE recomp_addr = ?", pairs - ) + """Expects iterable of (orig_addr, recomp_addr).""" + # We need to iterate over this multiple times. + pairlist = list(pairs) + + with self._sql: + # Copy orig information to recomp side. Prefer recomp information except for NULLS. + # json_patch(X, Y) copies keys from Y into X and replaces existing values. + # From inner-most to outer-most: + # - json_patch('{}', entities.kvstore) Eliminate NULLS on recomp side (so orig will replace) + # - json_patch(o.kvstore, ^) Merge orig and recomp keys. Prefer recomp values. + self._sql.executemany( + """UPDATE entities + SET kvstore = json_patch(o.kvstore, json_patch('{}', entities.kvstore)) + FROM (SELECT kvstore FROM entities WHERE orig_addr = ? and recomp_addr is null) o + WHERE recomp_addr = ? AND orig_addr is null""", + pairlist, + ) + # Patch orig address into recomp and delete orig entry. + self._sql.executemany( + "UPDATE OR REPLACE entities SET orig_addr = ? WHERE recomp_addr = ? AND orig_addr is null", + pairlist, + ) def get_unmatched_strings(self) -> list[str]: """Return any strings not already identified by `STRING` markers.""" diff --git a/tests/test_compare_db.py b/tests/test_compare_db.py index 8d1ab063..c1268d0b 100644 --- a/tests/test_compare_db.py +++ b/tests/test_compare_db.py @@ -1,5 +1,7 @@ """Testing compare database behavior, particularly matching""" +import sqlite3 +from unittest.mock import patch import pytest from reccmp.isledecomp.compare.db import EntityDb @@ -84,3 +86,261 @@ def test_dynamic_metadata(db): # Should preserve boolean type assert isinstance(obj.get("option"), bool) assert obj.get("option") is True + + +#### Testing new batch API #### + + +def test_batch(db): + """Demonstrate batch with context manager""" + with db.batch() as batch: + batch.set_orig(100, name="Hello") + batch.set_recomp(200, name="Test") + + assert db.get_by_orig(100).name == "Hello" + assert db.get_by_recomp(200).name == "Test" + + +def test_batch_replace(db): + """Calling the set or insert methods again on the same address and data will replace the pending value.""" + with db.batch() as batch: + batch.set_orig(100, name="") + batch.insert_orig(200, name="") + batch.set_recomp(100, name="") + batch.insert_recomp(200, name="") + + batch.set_orig(100, name="Orig100") + batch.insert_orig(200, name="Orig200") + batch.set_recomp(100, name="Recomp100") + batch.insert_recomp(200, name="Recomp200") + + assert db.get_by_orig(100).name == "Orig100" + assert db.get_by_orig(200).name == "Orig200" + assert db.get_by_recomp(100).name == "Recomp100" + assert db.get_by_recomp(200).name == "Recomp200" + + +def test_batch_insert_overwrite(db): + """Inserts and sets on the same address in the same batch will result in the + 'insert' values being replaced.""" + with db.batch() as batch: + batch.insert_orig(100, name="Test") + batch.set_orig(100, name="Hello", test=123) + batch.insert_recomp(100, name="Test") + batch.set_recomp(100, name="Hello", test=123) + + assert db.get_by_orig(100).name == "Hello" + assert db.get_by_orig(100).get("test") == 123 + + assert db.get_by_recomp(100).name == "Hello" + assert db.get_by_recomp(100).get("test") == 123 + + +def test_batch_insert(db): + """The 'insert' methods will abort if any data exists for the address""" + db.set_orig_symbol(100, name="Hello") + db.set_recomp_symbol(200, name="Test") + + with db.batch() as batch: + batch.insert_orig(100, name="abc") + batch.insert_recomp(200, name="xyz") + + assert db.get_by_orig(100).name != "abc" + assert db.get_by_recomp(200).name != "xyz" + + +def test_batch_upsert(db): + """The 'set' methods overwrite existing values""" + db.set_orig_symbol(100, name="Hello") + db.set_recomp_symbol(200, name="Test") + + with db.batch() as batch: + batch.set_orig(100, name="abc") + batch.set_recomp(200, name="xyz") + + assert db.get_by_orig(100).name == "abc" + assert db.get_by_recomp(200).name == "xyz" + + +def test_batch_match_attach(db): + """Match example with new orig addr. + There is no existing entity with the orig addr being matched.""" + with db.batch() as batch: + batch.set_recomp(200, name="Hello") + batch.match(100, 200) + + # Confirm match + assert db.get_by_orig(100).name == "Hello" + + +def test_batch_match_combine(db): + """Match example with existing orig addr.""" + with db.batch() as batch: + batch.set_orig(100, name="Test") + batch.set_recomp(200, name="Hello") + + # Two entities + assert len([*db.get_all()]) == 2 + + # Use separate batches to demonstrate + with db.batch() as batch: + batch.match(100, 200) + + # Should combine + assert len([*db.get_all()]) == 1 + + # Confirm match. Both entities have the "name" attribute. Should use recomp value. + assert db.get_by_orig(100).recomp_addr == 200 + assert db.get_by_orig(100).name == "Hello" + + +def test_batch_match_combine_except_null(db): + """We prefer recomp attributes when combining two entities. + The exception is when the recomp entity has a NULL. We should use the orig attribute in this case. + """ + with db.batch() as batch: + batch.set_orig(100, name="Test", test=123) + batch.set_recomp(200, name="Hello", test=None) + batch.match(100, 200) + + assert db.get_by_recomp(200).get("test") == 123 + + +def test_batch_match_combine_replace_null(db): + """Confirm that we will replace a NULL on the orig side with a recomp value.""" + with db.batch() as batch: + batch.set_orig(100, name="Test", test=None) + batch.set_recomp(200, name="Hello", test=123) + batch.match(100, 200) + + assert db.get_by_recomp(200).get("test") == 123 + + +@pytest.mark.xfail(reason="Known limitation.") +def test_batch_match_create(db): + """Matching requires either the orig or recomp entity to exist. It does not create entities.""" + with db.batch() as batch: + batch.match(100, 200) + + assert db.get_by_orig(100).recomp_addr == 200 + + +def test_batch_commit_twice(db): + """Calling commit() clears the pending updates. + Calling commit() again without adding new changes will not alter the database.""" + batch = db.batch() + batch.set_orig(100, name="Test") + + with patch("reccmp.isledecomp.compare.db.EntityDb.bulk_orig_insert") as mock: + batch.commit() + batch.commit() + mock.assert_called_once() + + with patch("reccmp.isledecomp.compare.db.EntityDb.bulk_orig_insert") as mock: + batch.commit() + mock.assert_not_called() + + +def test_batch_cannot_alter_matched(db): + """batch.match() will not change an entity that is already matched.""" + + # Set up the match + with db.batch() as batch: + batch.set_recomp(200, name="Test") + batch.match(100, 200) + + # Confirm it is there + assert db.get_by_orig(100).recomp_addr == 200 + + # Try to change recomp=200 to match orig=101 + with db.batch() as batch: + batch.match(101, 200) + + # Should not change it + assert db.get_by_recomp(200).orig_addr == 100 + + +def test_batch_change_staged_match(db): + """You can change an unsaved match by calling match() again on the same orig addr.""" + with db.batch() as batch: + batch.set_recomp(200, name="Hello") + batch.set_recomp(201, name="Test") + batch.match(100, 200) + batch.match(100, 201) + + assert db.get_by_orig(100).recomp_addr == 201 + assert db.get_by_recomp(200).orig_addr is None + + +def test_batch_match_repeat_recomp_addr(db): + """Calling match() with the same recomp addr should work the same as the orig addr case. + Discard the first match in favor of the new one.""" + with db.batch() as batch: + batch.set_recomp(200, name="Hello") + batch.set_recomp(201, name="Test") + batch.match(100, 200) + batch.match(101, 200) + + assert db.get_by_recomp(200).orig_addr == 101 + assert db.get_by_orig(100) is None + + +def test_batch_exception_uncaught(db): + """When using batch context manager, an uncaught exception should clear the staged changes.""" + try: + with db.batch() as batch: + batch.set_orig(100, name="Test") + batch.set_recomp(200, test=123) + batch.match(100, 200) + _ = 1 / 0 + except ZeroDivisionError: + pass + + assert db.get_by_orig(100) is None + assert db.get_by_orig(200) is None + + +def test_batch_exception_caught(db): + """If the exception is caught, allow the batch to go through.""" + with db.batch() as batch: + batch.set_orig(100, name="Test") + batch.set_recomp(200, test=123) + batch.match(100, 200) + try: + _ = 1 / 0 + except ZeroDivisionError: + pass + + assert db.get_by_orig(100) is not None + assert db.get_by_recomp(200) is not None + + +def test_batch_sqlite_exception(db): + """Should rollback if an exception occurs during the commit.""" + + # Not using batch context for clarity + batch = db.batch() + batch.set_orig(100, name="Test") + batch.set_recomp(200, test=123) + + # Insert bad data that will cause a binding error + batch.match(100, ("bogus",)) + + with pytest.raises(sqlite3.Error): + batch.commit() + + # Should rollback everything + assert db.get_by_orig(100) is None + assert db.get_by_recomp(200) is None + + +def test_batch_sqlite_exception_insert_only(db): + """Should rollback even if we don't start the explicit transaction in match()""" + batch = db.batch() + batch.insert_orig(100, name="Test") + batch.insert_orig(("bogus",), name="Test") + + with pytest.raises(sqlite3.Error): + batch.commit() + + assert db.get_by_orig(100) is None