Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialize recursive protobuf fields more efficiently. #1548

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions domain_tests/arbitrary_domains_protobuf_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,22 @@ TEST(ProtocolBufferWithRequiredFields, ShrinkingNeverRemovesRequiredFields) {
}
}

TEST(ProtocolBufferWithRecursiveFields, InfiniteleyRecursiveFieldsAreNotSet) {
auto domain = Arbitrary<internal::TestProtobufWithRecursion>()
.WithRepeatedFieldsAlwaysSet()
.WithProtobufFieldUnset("siblings");
absl::BitGen bitgen;
Value val(domain, bitgen);

ASSERT_TRUE(val.user_value.IsInitialized()) << val.user_value;

for (int i = 0; i < 1000; ++i) {
val.Mutate(domain, bitgen, {}, false);
ASSERT_TRUE(val.user_value.IsInitialized()) << val.user_value;
ASSERT_FALSE(val.user_value.has_siblings()) << val.user_value;
}
}

TEST(ProtocolBuffer, CanUsePerFieldDomains) {
Domain<TestProtobuf> domain =
Arbitrary<TestProtobuf>()
Expand Down
1 change: 1 addition & 0 deletions fuzztest/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ cc_library(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
],
Expand Down
1 change: 1 addition & 0 deletions fuzztest/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ fuzztest_cc_library(
absl::status
absl::statusor
absl::strings
absl::str_format
absl::synchronization
absl::span
)
Expand Down
231 changes: 153 additions & 78 deletions fuzztest/internal/domains/protobuf_domain_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "absl/random/bit_gen_ref.h"
#include "absl/random/random.h"
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
Expand Down Expand Up @@ -233,7 +234,10 @@ class ProtoPolicy {
public:
ProtoPolicy()
: optional_policies_({{/*filter=*/IncludeAll<FieldDescriptor>(),
/*value=*/OptionalPolicy::kWithNull}}) {}
/*value=*/OptionalPolicy::kWithNull}}) {
static int64_t next_id = 0;
id_ = next_id++;
}

void SetOptionalPolicy(OptionalPolicy optional_policy) {
SetOptionalPolicy(IncludeAll<FieldDescriptor>(), optional_policy);
Expand Down Expand Up @@ -314,7 +318,11 @@ class ProtoPolicy {
return max;
}

int64_t id() const { return id_; }

private:
int64_t id_;

template <typename T>
struct FilterToValue {
Filter filter;
Expand Down Expand Up @@ -459,10 +467,11 @@ class ProtobufDomainUntypedImpl

corpus_type Init(absl::BitGenRef prng) {
if (auto seed = this->MaybeGetRandomSeed(prng)) return *seed;
FUZZTEST_INTERNAL_CHECK(
!IsCustomizedRecursivelyOnly() || !IsInfinitelyRecursive(),
"Cannot set recursive fields by default.");
const auto* descriptor = prototype_.Get()->GetDescriptor();
FUZZTEST_INTERNAL_CHECK(
!IsCustomizedRecursivelyOnly() || !IsInfinitelyRecursive(descriptor),
absl::StrCat("Cannot set recursive fields for ",
descriptor->full_name(), " by default."));
corpus_type val;
absl::flat_hash_map<int, int> oneof_to_field;

Expand All @@ -474,15 +483,18 @@ class ProtobufDomainUntypedImpl
SelectAFieldIndexInOneof(oneof, prng);
}
if (oneof_to_field[oneof->index()] != field->index()) continue;
} else if (!MustBeSet(field) && IsCustomizedRecursivelyOnly() &&
IsFieldFinitelyRecursive(field)) {
// We avoid initializing non-required recursive fields by default (if
// they are not explicitly customized). Otherwise, the initialization
// may never terminate. If a proto has only non-required recursive
// fields, the initialization will be deterministic, which violates the
// assumption on domain Init. However, such cases should be extremely
// rare and breaking the assumption would not have severe consequences.
continue;
} else if (IsCustomizedRecursivelyOnly()) {
if (!MustBeSet(field) && IsRecursionBreaker(field)) {
// We avoid initializing non-required recursive fields by default (if
// they are not explicitly customized). Otherwise, the initialization
// may never terminate. If a proto has only non-required recursive
// fields, the initialization will be deterministic, which violates
// the assumption on domain Init. However, such cases should be
// extremely rare and breaking the assumption would not have severe
// consequences.
continue;
}
if (MustBeUnset(field)) continue;
}
VisitProtobufField(field, InitializeVisitor{prng, *this, val});
}
Expand Down Expand Up @@ -602,6 +614,7 @@ class ProtobufDomainUntypedImpl
GetOneofFieldPolicy(field) == OptionalPolicy::kAlwaysNull) {
continue;
}
if (IsCustomizedRecursivelyOnly() && MustBeUnset(field)) continue;
++total_weight;

