Skip to content

Commit

Permalink
[yaml] Fix edge cases in C++ string writing (#22299)
Browse files Browse the repository at this point in the history
  • Loading branch information
jwnimmer-tri authored Dec 12, 2024
1 parent dac7c79 commit 303ae30
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 39 deletions.
19 changes: 19 additions & 0 deletions bindings/pydrake/common/test/yaml_typed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,9 +728,28 @@ def test_write_float(self):
self.assertEqual(actual_doc, expected_doc)

def test_write_string(self):
# We'll use this abbreviation to help make our expected values clear.
dq = '"' # double quote
cases = [
# Plain string.
("a", "a"),
# Needs quoting for special characters.
("'", f"''''"),
('"', f"'{dq}'"),
# Needs quoting to avoid being misinterpreted as another data type.
("1", "'1'"),
("1.0", "'1.0'"),
(".NaN", "'.NaN'"),
("true", "'true'"),
("null", "'null'"),
("NO", "'NO'"),
("null", "'null'"),
("190:20:30", "'190:20:30'"), # YAML has sexagesimal literals.
# Similar to things that would be misinterpreted but actually a-ok.
("nonnull", "nonnull"),
("NaN", "NaN"),
("=1.0", "=1.0"),
("00:1A:2B:3C:4D:5E", "00:1A:2B:3C:4D:5E"),
]
for value, expected_str in cases:
actual_doc = yaml_dump_typed(StringStruct(value=value))
Expand Down
22 changes: 17 additions & 5 deletions common/yaml/test/yaml_read_archive_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ using internal::YamlReadArchive;

// TODO(jwnimmer-tri) Add a test case for reading NonPodVectorStruct.
// TODO(jwnimmer-tri) Add a test case for reading OuterWithBlankInner.
// TODO(jwnimmer-tri) Add a test case for reading StringStruct.
// TODO(jwnimmer-tri) Add a test case for reading UnorderedMapStruct.

// A test fixture with common helpers.
Expand Down Expand Up @@ -119,6 +118,19 @@ class YamlReadArchiveTest : public ::testing::TestWithParam<LoadYamlOptions> {
}
};

// TODO(jwnimmer-tri) This test case is extremely basic. We should add many more
// corner cases & etc. here.
TEST_P(YamlReadArchiveTest, String) {
const auto test = [](const std::string& value, const std::string& expected) {
const auto& x = AcceptNoThrow<StringStruct>(LoadSingleValue(value));
EXPECT_EQ(x.value, expected);
};

test("foo", "foo");
test("''''", "'");
test("'\"'", "\"");
}

