From de2c38eaa622468abf2da8f46a9d1c7e6c87517f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 9 Jan 2019 23:43:13 +0200 Subject: [PATCH] Add support for more dice, larger dice and function calls --- dice.py | 82 ++++++++++++++++++++++++++++++++++++++++++++--------- maubot.yaml | 2 +- 2 files changed, 70 insertions(+), 14 deletions(-) diff --git a/dice.py b/dice.py index 4a1dfa6..5252e85 100644 --- a/dice.py +++ b/dice.py @@ -13,20 +13,17 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Match +from typing import Match, Union, Any import operator import random +import math import ast import re from maubot import Plugin, MessageEvent from maubot.handlers import command -ARG_PATTERN = "$pattern" -COMMAND_ROLL = f"roll {ARG_PATTERN}" -COMMAND_ROLL_DEFAULT = "roll" - -pattern_regex = re.compile("([0-9]{0,2})d([0-9]{1,2})") +pattern_regex = re.compile("([0-9]{0,9})d([0-9]{1,9})") _OP_MAP = { ast.Add: operator.add, @@ -46,6 +43,9 @@ ast.LShift: operator.lshift, } +_NUM_MAX = 1_000_000_000_000_000 +_NUM_MIN = -_NUM_MAX + _OP_LIMITS = { ast.Pow: (1000, 1000), ast.LShift: (1000, 1000), @@ -55,10 +55,27 @@ ast.Mod: (1_000_000_000_000_000, 1_000_000_000_000_000), } +_ALLOWED_FUNCS = ["ceil", "copysign", "fabs", "factorial", "gcd", "remainder", "trunc", + "exp", "log", "log1p", "log2", "log10", "sqrt", + "acos", "asin", "atan", "atan2", "cos", "hypot", "sin", "tan", + "degrees", "radians", + "acosh", "asinh", "atanh", "cosh", "sinh", "tanh", + "erf", "erfc", "gamma", "lgamma"] + +_FUNC_MAP = {func: getattr(math, func) for func in _ALLOWED_FUNCS if hasattr(math, func)} + +_FUNC_LIMITS = { + "factorial": 1000, + "exp": 709, + "sqrt": 1_000_000_000_000_000, +} + +_ARG_COUNT_LIMIT = 5 + # AST-based calculator from https://stackoverflow.com/a/33030616/2120293 class Calc(ast.NodeVisitor): - def visit_BinOp(self, node): + def visit_BinOp(self, node: ast.BinOp) -> Any: left = self.visit(node.left) right = self.visit(node.right) op_type = type(node.op) @@ -74,7 +91,7 @@ def visit_BinOp(self, node): raise SyntaxError(f"Operator {op_type.__name__} not allowed") return op(left, right) - def visit_UnaryOp(self, node): + def visit_UnaryOp(self, node: ast.UnaryOp) -> Any: operand = self.visit(node.operand) try: op = _OP_MAP[type(node.op)] @@ -82,14 +99,47 @@ def visit_UnaryOp(self, node): raise SyntaxError(f"Operator {type(node.op).__name__} not allowed") return op(operand) - def visit_Num(self, node): + def visit_Num(self, node: ast.Num) -> Any: + if node.n > _NUM_MAX or node.n < _NUM_MIN: + raise ValueError(f"Number out of bounds") return node.n - def visit_Expr(self, node): + def visit_Name(self, node: ast.Name) -> Any: + if node.id == "pi": + return math.pi + elif node.id == "tau": + return math.tau + elif node.id == "e": + return math.e + + def visit_Call(self, node: ast.Call) -> Any: + if isinstance(node.func, ast.Name): + try: + func = _FUNC_MAP[node.func.id] + except KeyError: + raise NameError(f"Function {node.func.id} is not defined") + args = [self.visit(arg) for arg in node.args] + kwargs = {kwarg.arg: self.visit(kwarg.value) for kwarg in node.keywords} + if len(args) + len(kwargs) > _ARG_COUNT_LIMIT: + raise ValueError("Too many arguments") + try: + limit = _FUNC_LIMITS[node.func.id] + for value in args: + if value > limit: + raise ValueError(f"Value over bounds for function {node.func.id}") + for value in kwargs.values(): + if value > limit: + raise ValueError(f"Value over bounds for function {node.func.id}") + except KeyError: + pass + return func(*args, **kwargs) + raise SyntaxError("Indirect call") + + def visit_Expr(self, node: ast.Expr) -> Any: return self.visit(node.value) @classmethod - def evaluate(cls, expression): + def evaluate(cls, expression: str) -> Union[int, float]: tree = ast.parse(expression) return cls().visit(tree.body[0]) @@ -104,8 +154,14 @@ def randomize(number: int, size: int) -> int: elif size == 1: return number result = 0 - for i in range(number): - result += random.randint(1, size) + if number < 100: + for i in range(number): + result += random.randint(1, size) + else: + mean = number * (size + 1) / 2 + variance = number * (size ** 2 - 1) / 12 + while result < number or result > number * size: + result = int(random.gauss(mean, math.sqrt(variance))) return result @classmethod diff --git a/maubot.yaml b/maubot.yaml index b4eaf59..5e50ccd 100644 --- a/maubot.yaml +++ b/maubot.yaml @@ -1,7 +1,7 @@ +maubot: 0.1.0 id: xyz.maubot.dice version: 1.0.0 license: AGPL-3.0-or-later modules: - dice main_class: DiceBot -database: true