diff --git a/src/main/java/org/biscuitsec/biscuit/token/builder/parser/Parser.java b/src/main/java/org/biscuitsec/biscuit/token/builder/parser/Parser.java index dbd1c2c8..724fd8aa 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/builder/parser/Parser.java +++ b/src/main/java/org/biscuitsec/biscuit/token/builder/parser/Parser.java @@ -1,7 +1,9 @@ package org.biscuitsec.biscuit.token.builder.parser; import biscuit.format.schema.Schema; +import io.vavr.collection.Stream; import org.biscuitsec.biscuit.crypto.PublicKey; +import org.biscuitsec.biscuit.datalog.SymbolTable; import org.biscuitsec.biscuit.token.Policy; import io.vavr.Tuple2; import io.vavr.Tuple4; @@ -10,12 +12,48 @@ import java.time.OffsetDateTime; import java.time.format.DateTimeParseException; -import java.util.ArrayList; -import java.util.List; -import java.util.HashSet; +import java.util.*; import java.util.function.Function; public class Parser { + /** + * Takes a datalog string with \n as datalog line separator. It tries to parse + * each line using fact, rule, check and scope sequentially. + * + * If one succeed it returns Right(Block) + * else it returns a Map[lineNumber, List[Error]] + * + * @param index block index + * @param baseSymbols symbols table + * @param s datalog string to parse + * @return Either>, Block> + */ + public static Either>, Block> datalog(long index, SymbolTable baseSymbols, String s) { + Block blockBuilder = new Block(index, baseSymbols); + Map> errors = new HashMap<>(); + + Stream.of(s.split("\n")).zipWithIndex().forEach(indexedLine -> { + Integer lineNumber = indexedLine._2; + String codeLine = indexedLine._1; + List lineErrors = new ArrayList<>(); + + fact(codeLine).bimap(lineErrors::add, r -> r._2).map(blockBuilder::add_fact); + rule(codeLine).bimap(lineErrors::add, r -> r._2).map(blockBuilder::add_rule); + check(codeLine).bimap(lineErrors::add, r -> r._2).map(blockBuilder::add_check); + scope(codeLine).bimap(lineErrors::add, r -> r._2).map(blockBuilder::add_scope); + + if (lineErrors.size() > 3) { + errors.put(lineNumber, lineErrors); + } + }); + + if (!errors.isEmpty()) { + return Either.left(errors); + } + + return Either.right(blockBuilder); + } + public static Either> fact(String s) { Either> res = fact_predicate(s); if (res.isLeft()) { diff --git a/src/test/java/org/biscuitsec/biscuit/builder/parser/ParserTest.java b/src/test/java/org/biscuitsec/biscuit/builder/parser/ParserTest.java index 3c76890b..fff15f76 100644 --- a/src/test/java/org/biscuitsec/biscuit/builder/parser/ParserTest.java +++ b/src/test/java/org/biscuitsec/biscuit/builder/parser/ParserTest.java @@ -5,6 +5,7 @@ import org.biscuitsec.biscuit.datalog.SymbolTable; import org.biscuitsec.biscuit.datalog.TemporarySymbolTable; import org.biscuitsec.biscuit.datalog.expressions.Op; +import org.biscuitsec.biscuit.token.Biscuit; import org.biscuitsec.biscuit.token.builder.parser.Error; import org.biscuitsec.biscuit.token.builder.parser.Parser; import io.vavr.Tuple2; @@ -15,7 +16,7 @@ import org.junit.jupiter.api.Test; import static org.biscuitsec.biscuit.datalog.Check.Kind.One; -import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.*; import java.util.*; @@ -360,4 +361,40 @@ void testParens() throws org.biscuitsec.biscuit.error.Error.Execution { assertEquals(new org.biscuitsec.biscuit.datalog.Term.Integer(9), value2); assertEquals("(1 + 2) * 3", ex2.print(s2).get()); } + + @Test + void testDatalogSucceeds() throws org.biscuitsec.biscuit.error.Error.Parser { + SymbolTable symbols = Biscuit.default_symbol_table(); + + String l1 = "fact1(1)"; + String l2 = "fact2(\"2\")"; + String l3 = "rule1(2) <- fact2(\"2\")"; + String l4 = "check if rule1(2)"; + String toParse = String.join("\n", Arrays.asList(l1, l2, l3, l4)); + + Either>, Block> output = Parser.datalog(1, symbols, toParse); + assertTrue(output.isRight()); + + Block validBlock = new Block(1, symbols); + validBlock.add_fact(l1); + validBlock.add_fact(l2); + validBlock.add_rule(l3); + validBlock.add_check(l4); + + output.forEach(block -> + assertArrayEquals(block.build().to_bytes().get(), validBlock.build().to_bytes().get()) + ); + } + + @Test + void testDatalogFailed() { + SymbolTable symbols = Biscuit.default_symbol_table(); + + String l1 = "fact(1)"; + String l2 = "check fact(1)"; // typo missing "if" + String toParse = String.join("\n", Arrays.asList(l1, l2)); + + Either>, Block> output = Parser.datalog(1, symbols, toParse); + assertTrue(output.isLeft()); + } } \ No newline at end of file