Skip to content

Commit

Permalink
Type hints, cleanup, error handling, minor bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
LagoLunatic committed Jul 5, 2024
1 parent 50f1a13 commit 23b787d
Show file tree
Hide file tree
Showing 10 changed files with 93 additions and 65 deletions.
22 changes: 13 additions & 9 deletions asm/disassemble.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from randomizer import WWRandomizer

from subprocess import call
from subprocess import DEVNULL
Expand Down Expand Up @@ -30,7 +34,7 @@ def get_bin(name):
return os.path.join(devkitpath(), name + ".exe")


def disassemble_all_code(self):
def disassemble_all_code(self: WWRandomizer):
if not os.path.isfile(get_bin("powerpc-eabi-objdump")):
raise Exception(r"Failed to disassemble code: Could not find devkitPPC. devkitPPC should be installed to: C:\devkitPro\devkitPPC")

Expand Down Expand Up @@ -75,7 +79,7 @@ def disassemble_all_code(self):
data.seek(0)
f.write(data.read())

all_rels_by_path = {}
all_rels_by_path: dict[str, REL] = {}
all_rel_symbols_by_path = {}
for file_path_in_gcm in all_rel_paths:
basename_with_ext = os.path.basename(file_path_in_gcm)
Expand Down Expand Up @@ -141,7 +145,7 @@ def disassemble_file(bin_path, asm_path):
if result != 0:
raise Exception("Disassembler call failed")

def add_relocations_and_symbols_to_rel(asm_path, rel_path, file_path_in_gcm, main_symbols, all_rel_symbols_by_path, all_rels_by_path):
def add_relocations_and_symbols_to_rel(asm_path, rel_path, file_path_in_gcm: str, main_symbols, all_rel_symbols_by_path, all_rels_by_path: dict[str, REL]):
rel = all_rels_by_path[file_path_in_gcm]
rel_symbol_names = all_rel_symbols_by_path[file_path_in_gcm]

Expand Down Expand Up @@ -303,7 +307,7 @@ def add_relocations_and_symbols_to_rel(asm_path, rel_path, file_path_in_gcm, mai
"stfdux",
]

def add_symbols_to_main(self, asm_path, main_symbols):
def add_symbols_to_main(self: WWRandomizer, asm_path, main_symbols):
out_str = ""
with open(asm_path) as f:
last_lis_match = None
Expand Down Expand Up @@ -436,8 +440,8 @@ def add_symbols_to_main(self, asm_path, main_symbols):
f.write(out_str)


def get_list_of_all_rels(self):
all_rel_paths = []
def get_list_of_all_rels(self: WWRandomizer):
all_rel_paths: list[str] = []

for file_path in self.gcm.files_by_path:
if file_path.startswith("files/rels/"):
Expand All @@ -450,7 +454,7 @@ def get_list_of_all_rels(self):

return all_rel_paths

def find_rel_by_module_num(all_rels_by_path, module_num):
def find_rel_by_module_num(all_rels_by_path: dict[str, REL], module_num):
for rel_path, rel in all_rels_by_path.items():
if rel.id == module_num:
return (rel, rel_path)
Expand All @@ -465,7 +469,7 @@ def get_main_symbols(framework_map_contents):
main_symbols[address] = name
return main_symbols

def get_rel_symbols(rel, rel_map_data):
def get_rel_symbols(rel: REL, rel_map_data: str):
rel_map_lines = rel_map_data.splitlines()
found_memory_map = False
next_section_index = 0
Expand Down Expand Up @@ -526,7 +530,7 @@ def get_padded_comment_string_for_line(line):

return (" "*spaces_needed) + "; "

