diff --git a/lib/opentelemetry/trace/exporter/encoder.lua b/lib/opentelemetry/trace/exporter/encoder.lua index 413cf23..a7e5937 100644 --- a/lib/opentelemetry/trace/exporter/encoder.lua +++ b/lib/opentelemetry/trace/exporter/encoder.lua @@ -38,7 +38,7 @@ function _M.for_export(span) return { trace_id = span.ctx.trace_id, span_id = span.ctx.span_id, - trace_state = span.ctx.trace_state, + trace_state = span.ctx.trace_state:as_string(), parent_span_id = span.parent_ctx.span_id or "", name = span.name, kind = span.kind, diff --git a/lib/opentelemetry/trace/propagation/text_map/trace_context_propagator.lua b/lib/opentelemetry/trace/propagation/text_map/trace_context_propagator.lua index 27a36c8..6035ed5 100644 --- a/lib/opentelemetry/trace/propagation/text_map/trace_context_propagator.lua +++ b/lib/opentelemetry/trace/propagation/text_map/trace_context_propagator.lua @@ -1,4 +1,5 @@ local span_context = require("opentelemetry.trace.span_context") +local tracestate = require("opentelemetry.trace.tracestate") local text_map_getter = require("opentelemetry.trace.propagation.text_map.getter") local text_map_setter = require("opentelemetry.trace.propagation.text_map.setter") local util = require("opentelemetry.util") @@ -44,77 +45,10 @@ function _M:inject(context, carrier, setter) span_context.trace_id, span_context.span_id, span_context.trace_flags) setter.set(carrier, traceparent_header, traceparent) if span_context.trace_state then - setter.set(carrier, tracestate_header, span_context.trace_state) + setter.set(carrier, tracestate_header, span_context.trace_state:as_string()) end end -local function validate_member_key(key) - if #key > 256 then - return nil - end - - local valid_key = string.match(key, [[^%s*([a-z][_0-9a-z%-*/]*)$]]) - if not valid_key then - local tenant_id, system_id = string.match(key, [[^%s*([a-z0-9][_0-9a-z%-*/]*)@([a-z][_0-9a-z%-*/]*)$]]) - if not tenant_id or not system_id then - return nil - end - if #tenant_id > 241 or #system_id > 14 then - return nil - end - return tenant_id .. "@" .. system_id - end - - return valid_key -end - -local function validate_member_value(value) - if #value > 256 then - return nil - end - return string.match(value, - [[^([ !"#$%%&'()*+%-./0-9:;<>?@A-Z[\%]^_`a-z{|}~]*[!"#$%%&'()*+%-./0-9:;<>?@A-Z[\%]^_`a-z{|}~])%s*$]]) -end - -function _M.parse_trace_state(trace_state) - if not trace_state then - return "" - end - if type(trace_state) == "string" then - trace_state = { trace_state } - end - - local new_trace_state = {} - local members_count = 0 - for _, item in ipairs(trace_state) do - for member in string.gmatch(item, "([^,]+)") do - if member ~= "" then - local start_pos, end_pos = string.find(member, "=", 1, true) - if not start_pos or start_pos == 1 then - return "" - end - local key = validate_member_key(string.sub(member, 1, start_pos - 1)) - if not key then - return "" - end - - local value = validate_member_value(string.sub(member, end_pos + 1)) - if not value then - return "" - end - - members_count = members_count + 1 - if members_count > 32 then - return "" - end - table.insert(new_trace_state, key .. "=" .. value) - end - end - end - - return table.concat(new_trace_state, ",") -end - local function validate_trace_id(trace_id) return type(trace_id) == "string" and #trace_id == 32 and trace_id ~= invalid_trace_id and string.match(trace_id, "^[0-9a-f]+$") @@ -182,7 +116,7 @@ function _M:extract(context, carrier, getter) return context end - local trace_state = _M.parse_trace_state(getter.get(carrier, tracestate_header)) + local trace_state = tracestate.parse_tracestate(getter.get(carrier, tracestate_header)) return context:with_span_context(span_context.new(trace_id, span_id, trace_flags, trace_state, true)) end diff --git a/lib/opentelemetry/trace/span_context.lua b/lib/opentelemetry/trace/span_context.lua index ed9a08e..5933243 100644 --- a/lib/opentelemetry/trace/span_context.lua +++ b/lib/opentelemetry/trace/span_context.lua @@ -1,3 +1,4 @@ +local tracestate = require("opentelemetry.trace.tracestate") local _M = { INVALID_TRACE_ID = "00000000000000000000000000000000", INVALID_SPAN_ID = "0000000000000000" @@ -12,7 +13,7 @@ function _M.new(tid, sid, trace_flags, trace_state, remote) trace_id = tid, span_id = sid, trace_flags = trace_flags, - trace_state = trace_state, + trace_state = trace_state or tracestate.new({}), remote = remote, } return setmetatable(self, mt) diff --git a/lib/opentelemetry/trace/tracestate.lua b/lib/opentelemetry/trace/tracestate.lua new file mode 100644 index 0000000..9d08c8b --- /dev/null +++ b/lib/opentelemetry/trace/tracestate.lua @@ -0,0 +1,160 @@ +local _M = { + MAX_KEY_LEN = 256, + MAX_VAL_LEN = 256, + MAX_ENTRIES = 32, +} + +local mt = { + __index = _M +} + +local function validate_member_key(key) + if #key > _M.MAX_KEY_LEN then + return nil + end + + local valid_key = string.match(key, [[^%s*([a-z][_0-9a-z%-*/]*)$]]) + if not valid_key then + local tenant_id, system_id = string.match(key, [[^%s*([a-z0-9][_0-9a-z%-*/]*)@([a-z][_0-9a-z%-*/]*)$]]) + if not tenant_id or not system_id then + return nil + end + if #tenant_id > 241 or #system_id > 14 then + return nil + end + return tenant_id .. "@" .. system_id + end + + return valid_key +end + +local function validate_member_value(value) + if #value > _M.MAX_VAL_LEN then + return nil + end + return string.match(value, + [[^([ !"#$%%&'()*+%-./0-9:;<>?@A-Z[\%]^_`a-z{|}~]*[!"#$%%&'()*+%-./0-9:;<>?@A-Z[\%]^_`a-z{|}~])%s*$]]) +end + +function _M.new(values) + local self = { values = values } + return setmetatable(self, mt) +end + +-------------------------------------------------------------------------------- +-- Parse tracestate header into a tracestate +-- +-- @return tracestate +-------------------------------------------------------------------------------- +function _M.parse_tracestate(tracestate) + if not tracestate then + return _M.new({}) + end + if type(tracestate) == "string" then + tracestate = { tracestate } + end + + local new_tracestate = {} + local members_count = 0 + local error_message = "failed to parse tracestate" + for _, item in ipairs(tracestate) do + for member in string.gmatch(item, "([^,]+)") do + if member ~= "" then + local start_pos, end_pos = string.find(member, "=", 1, true) + if not start_pos or start_pos == 1 then + ngx.log(ngx.WARN, error_message) + return _M.new({}) + end + local key = validate_member_key(string.sub(member, 1, start_pos - 1)) + if not key then + ngx.log(ngx.WARN, error_message) + return _M.new({}) + end + local value = validate_member_value(string.sub(member, end_pos + 1)) + if not value then + ngx.log(ngx.WARN, error_message) + return _M.new({}) + end + members_count = members_count + 1 + if members_count > _M.MAX_ENTRIES then + ngx.log(ngx.WARN, error_message) + return _M.new({}) + end + table.insert(new_tracestate, {key, value}) + end + end + end + + return _M.new(new_tracestate) +end + +-------------------------------------------------------------------------------- +-- Set the key value pair for the tracestate +-- +-- @return tracestate +-------------------------------------------------------------------------------- +function _M.set(self, key, value) + if not validate_member_key(key) then + return self + end + if not validate_member_value(value) then + return self + end + self:del(key) + if #self.values >= _M.MAX_ENTRIES then + table.remove(self.values) + ngx.log(ngx.WARN, "tracestate max values exceeded, removing rightmost entry") + end + table.insert(self.values, 1, {key, value}) + return self +end + +-------------------------------------------------------------------------------- +-- Get the value for the current key from the tracestate +-- +-- @return value +-------------------------------------------------------------------------------- +function _M.get(self, key) + for _, item in ipairs(self.values) do + local ckey = item[1] + if ckey == key then + return item[2] + end + end + return "" +end + +-------------------------------------------------------------------------------- +-- Delete the key from the tracestate +-- +-- @return tracestate +-------------------------------------------------------------------------------- +function _M.del(self, key) + local index = 0 + for i, item in ipairs(self.values) do + local ckey = item[1] + if ckey == key then + index = i + break + end + end + if index ~= 0 then + table.remove(self.values, index) + end + return self +end + +-------------------------------------------------------------------------------- +-- Return the header value of the tracestate +-- +-- @return string +-------------------------------------------------------------------------------- +function _M.as_string(self) + local output = {} + for _, item in ipairs(self.values) do + table.insert(output, item[1] .. "=" .. item[2]) + end + return table.concat(output, ",") +end + +return _M diff --git a/rockspec/opentelemetry-lua-master-0.rockspec b/rockspec/opentelemetry-lua-master-0.rockspec index 637415e..9bee0af 100644 --- a/rockspec/opentelemetry-lua-master-0.rockspec +++ b/rockspec/opentelemetry-lua-master-0.rockspec @@ -54,6 +54,7 @@ build = { ["opentelemetry.trace.span_status"] = "lib/opentelemetry/trace/span_status.lua", ["opentelemetry.trace.tracer"] = "lib/opentelemetry/trace/tracer.lua", ["opentelemetry.trace.tracer_provider"] = "lib/opentelemetry/trace/tracer_provider.lua", + ["opentelemetry.trace.tracestate"] = "lib/opentelemetry/trace/tracestate.lua", ["opentelemetry.baggage"] = "lib/opentelemetry/baggage.lua", ["opentelemetry.baggage.propagation.text_map.baggage_propagator"] = "lib/opentelemetry/baggage/propagation/text_map/baggage_propagator.lua", ["opentelemetry.util"] = "lib/opentelemetry/util.lua" diff --git a/spec/trace/tracestate_spec.lua b/spec/trace/tracestate_spec.lua new file mode 100644 index 0000000..e01ec74 --- /dev/null +++ b/spec/trace/tracestate_spec.lua @@ -0,0 +1,50 @@ +local tracestate = require("opentelemetry.trace.tracestate") + +describe("is_valid", function() + it("parse, get works", function() + local ts = tracestate.parse_tracestate("foo=bar,baz=lehrman") + assert.is_true(#ts.values == 2) + assert.is_true(ts:get("foo") == "bar") + assert.is_true(ts:get("baz") == "lehrman") + end) + it("set works", function() + local ts = tracestate.parse_tracestate("foo=bar,baz=lehrman") + assert.is_true(#ts.values == 2) + ts:set("foo", "fun") + assert.is_true(#ts.values == 2) + assert.is_true(ts:get("foo") == "fun") + ts:set("family", "values") + assert.is_true(#ts.values == 3) + assert.is_true(ts:get("family") == "values") + -- setting an invalid value leaves the old kv pair + ts:set("foo", "v=l") + assert.is_true(ts:get("foo") == "fun") + end) + it("del works", function() + local ts = tracestate.parse_tracestate("foo=bar,baz=lehrman") + ts:del("foo") + assert.is_true(#ts.values == 1) + assert.is_true(ts:get("foo") == "") + end) + it("as_string works", function() + local ts = tracestate.parse_tracestate("foo=bar,baz=lehrman") + assert.is_true(ts:as_string() == "foo=bar,baz=lehrman") + ts:set("bing", "bong") + assert.is_true(ts:as_string() == "bing=bong,foo=bar,baz=lehrman") + end) + it("max len is respected", function() + local ts = tracestate.parse_tracestate("") + for i=1,tracestate.MAX_ENTRIES,1 do + ts:set("a" .. tostring(i), "b" .. tostring(i)) + end + assert.is_true(#ts.values == tracestate.MAX_ENTRIES) + ts:set("one", "more") + assert.is_true(#ts.values == tracestate.MAX_ENTRIES) + -- First elem added is the first one lost when we add over max entries + assert.is_true(ts:get("a1") == "") + assert.is_true(ts:get("one") == "more") + -- Newest elem is prepended + assert.is_true(ts.values[1][1] == "one") + + end) +end)