Skip to content

Commit

Permalink
(WIP) Refactor bot and main entry
Browse files Browse the repository at this point in the history
  • Loading branch information
AbhijeetKrishnan committed Feb 16, 2024
1 parent 2d90209 commit f9242d8
Show file tree
Hide file tree
Showing 10 changed files with 201 additions and 151 deletions.
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ repos:
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
Expand Down
2 changes: 1 addition & 1 deletion src/framedb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .character import Character, DiscordMd, Move, Url
from .const import * # TODO: use specific imports
from .const import CHARACTER_ALIAS, MOVE_TYPE_ALIAS, REPLACE, SORT_ORDER, CharacterName, MoveType
from .frame_service import FrameService
from .framedb import FrameDb
2 changes: 1 addition & 1 deletion src/framedb/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class CharacterName(enum.Enum):
AZUCENA = "azucena"
BRYAN = "bryan"
CLAUDIO = "claudio"
DEVIL_JIN = "devil jin"
DEVIL_JIN = "devil_jin"
DRAGUNOV = "dragunov"
FENG = "feng"
HWOARANG = "hwoarang"
Expand Down
10 changes: 8 additions & 2 deletions src/framedb/framedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ def load(self, frame_service: FrameService) -> None:
else:
logger.warning(f"Could not load frame data for {character}")

def refresh(self, frame_service: FrameService, export_dir_path: str, format: str = "json") -> None:
"Refresh the frame database using a frame service."

self.load(frame_service)
self.export(export_dir_path, format=format)

@staticmethod
def _simplify_input(input_query: str) -> str:
"""Removes bells and whistles from a move input query"""
Expand All @@ -69,7 +75,7 @@ def _is_command_in_alias(command: str, move: Move) -> bool:
return False

@staticmethod
def _correct_character_name(char_name_query: str) -> str | None:
def _correct_character_name(char_name_query: str) -> str | None: # TODO: overlap with get_character_by_name?
"Check if input in dictionary or in dictionary values"

if char_name_query in CHARACTER_ALIAS:
Expand Down Expand Up @@ -159,7 +165,7 @@ def get_character_by_name(self, name_query: str) -> Character | None:
return character
return None

def get_move_type(self, move_type_query: str) -> MoveType | None:
def get_move_type(self, move_type_query: str) -> MoveType | None: # TODO: overlap with get_moves_by_move_type?
"""Given a move type query, return the corresponding move type"""

for move_type, aliases in MOVE_TYPE_ALIAS.items():
Expand Down
158 changes: 92 additions & 66 deletions src/heihachi/bot.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,48 @@
import datetime
import logging
import sched
from typing import Any, Callable, Coroutine, List, Tuple

import discord
import discord.ext.commands

from framedb.const import CharacterName
from framedb import CharacterName, FrameDb, FrameService
from heihachi import Configurator, button, embed
from heihachi.embed import create_frame_data_embed
from heihachi.embed import get_frame_data_embed

logger = logging.getLogger(__name__)


class FrameDataBot(discord.Client):
def __init__(self, config: Configurator, intents: discord.Intents):
super().__init__(intents=intents)
class FrameDataBot(discord.ext.commands.Bot):
def __init__(
self,
command_prefix: str,
framedb: FrameDb,
frame_service: FrameService,
config: Configurator,
description: str | None = None,
intents: discord.Intents = discord.Intents.default(),
) -> None:
super().__init__(command_prefix, description=description, intents=intents)

self.framedb = framedb
self.frame_service = frame_service
self.config = config
self.synced = False

async def on_ready(self, tree: discord.app_commands.CommandTree) -> None:
self._add_bot_commands()

async def on_ready(self) -> None:
await self.wait_until_ready()
if not self.synced:
await tree.sync()
await self.tree.sync()
self.synced = True
logger.info(f"Logged on as {self.user}")

def is_user_blacklisted(self, user_id: str | int) -> bool:
"Check if a user is blacklisted"

blacklist: List[str] | List[int] | None
if isinstance(user_id, str):
blacklist = self.config.blacklist
else:
Expand All @@ -44,68 +60,78 @@ def is_author_newly_created(self, interaction: discord.Interaction) -> bool:
age = today - interaction.user.created_at.replace(tzinfo=None)
return age.days < self.config.new_author_age_limit

async def on_message(self, message: discord.Message) -> None:
if self.user:
if not self.is_user_blacklisted(message.author.id) and message.content and message.author.id != self.user.id:
user_command = message.content.split(" ", 1)[1]
parameters = user_command.strip().split(" ", 1)
character_name = parameters[0]
character_move = parameters[1]

