Skip to content

Commit

Permalink
Improvements to FastAPI submission endpoint for x-trace-id (#143)
Browse files Browse the repository at this point in the history
  • Loading branch information
jdye64 authored Oct 4, 2024
1 parent a5fd5f6 commit d504bad
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ COPY src/pipeline.py ./
COPY pyproject.toml ./
COPY ./docker/scripts/entrypoint_source_ext.sh /opt/docker/bin/entrypoint_source

# Start both the core nv-ingest pipeline service and teh FastAPI microservice in parallel
# Start both the core nv-ingest pipeline service and the FastAPI microservice in parallel
CMD ["sh", "-c", "python /workspace/pipeline.py & uvicorn nv_ingest.main:app --workers 32 --host 0.0.0.0 --port 7670 & wait"]

FROM base AS development
Expand Down
6 changes: 3 additions & 3 deletions client/src/nv_ingest_client/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,20 +418,20 @@ def _submit_job(
try:
message = json.dumps(job_state.job_spec.to_dict())

job_id = self._message_client.submit_message(job_queue_id, message)
x_trace_id, job_id = self._message_client.submit_message(job_queue_id, message)

job_state.state = JobStateEnum.SUBMITTED
job_state.job_id = job_id

# Free up memory -- payload should never be used again, and we don't want to keep it around.
job_state.job_spec.payload = None

return x_trace_id
except Exception as err:
logger.error(f"Failed to submit job {job_index} to queue {job_queue_id}: {err}")
job_state.state = JobStateEnum.FAILED
raise

return None

def submit_job(self, job_indices: Union[str, List[str]], job_queue_id: str) -> List[Union[Dict, None]]:
if isinstance(job_indices, str):
job_indices = [job_indices]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,8 @@ def submit_message(self, _: str, message: str) -> str:
logger.debug(f"JobSpec successfully submitted to http \
endpoint {self._submit_endpoint}, Resulting JobId: {result.json()}")
# The REST interface returns a JobId, so we capture that here

return result.json()
x_trace_id = result.headers['x-trace-id']
return x_trace_id, result.json()
else:
# We could just let this exception bubble, but we capture for clarity
# we may also choose to use more specific exceptions in the future
Expand Down
3 changes: 2 additions & 1 deletion docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ services:
test: curl --fail http://nv-ingest-ms-runtime:7670/v1/health/ready || exit 1
interval: 10s
timeout: 5s
retries: 500
retries: 5
deploy:
resources:
reservations:
Expand Down Expand Up @@ -305,3 +305,4 @@ services:
# - "3001:3000"
# depends_on:
# - "milvus"

9 changes: 8 additions & 1 deletion src/nv_ingest/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

import logging

from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.trace import TracerProvider
Expand All @@ -26,6 +28,8 @@
span_processor = BatchSpanProcessor(exporter)
trace.get_tracer_provider().add_span_processor(span_processor)

logger = logging.getLogger("uvicorn")

# nv-ingest FastAPI app declaration
app = FastAPI()

Expand All @@ -44,7 +48,10 @@ async def add_trace_id_header(request, call_next):
# Inject the current x-trace-id into the HTTP headers response
span = trace.get_current_span()
if span:
trace_id = format(span.get_span_context().trace_id, '032x')
raw_trace_id = span.get_span_context().trace_id
trace_id = format(raw_trace_id, '032x')
logger.debug(f"MIDDLEWARE add_trace_id_header Raw \
Trace Id: {raw_trace_id} - Formatted Trace Id: {trace_id}")
response.headers["x-trace-id"] = trace_id

return response
16 changes: 14 additions & 2 deletions src/nv_ingest/api/v1/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# pylint: skip-file

import base64
import copy
import json
from io import BytesIO
import logging
Expand All @@ -29,6 +30,7 @@
from nv_ingest.schemas.message_wrapper_schema import MessageWrapper
from nv_ingest.service.impl.ingest.redis_ingest_service import RedisIngestService
from nv_ingest.service.meta.ingest.ingest_service_meta import IngestServiceMeta
from nv_ingest.schemas.ingest_job_schema import DocumentTypeEnum

logger = logging.getLogger("uvicorn")
tracer = trace.get_tracer(__name__)
Expand Down Expand Up @@ -108,7 +110,6 @@ async def submit_job_curl_friendly(
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Nv-Ingest Internal Server Error: {str(ex)}")


# POST /submit_job
@router.post(
"/submit_job",
Expand All @@ -123,7 +124,18 @@ async def submit_job_curl_friendly(
)
async def submit_job(job_spec: MessageWrapper, ingest_service: INGEST_SERVICE_T):
try:
submitted_job_id = await ingest_service.submit_job(job_spec)
# Inject the x-trace-id into the JobSpec definition so that OpenTelemetry
# will be able to trace across uvicorn -> morpheus
current_trace_id = trace.get_current_span().get_span_context().trace_id

# Recreate the JobSpec to test what is going on ....
job_spec_dict = json.loads(job_spec.payload)
job_spec_dict['tracing_options']['trace_id'] = current_trace_id
updated_job_spec = MessageWrapper(
payload=json.dumps(job_spec_dict)
)

submitted_job_id = await ingest_service.submit_job(updated_job_spec)
return submitted_job_id
except Exception as ex:
traceback.print_exc()
Expand Down
7 changes: 5 additions & 2 deletions tests/nv_ingest_client/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,17 @@ def __init__(self, host, port):

def submit_message(self, job_queue_id, job_spec_str):
# Simulate message submission by storing it
random_x_trace_id = "123456789"
job_id = 0
self.submitted_messages.append((job_queue_id, job_spec_str))
return random_x_trace_id, job_id


class ExtendedMockClientWithFailure(ExtendedMockClient):
def submit_message(self, job_queue_id, job_spec_str):
if "fail_queue" in job_queue_id:
raise Exception("Simulated submission failure")
super().submit_message(job_queue_id, job_spec_str)
return super().submit_message(job_queue_id, job_spec_str)


class ExtendedMockClientWithFetch(ExtendedMockClientWithFailure):
Expand Down Expand Up @@ -407,7 +410,7 @@ def test_job_future_result_on_success(nv_ingest_client_with_jobs):
future = nv_ingest_client_with_jobs._job_states[job_id].future

result = future.result(timeout=5)
assert result == [None], "The future's result should reflect the job's success"
assert result == ["123456789"], "The future's result should reflect the job's success"


def test_job_future_result_on_failure(nv_ingest_client_with_jobs):
Expand Down

0 comments on commit d504bad

Please sign in to comment.