diff --git a/reccmp/isledecomp/compare/db.py b/reccmp/isledecomp/compare/db.py index fff032e3..894f2316 100644 --- a/reccmp/isledecomp/compare/db.py +++ b/reccmp/isledecomp/compare/db.py @@ -24,8 +24,6 @@ value text, primary key (addr, name) ) without rowid; - - CREATE INDEX `symbols_na` ON `symbols` (json_extract(kvstore, '$.name')); """ @@ -60,8 +58,8 @@ def size(self) -> Optional[int]: def matched(self) -> bool: return self.orig_addr is not None and self.recomp_addr is not None - def get(self, key: str) -> Any: - return self.options.get(key) + def get(self, key: str, default: Any = None) -> Any: + return self.options.get(key, default) def match_name(self) -> Optional[str]: """Combination of the name and compare type. @@ -93,6 +91,7 @@ class CompareDb: def __init__(self): self._sql = sqlite3.connect(":memory:") self._sql.executescript(_SETUP_SQL) + self._indexed = set() @property def sql(self) -> sqlite3.Connection: @@ -388,52 +387,79 @@ def is_vtordisp(self, recomp_addr: int) -> bool: return True - def _find_potential_match( - self, name: str, compare_type: SymbolType - ) -> Optional[int]: - """Name lookup""" - match_decorate = compare_type != SymbolType.STRING and name.startswith("?") - # If the index on orig_addr is unique, sqlite will prefer to use it over the name index. - # But this index will not help if we are checking for NULL, so we exclude it - # by adding the plus sign (Reference: https://www.sqlite.org/optoverview.html#uplus) - if match_decorate: - # TODO: Change when/if decorated becomes a unique column - for (recomp_addr,) in self._sql.execute( - "SELECT recomp_addr FROM symbols WHERE json_extract(kvstore, '$.symbol') = ? AND +orig_addr IS NULL LIMIT 1", - (name,), - ): - return recomp_addr + def search_symbol(self, symbol: str) -> Iterator[MatchInfo]: + if "symbol" not in self._indexed: + self._sql.execute( + "CREATE index idx_symbol on symbols(json_extract(kvstore, '$.symbol'))" + ) + self._indexed.add("symbol") - return None + cur = self._sql.execute( + """SELECT orig_addr, recomp_addr, kvstore FROM symbols + WHERE json_extract(kvstore, '$.symbol') = ?""", + (symbol,), + ) + cur.row_factory = matchinfo_factory + yield from cur - for (reccmp_addr,) in self._sql.execute( - """ - SELECT recomp_addr - FROM `symbols` - WHERE +orig_addr IS NULL - AND json_extract(kvstore, '$.name') = ? - AND (json_extract(kvstore, '$.type') IS NULL OR json_extract(kvstore, '$.type') = ?) - LIMIT 1""", - (name, compare_type), - ): - return reccmp_addr + def search_name(self, name: str, compare_type: SymbolType) -> Iterator[MatchInfo]: + if "name" not in self._indexed: + self._sql.execute( + "CREATE index idx_name on symbols(json_extract(kvstore, '$.name'))" + ) + self._indexed.add("name") - return None + # n.b. If the name matches and the type is not set, we will return the row. + # Ideally we would have perfect information on the recomp side and not need to do this + cur = self._sql.execute( + """SELECT orig_addr, recomp_addr, kvstore FROM symbols + WHERE json_extract(kvstore, '$.name') = ? + AND (json_extract(kvstore, '$.type') IS NULL OR json_extract(kvstore, '$.type') = ?)""", + (name, compare_type), + ) + cur.row_factory = matchinfo_factory + yield from cur def _match_on(self, compare_type: SymbolType, addr: int, name: str) -> bool: - # Update the compare_type here too since the marker tells us what we should do + """Search the program listing for the given name and type, then assign the + given address to the first unmatched result.""" + # If we identify the name as a linker symbol, search for that instead. + # TODO: Will need a customizable "name_is_symbol" function for other platforms + if compare_type != SymbolType.STRING and name.startswith("?"): + for obj in self.search_symbol(name): + if obj.orig_addr is None and obj.recomp_addr is not None: + return self.set_pair(addr, obj.recomp_addr, compare_type) + + return False # Truncate the name to 255 characters. It will not be possible to match a name - # longer than that because MSVC truncates the debug symbols to this length. + # longer than that because MSVC truncates to this length. # See also: warning C4786. name = name[:255] - logger.debug("Looking for %s %s", compare_type.name.lower(), name) - recomp_addr = self._find_potential_match(name, compare_type) - if recomp_addr is None: - return False + for obj in self.search_name(name, compare_type): + if obj.orig_addr is None and obj.recomp_addr is not None: + matched = self.set_pair(addr, obj.recomp_addr, compare_type) + + # Type field has been set by set_pair, so we can use it in our count query: + (count,) = self._sql.execute( + """SELECT count(rowid) from symbols + where json_extract(kvstore,'$.name') = ? + AND json_extract(kvstore,'$.type') = ?""", + (name, compare_type), + ).fetchone() + + if matched and count > 1: + logger.warning( + "Ambiguous match 0x%x on name '%s' to '%s'", + addr, + name, + obj.get("symbol"), + ) - return self.set_pair(addr, recomp_addr, compare_type) + return matched + + return False def get_next_orig_addr(self, addr: int) -> Optional[int]: """Return the original address (matched or not) that follows @@ -462,29 +488,48 @@ def match_function(self, addr: int, name: str) -> bool: return did_match def match_vtable( - self, addr: int, name: str, base_class: Optional[str] = None + self, addr: int, class_name: str, base_class: Optional[str] = None ) -> bool: - # Set up our potential match names - bare_vftable = f"{name}::`vftable'" - for_name = base_class if base_class is not None else name - for_vftable = f"{name}::`vftable'{{for `{for_name}'}}" - - # Try to match on the "vftable for X first" - recomp_addr = self._find_potential_match(for_vftable, SymbolType.VTABLE) - if recomp_addr is not None: - return self.set_pair(addr, recomp_addr, SymbolType.VTABLE) - - # Only allow a match against "Class:`vftable'" - # if this is the derived class. - if base_class is None or base_class == name: - recomp_addr = self._find_potential_match(bare_vftable, SymbolType.VTABLE) - if recomp_addr is not None: - return self.set_pair(addr, recomp_addr, SymbolType.VTABLE) + """Match the vtable for the given class name. If a base class is provided, + we will match the multiple inheritance vtable instead. + + As with other name-based searches, set the given address on the first unmatched result. + + Our search here depends on having already demangled the vtable symbol before + loading the data. For example: we want to search for "Pizza::`vftable'" + so we extract the class name from its symbol "??_7Pizza@@6B@". + + For multiple inheritance, the vtable name references the base class like this: + + - X::`vftable'{for `Y'} + + The vtable for the derived class will take one of these forms: + + - X::`vftable'{for `X'} + - X::`vftable' + + We assume only one of the above will appear for a given class.""" + # Most classes will not use multiple inheritance, so try the regular vtable + # first, unless a base class is provided. + if base_class is None or base_class == class_name: + bare_vftable = f"{class_name}::`vftable'" + + for obj in self.search_name(bare_vftable, SymbolType.VTABLE): + if obj.orig_addr is None and obj.recomp_addr is not None: + return self.set_pair(addr, obj.recomp_addr, SymbolType.VTABLE) + + # If we didn't find a match above, search for the multiple inheritance vtable. + for_name = base_class if base_class is not None else class_name + for_vftable = f"{class_name}::`vftable'{{for `{for_name}'}}" + + for obj in self.search_name(for_vftable, SymbolType.VTABLE): + if obj.orig_addr is None and obj.recomp_addr is not None: + return self.set_pair(addr, obj.recomp_addr, SymbolType.VTABLE) logger.error( "Failed to find vtable for class with annotation 0x%x and name '%s'", addr, - name, + class_name, ) return False