# TODO: fix all this
@hei.event
async def on_message(message) -> None:
if not self.is_user_blacklisted(message.author.id) and message.content and message.author.id != hei.user.id:
user_command = message.content.split(" ", 1)[1]
parameters = user_command.strip().split(" ", 1)
character_name = parameters[0].lower()
character_move = parameters[1]

embed = create_frame_data_embed(character_name, character_move)
await message.channel.send(embed=embed)


@tree.command(name="fd", description="Frame data from a character move")
async def self(interaction: discord.Interaction, character_name: str, move: str) -> None:
if not (self.is_user_blacklisted(str(interaction.user.id)) or self.is_author_newly_created(interaction)):
embed = create_frame_data_embed(character_name, move)
await interaction.response.send_message(embed=embed, ephemeral=False)


def character_command_factory(name: str):
async def command(interaction: discord.Interaction, move: str) -> None:
if not (self.is_user_blacklisted(str(interaction.user.id)) or self.is_author_newly_created(interaction)):
embed = create_frame_data_embed(name, move)
await interaction.response.send_message(embed=embed, ephemeral=False)

return command


for character in CharacterName:
name = character.value
tree.command(name=name, description=f"Frame data from {name}")(character_command_factory(name))

if self.config.feedback_channel_id:

@tree.command(name="feedback", description="Send feedback incase of wrong data")
async def self(interaction: discord.Interaction, message: str) -> None:
if not (self.is_user_blacklisted(str(interaction.user.id)) or util.is_author_newly_created(interaction)):
try:
feedback_message = "Feedback from **{}** with ID **{}** in **{}** \n- {}\n".format(
str(interaction.user.name),
interaction.user.id,
interaction.guild,
message,
)
try:
channel = hei.get_channel(config.feedback_channel_id)
actioned_channel = hei.get_channel(config.action_channel_id)
except Exception as e:
logger.error(f"Error getting channel: {e}")
await channel.send(content=feedback_message, view=button.DoneButton(actioned_channel))
result = embed.success_embed("Feedback sent")
except Exception as e:
result = embed.error_embed(f"Feedback couldn't be sent, caused by: {str(e)}")
embed = get_frame_data_embed(self.framedb, self.frame_service, character_name, character_move)
await message.channel.send(embed=embed)
else:
logger.warning(f"Received a {message=} when the bot is not logged in")

def _character_command_factory(self, name: str) -> Callable[[discord.Interaction, str], Coroutine[Any, Any, None]]:
"A factory function to create /character command functions"

async def _character_command(interaction: discord.Interaction, move: str) -> None:
if not (self.is_user_blacklisted(str(interaction.user.id)) or self.is_author_newly_created(interaction)):
embed = get_frame_data_embed(self.framedb, self.frame_service, name, move)
await interaction.response.send_message(embed=embed, ephemeral=False)

return _character_command

def _add_bot_commands(self) -> None:
"Add all frame commands to the bot"

@self.command(name="fd", description="Frame data from a character move")
async def _frame_data_cmd(interaction: discord.Interaction, character_name_query: str, move_query: str) -> None:
if not (self.is_user_blacklisted(str(interaction.user.id)) or self.is_author_newly_created(interaction)):
embed = get_frame_data_embed(self.framedb, self.frame_service, character_name_query, move_query)
await interaction.response.send_message(embed=embed, ephemeral=False)

for character in CharacterName:
char_name = character.value
self.command(name=char_name, description=f"Frame data from {char_name}")(
self._character_command_factory(char_name)
)

if self.config.feedback_channel_id and self.config.action_channel_id:

@self.command(name="feedback", description="Send feedback incase of wrong data")
async def _feedback_cmd(interaction: discord.Interaction, message: str) -> None:
if not (self.is_user_blacklisted(str(interaction.user.id)) or self.is_author_newly_created(interaction)):
try:
feedback_message = "Feedback from **{}** with ID **{}** in **{}** \n- {}\n".format(
str(interaction.user.name),
interaction.user.id,
interaction.guild,
message,
)
try:
assert self.config.feedback_channel_id and self.config.action_channel_id
channel = self.get_channel(self.config.feedback_channel_id)
actioned_channel = self.get_channel(self.config.action_channel_id)
except Exception as e:
logger.error(f"Error getting channel: {e}")
assert channel and actioned_channel
await channel.send(content=feedback_message, view=button.DoneButton(actioned_channel))
result = embed.get_success_embed("Feedback sent")
except Exception as e:
result = embed.get_error_embed(f"Feedback couldn't be sent, caused by: {str(e)}")

await interaction.response.send_message(embed=result, ephemeral=False)
else:
logger.warning("Feedback or Action channel ID is not set. Disabling feedback command.")

await interaction.response.send_message(embed=result, ephemeral=False)
else:
logger.warning("Feedback channel ID is not set. Disabling feedback command.")

