From f9242d8716766af88f8f8ea7a7c276e00deffb16 Mon Sep 17 00:00:00 2001 From: Abhijeet Krishnan Date: Thu, 15 Feb 2024 23:00:57 -0500 Subject: [PATCH] (WIP) Refactor bot and main entry --- .pre-commit-config.yaml | 1 - src/framedb/__init__.py | 2 +- src/framedb/const.py | 2 +- src/framedb/framedb.py | 10 +- src/heihachi/bot.py | 158 ++++++++++++++---------- src/heihachi/embed.py | 47 ++++--- src/heihachi/tests/test_bot.py | 5 +- src/heihachi/tests/test_configurator.py | 8 +- src/heihachi/tests/test_embed.py | 25 ++-- src/main.py | 94 +++++++------- 10 files changed, 201 insertions(+), 151 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 690f6d7..849b165 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,6 @@ repos: hooks: - id: trailing-whitespace - id: end-of-file-fixer - - id: check-yaml - id: check-added-large-files - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. diff --git a/src/framedb/__init__.py b/src/framedb/__init__.py index 1fb1be9..08d70d5 100644 --- a/src/framedb/__init__.py +++ b/src/framedb/__init__.py @@ -1,4 +1,4 @@ from .character import Character, DiscordMd, Move, Url -from .const import * # TODO: use specific imports +from .const import CHARACTER_ALIAS, MOVE_TYPE_ALIAS, REPLACE, SORT_ORDER, CharacterName, MoveType from .frame_service import FrameService from .framedb import FrameDb diff --git a/src/framedb/const.py b/src/framedb/const.py index 4992c4b..b8afcf0 100644 --- a/src/framedb/const.py +++ b/src/framedb/const.py @@ -9,7 +9,7 @@ class CharacterName(enum.Enum): AZUCENA = "azucena" BRYAN = "bryan" CLAUDIO = "claudio" - DEVIL_JIN = "devil jin" + DEVIL_JIN = "devil_jin" DRAGUNOV = "dragunov" FENG = "feng" HWOARANG = "hwoarang" diff --git a/src/framedb/framedb.py b/src/framedb/framedb.py index ca1b19e..e3ff5c9 100644 --- a/src/framedb/framedb.py +++ b/src/framedb/framedb.py @@ -43,6 +43,12 @@ def load(self, frame_service: FrameService) -> None: else: logger.warning(f"Could not load frame data for {character}") + def refresh(self, frame_service: FrameService, export_dir_path: str, format: str = "json") -> None: + "Refresh the frame database using a frame service." + + self.load(frame_service) + self.export(export_dir_path, format=format) + @staticmethod def _simplify_input(input_query: str) -> str: """Removes bells and whistles from a move input query""" @@ -69,7 +75,7 @@ def _is_command_in_alias(command: str, move: Move) -> bool: return False @staticmethod - def _correct_character_name(char_name_query: str) -> str | None: + def _correct_character_name(char_name_query: str) -> str | None: # TODO: overlap with get_character_by_name? "Check if input in dictionary or in dictionary values" if char_name_query in CHARACTER_ALIAS: @@ -159,7 +165,7 @@ def get_character_by_name(self, name_query: str) -> Character | None: return character return None - def get_move_type(self, move_type_query: str) -> MoveType | None: + def get_move_type(self, move_type_query: str) -> MoveType | None: # TODO: overlap with get_moves_by_move_type? """Given a move type query, return the corresponding move type""" for move_type, aliases in MOVE_TYPE_ALIAS.items(): diff --git a/src/heihachi/bot.py b/src/heihachi/bot.py index 4b012c3..20a52d8 100644 --- a/src/heihachi/bot.py +++ b/src/heihachi/bot.py @@ -1,32 +1,48 @@ import datetime import logging import sched +from typing import Any, Callable, Coroutine, List, Tuple import discord +import discord.ext.commands -from framedb.const import CharacterName +from framedb import CharacterName, FrameDb, FrameService from heihachi import Configurator, button, embed -from heihachi.embed import create_frame_data_embed +from heihachi.embed import get_frame_data_embed logger = logging.getLogger(__name__) -class FrameDataBot(discord.Client): - def __init__(self, config: Configurator, intents: discord.Intents): - super().__init__(intents=intents) +class FrameDataBot(discord.ext.commands.Bot): + def __init__( + self, + command_prefix: str, + framedb: FrameDb, + frame_service: FrameService, + config: Configurator, + description: str | None = None, + intents: discord.Intents = discord.Intents.default(), + ) -> None: + super().__init__(command_prefix, description=description, intents=intents) + + self.framedb = framedb + self.frame_service = frame_service self.config = config self.synced = False - async def on_ready(self, tree: discord.app_commands.CommandTree) -> None: + self._add_bot_commands() + + async def on_ready(self) -> None: await self.wait_until_ready() if not self.synced: - await tree.sync() + await self.tree.sync() self.synced = True logger.info(f"Logged on as {self.user}") def is_user_blacklisted(self, user_id: str | int) -> bool: "Check if a user is blacklisted" + blacklist: List[str] | List[int] | None if isinstance(user_id, str): blacklist = self.config.blacklist else: @@ -44,68 +60,78 @@ def is_author_newly_created(self, interaction: discord.Interaction) -> bool: age = today - interaction.user.created_at.replace(tzinfo=None) return age.days < self.config.new_author_age_limit + async def on_message(self, message: discord.Message) -> None: + if self.user: + if not self.is_user_blacklisted(message.author.id) and message.content and message.author.id != self.user.id: + user_command = message.content.split(" ", 1)[1] + parameters = user_command.strip().split(" ", 1) + character_name = parameters[0] + character_move = parameters[1] -# TODO: fix all this -@hei.event -async def on_message(message) -> None: - if not self.is_user_blacklisted(message.author.id) and message.content and message.author.id != hei.user.id: - user_command = message.content.split(" ", 1)[1] - parameters = user_command.strip().split(" ", 1) - character_name = parameters[0].lower() - character_move = parameters[1] - - embed = create_frame_data_embed(character_name, character_move) - await message.channel.send(embed=embed) - - -@tree.command(name="fd", description="Frame data from a character move") -async def self(interaction: discord.Interaction, character_name: str, move: str) -> None: - if not (self.is_user_blacklisted(str(interaction.user.id)) or self.is_author_newly_created(interaction)): - embed = create_frame_data_embed(character_name, move) - await interaction.response.send_message(embed=embed, ephemeral=False) - - -def character_command_factory(name: str): - async def command(interaction: discord.Interaction, move: str) -> None: - if not (self.is_user_blacklisted(str(interaction.user.id)) or self.is_author_newly_created(interaction)): - embed = create_frame_data_embed(name, move) - await interaction.response.send_message(embed=embed, ephemeral=False) - - return command - - -for character in CharacterName: - name = character.value - tree.command(name=name, description=f"Frame data from {name}")(character_command_factory(name)) - -if self.config.feedback_channel_id: - - @tree.command(name="feedback", description="Send feedback incase of wrong data") - async def self(interaction: discord.Interaction, message: str) -> None: - if not (self.is_user_blacklisted(str(interaction.user.id)) or util.is_author_newly_created(interaction)): - try: - feedback_message = "Feedback from **{}** with ID **{}** in **{}** \n- {}\n".format( - str(interaction.user.name), - interaction.user.id, - interaction.guild, - message, - ) - try: - channel = hei.get_channel(config.feedback_channel_id) - actioned_channel = hei.get_channel(config.action_channel_id) - except Exception as e: - logger.error(f"Error getting channel: {e}") - await channel.send(content=feedback_message, view=button.DoneButton(actioned_channel)) - result = embed.success_embed("Feedback sent") - except Exception as e: - result = embed.error_embed(f"Feedback couldn't be sent, caused by: {str(e)}") + embed = get_frame_data_embed(self.framedb, self.frame_service, character_name, character_move) + await message.channel.send(embed=embed) + else: + logger.warning(f"Received a {message=} when the bot is not logged in") + + def _character_command_factory(self, name: str) -> Callable[[discord.Interaction, str], Coroutine[Any, Any, None]]: + "A factory function to create /character command functions" + + async def _character_command(interaction: discord.Interaction, move: str) -> None: + if not (self.is_user_blacklisted(str(interaction.user.id)) or self.is_author_newly_created(interaction)): + embed = get_frame_data_embed(self.framedb, self.frame_service, name, move) + await interaction.response.send_message(embed=embed, ephemeral=False) + + return _character_command + + def _add_bot_commands(self) -> None: + "Add all frame commands to the bot" + + @self.command(name="fd", description="Frame data from a character move") + async def _frame_data_cmd(interaction: discord.Interaction, character_name_query: str, move_query: str) -> None: + if not (self.is_user_blacklisted(str(interaction.user.id)) or self.is_author_newly_created(interaction)): + embed = get_frame_data_embed(self.framedb, self.frame_service, character_name_query, move_query) + await interaction.response.send_message(embed=embed, ephemeral=False) + + for character in CharacterName: + char_name = character.value + self.command(name=char_name, description=f"Frame data from {char_name}")( + self._character_command_factory(char_name) + ) + + if self.config.feedback_channel_id and self.config.action_channel_id: + + @self.command(name="feedback", description="Send feedback incase of wrong data") + async def _feedback_cmd(interaction: discord.Interaction, message: str) -> None: + if not (self.is_user_blacklisted(str(interaction.user.id)) or self.is_author_newly_created(interaction)): + try: + feedback_message = "Feedback from **{}** with ID **{}** in **{}** \n- {}\n".format( + str(interaction.user.name), + interaction.user.id, + interaction.guild, + message, + ) + try: + assert self.config.feedback_channel_id and self.config.action_channel_id + channel = self.get_channel(self.config.feedback_channel_id) + actioned_channel = self.get_channel(self.config.action_channel_id) + except Exception as e: + logger.error(f"Error getting channel: {e}") + assert channel and actioned_channel + await channel.send(content=feedback_message, view=button.DoneButton(actioned_channel)) + result = embed.get_success_embed("Feedback sent") + except Exception as e: + result = embed.get_error_embed(f"Feedback couldn't be sent, caused by: {str(e)}") + + await interaction.response.send_message(embed=result, ephemeral=False) + else: + logger.warning("Feedback or Action channel ID is not set. Disabling feedback command.") - await interaction.response.send_message(embed=result, ephemeral=False) -else: - logger.warning("Feedback channel ID is not set. Disabling feedback command.") +def periodic_function( + scheduler: sched.scheduler, interval: float, function: sched._ActionCallback, args: Tuple[Any, ...] +) -> None: + "Run a function periodically" -def periodic_function(scheduler: sched.scheduler, interval: float, function: sched._ActionCallback, character_list_path: str): while True: - scheduler.enter(interval, 1, function, (character_list_path,)) + scheduler.enter(interval, 1, function, args) scheduler.run() diff --git a/src/heihachi/embed.py b/src/heihachi/embed.py index cd7ba30..3d4fda8 100644 --- a/src/heihachi/embed.py +++ b/src/heihachi/embed.py @@ -2,7 +2,7 @@ import discord -from framedb import Character, CharacterName, FrameService, Move, MoveType +from framedb import Character, CharacterName, FrameDb, FrameService, Move, MoveType MOVE_NOT_FOUND_TITLE = "Move not found" @@ -11,7 +11,7 @@ ERROR_COLOR = discord.Colour.from_rgb(220, 20, 60) -def similar_moves_embed( +def get_similar_moves_embed( frame_service: FrameService, similar_moves: List[Move], character_name: CharacterName ) -> discord.Embed: """Returns the embed message for similar moves.""" @@ -26,7 +26,7 @@ def similar_moves_embed( return embed -def move_list_embed( +def get_move_list_embed( frame_service: FrameService, character: Character, moves: List[Move], move_type: MoveType ) -> discord.Embed: """Returns the embed message for a list of moves matching to a special move type.""" @@ -41,17 +41,17 @@ def move_list_embed( return embed -def error_embed(message) -> discord.Embed: +def get_error_embed(message) -> discord.Embed: embed = discord.Embed(title="Error", colour=ERROR_COLOR, description=message) return embed -def success_embed(message) -> discord.Embed: +def get_success_embed(message) -> discord.Embed: embed = discord.Embed(title="Success", colour=SUCCESS_COLOR, description=message) return embed -def move_embed(frame_service: FrameService, character: Character, move: Move) -> discord.Embed: +def get_move_embed(frame_service: FrameService, character: Character, move: Move) -> discord.Embed: """Returns the embed message for character and move.""" embed = discord.Embed( @@ -78,28 +78,27 @@ def move_embed(frame_service: FrameService, character: Character, move: Move) -> return embed -def create_frame_data_embed(name: str, move: str) -> discord.Embed: - character_name = util.correct_character_name(name.lower()) # TODO: fix all this - if character_name: - character = util.get_character_by_name(character_name, character_list) - assert character is not None - move_list = json_directory.get_movelist(character_name, JSON_PATH) - move_type = util.get_move_type(move) +def get_frame_data_embed(framedb: FrameDb, frame_service: FrameService, char_query: str, move_query: str) -> discord.Embed: + """Creates an embed for the frame data of a character and move.""" + + character = framedb.get_character_by_name(char_query) + if character: + move_type = framedb.get_move_type(move_query) if move_type: - moves = json_directory.get_by_move_type(move_type, move_list) - moves_embed = embed.move_list_embed(character, moves, move_type) - return moves_embed + moves = framedb.get_moves_by_move_type(character.name, move_type.value) + moves_embed = get_move_list_embed(frame_service, character, moves, move_type) + embed = moves_embed else: - character_move = json_directory.get_move(move, move_list) + character_move = framedb.get_move_by_input(character.name, move_query) if character_move: - move_embed = embed.move_embed(character, character_move) - return move_embed + move_embed = get_move_embed(frame_service, character, character_move) + embed = move_embed else: - similar_moves = json_directory.get_similar_moves(move, move_list) - similar_moves_embed = embed.similar_moves_embed(similar_moves, character_name) - return similar_moves_embed + similar_moves = framedb.get_similar_moves(character.name, move_query) + similar_moves_embed = get_similar_moves_embed(frame_service, similar_moves, character.name) + embed = similar_moves_embed else: - error_embed = embed.error_embed(f"Could not locate character {name}.") - return error_embed + embed = get_error_embed(f"Could not locate character {char_query}.") + return embed diff --git a/src/heihachi/tests/test_bot.py b/src/heihachi/tests/test_bot.py index da6ad0c..346bb14 100644 --- a/src/heihachi/tests/test_bot.py +++ b/src/heihachi/tests/test_bot.py @@ -1,3 +1,6 @@ import pytest -# TODO: Add tests for the bot + +@pytest.mark.skip(reason="Not implemented") +def test_create_bot(): + pass diff --git a/src/heihachi/tests/test_configurator.py b/src/heihachi/tests/test_configurator.py index 4012842..806d0c4 100644 --- a/src/heihachi/tests/test_configurator.py +++ b/src/heihachi/tests/test_configurator.py @@ -10,7 +10,13 @@ @pytest.fixture def config(): - return Configurator(discord_token="123456789", feedback_channel_id=123456789, action_channel_id=987654321) + return Configurator( + discord_token="123456789", + feedback_channel_id=123456789, + action_channel_id=987654321, + blacklist=None, + id_blacklist=None, + ) def test_from_file(): diff --git a/src/heihachi/tests/test_embed.py b/src/heihachi/tests/test_embed.py index aa7d53a..56be7c2 100644 --- a/src/heihachi/tests/test_embed.py +++ b/src/heihachi/tests/test_embed.py @@ -1,26 +1,31 @@ import pytest -@pytest.mark.skip(reason="Not implemented yet.") -def test_similar_moves_embed(): +@pytest.mark.skip(reason="Not implemented") +def test_get_similar_moves_embed(): pass -@pytest.mark.skip(reason="Not implemented yet.") -def test_move_list_embed(): +@pytest.mark.skip(reason="Not implemented") +def test_get_move_list_embed(): pass -@pytest.mark.skip(reason="Not implemented yet.") -def test_error_embed(): +@pytest.mark.skip(reason="Not implemented") +def test_get_error_embed(): pass -@pytest.mark.skip(reason="Not implemented yet.") -def test_success_embed(): +@pytest.mark.skip(reason="Not implemented") +def test_get_success_embed(): pass -@pytest.mark.skip(reason="Not implemented yet.") -def test_move_embed(): +@pytest.mark.skip(reason="Not implemented") +def test_get_move_embed(): + pass + + +@pytest.mark.skip(reason="Not implemented") +def test_get_frame_data_embed(): pass diff --git a/src/main.py b/src/main.py index aa2fad6..f4fc285 100644 --- a/src/main.py +++ b/src/main.py @@ -6,57 +6,63 @@ import logging import os import sched +import sys import threading import time -import discord - from frame_service import Wavu from framedb import FrameDb from heihachi import configurator -from heihachi.bot import FrameDataBot +from heihachi.bot import FrameDataBot, periodic_function + +"How often to update the bot's frame data from the external service and write to file." +UPDATE_INTERVAL_SEC = 3600 logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # TODO: format logger output to include timestamps, line numbers and file names -try: - config = configurator.Configurator.from_file(os.path.abspath("config.json")) # TODO: take as cmdline arg - assert config is not None -except FileNotFoundError: - logger.error("Config file not found. Exiting.") - exit(1) - -CHARACTER_LIST_PATH = os.path.abspath(os.path.join("src", "resources", "character_list.json")) -JSON_PATH = os.path.abspath(os.path.join("json_movelist")) - -try: - hei = FrameDataBot(config, intents=discord.Intents.default()) - tree = discord.app_commands.CommandTree(hei) - -except Exception as e: - time_now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") - logger.error(f"{time_now} \n Error: {e}") - -wavu = Wavu() -framedb = FrameDb() -try: - character_list = util.create_json_movelists(CHARACTER_LIST_PATH) - scheduler = sched.scheduler(time.time, time.sleep) - - # Repeat importing move list of all character from frame service once an hour - scheduler_thread = threading.Thread( - target=util.periodic_function, - args=( - scheduler, - 3600, - framedb.load(wavu), - CHARACTER_LIST_PATH, - ), # TODO: schedule calling framedb.load(frame_service) to rebuild frame data - # TODO: also schedule saving newly loaded frame data to json based on configurable path - ) - scheduler_thread.start() - hei.run(config.discord_token) - -except Exception as e: - time_now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") - logger.error(f"{time_now} \n Error: {e}") + +def main() -> None: + # retrieve config + try: + config = configurator.Configurator.from_file(sys.argv[1]) # TODO: potentially use argparse + assert config is not None + except FileNotFoundError: + logger.error(f"Config file not found at {sys.argv[1]}. Exiting...") + exit(1) + + export_dir_path = os.path.join(os.getcwd(), "json_movelist") + + # initialize bot + try: + frame_service = Wavu() + framedb = FrameDb() + framedb.load(frame_service) + hei = FrameDataBot("/", framedb, frame_service, config) + + except Exception as e: + # time_now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + logger.error(f"Failed to initialize bot: {e}") + exit(1) + + # schedule and start the frame refresh thread + try: + scheduler = sched.scheduler(time.time, time.sleep) + scheduler_thread = threading.Thread( + target=periodic_function, + args=(scheduler, UPDATE_INTERVAL_SEC, framedb.refresh, (frame_service, export_dir_path, "json")), + ) + scheduler_thread.start() + + except Exception as e: + logger.error(f"Error in scheduling the frame refresh thread: {e}") + + # start the bot + try: + hei.run(config.discord_token) + except Exception as e: + logger.error(f"Error in running the bot: {e}") + + +if __name__ == "__main__": + main()