Skip to content

Commit

Permalink
Add support for more dice, larger dice and function calls
Browse files Browse the repository at this point in the history
  • Loading branch information
tulir committed Jan 9, 2019
1 parent 0e30d1b commit de2c38e
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 14 deletions.
82 changes: 69 additions & 13 deletions dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,17 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Match
from typing import Match, Union, Any
import operator
import random
import math
import ast
import re

from maubot import Plugin, MessageEvent
from maubot.handlers import command

ARG_PATTERN = "$pattern"
COMMAND_ROLL = f"roll {ARG_PATTERN}"
COMMAND_ROLL_DEFAULT = "roll"

pattern_regex = re.compile("([0-9]{0,2})d([0-9]{1,2})")
pattern_regex = re.compile("([0-9]{0,9})d([0-9]{1,9})")

_OP_MAP = {
ast.Add: operator.add,
Expand All @@ -46,6 +43,9 @@
ast.LShift: operator.lshift,
}

_NUM_MAX = 1_000_000_000_000_000
_NUM_MIN = -_NUM_MAX

_OP_LIMITS = {
ast.Pow: (1000, 1000),
ast.LShift: (1000, 1000),
Expand All @@ -55,10 +55,27 @@
ast.Mod: (1_000_000_000_000_000, 1_000_000_000_000_000),
}

_ALLOWED_FUNCS = ["ceil", "copysign", "fabs", "factorial", "gcd", "remainder", "trunc",
"exp", "log", "log1p", "log2", "log10", "sqrt",
"acos", "asin", "atan", "atan2", "cos", "hypot", "sin", "tan",
"degrees", "radians",
"acosh", "asinh", "atanh", "cosh", "sinh", "tanh",
"erf", "erfc", "gamma", "lgamma"]

_FUNC_MAP = {func: getattr(math, func) for func in _ALLOWED_FUNCS if hasattr(math, func)}

_FUNC_LIMITS = {
"factorial": 1000,
"exp": 709,
"sqrt": 1_000_000_000_000_000,
}

_ARG_COUNT_LIMIT = 5


# AST-based calculator from https://stackoverflow.com/a/33030616/2120293
class Calc(ast.NodeVisitor):
def visit_BinOp(self, node):
def visit_BinOp(self, node: ast.BinOp) -> Any:
left = self.visit(node.left)
right = self.visit(node.right)
op_type = type(node.op)
Expand All @@ -74,22 +91,55 @@ def visit_BinOp(self, node):
raise SyntaxError(f"Operator {op_type.__name__} not allowed")
return op(left, right)

def visit_UnaryOp(self, node):
def visit_UnaryOp(self, node: ast.UnaryOp) -> Any:
operand = self.visit(node.operand)
try:
op = _OP_MAP[type(node.op)]
except KeyError:
raise SyntaxError(f"Operator {type(node.op).__name__} not allowed")
return op(operand)

def visit_Num(self, node):
def visit_Num(self, node: ast.Num) -> Any:
if node.n > _NUM_MAX or node.n < _NUM_MIN:
raise ValueError(f"Number out of bounds")
return node.n

def visit_Expr(self, node):
def visit_Name(self, node: ast.Name) -> Any:
if node.id == "pi":
return math.pi
elif node.id == "tau":
return math.tau
elif node.id == "e":
return math.e

def visit_Call(self, node: ast.Call) -> Any:
if isinstance(node.func, ast.Name):
try:
func = _FUNC_MAP[node.func.id]
except KeyError:
raise NameError(f"Function {node.func.id} is not defined")
args = [self.visit(arg) for arg in node.args]
kwargs = {kwarg.arg: self.visit(kwarg.value) for kwarg in node.keywords}
if len(args) + len(kwargs) > _ARG_COUNT_LIMIT:
raise ValueError("Too many arguments")
try:
limit = _FUNC_LIMITS[node.func.id]
for value in args:
if value > limit:
raise ValueError(f"Value over bounds for function {node.func.id}")
for value in kwargs.values():
if value > limit:
raise ValueError(f"Value over bounds for function {node.func.id}")
except KeyError:
pass
return func(*args, **kwargs)
raise SyntaxError("Indirect call")

def visit_Expr(self, node: ast.Expr) -> Any:
return self.visit(node.value)

@classmethod
def evaluate(cls, expression):
def evaluate(cls, expression: str) -> Union[int, float]:
tree = ast.parse(expression)
return cls().visit(tree.body[0])

Expand All @@ -104,8 +154,14 @@ def randomize(number: int, size: int) -> int:
elif size == 1:
return number
result = 0
for i in range(number):
result += random.randint(1, size)
if number < 100:
for i in range(number):
result += random.randint(1, size)
else:
mean = number * (size + 1) / 2
variance = number * (size ** 2 - 1) / 12
while result < number or result > number * size:
result = int(random.gauss(mean, math.sqrt(variance)))
return result

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion maubot.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
maubot: 0.1.0
id: xyz.maubot.dice
version: 1.0.0
license: AGPL-3.0-or-later
modules:
- dice
main_class: DiceBot
database: true

0 comments on commit de2c38e

Please sign in to comment.