diff --git a/plyj/model.py b/plyj/model.py index 446b38d..fa9b4cb 100644 --- a/plyj/model.py +++ b/plyj/model.py @@ -643,15 +643,6 @@ def __init__(self, block, catches=None, _finally=None, resources=None): self._finally = _finally self.resources = resources - def accept(self, visitor): - if visitor.visit_Try(self): - for s in self.block: - s.accept(visitor) - for c in self.catches: - visitor.visit_Catch(c) - if self._finally: - self._finally.accept(visitor) - class Catch(SourceElement): diff --git a/test/visitors.py b/test/visitors.py new file mode 100644 index 0000000..d3a0189 --- /dev/null +++ b/test/visitors.py @@ -0,0 +1,34 @@ +import unittest + +import plyj.parser as plyj +import plyj.model as model + + +class VisitorTest(unittest.TestCase): + + def setUp(self): + self.parser = plyj.Parser() + + def test_visit_expressions_in_try_catch(self): + statement = ''' + try { + b = c; + } catch(Exception e) { + method(1, 2); + } + ''' + s = self.parser.parse_statement(statement) + + class TestVisitor(model.Visitor): + def __init__(self): + super(TestVisitor, self).__init__() + self._count = 0 + + def visit_ExpressionStatement(self, expression_statement): + self._count += 1 + return True + + visitor = TestVisitor() + s.accept(visitor) + self.assertEqual(visitor._count, 2, + 'for {} \nNumber of Expressions got: {}, expected: {}'.format(statement, visitor._count, 2))