Skip to content

Commit

Permalink
Create C++ types for Souffle values which can be printed to Datalog. (#…
Browse files Browse the repository at this point in the history
…528)

This change adds C++ types which directly map to Souffle's types. These
types use templates to allow a user to create C++ types that correspond
directly to Souffle types and which the C++ compiler can type check at
static time. Because some of the types that we create are rather
complex, involving both ADTs and records, this should be useful to catch
errors before having to run the Souffle compiler.

Closes #528

COPYBARA_INTEGRATE_REVIEW=#528 from google-research:souffle_base_value@winterrowd 3b54cb8
PiperOrigin-RevId: 443699109
  • Loading branch information
markww authored and arcs-c3po committed Apr 22, 2022
1 parent 433db3a commit 67d8731
Show file tree
Hide file tree
Showing 3 changed files with 340 additions and 0 deletions.
23 changes: 23 additions & 0 deletions src/ir/datalog/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,26 @@ cc_test(
"//src/common/testing:gtest",
],
)

# This library contains our C++ model of the value types of Souffle.
cc_library(
name = "value",
hdrs = [
"value.h",
],
deps = [
"//src/common/logging",
"@absl//absl/strings",
"@absl//absl/strings:str_format",
],
)

cc_test(
name = "value_test",
srcs = ["value_test.cc"],
deps = [
":value",
"//src/common/testing:gtest",
"@absl//absl/strings",
],
)
114 changes: 114 additions & 0 deletions src/ir/datalog/value.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
//-----------------------------------------------------------------------------
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//----------------------------------------------------------------------------

#ifndef SRC_IR_DATALOG_VALUE_H_
#define SRC_IR_DATALOG_VALUE_H_

#include <string>
#include <tuple>
#include <variant>

#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"

namespace raksha::ir::datalog {

// The common supertype of any "fully fledged" Souffle value type. Anything
// descending form `Value` should be able to be used in a `.type` declaration.
// By this logic, `number`, `symbol`, any record, and any ADT are `Value`s, but
// individual branches of an ADT are not.
class Value {
public:
virtual std::string ToDatalogString() const = 0;
virtual ~Value() {}
};

// Corresponds to Souffle's `number` type.
class Number : public Value {
public:
explicit Number(int64_t value) : number_value_(value) {}

std::string ToDatalogString() const override {
return std::to_string(number_value_);
}

private:
int64_t number_value_;
};

// Corresponds to Souffle's `symbol` type.
class Symbol : public Value {
public:
explicit Symbol(absl::string_view value) : symbol_value_(value) {}

std::string ToDatalogString() const override {
return absl::StrFormat(R"("%s")", symbol_value_);
}

private:
std::string symbol_value_;
};

// A Record is, in fact, a value pointing to a tuple of values that
// constitutes a record. This is necessary because Souffle conflates references
// and the objects to which those references point. In the case that a type has
// a recursive definition (common with linked lists), trying to implement this
// without a pointer would cause layout issues.
template <class... RecordFieldValueTypes>
class Record : public Value {
public:
explicit Record() : record_arguments_() {}
explicit Record(RecordFieldValueTypes &&...args)
: record_arguments_(
std::make_unique<std::tuple<RecordFieldValueTypes...>>(
std::forward<RecordFieldValueTypes>(args)...)) {}

std::string ToDatalogString() const override {
if (!record_arguments_) return "nil";
return absl::StrFormat(
"[%s]", absl::StrJoin(*record_arguments_, ", ",
[](std::string *out, const auto &arg) {
absl::StrAppend(out, arg.ToDatalogString());
}));
}

private:
std::unique_ptr<std::tuple<RecordFieldValueTypes...>> record_arguments_;
};

class Adt : public Value {
public:
explicit Adt(absl::string_view branch_name) : branch_name_(branch_name) {}

std::string ToDatalogString() const {
return absl::StrFormat(
"$%s{%s}", branch_name_,
absl::StrJoin(arguments_, ", ", [](std::string *str, const auto &arg) {
absl::StrAppend(str, arg->ToDatalogString());
}));
}

protected:
std::vector<std::unique_ptr<Value>> arguments_;

private:
absl::string_view branch_name_;
};

} // namespace raksha::ir::datalog

