-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create C++ types for Souffle values which can be printed to Datalog. (#…
…528) This change adds C++ types which directly map to Souffle's types. These types use templates to allow a user to create C++ types that correspond directly to Souffle types and which the C++ compiler can type check at static time. Because some of the types that we create are rather complex, involving both ADTs and records, this should be useful to catch errors before having to run the Souffle compiler. Closes #528 COPYBARA_INTEGRATE_REVIEW=#528 from google-research:souffle_base_value@winterrowd 3b54cb8 PiperOrigin-RevId: 443699109
- Loading branch information
Showing
3 changed files
with
340 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
//----------------------------------------------------------------------------- | ||
// Copyright 2022 Google LLC | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// https://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
//---------------------------------------------------------------------------- | ||
|
||
#ifndef SRC_IR_DATALOG_VALUE_H_ | ||
#define SRC_IR_DATALOG_VALUE_H_ | ||
|
||
#include <string> | ||
#include <tuple> | ||
#include <variant> | ||
|
||
#include "absl/strings/str_format.h" | ||
#include "absl/strings/str_join.h" | ||
#include "absl/strings/string_view.h" | ||
|
||
namespace raksha::ir::datalog { | ||
|
||
// The common supertype of any "fully fledged" Souffle value type. Anything | ||
// descending form `Value` should be able to be used in a `.type` declaration. | ||
// By this logic, `number`, `symbol`, any record, and any ADT are `Value`s, but | ||
// individual branches of an ADT are not. | ||
class Value { | ||
public: | ||
virtual std::string ToDatalogString() const = 0; | ||
virtual ~Value() {} | ||
}; | ||
|
||
// Corresponds to Souffle's `number` type. | ||
class Number : public Value { | ||
public: | ||
explicit Number(int64_t value) : number_value_(value) {} | ||
|
||
std::string ToDatalogString() const override { | ||
return std::to_string(number_value_); | ||
} | ||
|
||
private: | ||
int64_t number_value_; | ||
}; | ||
|
||
// Corresponds to Souffle's `symbol` type. | ||
class Symbol : public Value { | ||
public: | ||
explicit Symbol(absl::string_view value) : symbol_value_(value) {} | ||
|
||
std::string ToDatalogString() const override { | ||
return absl::StrFormat(R"("%s")", symbol_value_); | ||
} | ||
|
||
private: | ||
std::string symbol_value_; | ||
}; | ||
|
||
// A Record is, in fact, a value pointing to a tuple of values that | ||
// constitutes a record. This is necessary because Souffle conflates references | ||
// and the objects to which those references point. In the case that a type has | ||
// a recursive definition (common with linked lists), trying to implement this | ||
// without a pointer would cause layout issues. | ||
template <class... RecordFieldValueTypes> | ||
class Record : public Value { | ||
public: | ||
explicit Record() : record_arguments_() {} | ||
explicit Record(RecordFieldValueTypes &&...args) | ||
: record_arguments_( | ||
std::make_unique<std::tuple<RecordFieldValueTypes...>>( | ||
std::forward<RecordFieldValueTypes>(args)...)) {} | ||
|
||
std::string ToDatalogString() const override { | ||
if (!record_arguments_) return "nil"; | ||
return absl::StrFormat( | ||
"[%s]", absl::StrJoin(*record_arguments_, ", ", | ||
[](std::string *out, const auto &arg) { | ||
absl::StrAppend(out, arg.ToDatalogString()); | ||
})); | ||
} | ||
|
||
private: | ||
std::unique_ptr<std::tuple<RecordFieldValueTypes...>> record_arguments_; | ||
}; | ||
|
||
class Adt : public Value { | ||
public: | ||
explicit Adt(absl::string_view branch_name) : branch_name_(branch_name) {} | ||
|
||
std::string ToDatalogString() const { | ||
return absl::StrFormat( | ||
"$%s{%s}", branch_name_, | ||
absl::StrJoin(arguments_, ", ", [](std::string *str, const auto &arg) { | ||
absl::StrAppend(str, arg->ToDatalogString()); | ||
})); | ||
} | ||
|
||
protected: | ||
std::vector<std::unique_ptr<Value>> arguments_; | ||
|
||
private: | ||
absl::string_view branch_name_; | ||
}; | ||
|
||
} // namespace raksha::ir::datalog | ||
|
||
#endif // SRC_IR_DATALOG_VALUE_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
//----------------------------------------------------------------------------- | ||
// Copyright 2022 Google LLC | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// https://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
//---------------------------------------------------------------------------- | ||
|
||
#include "src/ir/datalog/value.h" | ||
|
||
#include <limits> | ||
|
||
#include "absl/strings/numbers.h" | ||
#include "absl/strings/string_view.h" | ||
#include "src/common/testing/gtest.h" | ||
|
||
namespace raksha::ir::datalog { | ||
|
||
using testing::Combine; | ||
using testing::TestWithParam; | ||
using testing::ValuesIn; | ||
|
||
class NumberTest : public TestWithParam<int64_t> {}; | ||
|
||
TEST_P(NumberTest, NumberTest) { | ||
int64_t num = GetParam(); | ||
Number number_value = Number(num); | ||
std::string datalog_str = number_value.ToDatalogString(); | ||
int64_t parsed_int = 0; | ||
bool conversion_succeeds = absl::SimpleAtoi(datalog_str, &parsed_int); | ||
ASSERT_TRUE(conversion_succeeds); | ||
ASSERT_EQ(parsed_int, num); | ||
} | ||
|
||
static int64_t kSampleIntegerValues[] = {0, -1, 1, | ||
std::numeric_limits<long>::max(), | ||
std::numeric_limits<long>::min()}; | ||
|
||
INSTANTIATE_TEST_SUITE_P(NumberTest, NumberTest, | ||
ValuesIn(kSampleIntegerValues)); | ||
|
||
class SymbolTest : public TestWithParam<absl::string_view> {}; | ||
|
||
TEST_P(SymbolTest, SymbolTest) { | ||
absl::string_view symbol = GetParam(); | ||
Symbol symbol_value = Symbol(symbol); | ||
std::string symbol_str = symbol_value.ToDatalogString(); | ||
ASSERT_EQ(symbol_str, "\"" + std::string(symbol) + "\""); | ||
} | ||
|
||
static absl::string_view kSampleSymbols[] = {"", "x", "foo", "hello_world"}; | ||
|
||
INSTANTIATE_TEST_SUITE_P(SymbolTest, SymbolTest, ValuesIn(kSampleSymbols)); | ||
|
||
using SimpleRecord = Record<Symbol, Number>; | ||
|
||
TEST(SimpleRecordNilTest, SimpleRecordNilTest) { | ||
ASSERT_EQ(Record<SimpleRecord>().ToDatalogString(), "nil"); | ||
} | ||
|
||
class SimpleRecordTest | ||
: public TestWithParam<std::tuple<absl::string_view, int64_t>> {}; | ||
|
||
TEST_P(SimpleRecordTest, SimpleRecordTest) { | ||
const auto [symbol, number] = GetParam(); | ||
SimpleRecord record_value = SimpleRecord(Symbol(symbol), Number(number)); | ||
ASSERT_EQ(record_value.ToDatalogString(), | ||
absl::StrFormat(R"(["%s", %d])", symbol, number)); | ||
} | ||
|
||
INSTANTIATE_TEST_SUITE_P(SimpleRecordTest, SimpleRecordTest, | ||
Combine(ValuesIn(kSampleSymbols), | ||
ValuesIn(kSampleIntegerValues))); | ||
|
||
class NumList : public Record<Number, NumList> { | ||
public: | ||
using Record::Record; | ||
}; | ||
|
||
struct NumListAndExpectedDatalog { | ||
const NumList *num_list_ptr; | ||
absl::string_view expected_datalog; | ||
}; | ||
|
||
class NumListTest : public TestWithParam<NumListAndExpectedDatalog> {}; | ||
|
||
TEST_P(NumListTest, NumListTest) { | ||
const auto [num_list_ptr, expected_datalog] = GetParam(); | ||
EXPECT_EQ(num_list_ptr->ToDatalogString(), expected_datalog); | ||
} | ||
|
||
static const NumList kEmptyNumList; | ||
static const NumList kOneElementNumList = NumList(Number(5), NumList()); | ||
static const NumList kTwoElementNumList(Number(-30), | ||
NumList(Number(28), NumList())); | ||
|
||
static NumListAndExpectedDatalog kListAndExpectedDatalog[] = { | ||
{.num_list_ptr = &kEmptyNumList, .expected_datalog = "nil"}, | ||
{.num_list_ptr = &kOneElementNumList, .expected_datalog = "[5, nil]"}, | ||
{.num_list_ptr = &kTwoElementNumList, | ||
.expected_datalog = "[-30, [28, nil]]"}}; | ||
|
||
INSTANTIATE_TEST_SUITE_P(NumListTest, NumListTest, | ||
ValuesIn(kListAndExpectedDatalog)); | ||
|
||
using NumberSymbolPair = Record<Number, Symbol>; | ||
using NumberSymbolPairPair = Record<NumberSymbolPair, NumberSymbolPair>; | ||
|
||
class NumberSymbolPairPairTest | ||
: public TestWithParam<std::tuple<std::tuple<int64_t, absl::string_view>, | ||
std::tuple<int64_t, absl::string_view>>> { | ||
}; | ||
|
||
TEST_P(NumberSymbolPairPairTest, NumberSymbolPairPairTest) { | ||
auto const &[pair1, pair2] = GetParam(); | ||
auto const [number1, symbol1] = pair1; | ||
auto const [number2, symbol2] = pair2; | ||
NumberSymbolPair number_symbol_pair1 = | ||
NumberSymbolPair(Number(number1), Symbol(symbol1)); | ||
NumberSymbolPair number_symbol_pair2 = | ||
NumberSymbolPair(Number(number2), Symbol(symbol2)); | ||
NumberSymbolPairPair pair_pair(std::move(number_symbol_pair1), | ||
std::move(number_symbol_pair2)); | ||
EXPECT_EQ(pair_pair.ToDatalogString(), | ||
absl::StrFormat(R"([[%d, "%s"], [%d, "%s"]])", number1, symbol1, | ||
number2, symbol2)); | ||
} | ||
|
||
INSTANTIATE_TEST_SUITE_P( | ||
NumberSymbolPairPairTest, NumberSymbolPairPairTest, | ||
Combine(Combine(ValuesIn(kSampleIntegerValues), ValuesIn(kSampleSymbols)), | ||
Combine(ValuesIn(kSampleIntegerValues), ValuesIn(kSampleSymbols)))); | ||
|
||
static constexpr char kNullBranchName[] = "Null"; | ||
static constexpr char kNumberBranchName[] = "Number"; | ||
static constexpr char kAddBranchName[] = "Add"; | ||
|
||
class ArithmeticAdt : public Adt { | ||
using Adt::Adt; | ||
}; | ||
|
||
class NullBranch : public ArithmeticAdt { | ||
public: | ||
NullBranch() : ArithmeticAdt(kNullBranchName) {} | ||
}; | ||
|
||
class NumberBranch : public ArithmeticAdt { | ||
public: | ||
NumberBranch(Number number) : ArithmeticAdt(kNumberBranchName) { | ||
arguments_.push_back(std::make_unique<Number>(number)); | ||
} | ||
}; | ||
|
||
class AddBranch : public ArithmeticAdt { | ||
public: | ||
AddBranch(ArithmeticAdt lhs, ArithmeticAdt rhs) | ||
: ArithmeticAdt(kAddBranchName) { | ||
arguments_.push_back(std::make_unique<ArithmeticAdt>(std::move(lhs))); | ||
arguments_.push_back(std::make_unique<ArithmeticAdt>(std::move(rhs))); | ||
} | ||
}; | ||
|
||
struct AdtAndExpectedDatalog { | ||
const ArithmeticAdt *adt; | ||
absl::string_view expected_datalog; | ||
}; | ||
|
||
class AdtTest : public TestWithParam<AdtAndExpectedDatalog> {}; | ||
|
||
TEST_P(AdtTest, AdtTest) { | ||
auto &[adt, expected_datalog] = GetParam(); | ||
EXPECT_EQ(adt->ToDatalogString(), expected_datalog); | ||
} | ||
|
||
static const ArithmeticAdt kNull = ArithmeticAdt(NullBranch()); | ||
static const ArithmeticAdt kFive = ArithmeticAdt(NumberBranch(Number(5))); | ||
static const ArithmeticAdt kTwo = ArithmeticAdt(NumberBranch(Number(2))); | ||
static const ArithmeticAdt kFivePlusTwo = | ||
ArithmeticAdt(AddBranch(ArithmeticAdt(NumberBranch(Number(5))), | ||
ArithmeticAdt(NumberBranch(Number(2))))); | ||
static const ArithmeticAdt kFivePlusTwoPlusNull = ArithmeticAdt( | ||
AddBranch(ArithmeticAdt(NumberBranch(Number(5))), | ||
ArithmeticAdt(AddBranch(ArithmeticAdt(NumberBranch(Number(2))), | ||
ArithmeticAdt(NullBranch()))))); | ||
|
||
static const AdtAndExpectedDatalog kAdtAndExpectedDatalog[] = { | ||
{.adt = &kNull, .expected_datalog = "$Null{}"}, | ||
{.adt = &kFive, .expected_datalog = "$Number{5}"}, | ||
{.adt = &kFivePlusTwo, .expected_datalog = "$Add{$Number{5}, $Number{2}}"}, | ||
{.adt = &kFivePlusTwoPlusNull, | ||
.expected_datalog = "$Add{$Number{5}, $Add{$Number{2}, $Null{}}}"}}; | ||
|
||
INSTANTIATE_TEST_SUITE_P(AdtTest, AdtTest, ValuesIn(kAdtAndExpectedDatalog)); | ||
|
||
} // namespace raksha::ir::datalog |