Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate trace IDs as ULIDs by default #783

Merged
merged 20 commits into from
Jan 9, 2025
2 changes: 1 addition & 1 deletion logfire/_internal/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def _load_configuration(
# This is particularly for deserializing from a dict as in executors.py
advanced = AdvancedOptions(**advanced) # type: ignore
id_generator = advanced.id_generator
if isinstance(id_generator, dict) and list(id_generator.keys()) == ['seed']: # type: ignore # pragma: no branch
if isinstance(id_generator, dict) and list(id_generator.keys()) == ['seed', '_ms_timestamp_generator']: # type: ignore # pragma: no branch
advanced.id_generator = SeededRandomIdGenerator(**id_generator) # type: ignore
elif advanced is None:
advanced = AdvancedOptions(base_url=param_manager.load_param('base_url'))
Expand Down
40 changes: 40 additions & 0 deletions logfire/_internal/ulid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from __future__ import annotations

from random import Random
from typing import Callable


def ulid(random: Random, ms_timestamp_generator: Callable[[], int]) -> int:
"""Generate an integer ULID compatible with UUID v4.

ULIDs as defined by the [spec](https://github.com/ulid/spec) look like this:

01AN4Z07BY 79KA1307SR9X4MV3
|----------| |----------------|
Timestamp Randomness
48bits 80bits

In the future it would be nice to make this compatible with a UUID,
e.g. v4 UUIDs by setting the version and variant bits correctly.
We can't currently do this because setting these bits would leave us with only 7 bytes of randomness,
which isn't enough for the Python SDK's sampler that currently expects 8 bytes of randomness.
In the future OTEL will probably adopt https://www.w3.org/TR/trace-context-2/#random-trace-id-flag
which relies only on the lower 7 bytes of the trace ID, then all SDKs and tooling should be updated
and leaving only 7 bytes of randomness should be fine.

Right now we only care about:
- Our SDK / Python SDK's in general.
- The OTEL collector.

And both behave properly with 8 bytes of randomness because trace IDs were originally 64 bits
so to be compatible with old trace IDs nothing in OTEL can assume >8 bytes of randomness in trace IDs
unless they generated the trace ID themselves (e.g. the Go SDK _does_ expect >8 bytes of randomness internally).
"""
# Timestamp: first 6 bytes of the ULID (48 bits)
# Note that it's not important that this timestamp is super precise or unique.
# It just needs to be roughly monotonically increasing so that the ULID is sortable, at least for our purposes.
timestamp = ms_timestamp_generator().to_bytes(6, byteorder='big')
# Randomness: next 10 bytes of the ULID (80 bits)
randomness = random.getrandbits(80).to_bytes(10, byteorder='big')
# Convert to int and return
return int.from_bytes(timestamp + randomness, byteorder='big')
16 changes: 12 additions & 4 deletions logfire/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from time import time
from types import TracebackType
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Sequence, Tuple, TypedDict, TypeVar, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Sequence, Tuple, TypedDict, TypeVar, Union

from opentelemetry import context, trace as trace_api
from opentelemetry.sdk.resources import Resource
Expand All @@ -22,6 +23,7 @@
from requests import RequestException, Response

from logfire._internal.stack_info import is_user_code
from logfire._internal.ulid import ulid

if TYPE_CHECKING:
from packaging.version import Version
Expand Down Expand Up @@ -358,7 +360,11 @@ def is_asgi_send_receive_span_name(name: str) -> bool:
return name.endswith((' http send', ' http receive', ' websocket send', ' websocket receive'))


@dataclass(repr=True)
def _default_ms_timestamp_generator() -> int:
return int(time() * 1000)


@dataclass(repr=True, eq=True)
class SeededRandomIdGenerator(IdGenerator):
"""Generate random span/trace IDs from a seed for deterministic tests.

Expand All @@ -371,6 +377,8 @@ class SeededRandomIdGenerator(IdGenerator):
"""

seed: int | None = 0
_ms_timestamp_generator: Callable[[], int] = _default_ms_timestamp_generator
"""Private argument, do not set this directly."""

def __post_init__(self) -> None:
self.random = random.Random(self.seed)
Expand All @@ -384,7 +392,7 @@ def generate_span_id(self) -> int:
return span_id

def generate_trace_id(self) -> int:
trace_id = self.random.getrandbits(128)
trace_id = ulid(self.random, self._ms_timestamp_generator)
while trace_id == trace_api.INVALID_TRACE_ID: # pragma: no cover
trace_id = self.random.getrandbits(128)
trace_id = ulid(self.random, self._ms_timestamp_generator)
return trace_id
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from opentelemetry import trace
from opentelemetry.sdk.metrics.export import InMemoryMetricReader
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.id_generator import IdGenerator

import logfire
from logfire import configure
Expand Down Expand Up @@ -55,7 +56,7 @@ def metrics_reader() -> InMemoryMetricReader:
@pytest.fixture
def config_kwargs(
exporter: TestExporter,
id_generator: IncrementalIdGenerator,
id_generator: IdGenerator,
time_generator: TimeGenerator,
) -> dict[str, Any]:
"""
Expand Down
41 changes: 40 additions & 1 deletion tests/test_logfire.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from logfire._internal.formatter import FormattingFailedWarning, InspectArgumentsFailedWarning
from logfire._internal.main import NoopSpan
from logfire._internal.tracer import record_exception
from logfire._internal.utils import is_instrumentation_suppressed
from logfire._internal.utils import SeededRandomIdGenerator, is_instrumentation_suppressed
from logfire.integrations.logging import LogfireLoggingHandler
from logfire.testing import TestExporter
from tests.test_metrics import get_collected_metrics
Expand Down Expand Up @@ -3213,3 +3213,42 @@ def test_exit_ended_span(exporter: TestExporter):
}
]
)


_ns_currnet_ts = 0


def incrementing_ms_ts_generator() -> int:
global _ns_currnet_ts
_ns_currnet_ts += 420_000 # some randon number that results in non-whole ms
return _ns_currnet_ts // 1_000_000


@pytest.mark.parametrize(
'id_generator',
[SeededRandomIdGenerator(_ms_timestamp_generator=incrementing_ms_ts_generator)],
)
def test_default_id_generator(exporter: TestExporter) -> None:
"""Test that SeededRandomIdGenerator generates trace and span ids without errors."""
for i in range(1024):
logfire.info('log', i=i)

exported = exporter.exported_spans_as_dict()

# sanity check: there are 1024 trace ids
assert len({export['context']['trace_id'] for export in exported}) == 1024
# sanity check: there are multiple milliseconds (first 6 bytes)
assert len({export['context']['trace_id'] >> 80 for export in exported}) == snapshot(431)

# Check that trace ids are sortable and unique
# We use ULIDs to generate trace ids, so they should be sortable.
sorted_by_trace_id = [
export['attributes']['i']
# sort by trace_id and start_time so that if two trace ids were generated in the same ms and thus may sort randomly
# we disambiguate with the start time
for export in sorted(exported, key=lambda span: (span['context']['trace_id'] >> 80, span['start_time']))
]
sorted_by_start_timestamp = [
export['attributes']['i'] for export in sorted(exported, key=lambda span: span['start_time'])
]
assert sorted_by_trace_id == sorted_by_start_timestamp
2 changes: 1 addition & 1 deletion tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_sample_rate_config(exporter: TestExporter, config_kwargs: dict[str, Any

# 1000 iterations of 2 spans -> 2000 spans
# 30% sampling -> 600 spans (approximately)
assert len(exporter.exported_spans_as_dict()) == 634
assert len(exporter.exported_spans_as_dict()) == 588, len(exporter.exported_spans_as_dict())
alexmojaki marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.skipif(
Expand Down
12 changes: 6 additions & 6 deletions tests/test_tail_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def test_background_rate(config_kwargs: dict[str, Any], exporter: TestExporter):
# None of them meet the tail sampling criteria.
for _ in range(1000):
logfire.info('info')
assert len(exporter.exported_spans) == 100 + 321
assert len(exporter.exported_spans) - 100 == snapshot(299)


class TestSampler(Sampler):
Expand Down Expand Up @@ -406,7 +406,7 @@ def test_raw_head_sampler_with_tail_sampling(config_kwargs: dict[str, Any], expo
# None of them meet the tail sampling criteria.
for _ in range(1000):
logfire.info('info')
assert len(exporter.exported_spans) == 100 + 315
assert len(exporter.exported_spans) - 100 == snapshot(293)


def test_custom_head_and_tail(config_kwargs: dict[str, Any], exporter: TestExporter):
Expand All @@ -432,20 +432,20 @@ def get_tail_sample_rate(span_info: TailSamplingSpanInfo) -> float:

for _ in range(1000):
logfire.warn('warn')
assert span_counts == snapshot({'start': 720, 'end': 617})
assert len(exporter.exported_spans) == snapshot(103)
assert span_counts == snapshot({'start': 719, 'end': 611})
assert len(exporter.exported_spans) == snapshot(108)
assert span_counts['end'] + len(exporter.exported_spans) == span_counts['start']

exporter.clear()
for _ in range(1000):
with logfire.span('span'):
pass
assert len(exporter.exported_spans_as_dict()) == snapshot(505)
assert len(exporter.exported_spans_as_dict()) == snapshot(511)

exporter.clear()
for _ in range(1000):
logfire.error('error')
assert len(exporter.exported_spans) == snapshot(282)
assert len(exporter.exported_spans) == snapshot(298)


def test_span_levels():
Expand Down
Loading