def check_offset_in_executable_dol_section(self, offset):
def check_offset_in_executable_dol_section(self: WWRandomizer, offset):
section_index = self.dol.convert_offset_to_section_index(offset)
if section_index is None:
return False
Expand Down
13 changes: 8 additions & 5 deletions logic/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def get_num_progression_locations(self):
@staticmethod
def get_num_progression_locations_static(item_locations: dict[str, dict], options: Options):
progress_locations = Logic.filter_locations_for_progression_static(
item_locations.keys(),
list(item_locations.keys()),
item_locations,
options,
filter_sunken_treasure=True
Expand All @@ -224,7 +224,7 @@ def get_max_required_bosses_banned_locations(self):
if not self.options.required_bosses:
return 0

all_locations = self.item_locations.keys()
all_locations = list(self.item_locations.keys())
progress_locations = self.filter_locations_for_progression(all_locations)
location_counts_by_dungeon = {}

Expand Down Expand Up @@ -259,7 +259,7 @@ def get_max_required_bosses_banned_locations(self):
return max_banned_locations

def get_progress_and_non_progress_locations(self):
all_locations = self.item_locations.keys()
all_locations = list(self.item_locations.keys())
progress_locations = self.filter_locations_for_progression(all_locations, filter_sunken_treasure=True)
nonprogress_locations = []
for location_name in all_locations:
Expand Down Expand Up @@ -481,7 +481,7 @@ def check_item_is_useful(self, item_name, inaccessible_undone_item_locations):
self.cached_items_are_useful[item_name] = False
return False

def filter_locations_for_progression(self, locations_to_filter, filter_sunken_treasure=False):
def filter_locations_for_progression(self, locations_to_filter: list[str], filter_sunken_treasure=False):
return Logic.filter_locations_for_progression_static(
locations_to_filter,
self.item_locations,
Expand Down Expand Up @@ -765,7 +765,7 @@ def make_useless_progress_items_nonprogress(self):
if self.options.progression_triforce_charts or self.options.progression_treasure_charts:
filter_sunken_treasure = False
progress_locations = Logic.filter_locations_for_progression_static(
self.item_locations.keys(),
list(self.item_locations.keys()),
self.item_locations,
self.options,
filter_sunken_treasure=filter_sunken_treasure
Expand Down Expand Up @@ -1069,6 +1069,7 @@ def get_items_needed_from_logical_expression_req(self, logical_expression, reqs_

def check_progressive_item_req(self, req_name: str):
match = re.search(r"^(Progressive .+) x(\d+)$", req_name)
assert match
item_name = match.group(1)
num_required = int(match.group(2))

Expand All @@ -1077,6 +1078,7 @@ def check_progressive_item_req(self, req_name: str):

def check_small_key_req(self, req_name: str):
match = re.search(r"^(.+ Small Key) x(\d+)$", req_name)
assert match
small_key_name = match.group(1)
num_keys_required = int(match.group(2))

Expand All @@ -1085,6 +1087,7 @@ def check_small_key_req(self, req_name: str):

def check_item_location_requirement(self, req_name: str):
match = re.search(r"^Can Access Item Location \"([^\"]+)\"$", req_name)
assert match
item_location_name = match.group(1)

return self.check_location_accessible(item_location_name)
Expand Down
12 changes: 9 additions & 3 deletions randomizers/base_randomizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, rando: WWRandomizer):
self.rando = rando
self.logic = rando.logic
self.options = rando.options
self.rng = None
self._rng = None
self.made_any_changes = False

def is_enabled(self) -> bool:
Expand Down Expand Up @@ -48,13 +48,19 @@ def progress_save_text(self) -> str:
"""The message displayed to the user during the save step."""
return "Applying changes..."

@property
def rng(self):
if self._rng is None:
raise Exception("Attempted to use the RNG outside of the randomization step.")
return self._rng

def reset_rng(self):
self.rng = self.rando.get_new_rng()
self._rng = self.rando.get_new_rng()

def randomize(self):
self.reset_rng()
self._randomize()
self.rng = None
self._rng = None
self.made_any_changes = True

def _randomize(self):
Expand Down
2 changes: 1 addition & 1 deletion randomizers/boss_reqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def randomize_required_bosses(self):
if len(possible_boss_item_locations) != 6:
raise Exception("Number of boss item locations is incorrect: " + ", ".join(possible_boss_item_locations))
if num_required_bosses > 6 or num_required_bosses < 1:
raise Exception(f"Number of required bosses is invalid: {len(num_required_bosses)}")
raise Exception(f"Number of required bosses is invalid: {num_required_bosses}")

self.required_boss_item_locations = self.rng.sample(possible_boss_item_locations, num_required_bosses)

Expand Down
18 changes: 9 additions & 9 deletions randomizers/entrances.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ class ZoneEntrance:
scls_exit_index: int
spawn_id: int
entrance_name: str
island_name: str = None
warp_out_stage_name: str = None
warp_out_room_num: int = None
warp_out_spawn_id: int = None
nested_in: 'ZoneExit' = None
island_name: str | None = None
warp_out_stage_name: str | None = None
warp_out_room_num: int | None = None
warp_out_spawn_id: int | None = None
nested_in: 'ZoneExit | None' = None

@property
def is_nested(self):
Expand All @@ -41,12 +41,12 @@ def __post_init__(self):
class ZoneExit:
stage_name: str
room_num: int
scls_exit_index: int
scls_exit_index: int | None
spawn_id: int
unique_name: str
_: KW_ONLY
boss_stage_name: str = None
zone_name: str = None
boss_stage_name: str | None = None
zone_name: str | None = None
# If zone_name is specified, this exit will assume by default that it owns all item locations in
# that zone which are behind randomizable entrances. If a single zone has multiple randomizable
# entrances, only one of them at most can use zone_name. The rest must have their item locations
Expand Down Expand Up @@ -1047,7 +1047,7 @@ def get_entrance_zone_for_item_location(self, location_name: str) -> str:
outermost_entrance = self.get_outermost_entrance_for_exit(zone_exit)
return outermost_entrance.island_name

def get_all_zones_for_item_location(self, location_name: str) -> list[str]:
def get_all_zones_for_item_location(self, location_name: str) -> set[str]:
# Helper function to return a set of zone names that include the location.
#
# All returned zones are either an island name or a dungeon name - that is, if the entrance to
Expand Down
32 changes: 22 additions & 10 deletions randomizers/hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@ class ItemImportance(Enum):


class Hint:
def __init__(self, type: HintType, place, reward=None, importance=None):
type: HintType
place: str
reward: str | None
importance: ItemImportance | None

def __init__(self, type: HintType, place: str, reward: str | None = None, importance: ItemImportance | None = None):
assert place is not None
if type == HintType.BARREN: assert reward is None
if type != HintType.BARREN: assert reward is not None
Expand Down Expand Up @@ -148,13 +153,13 @@ class HintsRandomizer(BaseRandomizer):
#endregion


cryptic_item_hints = None
cryptic_zone_hints = None
location_hints = None
cryptic_item_hints: dict[str, str] = None
cryptic_zone_hints: dict[str, str] = None
location_hints: dict[str, dict] = None

def __init__(self, rando):
super().__init__(rando)
self.path_logic = None
self._path_logic = None
self.path_logic_initial_state = None

# Define instance variable shortcuts for hint distribution options.
Expand Down Expand Up @@ -213,8 +218,14 @@ def progress_randomize_text(self) -> str:
def progress_save_text(self) -> str:
return "Saving hints..."

@property
def path_logic(self):
if self._path_logic is None:
raise Exception("Hints randomizer attempted to use uninitialized path logic.")
return self._path_logic

def _randomize(self):
self.path_logic = Logic(self.rando)
self._path_logic = Logic(self.rando)
self.path_logic_initial_state = self.path_logic.save_simulated_playthrough_state()

# Generate the hints that will be distributed over the hint placement options
Expand Down Expand Up @@ -871,9 +882,9 @@ def get_barren_hint(self, unhinted_zones, zone_weights):
return barren_hint


def filter_out_hinted_barren_locations(self, hintable_locations, hinted_barren_zones):
def filter_out_hinted_barren_locations(self, hintable_locations: list[str], hinted_barren_zones: list[Hint]):
# Remove locations in hinted barren areas.
new_hintable_locations = []
new_hintable_locations: list[str] = []
barrens = [hint.place for hint in hinted_barren_zones]
for location_name in hintable_locations:
entrance_zones = self.rando.entrances.get_all_zones_for_item_location(location_name)
Expand Down Expand Up @@ -916,6 +927,7 @@ def get_importance_for_location(self, location_name):

def check_is_legal_item_hint(self, location_name, progress_locations, previously_hinted_locations):
item_name = self.logic.done_item_locations[location_name]
assert item_name is not None

if not self.check_item_can_be_hinted_at(item_name):
return False
Expand All @@ -934,7 +946,7 @@ def check_is_legal_item_hint(self, location_name, progress_locations, previously

return True

def get_legal_item_hints(self, progress_locations, hinted_barren_zones, previously_hinted_locations):
def get_legal_item_hints(self, progress_locations, hinted_barren_zones: list[Hint], previously_hinted_locations):
# Helper function to build a list of locations which may be hinted as item hints in this seed.

# Filter out locations which are invalid to be hinted at for item hints.
Expand Down Expand Up @@ -1157,7 +1169,7 @@ def generate_hints(self):
# We select at most `self.max_barren_hints` zones at random to hint as barren. Barren zones are weighted by the
# square root of the number of locations at that zone.
unhinted_barren_zones = self.get_barren_zones(progress_locations, [hint.place for hint in hinted_remote_locations])
hinted_barren_zones = []
hinted_barren_zones: list[Hint] = []
while len(unhinted_barren_zones) > 0 and len(hinted_barren_zones) < self.max_barren_hints:
# Weight each barren zone by the square root of the number of locations there.
zone_weights = [math.sqrt(location_counter[zone]) for zone in unhinted_barren_zones]
Expand Down
2 changes: 1 addition & 1 deletion test/test_dry.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_trick_logic_checks():

def test_parse_string_option_to_enum():
options = Options()
options.logic_precision = "Normal"
options.logic_precision = "Normal" # pyright: ignore [reportAttributeAccessIssue]
rando = dry_rando_with_options(options)
assert isinstance(rando.options.logic_precision, StrEnum)

Expand Down
Loading

0 comments on commit 23b787d

Please sign in to comment.