diff --git a/src/evaluator/evaluator.go b/src/evaluator/evaluator.go index 08a6536..cd6d735 100644 --- a/src/evaluator/evaluator.go +++ b/src/evaluator/evaluator.go @@ -12,39 +12,47 @@ var ( FALSE = &object.Boolean{Value: false} ) -func Eval(node ast.Node) object.Object { +func Eval(node ast.Node, env *object.Environment) object.Object { switch node := node.(type) { // Statements case *ast.Program: - return evalProgram(node) + return evalProgram(node, env) case *ast.ExpressionStatement: - return Eval(node.Expression) + return Eval(node.Expression, env) case *ast.PrefixExpression: - right := Eval(node.Right) + right := Eval(node.Right, env) if isError(right) { return right } return evalPrefixExpression(node.Operator, right) case *ast.InfixExpression: - left := Eval(node.Left) + left := Eval(node.Left, env) if isError(left) { return left } - right := Eval(node.Right) + right := Eval(node.Right, env) if isError(right) { return right } return evalInfixExpression(node.Operator, left, right) case *ast.BlockStatement: - return evalBlockStatement(node) + return evalBlockStatement(node, env) case *ast.IfExpression: - return evalIfExpression(node) + return evalIfExpression(node, env) case *ast.ReturnStatement: - val := Eval(node.ReturnValue) + val := Eval(node.ReturnValue, env) if isError(val) { return val } return &object.ReturnValue{Value: val} + case *ast.Identifier: + return evalIdentifier(node, env) + case *ast.AssigmentStatement: + val := Eval(node.Value, env) + if isError(val) { + return val + } + env.Set(node.Name.Value, val) // Expressions case *ast.IntegerLiteral: return &object.Integer{Value: node.Value} @@ -54,10 +62,10 @@ func Eval(node ast.Node) object.Object { return nil } -func evalProgram(program *ast.Program) object.Object { +func evalProgram(program *ast.Program, env *object.Environment) object.Object { var result object.Object for _, statement := range program.Statements { - result = Eval(statement) + result = Eval(statement, env) switch result := result.(type) { case *object.ReturnValue: return result.Value @@ -68,10 +76,10 @@ func evalProgram(program *ast.Program) object.Object { return result } -func evalBlockStatement(block *ast.BlockStatement) object.Object { +func evalBlockStatement(block *ast.BlockStatement, env *object.Environment) object.Object { var result object.Object for _, statement := range block.Statements { - result = Eval(statement) + result = Eval(statement, env) if result != nil { rt := result.Type() if rt == object.RETURN_VALUE_OBJ || rt == object.ERROR_OBJ { @@ -170,15 +178,15 @@ func evalIntegerInfixExpression( } } -func evalIfExpression(ie *ast.IfExpression) object.Object { - condition := Eval(ie.Condition) +func evalIfExpression(ie *ast.IfExpression, env *object.Environment) object.Object { + condition := Eval(ie.Condition, env) if isError(condition) { return condition } if isTruthy(condition) { - return Eval(ie.Consequence) + return Eval(ie.Consequence, env) } else if ie.Alternative != nil { - return Eval(ie.Alternative) + return Eval(ie.Alternative, env) } else { return NULL } @@ -206,3 +214,14 @@ func isError(obj object.Object) bool { } return false } + +func evalIdentifier( + node *ast.Identifier, + env *object.Environment, +) object.Object { + val, ok := env.Get(node.Value) + if !ok { + return newError("identifier not found: " + node.Value) + } + return val +} diff --git a/src/evaluator/evaluator_test.go b/src/evaluator/evaluator_test.go index 36801f4..c564ac1 100644 --- a/src/evaluator/evaluator_test.go +++ b/src/evaluator/evaluator_test.go @@ -11,7 +11,8 @@ func testEval(input string) object.Object { l := lexer.New([]byte(input)) p := parser.New(l) program := p.ParseProgram() - return Eval(program) + env := object.NewEnvironment() + return Eval(program, env) } func testIntegerObject(t *testing.T, obj object.Object, expected int64) bool { @@ -197,13 +198,17 @@ func TestErrorHandling(t *testing.T) { ` if (10 > 1) { if (10 > 1) { - return zoona + bodza; + bweza zoona + bodza; } bweza 1; } `, "unknown operator: BOOLEAN + BOOLEAN", }, + { + "foobar", + "identifier not found: foobar", + }, } for _, tt := range tests { evaluated := testEval(tt.input) @@ -219,3 +224,18 @@ func TestErrorHandling(t *testing.T) { } } } + +func TestAssignmentStatements(t *testing.T) { + tests := []struct { + input string + expected int64 + }{ + {"nambala a = 5; a;", 5}, + {"nambala a = 5 * 5; a;", 25}, + {"nambala a = 5; nambala b = a; b;", 5}, + {"nambala a = 5; nambala b = a; nambala c = a + b + 5; c;", 15}, + } + for _, tt := range tests { + testIntegerObject(t, testEval(tt.input), tt.expected) + } +} diff --git a/src/object/object.go b/src/object/object.go index 48bcaf8..d6c5efb 100644 --- a/src/object/object.go +++ b/src/object/object.go @@ -50,3 +50,21 @@ type Error struct { func (e *Error) Type() ObjectType { return ERROR_OBJ } func (e *Error) Inspect() string { return "ERROR: " + e.Message } + +func NewEnvironment() *Environment { + s := make(map[string]Object) + return &Environment{store: s} +} + +type Environment struct { + store map[string]Object +} + +func (e *Environment) Get(name string) (Object, bool) { + obj, ok := e.store[name] + return obj, ok +} +func (e *Environment) Set(name string, val Object) Object { + e.store[name] = val + return val +} diff --git a/src/repl/repl.go b/src/repl/repl.go index 147a92c..f7435ea 100644 --- a/src/repl/repl.go +++ b/src/repl/repl.go @@ -4,6 +4,7 @@ import ( "bufio" "fmt" "github.com/sevenreup/chewa/src/evaluator" + "github.com/sevenreup/chewa/src/object" "github.com/sevenreup/chewa/src/parser" "io" @@ -15,6 +16,7 @@ const ERROR_HEDEAR = "Errorr!!" func Start(in io.Reader, out io.Writer) { scanner := bufio.NewScanner(in) + env := object.NewEnvironment() for { fmt.Printf(PROMPT) scanned := scanner.Scan() @@ -29,7 +31,7 @@ func Start(in io.Reader, out io.Writer) { printParserErrors(out, p.Errors()) continue } - evaluated := evaluator.Eval(program) + evaluated := evaluator.Eval(program, env) if evaluated != nil { io.WriteString(out, evaluated.Inspect()) io.WriteString(out, "\n")