Feat: External_trace_id compatible with OpenTelemetry (#23918)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
heyszt 2025-08-15 09:13:41 +08:00 committed by GitHub
parent 11fdcb18c6
commit aa71173dbb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 113 additions and 10 deletions

View File

@ -1,6 +1,7 @@
import logging import logging
import flask_login import flask_login
from flask import request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
@ -24,6 +25,7 @@ from core.errors.error import (
ProviderTokenNotInitError, ProviderTokenNotInitError,
QuotaExceededError, QuotaExceededError,
) )
from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs import helper from libs import helper
from libs.helper import uuid_value from libs.helper import uuid_value
@ -115,6 +117,10 @@ class ChatMessageApi(Resource):
streaming = args["response_mode"] != "blocking" streaming = args["response_mode"] != "blocking"
args["auto_generate_name"] = False args["auto_generate_name"] = False
external_trace_id = get_external_trace_id(request)
if external_trace_id:
args["external_trace_id"] = external_trace_id
account = flask_login.current_user account = flask_login.current_user
try: try:

View File

@ -23,6 +23,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File from core.file.models import File
from core.helper.trace_id_helper import get_external_trace_id
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory, variable_factory from factories import file_factory, variable_factory
from fields.workflow_fields import workflow_fields, workflow_pagination_fields from fields.workflow_fields import workflow_fields, workflow_pagination_fields
@ -185,6 +186,10 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
args = parser.parse_args() args = parser.parse_args()
external_trace_id = get_external_trace_id(request)
if external_trace_id:
args["external_trace_id"] = external_trace_id
try: try:
response = AppGenerateService.generate( response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True
@ -373,6 +378,10 @@ class DraftWorkflowRunApi(Resource):
parser.add_argument("files", type=list, required=False, location="json") parser.add_argument("files", type=list, required=False, location="json")
args = parser.parse_args() args = parser.parse_args()
external_trace_id = get_external_trace_id(request)
if external_trace_id:
args["external_trace_id"] = external_trace_id
try: try:
response = AppGenerateService.generate( response = AppGenerateService.generate(
app_model=app_model, app_model=app_model,

View File

@ -16,15 +16,33 @@ def get_external_trace_id(request: Any) -> Optional[str]:
""" """
Retrieve the trace_id from the request. Retrieve the trace_id from the request.
Priority: header ('X-Trace-Id'), then parameters, then JSON body. Returns None if not provided or invalid. Priority:
1. header ('X-Trace-Id')
2. parameters
3. JSON body
4. Current OpenTelemetry context (if enabled)
5. OpenTelemetry traceparent header (if present and valid)
Returns None if no valid trace_id is provided.
""" """
trace_id = request.headers.get("X-Trace-Id") trace_id = request.headers.get("X-Trace-Id")
if not trace_id: if not trace_id:
trace_id = request.args.get("trace_id") trace_id = request.args.get("trace_id")
if not trace_id and getattr(request, "is_json", False): if not trace_id and getattr(request, "is_json", False):
json_data = getattr(request, "json", None) json_data = getattr(request, "json", None)
if json_data: if json_data:
trace_id = json_data.get("trace_id") trace_id = json_data.get("trace_id")
if not trace_id:
trace_id = get_trace_id_from_otel_context()
if not trace_id:
traceparent = request.headers.get("traceparent")
if traceparent:
trace_id = parse_traceparent_header(traceparent)
if isinstance(trace_id, str) and is_valid_trace_id(trace_id): if isinstance(trace_id, str) and is_valid_trace_id(trace_id):
return trace_id return trace_id
return None return None
@ -40,3 +58,49 @@ def extract_external_trace_id_from_args(args: Mapping[str, Any]) -> dict:
if trace_id: if trace_id:
return {"external_trace_id": trace_id} return {"external_trace_id": trace_id}
return {} return {}
def get_trace_id_from_otel_context() -> Optional[str]:
"""
Retrieve the current trace ID from the active OpenTelemetry trace context.
Returns None if:
1. OpenTelemetry SDK is not installed or enabled.
2. There is no active span or trace context.
"""
try:
from opentelemetry.trace import SpanContext, get_current_span
from opentelemetry.trace.span import INVALID_TRACE_ID
span = get_current_span()
if not span:
return None
span_context: SpanContext = span.get_span_context()
if not span_context or span_context.trace_id == INVALID_TRACE_ID:
return None
trace_id_hex = f"{span_context.trace_id:032x}"
return trace_id_hex
except Exception:
return None
def parse_traceparent_header(traceparent: str) -> Optional[str]:
"""
Parse the `traceparent` header to extract the trace_id.
Expected format:
'version-trace_id-span_id-flags'
Reference:
W3C Trace Context Specification: https://www.w3.org/TR/trace-context/
"""
try:
parts = traceparent.split("-")
if len(parts) == 4 and len(parts[1]) == 32:
return parts[1]
except Exception:
pass
return None

View File

@ -4,15 +4,15 @@ from collections.abc import Sequence
from typing import Optional from typing import Optional
from urllib.parse import urljoin from urllib.parse import urljoin
from opentelemetry.trace import Status, StatusCode from opentelemetry.trace import Link, Status, StatusCode
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
from core.ops.aliyun_trace.data_exporter.traceclient import ( from core.ops.aliyun_trace.data_exporter.traceclient import (
TraceClient, TraceClient,
convert_datetime_to_nanoseconds, convert_datetime_to_nanoseconds,
convert_string_to_id,
convert_to_span_id, convert_to_span_id,
convert_to_trace_id, convert_to_trace_id,
create_link,
generate_span_id, generate_span_id,
) )
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
@ -103,10 +103,11 @@ class AliyunDataTrace(BaseTraceInstance):
def workflow_trace(self, trace_info: WorkflowTraceInfo): def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = convert_to_trace_id(trace_info.workflow_run_id) trace_id = convert_to_trace_id(trace_info.workflow_run_id)
links = []
if trace_info.trace_id: if trace_info.trace_id:
trace_id = convert_string_to_id(trace_info.trace_id) links.append(create_link(trace_id_str=trace_info.trace_id))
workflow_span_id = convert_to_span_id(trace_info.workflow_run_id, "workflow") workflow_span_id = convert_to_span_id(trace_info.workflow_run_id, "workflow")
self.add_workflow_span(trace_id, workflow_span_id, trace_info) self.add_workflow_span(trace_id, workflow_span_id, trace_info, links)
workflow_node_executions = self.get_workflow_node_executions(trace_info) workflow_node_executions = self.get_workflow_node_executions(trace_info)
for node_execution in workflow_node_executions: for node_execution in workflow_node_executions:
@ -132,8 +133,9 @@ class AliyunDataTrace(BaseTraceInstance):
status = Status(StatusCode.ERROR, trace_info.error) status = Status(StatusCode.ERROR, trace_info.error)
trace_id = convert_to_trace_id(message_id) trace_id = convert_to_trace_id(message_id)
links = []
if trace_info.trace_id: if trace_info.trace_id:
trace_id = convert_string_to_id(trace_info.trace_id) links.append(create_link(trace_id_str=trace_info.trace_id))
message_span_id = convert_to_span_id(message_id, "message") message_span_id = convert_to_span_id(message_id, "message")
message_span = SpanData( message_span = SpanData(
@ -152,6 +154,7 @@ class AliyunDataTrace(BaseTraceInstance):
OUTPUT_VALUE: str(trace_info.outputs), OUTPUT_VALUE: str(trace_info.outputs),
}, },
status=status, status=status,
links=links,
) )
self.trace_client.add_span(message_span) self.trace_client.add_span(message_span)
@ -192,8 +195,9 @@ class AliyunDataTrace(BaseTraceInstance):
message_id = trace_info.message_id message_id = trace_info.message_id
trace_id = convert_to_trace_id(message_id) trace_id = convert_to_trace_id(message_id)
links = []
if trace_info.trace_id: if trace_info.trace_id:
trace_id = convert_string_to_id(trace_info.trace_id) links.append(create_link(trace_id_str=trace_info.trace_id))
documents_data = extract_retrieval_documents(trace_info.documents) documents_data = extract_retrieval_documents(trace_info.documents)
dataset_retrieval_span = SpanData( dataset_retrieval_span = SpanData(
@ -211,6 +215,7 @@ class AliyunDataTrace(BaseTraceInstance):
INPUT_VALUE: str(trace_info.inputs), INPUT_VALUE: str(trace_info.inputs),
OUTPUT_VALUE: json.dumps(documents_data, ensure_ascii=False), OUTPUT_VALUE: json.dumps(documents_data, ensure_ascii=False),
}, },
links=links,
) )
self.trace_client.add_span(dataset_retrieval_span) self.trace_client.add_span(dataset_retrieval_span)
@ -224,8 +229,9 @@ class AliyunDataTrace(BaseTraceInstance):
status = Status(StatusCode.ERROR, trace_info.error) status = Status(StatusCode.ERROR, trace_info.error)
trace_id = convert_to_trace_id(message_id) trace_id = convert_to_trace_id(message_id)
links = []
if trace_info.trace_id: if trace_info.trace_id:
trace_id = convert_string_to_id(trace_info.trace_id) links.append(create_link(trace_id_str=trace_info.trace_id))
tool_span = SpanData( tool_span = SpanData(
trace_id=trace_id, trace_id=trace_id,
@ -244,6 +250,7 @@ class AliyunDataTrace(BaseTraceInstance):
OUTPUT_VALUE: str(trace_info.tool_outputs), OUTPUT_VALUE: str(trace_info.tool_outputs),
}, },
status=status, status=status,
links=links,
) )
self.trace_client.add_span(tool_span) self.trace_client.add_span(tool_span)
@ -413,7 +420,9 @@ class AliyunDataTrace(BaseTraceInstance):
status=self.get_workflow_node_status(node_execution), status=self.get_workflow_node_status(node_execution),
) )
def add_workflow_span(self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo): def add_workflow_span(
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, links: Sequence[Link]
):
message_span_id = None message_span_id = None
if trace_info.message_id: if trace_info.message_id:
message_span_id = convert_to_span_id(trace_info.message_id, "message") message_span_id = convert_to_span_id(trace_info.message_id, "message")
@ -438,6 +447,7 @@ class AliyunDataTrace(BaseTraceInstance):
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False), OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
}, },
status=status, status=status,
links=links,
) )
self.trace_client.add_span(message_span) self.trace_client.add_span(message_span)
@ -456,6 +466,7 @@ class AliyunDataTrace(BaseTraceInstance):
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False), OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
}, },
status=status, status=status,
links=links,
) )
self.trace_client.add_span(workflow_span) self.trace_client.add_span(workflow_span)
@ -466,8 +477,9 @@ class AliyunDataTrace(BaseTraceInstance):
status = Status(StatusCode.ERROR, trace_info.error) status = Status(StatusCode.ERROR, trace_info.error)
trace_id = convert_to_trace_id(message_id) trace_id = convert_to_trace_id(message_id)
links = []
if trace_info.trace_id: if trace_info.trace_id:
trace_id = convert_string_to_id(trace_info.trace_id) links.append(create_link(trace_id_str=trace_info.trace_id))
suggested_question_span = SpanData( suggested_question_span = SpanData(
trace_id=trace_id, trace_id=trace_id,
@ -487,6 +499,7 @@ class AliyunDataTrace(BaseTraceInstance):
OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False), OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False),
}, },
status=status, status=status,
links=links,
) )
self.trace_client.add_span(suggested_question_span) self.trace_client.add_span(suggested_question_span)

View File

@ -16,6 +16,7 @@ from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import ReadableSpan from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.util.instrumentation import InstrumentationScope from opentelemetry.sdk.util.instrumentation import InstrumentationScope
from opentelemetry.semconv.resource import ResourceAttributes from opentelemetry.semconv.resource import ResourceAttributes
from opentelemetry.trace import Link, SpanContext, TraceFlags
from configs import dify_config from configs import dify_config
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
@ -166,6 +167,16 @@ class SpanBuilder:
return span return span
def create_link(trace_id_str: str) -> Link:
placeholder_span_id = 0x0000000000000000
trace_id = int(trace_id_str, 16)
span_context = SpanContext(
trace_id=trace_id, span_id=placeholder_span_id, is_remote=False, trace_flags=TraceFlags(TraceFlags.SAMPLED)
)
return Link(span_context)
def generate_span_id() -> int: def generate_span_id() -> int:
span_id = random.getrandbits(64) span_id = random.getrandbits(64)
while span_id == INVALID_SPAN_ID: while span_id == INVALID_SPAN_ID: