Skip to content

Commit

Permalink
Address multiple issues with AST traversal. (VirusTotal#51)
Browse files Browse the repository at this point in the history
* This commit address multiple issues with AST traversal.


* Removes the `Quantifier` type, which is not really required and only adds an additional level in AST trees.
* Removes `pb/traversal.go` which is the legacy AST traversal code, based in protobuf.
* Force all AST expression types to explicitly implement the Expression interface by removing embedded structs.

* Always return nil when the node has no children, never an empty list.
  • Loading branch information
plusvic authored Feb 16, 2022
1 parent 569d025 commit 56b111c
Show file tree
Hide file tree
Showing 6 changed files with 553 additions and 794 deletions.
151 changes: 76 additions & 75 deletions ast/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ import (

// Node is the interface implemented by all types of nodes in the AST.
type Node interface {
// Writes the source of the node to a writer.
// WriteSource writes the source of the node to a writer.
WriteSource(io.Writer) error
// Returns the node's children. The children are returned left to right,
// if the node represents the operation A + B + C, the children will
// appear as A, B, C.
// Children returns the node's children. The children are returned left to
// right, if the node represents the operation A + B + C, the children will
// appear as A, B, C. The result can be nil if the Node does not have
// children.
Children() []Node
}

Expand All @@ -42,9 +43,9 @@ const (
KeywordTrue Keyword = "true"
)

// Group is an Expression that encloses another Expression in parenthesis.
// Group is an Expression that encloses another Expression in parentheses.
type Group struct {
Expression
Expression Expression
}

// LiteralInteger is an Expression that represents a literal integer.
Expand Down Expand Up @@ -84,22 +85,22 @@ type LiteralRegexp struct {

// Minus is an Expression that represents the unary minus operation.
type Minus struct {
Expression
Expression Expression
}

// Not is an Expression that represents the "not" operation.
type Not struct {
Expression
Expression Expression
}

// Defined is an Expression that represents the "defined" operation.
type Defined struct {
Expression
Expression Expression
}

// BitwiseNot is an Expression that represents the bitwise not operation.
type BitwiseNot struct {
Expression
Expression Expression
}

// Range is a Node that represents an integer range. Example: (1..10).
Expand Down Expand Up @@ -178,22 +179,16 @@ type Subscripting struct {
Index Expression
}

// Quantifier is an Expression used in for loops, it can be either a numeric
// expression or the keywords "any" or "all".
type Quantifier struct {
Expression
}

// Percentage is an Expression used in evaluating string sets. Example:
// <expression>% of <string set>
type Percentage struct {
Expression
Expression Expression
}

// ForIn is an Expression representing a "for in" loop. Example:
// for <quantifier> <variables> in <iterator> : ( <condition> )
type ForIn struct {
Quantifier *Quantifier
Quantifier Expression
Variables []string
Iterator Node
Condition Expression
Expand All @@ -202,7 +197,7 @@ type ForIn struct {
// ForOf is an Expression representing a "for of" loop. Example:
// for <quantifier> of <string_set> : ( <condition> )
type ForOf struct {
Quantifier *Quantifier
Quantifier Expression
Strings Node
Condition Expression
}
Expand All @@ -212,7 +207,7 @@ type ForOf struct {
// <quantifier> of <string_set> in <range>
// If "In" is non-nil there is an "in" condition: 3 of them in (0..100)
type Of struct {
Quantifier *Quantifier
Quantifier Expression
Strings Node
Rules Node
In *Range
Expand Down Expand Up @@ -573,35 +568,41 @@ func (o *Operation) WriteSource(w io.Writer) error {
return nil
}

// Children returns an empty list of nodes as a keyword never has children,
// this function is required anyways in order to satisfy the Node interface.
// Children returns nil as a keyword never has children, this function is
// required anyways in order to satisfy the Node interface.
func (k Keyword) Children() []Node {
return []Node{}
return nil
}

// Children returns the group's children, which is the expression inside the
// group.
func (g *Group) Children() []Node {
return []Node{g.Expression}
}

// Children returns the Node's children.
func (l *LiteralInteger) Children() []Node {
return []Node{}
return nil
}

// Children returns the Node's children.
func (l *LiteralFloat) Children() []Node {
return []Node{}
return nil
}

// Children returns the Node's children.
func (l *LiteralString) Children() []Node {
return []Node{}
return nil
}

// Children returns the Node's children.
func (l *LiteralRegexp) Children() []Node {
return []Node{}
return nil
}

// Children returns the Node's children.
func (i *Identifier) Children() []Node {
return []Node{}
return nil
}

// Children returns the Node's children.
Expand All @@ -620,6 +621,9 @@ func (e *Enum) Children() []Node {

// Children returns the Node's children.
func (s *StringIdentifier) Children() []Node {
if s.At == nil && s.In == nil {
return nil
}
children := make([]Node, 0)
if s.At != nil {
children = append(children, s.At)
Expand All @@ -632,29 +636,26 @@ func (s *StringIdentifier) Children() []Node {

// Children returns the Node's children.
func (s *StringCount) Children() []Node {
nodes := []Node{}

if s.In != nil {
nodes = append(nodes, s.In)
return []Node{s.In}
}

return nodes
return nil
}

// Children returns the Node's children.
func (s *StringOffset) Children() []Node {
if s.Index != nil {
return []Node{s.Index}
}
return []Node{}
return nil
}

// Children returns the Node's children.
func (s *StringLength) Children() []Node {
if s.Index != nil {
return []Node{s.Index}
}
return []Node{}
return nil
}

// Children returns the Node's children.
Expand Down Expand Up @@ -693,7 +694,6 @@ func (o *Of) Children() []Node {
if o.Rules != nil {
nodes = append(nodes, o.Rules)
}

return nodes
}

Expand All @@ -706,6 +706,26 @@ func (o *Operation) Children() []Node {
return nodes
}

func (n *Not) Children() []Node {
return []Node{n.Expression}
}

func (m *Minus) Children() []Node {
return []Node{m.Expression}
}

func (b *BitwiseNot) Children() []Node {
return []Node{b.Expression}
}

func (d *Defined) Children() []Node {
return []Node{d.Expression}
}

func (p *Percentage) Children() []Node {
return []Node{p.Expression}
}

// AsProto returns the Expression serialized as a pb.Expression.
func (k Keyword) AsProto() *pb.Expression {
switch k {
Expand Down Expand Up @@ -734,6 +754,10 @@ func (k Keyword) AsProto() *pb.Expression {
}
}

func (g *Group) AsProto() *pb.Expression {
return g.Expression.AsProto()
}

// AsProto returns the Expression serialized as a pb.Expression.
func (l *LiteralInteger) AsProto() *pb.Expression {
return &pb.Expression{
Expand Down Expand Up @@ -800,6 +824,17 @@ func (d *Defined) AsProto() *pb.Expression {
}
}

func (b *BitwiseNot) AsProto() *pb.Expression {
return &pb.Expression{
Expression: &pb.Expression_UnaryExpression{
UnaryExpression: &pb.UnaryExpression{
Operator: pb.UnaryExpression_BITWISE_NOT.Enum(),
Expression: b.Expression.AsProto(),
},
},
}
}

// AsProto returns the Expression serialized as a pb.Expression.
func (n *Not) AsProto() *pb.Expression {
return &pb.Expression{
Expand Down Expand Up @@ -1002,42 +1037,6 @@ func (p *Percentage) AsProto() *pb.Expression {
}
}

// AsProto returns the Expression serialized as a pb.Expression.
func (q *Quantifier) AsProto() *pb.ForExpression {
var expr *pb.ForExpression
switch v := q.Expression.(type) {
case *Percentage:
expr = &pb.ForExpression{
For: &pb.ForExpression_Expression{
Expression: v.AsProto(),
},
}
case Keyword:
var pbkw pb.ForKeyword
if v == KeywordAll {
pbkw = pb.ForKeyword_ALL
} else if v == KeywordAny {
pbkw = pb.ForKeyword_ANY
} else if v == KeywordNone {
pbkw = pb.ForKeyword_NONE
} else {
panic(fmt.Sprintf("unexpected keyword in for: %s", v))
}
expr = &pb.ForExpression{
For: &pb.ForExpression_Keyword{
Keyword: pbkw,
},
}
default:
expr = &pb.ForExpression{
For: &pb.ForExpression_Expression{
Expression: q.Expression.AsProto(),
},
}
}
return expr
}

// AsProto returns the Expression serialized as a pb.Expression.
func (f *ForIn) AsProto() *pb.Expression {
var iterator *pb.Iterator
Expand Down Expand Up @@ -1075,7 +1074,7 @@ func (f *ForIn) AsProto() *pb.Expression {
return &pb.Expression{
Expression: &pb.Expression_ForInExpression{
ForInExpression: &pb.ForInExpression{
ForExpression: f.Quantifier.AsProto(),
ForExpression: quantifierToProto(f.Quantifier),
Identifiers: f.Variables,
Iterator: iterator,
Expression: f.Condition.AsProto(),
Expand Down Expand Up @@ -1117,7 +1116,7 @@ func (f *ForOf) AsProto() *pb.Expression {
return &pb.Expression{
Expression: &pb.Expression_ForOfExpression{
ForOfExpression: &pb.ForOfExpression{
ForExpression: f.Quantifier.AsProto(),
ForExpression: quantifierToProto(f.Quantifier),
StringSet: s,
Expression: f.Condition.AsProto(),
},
Expand Down Expand Up @@ -1187,7 +1186,7 @@ func (o *Of) AsProto() *pb.Expression {
return &pb.Expression{
Expression: &pb.Expression_ForOfExpression{
ForOfExpression: &pb.ForOfExpression{
ForExpression: o.Quantifier.AsProto(),
ForExpression: quantifierToProto(o.Quantifier),
StringSet: s,
Range: r,
RuleEnumeration: rule_enumeration,
Expand Down Expand Up @@ -1234,3 +1233,5 @@ func (o *Operation) AsProto() *pb.Expression {
}
return expr
}


39 changes: 37 additions & 2 deletions ast/serialization.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ func enumFromProto(e *pb.IntegerEnumeration) *Enum {
}
}

func quantifierFromProto(expr *pb.ForExpression) *Quantifier {
func quantifierFromProto(expr *pb.ForExpression) Expression {
if expr == nil {
return nil
}
Expand All @@ -304,7 +304,42 @@ func quantifierFromProto(expr *pb.ForExpression) *Quantifier {
case *pb.ForExpression_Expression:
q = expressionFromProto(v.Expression)
}
return &Quantifier{q}
return q
}

func quantifierToProto(expr Expression) *pb.ForExpression {
var quantifier *pb.ForExpression
switch v := expr.(type) {
case *Percentage:
quantifier = &pb.ForExpression{
For: &pb.ForExpression_Expression{
Expression: v.AsProto(),
},
}
case Keyword:
var pbkw pb.ForKeyword
if v == KeywordAll {
pbkw = pb.ForKeyword_ALL
} else if v == KeywordAny {
pbkw = pb.ForKeyword_ANY
} else if v == KeywordNone {
pbkw = pb.ForKeyword_NONE
} else {
panic(fmt.Sprintf("unexpected keyword in for: %s", v))
}
quantifier = &pb.ForExpression{
For: &pb.ForExpression_Keyword{
Keyword: pbkw,
},
}
default:
quantifier = &pb.ForExpression{
For: &pb.ForExpression_Expression{
Expression: expr.AsProto(),
},
}
}
return quantifier
}

func forInExpressionFromProto(expr *pb.ForInExpression) *ForIn {
Expand Down
Loading

0 comments on commit 56b111c

Please sign in to comment.