diff --git a/datadog_lambda/tracing.py b/datadog_lambda/tracing.py index 9870f756..4f679fe1 100644 --- a/datadog_lambda/tracing.py +++ b/datadog_lambda/tracing.py @@ -362,7 +362,13 @@ def is_lambda_context(): def set_dd_trace_py_root(trace_context_source, merge_xray_traces): if trace_context_source == TraceContextSource.EVENT or merge_xray_traces: - headers = _context_obj_to_headers(dd_trace_context) + context = dict(dd_trace_context) + if merge_xray_traces: + xray_context = _get_xray_trace_context() + if xray_context is not None: + context["parent-id"] = xray_context["parent-id"] + + headers = _context_obj_to_headers(context) span_context = propagator.extract(headers) tracer.context_provider.activate(span_context) logger.debug( diff --git a/tests/test_tracing.py b/tests/test_tracing.py index 82e11f3f..e0ba12c4 100644 --- a/tests/test_tracing.py +++ b/tests/test_tracing.py @@ -8,6 +8,7 @@ from mock import MagicMock, patch, call from ddtrace.helpers import get_correlation_ids +from ddtrace.context import Context from datadog_lambda.constants import SamplingPriority, TraceHeader, XraySubsegment from datadog_lambda.tracing import ( @@ -16,6 +17,7 @@ create_function_execution_span, get_dd_trace_context, set_correlation_ids, + set_dd_trace_py_root, _convert_xray_trace_id, _convert_xray_entity_id, _convert_xray_sampling, @@ -493,3 +495,53 @@ def test_function_with_trigger_tags(self): self.assertEqual( span.get_tag("function_trigger.event_source"), "cloudwatch-logs" ) + + +class TestSetTraceRootSpan(unittest.TestCase): + def setUp(self): + global dd_tracing_enabled + dd_tracing_enabled = False + os.environ["_X_AMZN_TRACE_ID"] = fake_xray_header_value + patcher = patch("datadog_lambda.tracing.send_segment") + self.mock_send_segment = patcher.start() + self.addCleanup(patcher.stop) + patcher = patch("datadog_lambda.tracing.is_lambda_context") + self.mock_is_lambda_context = patcher.start() + self.mock_is_lambda_context.return_value = True + self.addCleanup(patcher.stop) + patcher = patch("ddtrace.tracer.context_provider.activate") + self.mock_activate = patcher.start() + self.mock_activate.return_value = True + self.addCleanup(patcher.stop) + + def tearDown(self): + global dd_tracing_enabled + dd_tracing_enabled = False + del os.environ["_X_AMZN_TRACE_ID"] + + def test_mixed_parent_context_when_merging(self): + # When trace merging is enabled, and dd_trace headers are present, + # use the dd-trace trace-id and the x-ray parent-id + # This allows parenting relationships like dd-trace -> x-ray -> dd-trace + lambda_ctx = get_mock_context() + ctx, source = extract_dd_trace_context( + { + "headers": { + TraceHeader.TRACE_ID: "123", + TraceHeader.PARENT_ID: "321", + TraceHeader.SAMPLING_PRIORITY: "1", + } + }, + lambda_ctx, + ) + set_dd_trace_py_root( + source, True + ) # When merging is off, always use dd-trace-context + + expected_context = Context( + trace_id=123, # Trace Id from incomming context + span_id=int(fake_xray_header_value_parent_decimal), # Parent Id from x-ray + sampling_priority=1, # Sampling priority from incomming context + ) + self.mock_activate.assert_called() + self.mock_activate.assert_has_calls([call(expected_context)])