Skip to content

Commit

Permalink
Database batch update API (#81)
Browse files Browse the repository at this point in the history
* Database batch update

* Retest

* Handle exceptions in batch context manager

* Use explicit transaction for batch commit

* Use sqlite3.Error parent class for 3.10
  • Loading branch information
disinvite authored Jan 30, 2025
1 parent 5db929e commit 0fbd24c
Show file tree
Hide file tree
Showing 3 changed files with 408 additions and 47 deletions.
83 changes: 40 additions & 43 deletions reccmp/isledecomp/compare/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ def _load_cvdump(self):
# In the rare case we have duplicate symbols for an address, ignore them.
dataset = {}

batch = self._db.batch()

for sym in self.cvdump_analysis.nodes:
# Skip nodes where we have almost no information.
# These probably came from SECTION CONTRIBUTIONS.
Expand Down Expand Up @@ -215,15 +217,15 @@ def _load_cvdump(self):
except UnicodeDecodeError:
pass

dataset[addr] = {
"type": sym.node_type,
"name": sym.name(),
"symbol": sym.decorated_name,
"size": sym.size(),
}
batch.set_recomp(
addr,
type=sym.node_type,
name=sym.name(),
symbol=sym.decorated_name,
size=sym.size(),
)

# Convert dict of dicts (keyed by addr) to list of dicts (that contains the addr)
self._db.bulk_recomp_insert(dataset.items())
batch.commit()

for (section, offset), (
filename,
Expand All @@ -233,7 +235,9 @@ def _load_cvdump(self):
self._lines_db.add_line(filename, line_no, addr)

# The _entry symbol is referenced in the PE header so we get this match for free.
self._db.set_function_pair(self.orig_bin.entry, self.recomp_bin.entry)
with self._db.batch() as batch:
batch.set_recomp(self.recomp_bin.entry, type=EntityType.FUNCTION)
batch.match(self.orig_bin.entry, self.recomp_bin.entry)

def _load_markers(self):
codefiles = list(walk_source_dir(self.code_dir))
Expand Down Expand Up @@ -311,21 +315,22 @@ def _match_array_elements(self):
Note that there is no recursion, so an array of arrays would not be handled entirely.
This step is necessary e.g. for `0x100f0a20` (LegoRacers.cpp).
"""
dataset: dict[int, dict[str, str]] = {}
orig_by_recomp: dict[int, int] = {}
seen_recomp = set()
batch = self._db.batch()

# Helper function
def _add_match_in_array(
name: str, type_id: str, orig_addr: int, recomp_addr: int, max_orig: int
):
# pylint: disable=unused-argument
# TODO: Previously used scalar_type_pointer(type_id) to set whether this is a pointer
if recomp_addr in dataset:
if recomp_addr in seen_recomp:
return

dataset[recomp_addr] = {"name": name}
seen_recomp.add(recomp_addr)
batch.set_recomp(recomp_addr, name=name)
if orig_addr < max_orig:
orig_by_recomp[recomp_addr] = orig_addr
batch.match(orig_addr, recomp_addr)

# Indexed by recomp addr. Need to preload this data because it is not stored alongside the db rows.
cvdump_lookup = {x.addr: x for x in self.cvdump_analysis.nodes}
Expand Down Expand Up @@ -386,17 +391,7 @@ def _add_match_in_array(
upper_bound,
)

# Upsert here to update the starting address of variables already in the db.
self._db.bulk_recomp_insert(
((addr, {"name": values["name"]}) for addr, values in dataset.items()),
upsert=True,
)
self._db.bulk_match(
(
(orig_addr, recomp_addr)
for recomp_addr, orig_addr in orig_by_recomp.items()
)
)
batch.commit()

def _find_original_strings(self):
"""Go to the original binary and look for the specified string constants
Expand Down Expand Up @@ -434,31 +429,33 @@ def is_real_string(s: str) -> bool:
# We could try to match the string addrs if there is only one in orig and recomp.
# When we sanitize the asm, the result is the same regardless.
if self.orig_bin.is_debug:
for addr, string in self.orig_bin.iter_string("latin1"):
if is_real_string(string):
self._db.set_orig_symbol(
addr, type=EntityType.STRING, name=string, size=len(string)
)
with self._db.batch() as batch:
for addr, string in self.orig_bin.iter_string("latin1"):
if is_real_string(string):
batch.insert_orig(
addr, type=EntityType.STRING, name=string, size=len(string)
)

for addr, string in self.recomp_bin.iter_string("latin1"):
if is_real_string(string):
self._db.set_recomp_symbol(
addr, type=EntityType.STRING, name=string, size=len(string)
)
for addr, string in self.recomp_bin.iter_string("latin1"):
if is_real_string(string):
batch.insert_recomp(
addr, type=EntityType.STRING, name=string, size=len(string)
)

def _find_float_const(self):
"""Add floating point constants in each binary to the database.
We are not matching anything right now because these values are not
deduped like strings."""
for addr, size, float_value in find_float_consts(self.orig_bin):
self._db.set_orig_symbol(
addr, type=EntityType.FLOAT, name=str(float_value), size=size
)
with self._db.batch() as batch:
for addr, size, float_value in find_float_consts(self.orig_bin):
batch.insert_orig(
addr, type=EntityType.FLOAT, name=str(float_value), size=size
)

for addr, size, float_value in find_float_consts(self.recomp_bin):
self._db.set_recomp_symbol(
addr, type=EntityType.FLOAT, name=str(float_value), size=size
)
for addr, size, float_value in find_float_consts(self.recomp_bin):
batch.insert_recomp(
addr, type=EntityType.FLOAT, name=str(float_value), size=size
)

def _match_imports(self):
"""We can match imported functions based on the DLL name and
Expand Down
112 changes: 108 additions & 4 deletions reccmp/isledecomp/compare/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,89 @@ def matched_entity_factory(_, row: object) -> ReccmpMatch:
logger = logging.getLogger(__name__)


class EntityBatch:
base: "EntityDb"

# To be inserted only if the address is unused
_orig_insert: dict[int, dict[str, Any]]
_recomp_insert: dict[int, dict[str, Any]]

# To be upserted
_orig: dict[int, dict[str, Any]]
_recomp: dict[int, dict[str, Any]]

# Matches
_orig_to_recomp: dict[int, int]
_recomp_to_orig: dict[int, int]

def __init__(self, backref: "EntityDb") -> None:
self.base = backref
self._orig_insert = {}
self._recomp_insert = {}
self._orig = {}
self._recomp = {}
self._orig_to_recomp = {}
self._recomp_to_orig = {}

def reset(self):
"""Clear all pending changes"""
self._orig_insert.clear()
self._recomp_insert.clear()
self._orig.clear()
self._recomp.clear()
self._orig_to_recomp.clear()
self._recomp_to_orig.clear()

def insert_orig(self, addr: int, **kwargs):
self._orig_insert.setdefault(addr, {}).update(kwargs)

def insert_recomp(self, addr: int, **kwargs):
self._recomp_insert.setdefault(addr, {}).update(kwargs)

def set_orig(self, addr: int, **kwargs):
self._orig.setdefault(addr, {}).update(kwargs)

def set_recomp(self, addr: int, **kwargs):
self._recomp.setdefault(addr, {}).update(kwargs)

def match(self, orig: int, recomp: int):
# Integrity check: orig and recomp addr must be used only once
if (used_orig := self._recomp_to_orig.pop(recomp, None)) is not None:
self._orig_to_recomp.pop(used_orig, None)

self._orig_to_recomp[orig] = recomp
self._recomp_to_orig[recomp] = orig

def commit(self):
# SQL transaction
with self.base.sql:
if self._orig_insert:
self.base.bulk_orig_insert(self._orig_insert.items())

if self._recomp_insert:
self.base.bulk_recomp_insert(self._recomp_insert.items())

if self._orig:
self.base.bulk_orig_insert(self._orig.items(), upsert=True)

if self._recomp:
self.base.bulk_recomp_insert(self._recomp.items(), upsert=True)

if self._orig_to_recomp:
self.base.bulk_match(self._orig_to_recomp.items())

self.reset()

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, exc_traceback):
if exc_type is not None:
self.reset()
else:
self.commit()


class EntityDb:
# pylint: disable=too-many-public-methods
def __init__(self):
Expand All @@ -152,6 +235,9 @@ def __init__(self):
def sql(self) -> sqlite3.Connection:
return self._sql

def batch(self) -> EntityBatch:
return EntityBatch(self)

def set_orig_symbol(self, addr: int, **kwargs):
self.bulk_orig_insert(iter([(addr, kwargs)]))

Expand Down Expand Up @@ -191,10 +277,28 @@ def bulk_recomp_insert(
)

def bulk_match(self, pairs: Iterable[tuple[int, int]]):
"""Expects iterable of `(orig_addr, recomp_addr)`."""
self._sql.executemany(
"UPDATE or ignore entities SET orig_addr = ? WHERE recomp_addr = ?", pairs
)
"""Expects iterable of (orig_addr, recomp_addr)."""
# We need to iterate over this multiple times.
pairlist = list(pairs)

with self._sql:
# Copy orig information to recomp side. Prefer recomp information except for NULLS.
# json_patch(X, Y) copies keys from Y into X and replaces existing values.
# From inner-most to outer-most:
# - json_patch('{}', entities.kvstore) Eliminate NULLS on recomp side (so orig will replace)
# - json_patch(o.kvstore, ^) Merge orig and recomp keys. Prefer recomp values.
self._sql.executemany(
"""UPDATE entities
SET kvstore = json_patch(o.kvstore, json_patch('{}', entities.kvstore))
FROM (SELECT kvstore FROM entities WHERE orig_addr = ? and recomp_addr is null) o
WHERE recomp_addr = ? AND orig_addr is null""",
pairlist,
)
# Patch orig address into recomp and delete orig entry.
self._sql.executemany(
"UPDATE OR REPLACE entities SET orig_addr = ? WHERE recomp_addr = ? AND orig_addr is null",
pairlist,
)

def get_unmatched_strings(self) -> list[str]:
"""Return any strings not already identified by `STRING` markers."""
Expand Down
Loading

0 comments on commit 0fbd24c

Please sign in to comment.