TEST_P(YamlReadArchiveTest, Double) {
const auto test = [](const std::string& value, double expected) {
const auto& x = AcceptNoThrow<DoubleStruct>(LoadSingleValue(value));
Expand Down Expand Up @@ -157,8 +169,8 @@ TEST_P(YamlReadArchiveTest, AllScalars) {
const std::string doc = R"""(
doc:
some_bool: true
some_double: 100.0
some_float: 101.0
some_float: 100.0
some_double: 101.0
some_int32: 102
some_uint32: 103
some_int64: 104
Expand All @@ -167,8 +179,8 @@ TEST_P(YamlReadArchiveTest, AllScalars) {
)""";
const auto& x = AcceptNoThrow<AllScalarsStruct>(Load(doc));
EXPECT_EQ(x.some_bool, true);
EXPECT_EQ(x.some_double, 100.0);
EXPECT_EQ(x.some_float, 101.0);
EXPECT_EQ(x.some_float, 100.0);
EXPECT_EQ(x.some_double, 101.0);
EXPECT_EQ(x.some_int32, 102);
EXPECT_EQ(x.some_uint32, 103);
EXPECT_EQ(x.some_int64, 104);
Expand Down
54 changes: 52 additions & 2 deletions common/yaml/test/yaml_write_archive_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <vector>

#include <fmt/args.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>

Expand Down Expand Up @@ -73,8 +74,57 @@ TEST_F(YamlWriteArchiveTest, String) {
EXPECT_EQ(Save(x), WrapDoc(expected));
};

// We'll use these named fmt args to help make our expected values clear.
fmt::dynamic_format_arg_store<fmt::format_context> args;
args.push_back(fmt::arg("bs", '\\')); // backslash
args.push_back(fmt::arg("dq", '"')); // double quote

// Plain string.
test("a", "a");
test("1", "1");

// Needs quoting for special characters. Note that there are several valid
// ways to quote and/or escape these, but for now we just check against the
// exact choice that yaml-cpp uses. In the future if we see new outputs, we
// could allow them too.
test("'", fmt::vformat("{dq}'{dq}", args));
test("\"", fmt::vformat("{dq}{bs}{dq}{dq}", args));

// Needs quoting to avoid being misinterpreted as another data type.
test("1", "'1'");
test("1.0", "'1.0'");
test(".NaN", "'.NaN'");
test("true", "'true'");
test("NO", "'NO'");
test("null", "'null'");
test("190:20:30", "'190:20:30'"); // YAML has sexagesimal literals.

// Similar to things that would be misinterpreted but actually a-ok.
test("nonnull", "nonnull");
test("NaN", "NaN");
test("=1.0", "=1.0");
test("00:1A:2B:3C:4D:5E", "00:1A:2B:3C:4D:5E");
}

TEST_F(YamlWriteArchiveTest, AllScalars) {
AllScalarsStruct x;
x.some_bool = true;
x.some_float = 100.0;
x.some_double = 101.0;
x.some_int32 = 102;
x.some_uint32 = 103;
x.some_int64 = 104;
x.some_uint64 = 105;
x.some_string = "foo";
EXPECT_EQ(Save(x), R"""(doc:
some_bool: true
some_float: 100.0
some_double: 101.0
some_int32: 102
some_uint32: 103
some_int64: 104
some_uint64: 105
some_string: foo
)""");
}

TEST_F(YamlWriteArchiveTest, StdArray) {
Expand Down Expand Up @@ -208,7 +258,7 @@ TEST_F(YamlWriteArchiveTest, Variant) {
EXPECT_EQ(Save(x), WrapDoc(expected));
};

test(Variant4(std::string()), "\"\"");
test(Variant4(std::string()), "''");
test(Variant4(std::string("foo")), "foo");
test(Variant4(1.0), "!!float 1.0");
test(Variant4(DoubleStruct{1.0}), "!DoubleStruct\n value: 1.0");
Expand Down
115 changes: 99 additions & 16 deletions common/yaml/yaml_write_archive.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "drake/common/yaml/yaml_write_archive.h"

#include <algorithm>
#include <regex>
#include <sstream>
#include <utility>
#include <vector>
Expand All @@ -9,6 +10,7 @@
#include <yaml-cpp/yaml.h>

#include "drake/common/drake_assert.h"
#include "drake/common/never_destroyed.h"
#include "drake/common/overloaded.h"
#include "drake/common/unused.h"

Expand All @@ -19,6 +21,75 @@ namespace {

constexpr const char* const kKeyOrder = "__key_order";

// Returns true iff the `value` looks like a null, int, float, or bool literal:
// specifically, returns true iff the given value (when parsed as an untagged
// "plain scalar") would resolve to a core YAML type (i.e., a type like
// "tag:yaml.org,2002:...") but not the string type "tag:yaml.org,2002:str".
//
// When loading a yaml document, there is syntax called a "plain scalar" which
// is basically just bare word(s) without any quoting. When loading a plain
// scalar that doesn't have any explicit tag given in the document, the type of
// the scalar needs to be "resolved" by the application loading the document,
// and YAML recommends that the application do so by using the "core schema"
// suite of regexes -- e.g., plain scalars that look like integers resolve to
// "tag:yaml.org,2002:int". Therefore, when writing out a document we need to be
// careful that when emitting a string we must not emit it as a plain scalar
// when the core schema would misinterpret the value as a non-string type.
//
// Ideally yaml-cpp would have some way for us to signal to its emitter that
// this is what we want (rather than implementing this logic ourselves), but
// so far we haven't been able to find anything.
bool DoesPlainScalarResolveToNonStrInYamlCoreSchema(const std::string& value) {
static const never_destroyed<std::regex> regex_null_bool_int_float{
// ----------------------------------------------------------------------
// Regexes adapted from https://yaml.org/spec/1.2.2/#1032-tag-resolution:
// ----------------------------------------------------------------------
// tag: null (literal)
"null|Null|NULL|~"
// tag: null (empty)
"|"
// tag: bool
"|true|True|TRUE|false|False|FALSE"
// tag: int (base 10)
// Note that https://yaml.org/type/int.html subsequently narrow this to
// forbid multiple leading zeros in base 10, but we stick with the more
// generous spelling here to err on the side of caution when writing.
"|[-+]?[0-9]+"
// tag: int (base 8)
"|0o[0-7]+"
// tag: int (base 16)
"|0x[0-9a-fA-F]+"
// tag: float (number)
"|[-+]?(\\.[0-9]+|[0-9]+(\\.[0-9]*)?)([eE][-+]?[0-9]+)?"
// tag: float (infinity)
"|[-+]?\\.(inf|Inf|INF)"
// tag: float (nan)
"|\\.(nan|NaN|NAN)"
//
// ----------------------------------------------------------------------
// Additional regexes from https://yaml.org/type/bool.html:
// ----------------------------------------------------------------------
"|y|Y|yes|Yes|YES|n|N|no|No|NO"
"|on|On|ON|off|Off|OFF"
//
// ----------------------------------------------------------------------
// Additional regexes from https://yaml.org/type/int.html:
// ----------------------------------------------------------------------
"|[-+]?0b[0-1_]+" // (base 2)
"|[-+]?[1-9][0-9_]*(:[0-5]?[0-9])+" // (base 60)
//
// ----------------------------------------------------------------------
// Additional regexes from https://yaml.org/type/float.html:
// ----------------------------------------------------------------------
"|[-+]?([0-9][0-9_]*)?\\.[0-9.]*([eE][-+][0-9]+)?" // (base 10)
"|[-+]?[0-9][0-9_]*(:[0-5]?[0-9])+\\.[0-9_]*" // (base 60)
// N.B. There is also https://yaml.org/type/null.html but it doesn't
// contain any regexs beyond what #1032-tag-resolution says.
};
std::smatch ignored;
return std::regex_match(value, ignored, regex_null_bool_int_float.access());
}

// This function uses the same approach as YAML::NodeEvents::Emit.
// https://github.com/jbeder/yaml-cpp/blob/release-0.5.2/src/nodeevents.cpp#L55
//
Expand All @@ -27,28 +98,40 @@ constexpr const char* const kKeyOrder = "__key_order";
// end sequence, end mapping) and then its job is to spit out the equivalent
// YAML syntax for that stream (e.g., "foo: [1, 2]") with appropriately matched
// delimiters (i.e., `:` or `{}` or `[]`) and horizontal indentation levels.
void RecursiveEmit(const internal::Node& node, YAML::EmitFromEvents* sink) {
void RecursiveEmit(const internal::Node& node, YAML::Emitter* emitter,
YAML::EmitFromEvents* sink) {
const YAML::Mark no_mark;
const YAML::anchor_t no_anchor = YAML::NullAnchor;
std::string tag{node.GetTag()};
if ((tag == internal::Node::kTagNull) || (tag == internal::Node::kTagBool) ||
(tag == internal::Node::kTagInt) || (tag == internal::Node::kTagFloat) ||
(tag == internal::Node::kTagStr)) {
const std::string_view node_tag = node.GetTag();
std::string emitted_tag;
if ((node_tag == internal::Node::kTagNull) ||
(node_tag == internal::Node::kTagBool) ||
(node_tag == internal::Node::kTagInt) ||
(node_tag == internal::Node::kTagFloat) ||
(node_tag == internal::Node::kTagStr)) {
// In most cases we don't need to emit the "JSON Schema" tags for YAML data,
// because they are implied by default. However, YamlWriteArchive on variant
// types sometimes marks the tag as important.
if (node.IsTagImportant()) {
DRAKE_DEMAND(tag.size() > 0);
// The `internal::Node::kTagFoo` all look like "tag:yaml.org,2002:foo".
// We only want the "foo" part (after the second colon).
tag = "!!" + tag.substr(18);
} else {
tag.clear();
emitted_tag = std::string("!!");
emitted_tag.append(node_tag.substr(18));
}
} else {
emitted_tag = node_tag;
}
node.Visit(overloaded{
[&](const internal::Node::ScalarData& data) {
sink->OnScalar(no_mark, tag, no_anchor, data.scalar);
if (emitted_tag.empty() && node_tag == internal::Node::kTagStr &&
DoesPlainScalarResolveToNonStrInYamlCoreSchema(data.scalar)) {
// We need to force this scalar to be seen as a string, so we'll turn
// off "auto" string format by asking for "single quoted" instead. If
// the value can't be single quoted, yaml-cpp will fall back to using
// double quotes automatically.
emitter->SetLocalValue(YAML::SingleQuoted);
}
sink->OnScalar(no_mark, emitted_tag, no_anchor, data.scalar);
},
[&](const internal::Node::SequenceData& data) {
// If all children are scalars, then format this sequence onto a
Expand All @@ -59,9 +142,9 @@ void RecursiveEmit(const internal::Node& node, YAML::EmitFromEvents* sink) {
style = YAML::EmitterStyle::Block;
}
}
sink->OnSequenceStart(no_mark, tag, no_anchor, style);
sink->OnSequenceStart(no_mark, emitted_tag, no_anchor, style);
for (const auto& child : data.sequence) {
RecursiveEmit(child, sink);
RecursiveEmit(child, emitter, sink);
}
sink->OnSequenceEnd();
},
Expand All @@ -72,7 +155,7 @@ void RecursiveEmit(const internal::Node& node, YAML::EmitFromEvents* sink) {
if (data.mapping.empty()) {
style = YAML::EmitterStyle::Flow;
}
sink->OnMapStart(no_mark, tag, no_anchor, style);
sink->OnMapStart(no_mark, emitted_tag, no_anchor, style);
// If there is a __key_order node inserted (as part of the Accept()
// member function in our header file), use it to specify output order;
// otherwise, use alphabetical order.
Expand All @@ -95,8 +178,8 @@ void RecursiveEmit(const internal::Node& node, YAML::EmitFromEvents* sink) {
}
}
for (const auto& string_key : key_order) {
RecursiveEmit(internal::Node::MakeScalar(string_key), sink);
RecursiveEmit(data.mapping.at(string_key), sink);
RecursiveEmit(internal::Node::MakeScalar(string_key), emitter, sink);
RecursiveEmit(data.mapping.at(string_key), emitter, sink);
}
sink->OnMapEnd();
},
Expand All @@ -114,7 +197,7 @@ std::string YamlWriteArchive::YamlDumpWithSortedMaps(
const internal::Node& document) {
YAML::Emitter emitter;
YAML::EmitFromEvents sink(emitter);
RecursiveEmit(document, &sink);
RecursiveEmit(document, &emitter, &sink);
return emitter.c_str();
}

Expand Down
38 changes: 22 additions & 16 deletions common/yaml/yaml_write_archive.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,23 +222,29 @@ class YamlWriteArchive final {
void VisitScalar(const NVP& nvp) {
using T = typename NVP::value_type;
const T& value = *nvp.value();
if constexpr (std::is_floating_point_v<T>) {
std::string value_str = std::isfinite(value) ? fmt_floating_point(value)
: std::isnan(value) ? ".nan"
: (value > 0) ? ".inf"
: "-.inf";
auto scalar = internal::Node::MakeScalar(std::move(value_str));
scalar.SetTag(internal::JsonSchemaTag::kFloat);
root_.Add(nvp.name(), std::move(scalar));
return;
}
auto scalar = internal::Node::MakeScalar(fmt::format("{}", value));
if constexpr (std::is_same_v<T, bool>) {
scalar.SetTag(internal::JsonSchemaTag::kBool);
}
if constexpr (std::is_integral_v<T>) {
scalar.SetTag(internal::JsonSchemaTag::kInt);
std::string text;
JsonSchemaTag tag;
if constexpr (std::is_same_v<T, std::string>) {
text = value;
tag = internal::JsonSchemaTag::kStr;
} else if constexpr (std::is_same_v<T, bool>) {
text = value ? "true" : "false";
tag = internal::JsonSchemaTag::kBool;
} else if constexpr (std::is_integral_v<T>) {
text = fmt::to_string(value);
tag = internal::JsonSchemaTag::kInt;
} else if constexpr (std::is_floating_point_v<T>) {
text = std::isfinite(value) ? fmt_floating_point(value)
: std::isnan(value) ? ".nan"
: (value > 0) ? ".inf"
: "-.inf";
tag = internal::JsonSchemaTag::kFloat;
} else {
text = fmt::format("{}", value);
tag = internal::JsonSchemaTag::kStr;
}
auto scalar = internal::Node::MakeScalar(std::move(text));
scalar.SetTag(tag);
root_.Add(nvp.name(), std::move(scalar));
}

Expand Down

0 comments on commit 303ae30

Please sign in to comment.