if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
Expand Down Expand Up @@ -639,6 +652,7 @@ class ProtobufDomainUntypedImpl
GetOneofFieldPolicy(field) == OptionalPolicy::kAlwaysNull) {
continue;
}
if (IsCustomizedRecursivelyOnly() && MustBeUnset(field)) continue;
++field_counter;
if (field_counter == selected_field_index) {
VisitProtobufField(
Expand Down Expand Up @@ -951,9 +965,9 @@ class ProtobufDomainUntypedImpl
for (int i = 0; i < oneof->field_count(); ++i) {
OptionalPolicy policy = GetOneofFieldPolicy(oneof->field(i));
if (policy == OptionalPolicy::kAlwaysNull) continue;
if (IsCustomizedRecursivelyOnly() &&
IsFieldFinitelyRecursive(oneof->field(i))) {
continue;
if (IsCustomizedRecursivelyOnly()) {
if (IsRecursionBreaker(oneof->field(i))) continue;
if (MustBeUnset(oneof->field(i))) continue;
}
fields.push_back(i);
}
Expand Down Expand Up @@ -1690,38 +1704,33 @@ class ProtobufDomainUntypedImpl
return GetDomainForField<T, is_repeated>(field, /*use_policy=*/false);
}

// Analysis type for protobuf recursions.
enum class RecursionType {
// The proto contains a proto of type P, that must contain another P.
kInfinitelyRecursive,
// The proto contains a proto of type P, that can contain another P.
kFinitelyRecursive,
};

bool IsInfinitelyRecursive() {
absl::flat_hash_set<decltype(prototype_.Get()->GetDescriptor())> parents;
return IsProtoRecursive(prototype_.Get()->GetDescriptor(), parents,
RecursionType::kInfinitelyRecursive);
// Returns true if there are subprotos in the `descriptor` that form an
// infinite recursion.
bool IsInfinitelyRecursive(const Descriptor* descriptor) const {
FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error.");
absl::flat_hash_set<const FieldDescriptor*> parents;
return IsProtoRecursive(/*field=*/nullptr, parents, descriptor);
}

bool IsFieldFinitelyRecursive(const FieldDescriptor* field) {
if (!field->message_type()) return false;
// Returns true if there are subfields in the `field` that form an
// infinite recursion of the form: F0 -> F1 -> ... -> Fs -> ... -> Fn -> Fs,
// because all Fi-s have to be set (e.g., Fi is a required field, or is
// customized using `WithFieldsAlwaysSet`).
bool IsInfinitelyRecursive(const FieldDescriptor* field) const {
FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error.");
ABSL_CONST_INIT static absl::Mutex mutex(absl::kConstInit);
static absl::NoDestructor<absl::flat_hash_map<const FieldDescriptor*, bool>>
static absl::NoDestructor<
absl::flat_hash_map<std::pair<int64_t, const FieldDescriptor*>, bool>>
cache ABSL_GUARDED_BY(mutex);
bool can_use_cache = IsCustomizedRecursivelyOnly();
if (can_use_cache) {
{
absl::MutexLock l(&mutex);
auto it = cache->find(field);
auto it = cache->find({policy_.id(), field});
if (it != cache->end()) return it->second;
}
absl::flat_hash_set<decltype(field->message_type())> parents;
bool result = IsProtoRecursive(field->message_type(), parents,
RecursionType::kFinitelyRecursive);
if (can_use_cache) {
absl::MutexLock l(&mutex);
cache->insert({field, result});
}
absl::flat_hash_set<const FieldDescriptor*> parents;
bool result = IsProtoRecursive(field, parents);
absl::MutexLock l(&mutex);
cache->insert({{policy_.id(), field}, result});
return result;
}

Expand All @@ -1733,32 +1742,24 @@ class ProtobufDomainUntypedImpl
return index == kFieldCountIndex;
}

bool IsOneofRecursive(const OneofDescriptor* oneof,
absl::flat_hash_set<const Descriptor*>& parents,
RecursionType recursion_type) const {
bool IsOneofRecursive(
const OneofDescriptor* oneof,
absl::flat_hash_set<const FieldDescriptor*>& parents) const {
bool is_oneof_recursive = false;
for (int i = 0; i < oneof->field_count(); ++i) {
const auto* field = oneof->field(i);
const auto field_policy = policy_.GetOptionalPolicy(field);
if (field_policy == OptionalPolicy::kAlwaysNull) continue;
const auto* child = field->message_type();
if (recursion_type == RecursionType::kInfinitelyRecursive) {
is_oneof_recursive = field_policy != OptionalPolicy::kWithNull &&
child &&
IsProtoRecursive(child, parents, recursion_type);
if (!is_oneof_recursive) {
return false;
}
} else {
if (child && IsProtoRecursive(child, parents, recursion_type)) {
return true;
}
}
is_oneof_recursive = field_policy != OptionalPolicy::kWithNull &&
field->message_type() &&
IsProtoRecursive(field, parents);
if (!is_oneof_recursive) return false;
}
return is_oneof_recursive;
}

bool MustBeSet(const FieldDescriptor* field) const {
FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error.");
if (IsRequired(field)) {
return true;
} else if (field->containing_oneof()) {
Expand All @@ -1775,6 +1776,14 @@ class ProtobufDomainUntypedImpl
}

bool MustBeUnset(const FieldDescriptor* field) const {
FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error.");
if (field->message_type() && IsInfinitelyRecursive(field)) {
absl::FPrintF(
GetStderr(),
"[!] Infinite recursion detected for %s and it remains unset.\n",
field->full_name());
return true;
}
if (IsRequired(field)) {
return false;
} else if (field->containing_oneof()) {
Expand All @@ -1790,39 +1799,105 @@ class ProtobufDomainUntypedImpl
return false;
}

template <typename Descriptor>
bool IsProtoRecursive(const Descriptor* descriptor,
absl::flat_hash_set<const Descriptor*>& parents,
RecursionType recursion_type) const {
if (parents.contains(descriptor)) return true;
parents.insert(descriptor);
// If `field` is nullptr, all fields of `descriptor` are checked.
bool IsProtoRecursive(const FieldDescriptor* field,
absl::flat_hash_set<const FieldDescriptor*>& parents,
const Descriptor* descriptor = nullptr) const {
if (field != nullptr) {
if (parents.contains(field)) return true;
parents.insert(field);
descriptor = field->message_type();
} else {
FUZZTEST_INTERNAL_CHECK(descriptor,
"one of field or descriptor must be non-null!");
}
for (int i = 0; i < descriptor->oneof_decl_count(); ++i) {
const auto* oneof = descriptor->oneof_decl(i);
if (IsOneofRecursive(oneof, parents, recursion_type)) {
parents.erase(descriptor);
if (IsOneofRecursive(oneof, parents)) {
if (field != nullptr) parents.erase(field);
return true;
}
}
for (const FieldDescriptor* field : GetProtobufFields(descriptor)) {
if (field->containing_oneof()) continue;
const auto* child = field->message_type();
if (!child) continue;
if (policy_.GetDefaultDomainForProtobufs(field) != std::nullopt) {
for (const FieldDescriptor* subfield : GetProtobufFields(descriptor)) {
if (subfield->containing_oneof()) continue;
if (!subfield->message_type()) continue;
if (auto default_domain = policy_.GetDefaultDomainForProtobufs(subfield);
default_domain != std::nullopt) { // For handling WithProtobufFields.
// If this field is recursive, it will be detected when initializing
// its default domain. Otherwise, this field can always be set safely.
absl::BitGen prng;
default_domain->Init(prng);
continue;
}
if (recursion_type == RecursionType::kInfinitelyRecursive) {
if (!MustBeSet(field)) continue;
} else {
if (MustBeUnset(field)) continue;
}
if (IsProtoRecursive(child, parents, recursion_type)) {
parents.erase(descriptor);
if (!MustBeSet(subfield)) continue;
if (IsProtoRecursive(subfield, parents)) {
if (field != nullptr) parents.erase(field);
return true;
}
}
parents.erase(descriptor);
if (field != nullptr) parents.erase(field);
return false;
}

// A subset of proto types are considered as recursion breakers and during
// domain initialization, won't get recursively initialized to avoid
// non-terminating initialization.
//
// Returns true if the `field` (F0) does not have to be set, and there are
// subfields in the form: F0 -> F1 -> ... -> Fn -> F0 or F20 -> F19 ... -> F0
// and none of other Fi-s are marked as recursion breakers so far. In other
// words, this method computes recursion breakers and check membership of
// `field` in the set of recursion breakers.
bool IsRecursionBreaker(const FieldDescriptor* field) {
FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error.");
if (!field->message_type()) return false;
absl::flat_hash_set<const FieldDescriptor*> parents;
return IsRecursionBreaker(field, field, parents);
}

bool IsRecursionBreaker(
const FieldDescriptor* root, const FieldDescriptor* field,
absl::flat_hash_set<const FieldDescriptor*>& parents) const {
ABSL_CONST_INIT static absl::Mutex mutex(absl::kConstInit);
static absl::NoDestructor<
absl::flat_hash_map<std::pair<int64_t, const FieldDescriptor*>, bool>>
cache ABSL_GUARDED_BY(mutex);
{
absl::MutexLock l(&mutex);
auto it = cache->find({policy_.id(), field});
if (it != cache->end()) return it->second;
}
// Cannot break the recursion for required fields.
bool can_be_unset = !MustBeSet(field);
if (field->containing_oneof() && !can_be_unset) { // oneof must be set
// We check whether `field` is infinitely recursive without considering
// other oneof fields. If it is, there's another field in the oneof that
// can be set.
absl::flat_hash_set<const FieldDescriptor*> subfield_parents;
subfield_parents.insert(field);
can_be_unset = IsProtoRecursive(field, subfield_parents);
}
if (can_be_unset) {
// Break recursion for deeply nested or recursive protos.
if (parents.size() > 20 || parents.contains(field)) {
absl::MutexLock l(&mutex);
cache->insert({{policy_.id(), field}, true});
return true;
}
parents.insert(field);
}
for (const FieldDescriptor* subfield :
GetProtobufFields(field->message_type())) {
if (!subfield->message_type()) continue;
if (MustBeUnset(subfield)) continue;
IsRecursionBreaker(root, subfield, parents);
}
if (can_be_unset) parents.erase(field);
absl::MutexLock l(&mutex);
cache->insert({{policy_.id(), field}, false});
auto it = cache->find({policy_.id(), field});
if (it != cache->end()) return it->second;
cache->insert({{policy_.id(), field}, false});
return false;
}

Expand Down
Loading