def periodic_function(
scheduler: sched.scheduler, interval: float, function: sched._ActionCallback, args: Tuple[Any, ...]
) -> None:
"Run a function periodically"

def periodic_function(scheduler: sched.scheduler, interval: float, function: sched._ActionCallback, character_list_path: str):
while True:
scheduler.enter(interval, 1, function, (character_list_path,))
scheduler.enter(interval, 1, function, args)
scheduler.run()
47 changes: 23 additions & 24 deletions src/heihachi/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import discord

from framedb import Character, CharacterName, FrameService, Move, MoveType
from framedb import Character, CharacterName, FrameDb, FrameService, Move, MoveType

MOVE_NOT_FOUND_TITLE = "Move not found"

Expand All @@ -11,7 +11,7 @@
ERROR_COLOR = discord.Colour.from_rgb(220, 20, 60)


def similar_moves_embed(
def get_similar_moves_embed(
frame_service: FrameService, similar_moves: List[Move], character_name: CharacterName
) -> discord.Embed:
"""Returns the embed message for similar moves."""
Expand All @@ -26,7 +26,7 @@ def similar_moves_embed(
return embed


def move_list_embed(
def get_move_list_embed(
frame_service: FrameService, character: Character, moves: List[Move], move_type: MoveType
) -> discord.Embed:
"""Returns the embed message for a list of moves matching to a special move type."""
Expand All @@ -41,17 +41,17 @@ def move_list_embed(
return embed


def error_embed(message) -> discord.Embed:
def get_error_embed(message) -> discord.Embed:
embed = discord.Embed(title="Error", colour=ERROR_COLOR, description=message)
return embed


def success_embed(message) -> discord.Embed:
def get_success_embed(message) -> discord.Embed:
embed = discord.Embed(title="Success", colour=SUCCESS_COLOR, description=message)
return embed


def move_embed(frame_service: FrameService, character: Character, move: Move) -> discord.Embed:
def get_move_embed(frame_service: FrameService, character: Character, move: Move) -> discord.Embed:
"""Returns the embed message for character and move."""

embed = discord.Embed(
Expand All @@ -78,28 +78,27 @@ def move_embed(frame_service: FrameService, character: Character, move: Move) ->
return embed


def create_frame_data_embed(name: str, move: str) -> discord.Embed:
character_name = util.correct_character_name(name.lower()) # TODO: fix all this
if character_name:
character = util.get_character_by_name(character_name, character_list)
assert character is not None
move_list = json_directory.get_movelist(character_name, JSON_PATH)
move_type = util.get_move_type(move)
def get_frame_data_embed(framedb: FrameDb, frame_service: FrameService, char_query: str, move_query: str) -> discord.Embed:
"""Creates an embed for the frame data of a character and move."""

character = framedb.get_character_by_name(char_query)
if character:
move_type = framedb.get_move_type(move_query)

if move_type:
moves = json_directory.get_by_move_type(move_type, move_list)
moves_embed = embed.move_list_embed(character, moves, move_type)
return moves_embed
moves = framedb.get_moves_by_move_type(character.name, move_type.value)
moves_embed = get_move_list_embed(frame_service, character, moves, move_type)
embed = moves_embed
else:
character_move = json_directory.get_move(move, move_list)
character_move = framedb.get_move_by_input(character.name, move_query)

if character_move:
move_embed = embed.move_embed(character, character_move)
return move_embed
move_embed = get_move_embed(frame_service, character, character_move)
embed = move_embed
else:
similar_moves = json_directory.get_similar_moves(move, move_list)
similar_moves_embed = embed.similar_moves_embed(similar_moves, character_name)
return similar_moves_embed
similar_moves = framedb.get_similar_moves(character.name, move_query)
similar_moves_embed = get_similar_moves_embed(frame_service, similar_moves, character.name)
embed = similar_moves_embed
else:
error_embed = embed.error_embed(f"Could not locate character {name}.")
return error_embed
embed = get_error_embed(f"Could not locate character {char_query}.")
return embed
5 changes: 4 additions & 1 deletion src/heihachi/tests/test_bot.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import pytest

# TODO: Add tests for the bot

@pytest.mark.skip(reason="Not implemented")
def test_create_bot():
pass
8 changes: 7 additions & 1 deletion src/heihachi/tests/test_configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@

@pytest.fixture
def config():
return Configurator(discord_token="123456789", feedback_channel_id=123456789, action_channel_id=987654321)
return Configurator(
discord_token="123456789",
feedback_channel_id=123456789,
action_channel_id=987654321,
blacklist=None,
id_blacklist=None,
)


def test_from_file():
Expand Down
Loading

0 comments on commit f9242d8

Please sign in to comment.