diff --git a/README.md b/README.md index a5b997a..799d83b 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,7 @@ The bot supports the following slash commands - | Command | Description | | --- | --- | | `/fd ` | Get frame data of a particular character's move | +| `/ms ` | Find a character's moves that match a particular frame scenario | | `/` | Get frame data for a particular character's move | | `/feedback ` | Send feedback to the bot owner | | `/help` | Get help on the bot's usage | diff --git a/src/framedb/const.py b/src/framedb/const.py index 3cbbc1c..0a79cc2 100644 --- a/src/framedb/const.py +++ b/src/framedb/const.py @@ -1,5 +1,5 @@ import enum -from typing import Dict, List +from typing import Callable, Dict, List NUM_CHARACTERS = 34 @@ -95,6 +95,12 @@ class MoveType(enum.Enum): HS = "Heat Smash" +class FrameSituation(enum.Enum): + STARTUP = "startup" + BLOCK = "block" + HIT = "hit" + + MOVE_TYPE_ALIAS: Dict[MoveType, List[str]] = { MoveType.RA: ["ra", "rage_art", "rageart", "rage art"], MoveType.T: ["screw", "t!", "t", "screws", "tor", "tornado"], @@ -154,3 +160,11 @@ class MoveType(enum.Enum): "4\ufe0f\u20e3", "5\ufe0f\u20e3", ] + +CONDITION_MAP: Dict[str, Callable[[int, int], bool]] = { + ">": lambda x, y: x > y, + ">=": lambda x, y: x >= y, + "<": lambda x, y: x < y, + "<=": lambda x, y: x <= y, + "==": lambda x, y: x == y, +} diff --git a/src/framedb/framedb.py b/src/framedb/framedb.py index 817cb43..4644e50 100644 --- a/src/framedb/framedb.py +++ b/src/framedb/framedb.py @@ -1,5 +1,6 @@ import logging import os +import re from typing import Dict, List import requests @@ -8,7 +9,7 @@ from fast_autocomplete import AutoComplete from .character import Character, Move -from .const import CHARACTER_ALIAS, MOVE_TYPE_ALIAS, REPLACE, CharacterName, MoveType +from .const import CHARACTER_ALIAS, CONDITION_MAP, MOVE_TYPE_ALIAS, REPLACE, CharacterName, FrameSituation, MoveType from .frame_service import FrameService MATCH_SCORE_CUTOFF = 95 @@ -103,6 +104,21 @@ def _is_command_in_alt(move_query: str, move: Move) -> bool: return True return False + @staticmethod + def _sanitize_frame_data(frame_data: str) -> int | None: + """Removes bells and whistles from a move's frame data result'""" + + # Step 1: Remove any leading non-numeric characters (like 'i' or ',') + frame_data = re.sub(r"^[^-\d]+", "", frame_data) + + # Step 2: Match the first number or range (e.g., "-5", "5", "5~10") + match = re.match(r"([-+]?\d+)", frame_data) + + # Step 3: If a match is found, return the first number (and remove any '+' sign) + if match: + return int(match.group(1)) + return None + def get_move_by_input(self, character: CharacterName, input_query: str) -> Move | None: """Given an input move query for a known character, retrieve the move from the database.""" @@ -130,6 +146,46 @@ def get_move_by_input(self, character: CharacterName, input_query: str) -> Move # couldn't match anything :-( return None + def get_move_by_frame( + self, character: CharacterName, condition: str, frame_value: int, situation: FrameSituation + ) -> List[Move]: + """Given a frame value query for a known character, retrieve the moves from the database that matches that frame value.""" + + character_movelist = self.frames[character].movelist.values() + result = [] + + # Get the comparison function based on the condition + compare_func = CONDITION_MAP.get(condition.strip()) + if compare_func is None: + raise ValueError(f"Unsupported condition: {condition}") + + match situation: + case FrameSituation.STARTUP: + result = [ + entry + for entry in character_movelist + if entry.startup.strip() != "" # Ignore moves with no frame data + if (sanitized_value := FrameDb._sanitize_frame_data(entry.startup)) is not None + if compare_func(sanitized_value, frame_value) + ] + case FrameSituation.BLOCK: + result = [ + entry + for entry in character_movelist + if entry.on_block.strip() != "" # Ignores moves with no frame data + if (sanitized_value := FrameDb._sanitize_frame_data(entry.on_block)) is not None + if compare_func(sanitized_value, frame_value) + ] + case FrameSituation.HIT: + result = [ + entry + for entry in character_movelist + if entry.on_hit.strip() != "" # Ignores moves with no frame data + if (sanitized_value := FrameDb._sanitize_frame_data(entry.on_hit)) is not None + if compare_func(sanitized_value, frame_value) + ] + return result + def get_moves_by_move_name(self, character: CharacterName, move_name_query: str) -> List[Move]: """ Gets a list of moves that match a move name query, by comparing the move name and its aliases diff --git a/src/framedb/tests/conftest.py b/src/framedb/tests/conftest.py new file mode 100644 index 0000000..4e879d7 --- /dev/null +++ b/src/framedb/tests/conftest.py @@ -0,0 +1,7 @@ +import os +import sys + +import pytest + +# Add the project root directory to sys.path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) diff --git a/src/framedb/tests/test_framedb.py b/src/framedb/tests/test_framedb.py index 6711134..5e8bae8 100644 --- a/src/framedb/tests/test_framedb.py +++ b/src/framedb/tests/test_framedb.py @@ -2,7 +2,7 @@ from frame_service.json_directory.tests.test_json_directory import json_directory from framedb import FrameDb, FrameService -from framedb.const import CharacterName, MoveType +from framedb.const import CharacterName, FrameSituation, MoveType @pytest.fixture @@ -85,3 +85,223 @@ def test_search_move() -> None: def test_all_autocomplete_words_match() -> None: "Test that all words in the autocomplete list can be matched to a CharacterName" pass + + +def test_sanitize_frame_data(frameDb: FrameDb) -> None: + assert frameDb._sanitize_frame_data("i59~61") == 59 + assert frameDb._sanitize_frame_data(",i13,14,15") == 13 + assert frameDb._sanitize_frame_data("i13~14,i25 i35 i39 i42") == 13 + assert frameDb._sanitize_frame_data("-5") == -5 + assert frameDb._sanitize_frame_data("+5") == 5 + assert frameDb._sanitize_frame_data("+5~10") == 5 + assert frameDb._sanitize_frame_data("i16") == 16 + assert frameDb._sanitize_frame_data("i5~8") == 5 + assert frameDb._sanitize_frame_data("-5~-10") == -5 + assert frameDb._sanitize_frame_data("+67a (+51)") == 67 + assert frameDb._sanitize_frame_data("invalid data") is None + assert frameDb._sanitize_frame_data("") is None + + +def test_get_move_by_frame_on_block(frameDb: FrameDb) -> None: + # == Case + returned_moves = frameDb.get_move_by_frame(CharacterName.RAVEN, "==", -5, FrameSituation.BLOCK) + move_ids = list(map(lambda move: move.id, returned_moves)) + expected_move_ids = [ + "Raven-BT.4", + "Raven-BT.f+2", + "Raven-BT.f+3", + "Raven-d+1", + "Raven-FC.1", + "Raven-H.BT.4,F", + ] + assert set(move_ids) == set(expected_move_ids) + + # > Case + returned_moves = frameDb.get_move_by_frame(CharacterName.RAVEN, ">", 5, FrameSituation.BLOCK) + move_ids = list(map(lambda move: move.id, returned_moves)) + expected_move_ids = ["Raven-(Back to wall).b,b,UB", "Raven-b+1", "Raven-H.f,f,F+3,4", "Raven-H.2+3"] + assert set(move_ids) == set(expected_move_ids) + + # >= Case + returned_moves = frameDb.get_move_by_frame(CharacterName.RAVEN, ">=", 5, FrameSituation.BLOCK) + move_ids = list(map(lambda move: move.id, returned_moves)) + expected_move_ids = [ + "Raven-H.1+2,F", + "Raven-H.f+1+2,F", + "Raven-H.SZN.2,F", + "Raven-H.ws3+4,F", + "Raven-(Back to wall).b,b,UB", + "Raven-b+1", + "Raven-H.f,f,F+3,4", + "Raven-H.2+3", + ] + assert set(move_ids) == set(expected_move_ids) + + # < Case + returned_moves = frameDb.get_move_by_frame(CharacterName.RAVEN, "<", -25, FrameSituation.BLOCK) + move_ids = list(map(lambda move: move.id, returned_moves)) + expected_move_ids = ["Raven-3~4,F", "Raven-b+2,2,1+2", "Raven-BT.3,4,4,F", "Raven-BT.d+3", "Raven-BT.f+3+4,F"] + assert set(move_ids) == set(expected_move_ids) + + # <= Case + returned_moves = frameDb.get_move_by_frame(CharacterName.RAVEN, "<=", -25, FrameSituation.BLOCK) + move_ids = list(map(lambda move: move.id, returned_moves)) + expected_move_ids = [ + "Raven-3~4,F", + "Raven-b+2,2,1+2", + "Raven-BT.3,4,4,F", + "Raven-BT.d+3", + "Raven-BT.f+3+4,F", + "Raven-uf+3+4,3+4", + "Raven-b+4,B+4~3,3+4", + ] + assert set(move_ids) == set(expected_move_ids) + + +def test_get_move_by_frame_on_hit(frameDb: FrameDb) -> None: + # == Case + returned_moves = frameDb.get_move_by_frame(CharacterName.RAVEN, "==", 12, FrameSituation.HIT) + move_ids = list(map(lambda move: move.id, returned_moves)) + expected_move_ids = [ + "Raven-b+2,2,3", + "Raven-df+3", + "Raven-SZN.1~F", + ] + assert set(move_ids) == set(expected_move_ids) + + # > Case + returned_moves = frameDb.get_move_by_frame(CharacterName.RAVEN, ">", 45, FrameSituation.HIT) + move_ids = list(map(lambda move: move.id, returned_moves)) + expected_move_ids = [ + "Raven-H.ws3+4,F", + "Raven-ws3,2", + "Raven-H.f,f,F+3,4", + "Raven-H.ws3,2", + ] + assert set(move_ids) == set(expected_move_ids) + + # >= Case + returned_moves = frameDb.get_move_by_frame(CharacterName.RAVEN, ">=", 44, FrameSituation.HIT) + move_ids = list(map(lambda move: move.id, returned_moves)) + expected_move_ids = [ + "Raven-H.1+2,F", + "Raven-BT.4", + "Raven-H.ws3+4,F", + "Raven-ws3,2", + "Raven-H.f,f,F+3,4", + "Raven-H.ws3,2", + ] + assert set(move_ids) == set(expected_move_ids) + + # < Case + returned_moves = frameDb.get_move_by_frame(CharacterName.RAVEN, "<", -5, FrameSituation.HIT) + move_ids = list(map(lambda move: move.id, returned_moves)) + expected_move_ids = ["Raven-FC.3", "Raven-3~4,F", "Raven-UB,b,3+4", "Raven-BT.f+3+4,F", "Raven-b+1+3,P"] + assert set(move_ids) == set(expected_move_ids) + + # <= Case + returned_moves = frameDb.get_move_by_frame(CharacterName.RAVEN, "<=", -5, FrameSituation.HIT) + move_ids = list(map(lambda move: move.id, returned_moves)) + expected_move_ids = [ + "Raven-FC.3", + "Raven-3~4,F", + "Raven-UB,b,3+4", + "Raven-BT.f+3+4,F", + "Raven-b+1+3,P", + "Raven-FC.df+3+4", + "Raven-H.FC.df+3+4", + ] + assert set(move_ids) == set(expected_move_ids) + + +def test_get_move_by_frame_startup(frameDb: FrameDb) -> None: + # == Case + returned_moves = frameDb.get_move_by_frame(CharacterName.RAVEN, "==", 11, FrameSituation.STARTUP) + move_ids = list(map(lambda move: move.id, returned_moves)) + expected_move_ids = [ + "Raven-d+2", + "Raven-df+1+4", + "Raven-FC.2", + "Raven-ws4", + ] + assert set(move_ids) == set(expected_move_ids) + + # > Case + returned_moves = frameDb.get_move_by_frame(CharacterName.RAVEN, ">", 26, FrameSituation.STARTUP) + move_ids = list(map(lambda move: move.id, returned_moves)) + expected_move_ids = [ + "Raven-b+3", + "Raven-u+3", + "Raven-u+3,3", + "Raven-u+3,3,3", + "Raven-b+1+2", + "Raven-db+3", + "Raven-uf+3", + "Raven-(Back to wall).b,b,UB", + ] + assert set(move_ids) == set(expected_move_ids) + + # >= Case + returned_moves = frameDb.get_move_by_frame(CharacterName.RAVEN, ">=", 26, FrameSituation.STARTUP) + move_ids = list(map(lambda move: move.id, returned_moves)) + expected_move_ids = [ + "Raven-H.f,f,F+3,4", + "Raven-f,f,F+3", + "Raven-b+3", + "Raven-u+3", + "Raven-u+3,3", + "Raven-u+3,3,3", + "Raven-b+1+2", + "Raven-db+3", + "Raven-uf+3", + "Raven-(Back to wall).b,b,UB", + ] + assert set(move_ids) == set(expected_move_ids) + + # < Case + returned_moves = frameDb.get_move_by_frame(CharacterName.RAVEN, "<", 10, FrameSituation.STARTUP) + move_ids = list(map(lambda move: move.id, returned_moves)) + expected_move_ids = [ + "Raven-BT.1", + "Raven-BT.1,4", + ] + assert set(move_ids) == set(expected_move_ids) + + # <= Case + returned_moves = frameDb.get_move_by_frame(CharacterName.RAVEN, "<=", 10, FrameSituation.STARTUP) + move_ids = list(map(lambda move: move.id, returned_moves)) + expected_move_ids = [ + "Raven-1", + "Raven-1,2", + "Raven-1,2,3+4", + "Raven-1,2,4", + "Raven-2", + "Raven-2,3", + "Raven-2,4", + "Raven-BT.1", + "Raven-BT.1,4", + "Raven-BT.2", + "Raven-BT.2,1", + "Raven-BT.2,1~F", + "Raven-BT.2,2", + "Raven-BT.2,2,1", + "Raven-BT.2,2,3+4", + "Raven-BT.3", + "Raven-BT.3,4", + "Raven-BT.3,4,3", + "Raven-BT.3,4,4", + "Raven-BT.3,4,4,F", + "Raven-BT.d+1", + "Raven-d+1", + "Raven-FC.1", + "Raven-H.BT.2,2,1", + ] + + assert set(move_ids) == set(expected_move_ids) + + +def test_get_move_by_frame_no_results(frameDb: FrameDb) -> None: + returned_moves = frameDb.get_move_by_frame(CharacterName.RAVEN, "<", 1, FrameSituation.STARTUP) + move_ids = list(map(lambda move: move.id, returned_moves)) + expected_move_ids: list[str] = [] + assert set(move_ids) == set(expected_move_ids) diff --git a/src/heihachi/bot.py b/src/heihachi/bot.py index 5b7c09c..ab84627 100644 --- a/src/heihachi/bot.py +++ b/src/heihachi/bot.py @@ -8,10 +8,10 @@ from discord import Interaction from framedb import FrameDb, FrameService -from framedb.const import CharacterName +from framedb.const import CHARACTER_ALIAS, CharacterName from heihachi import button, embed from heihachi.configurator import Configurator -from heihachi.embed import get_frame_data_embed +from heihachi.embed import get_frame_data_embed, get_move_search_embed logger = logging.getLogger("main") @@ -34,10 +34,11 @@ def __init__( self.tree = discord.app_commands.CommandTree(self) self._add_bot_commands() - for char in CharacterName: - self.tree.command(name=char.value, description=f"Frame data from {char.value}")( - self._character_command_factory(char.value) - ) + char_names = [char.value for char in CharacterName] + flattened_aliases = [alias for sublist in CHARACTER_ALIAS.values() for alias in sublist] + char_names_and_alias = set(char_names + flattened_aliases) + for char in char_names_and_alias: + self.tree.command(name=char, description=f"Frame data from {char}")(self._character_command_factory(char)) logger.debug(f"Bot command tree: {[command.name for command in self.tree.get_commands()]}") @@ -112,9 +113,30 @@ async def _character_name_autocomplete( :25 ] # Discord has a max choice number of 25 (https://github.com/Rapptz/discord.py/discussions/9241) + async def _condition_autocomplete( + self, interaction: discord.Interaction["FrameDataBot"], current: str + ) -> List[discord.app_commands.Choice[str]]: + """Autocomplete function for the condition argument""" + conditions = [">", "<", ">=", "<=", "=="] + + # Return matching conditions that start with the current input + return [discord.app_commands.Choice(name=cond, value=cond) for cond in conditions] + + async def _situation_autocomplete( + self, interaction: discord.Interaction["FrameDataBot"], current: str + ) -> List[discord.app_commands.Choice[str]]: + """Autocomplete function for the situation argument""" + # List of valid frame situations + current = current.lower() + situations = ["startup", "block", "hit"] + + # Return matching situations that start with the current input + return [discord.app_commands.Choice(name=situation, value=situation) for situation in situations] + def _add_bot_commands(self) -> None: "Add all frame commands to the bot" + # Frame Data Command @self.tree.command(name="fd", description="Frame data from a character move") @discord.app_commands.autocomplete(character=self._character_name_autocomplete) async def _frame_data_cmd(interaction: discord.Interaction["FrameDataBot"], character: str, move: str) -> None: @@ -125,6 +147,26 @@ async def _frame_data_cmd(interaction: discord.Interaction["FrameDataBot"], char embed = get_frame_data_embed(self.framedb, self.frame_service, character_name_query, move_query) await interaction.response.send_message(embed=embed, ephemeral=False) + # Move Search Command + @self.tree.command(name="ms", description="Search for a character's move based on its frame data") + @discord.app_commands.autocomplete(character=self._character_name_autocomplete) + @discord.app_commands.autocomplete(condition=self._condition_autocomplete) + @discord.app_commands.autocomplete(situation=self._situation_autocomplete) + async def _move_search_cmd( + interaction: discord.Interaction["FrameDataBot"], character: str, condition: str, frames: str, situation: str + ) -> None: + logger.info( + f"Received command from {interaction.user.name} in {interaction.guild}: /ms {character} {condition} {frames} {situation}" + ) + character_name_query = character + frame_query = frames + if not (self._is_user_blacklisted(str(interaction.user.id)) or self._is_author_newly_created(interaction)): + embed = get_move_search_embed( + self.framedb, self.frame_service, character_name_query, condition, frame_query, situation + ) + await interaction.response.send_message(embed=embed, ephemeral=False) + + # Feedback Command if self.config.feedback_channel_id and self.config.action_channel_id: @self.tree.command(name="feedback", description="Send feedback to the authors in case of incorrect data") diff --git a/src/heihachi/embed.py b/src/heihachi/embed.py index 7ac364c..0525ac9 100644 --- a/src/heihachi/embed.py +++ b/src/heihachi/embed.py @@ -6,6 +6,7 @@ import discord from framedb import Character, FrameDb, FrameService, Move +from framedb.const import FrameSituation logger = logging.getLogger("main") @@ -134,6 +135,49 @@ def get_frame_data_embed(framedb: FrameDb, frame_service: FrameService, char_que return embed +def get_move_search_embed( + framedb: FrameDb, frame_service: FrameService, char_query: str, condition: str, frame_query: str, situation: str +) -> discord.Embed: + """ + Creates an embed for the move(s) that match the frame data provided + """ + character = framedb.get_character_by_name(char_query) + logger.debug(f"Character: {character}") + frame_situation = get_frame_situation_by_value(situation) + frames = int(frame_query) + matching_moves: Move | List[Move] = [] + + if character: + if frame_situation: + matching_moves = framedb.get_move_by_frame(character.name, condition, frames, frame_situation) + if not matching_moves: + embed = get_error_embed(f"Could not locate move that is {condition} {frame_query} on {situation}.") + + elif len(matching_moves) == 1: + embed = get_move_embed(frame_service, character, matching_moves[0]) + else: + if condition == ">": + condition = "\\>" # Escapes the '>' sign since Discord interprets it as the beginning of a quote + if frame_situation == FrameSituation.STARTUP: + embed = get_success_movelist_embed( + frame_service, character, matching_moves, f"Moves {condition} i{frame_query} on {situation}" + ) + else: + embed = get_success_movelist_embed( + frame_service, character, matching_moves, f"Moves {condition} {frame_query} frames on {situation}" + ) + else: + embed = get_error_embed(f"Could not locate character {char_query}.") + return embed + + +def get_frame_situation_by_value(value: str) -> FrameSituation: + try: + return FrameSituation(value) + except ValueError: + raise ValueError(f"Invalid value: {value}") + + def get_help_embed(frame_service: FrameService) -> discord.Embed: """Returns the help embed message for the bot.""" @@ -148,6 +192,11 @@ def get_help_embed(frame_service: FrameService) -> discord.Embed: value="Get frame data for a particular character's move.", inline=False, ) + embed.add_field( + name="/ms `` `` `` `` ", + value="Find character's move based on the frame data ex. /ms Jin >= -5 block", + inline=False, + ) embed.add_field( name="/feedback `message`", value="Send feedback to the bot authors in case of incorrect frame data (or any other reason).",