#endif // SRC_IR_DATALOG_VALUE_H_
203 changes: 203 additions & 0 deletions src/ir/datalog/value_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
//-----------------------------------------------------------------------------
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//----------------------------------------------------------------------------

#include "src/ir/datalog/value.h"

#include <limits>

#include "absl/strings/numbers.h"
#include "absl/strings/string_view.h"
#include "src/common/testing/gtest.h"

namespace raksha::ir::datalog {

using testing::Combine;
using testing::TestWithParam;
using testing::ValuesIn;

class NumberTest : public TestWithParam<int64_t> {};

TEST_P(NumberTest, NumberTest) {
int64_t num = GetParam();
Number number_value = Number(num);
std::string datalog_str = number_value.ToDatalogString();
int64_t parsed_int = 0;
bool conversion_succeeds = absl::SimpleAtoi(datalog_str, &parsed_int);
ASSERT_TRUE(conversion_succeeds);
ASSERT_EQ(parsed_int, num);
}

static int64_t kSampleIntegerValues[] = {0, -1, 1,
std::numeric_limits<long>::max(),
std::numeric_limits<long>::min()};

INSTANTIATE_TEST_SUITE_P(NumberTest, NumberTest,
ValuesIn(kSampleIntegerValues));

class SymbolTest : public TestWithParam<absl::string_view> {};

TEST_P(SymbolTest, SymbolTest) {
absl::string_view symbol = GetParam();
Symbol symbol_value = Symbol(symbol);
std::string symbol_str = symbol_value.ToDatalogString();
ASSERT_EQ(symbol_str, "\"" + std::string(symbol) + "\"");
}

static absl::string_view kSampleSymbols[] = {"", "x", "foo", "hello_world"};

INSTANTIATE_TEST_SUITE_P(SymbolTest, SymbolTest, ValuesIn(kSampleSymbols));

using SimpleRecord = Record<Symbol, Number>;

TEST(SimpleRecordNilTest, SimpleRecordNilTest) {
ASSERT_EQ(Record<SimpleRecord>().ToDatalogString(), "nil");
}

class SimpleRecordTest
: public TestWithParam<std::tuple<absl::string_view, int64_t>> {};

TEST_P(SimpleRecordTest, SimpleRecordTest) {
const auto [symbol, number] = GetParam();
SimpleRecord record_value = SimpleRecord(Symbol(symbol), Number(number));
ASSERT_EQ(record_value.ToDatalogString(),
absl::StrFormat(R"(["%s", %d])", symbol, number));
}

INSTANTIATE_TEST_SUITE_P(SimpleRecordTest, SimpleRecordTest,
Combine(ValuesIn(kSampleSymbols),
ValuesIn(kSampleIntegerValues)));

class NumList : public Record<Number, NumList> {
public:
using Record::Record;
};

struct NumListAndExpectedDatalog {
const NumList *num_list_ptr;
absl::string_view expected_datalog;
};

class NumListTest : public TestWithParam<NumListAndExpectedDatalog> {};

TEST_P(NumListTest, NumListTest) {
const auto [num_list_ptr, expected_datalog] = GetParam();
EXPECT_EQ(num_list_ptr->ToDatalogString(), expected_datalog);
}

static const NumList kEmptyNumList;
static const NumList kOneElementNumList = NumList(Number(5), NumList());
static const NumList kTwoElementNumList(Number(-30),
NumList(Number(28), NumList()));

static NumListAndExpectedDatalog kListAndExpectedDatalog[] = {
{.num_list_ptr = &kEmptyNumList, .expected_datalog = "nil"},
{.num_list_ptr = &kOneElementNumList, .expected_datalog = "[5, nil]"},
{.num_list_ptr = &kTwoElementNumList,
.expected_datalog = "[-30, [28, nil]]"}};

INSTANTIATE_TEST_SUITE_P(NumListTest, NumListTest,
ValuesIn(kListAndExpectedDatalog));

using NumberSymbolPair = Record<Number, Symbol>;
using NumberSymbolPairPair = Record<NumberSymbolPair, NumberSymbolPair>;

class NumberSymbolPairPairTest
: public TestWithParam<std::tuple<std::tuple<int64_t, absl::string_view>,
std::tuple<int64_t, absl::string_view>>> {
};

TEST_P(NumberSymbolPairPairTest, NumberSymbolPairPairTest) {
auto const &[pair1, pair2] = GetParam();
auto const [number1, symbol1] = pair1;
auto const [number2, symbol2] = pair2;
NumberSymbolPair number_symbol_pair1 =
NumberSymbolPair(Number(number1), Symbol(symbol1));
NumberSymbolPair number_symbol_pair2 =
NumberSymbolPair(Number(number2), Symbol(symbol2));
NumberSymbolPairPair pair_pair(std::move(number_symbol_pair1),
std::move(number_symbol_pair2));
EXPECT_EQ(pair_pair.ToDatalogString(),
absl::StrFormat(R"([[%d, "%s"], [%d, "%s"]])", number1, symbol1,
number2, symbol2));
}

INSTANTIATE_TEST_SUITE_P(
NumberSymbolPairPairTest, NumberSymbolPairPairTest,
Combine(Combine(ValuesIn(kSampleIntegerValues), ValuesIn(kSampleSymbols)),
Combine(ValuesIn(kSampleIntegerValues), ValuesIn(kSampleSymbols))));

static constexpr char kNullBranchName[] = "Null";
static constexpr char kNumberBranchName[] = "Number";
static constexpr char kAddBranchName[] = "Add";

class ArithmeticAdt : public Adt {
using Adt::Adt;
};

class NullBranch : public ArithmeticAdt {
public:
NullBranch() : ArithmeticAdt(kNullBranchName) {}
};

class NumberBranch : public ArithmeticAdt {
public:
NumberBranch(Number number) : ArithmeticAdt(kNumberBranchName) {
arguments_.push_back(std::make_unique<Number>(number));
}
};

class AddBranch : public ArithmeticAdt {
public:
AddBranch(ArithmeticAdt lhs, ArithmeticAdt rhs)
: ArithmeticAdt(kAddBranchName) {
arguments_.push_back(std::make_unique<ArithmeticAdt>(std::move(lhs)));
arguments_.push_back(std::make_unique<ArithmeticAdt>(std::move(rhs)));
}
};

struct AdtAndExpectedDatalog {
const ArithmeticAdt *adt;
absl::string_view expected_datalog;
};

class AdtTest : public TestWithParam<AdtAndExpectedDatalog> {};

TEST_P(AdtTest, AdtTest) {
auto &[adt, expected_datalog] = GetParam();
EXPECT_EQ(adt->ToDatalogString(), expected_datalog);
}

static const ArithmeticAdt kNull = ArithmeticAdt(NullBranch());
static const ArithmeticAdt kFive = ArithmeticAdt(NumberBranch(Number(5)));
static const ArithmeticAdt kTwo = ArithmeticAdt(NumberBranch(Number(2)));
static const ArithmeticAdt kFivePlusTwo =
ArithmeticAdt(AddBranch(ArithmeticAdt(NumberBranch(Number(5))),
ArithmeticAdt(NumberBranch(Number(2)))));
static const ArithmeticAdt kFivePlusTwoPlusNull = ArithmeticAdt(
AddBranch(ArithmeticAdt(NumberBranch(Number(5))),
ArithmeticAdt(AddBranch(ArithmeticAdt(NumberBranch(Number(2))),
ArithmeticAdt(NullBranch())))));

static const AdtAndExpectedDatalog kAdtAndExpectedDatalog[] = {
{.adt = &kNull, .expected_datalog = "$Null{}"},
{.adt = &kFive, .expected_datalog = "$Number{5}"},
{.adt = &kFivePlusTwo, .expected_datalog = "$Add{$Number{5}, $Number{2}}"},
{.adt = &kFivePlusTwoPlusNull,
.expected_datalog = "$Add{$Number{5}, $Add{$Number{2}, $Null{}}}"}};

INSTANTIATE_TEST_SUITE_P(AdtTest, AdtTest, ValuesIn(kAdtAndExpectedDatalog));

} // namespace raksha::ir::datalog

0 comments on commit 67d8731

Please sign in to comment.