diff --git a/truss/templates/server/common/tracing.py b/truss/templates/server/common/tracing.py index f2f40ba79..efb980816 100644 --- a/truss/templates/server/common/tracing.py +++ b/truss/templates/server/common/tracing.py @@ -21,6 +21,8 @@ HONEYCOMB_DATASET = "HONEYCOMB_DATASET" HONEYCOMB_API_KEY = "HONEYCOMB_API_KEY" +DEFAULT_ENABLE_TRACING_DATA = False # This should be in sync with truss_config.py. + class JSONFileExporter(trace_export.SpanExporter): """Writes spans to newline-delimited JSON file for debugging / testing.""" @@ -45,7 +47,7 @@ def shutdown(self) -> None: _truss_tracer: Optional[trace.Tracer] = None -def get_truss_tracer(secrets: secrets_resolver.SecretsResolver) -> trace.Tracer: +def get_truss_tracer(secrets: secrets_resolver.SecretsResolver, config) -> trace.Tracer: """Creates a cached tracer (i.e. runtime-singleton) to be used for truss internal tracing. @@ -53,6 +55,10 @@ def get_truss_tracer(secrets: secrets_resolver.SecretsResolver) -> trace.Tracer: completely from potential user-defined tracing - see also `detach_context` below. """ + enable_tracing_data = config.get("runtime", {}).get( + "enable_tracing_data", DEFAULT_ENABLE_TRACING_DATA + ) + global _truss_tracer if _truss_tracer: return _truss_tracer @@ -65,28 +71,27 @@ def get_truss_tracer(secrets: secrets_resolver.SecretsResolver) -> trace.Tracer: span_processors.append(otlp_processor) if tracing_log_file := os.getenv(OTEL_TRACING_NDJSON_FILE): - logger.info("Exporting trace data to `tracing_log_file`.") + logger.info(f"Exporting trace data to file `{tracing_log_file}`.") json_file_exporter = JSONFileExporter(pathlib.Path(tracing_log_file)) file_processor = sdk_trace.export.SimpleSpanProcessor(json_file_exporter) span_processors.append(file_processor) - if honeycomb_dataset := os.getenv(HONEYCOMB_DATASET): - if HONEYCOMB_API_KEY in secrets: - honeycomb_api_key = secrets[HONEYCOMB_API_KEY] - logger.info("Exporting trace data to honeycomb.") - honeycomb_exporter = oltp_exporter.OTLPSpanExporter( - endpoint="https://api.honeycomb.io/v1/traces", - headers={ - "x-honeycomb-team": honeycomb_api_key, - "x-honeycomb-dataset": honeycomb_dataset, - }, - ) - honeycomb_processor = sdk_trace.export.BatchSpanProcessor( - honeycomb_exporter - ) - span_processors.append(honeycomb_processor) + if ( + honeycomb_dataset := os.getenv(HONEYCOMB_DATASET) + ) and HONEYCOMB_API_KEY in secrets: + honeycomb_api_key = secrets[HONEYCOMB_API_KEY] + logger.info("Exporting trace data to honeycomb.") + honeycomb_exporter = oltp_exporter.OTLPSpanExporter( + endpoint="https://api.honeycomb.io/v1/traces", + headers={ + "x-honeycomb-team": honeycomb_api_key, + "x-honeycomb-dataset": honeycomb_dataset, + }, + ) + honeycomb_processor = sdk_trace.export.BatchSpanProcessor(honeycomb_exporter) + span_processors.append(honeycomb_processor) - if span_processors: + if span_processors and enable_tracing_data: logger.info("Instantiating truss tracer.") resource = resources.Resource.create({resources.SERVICE_NAME: "TrussServer"}) trace_provider = sdk_trace.TracerProvider(resource=resource) @@ -94,7 +99,13 @@ def get_truss_tracer(secrets: secrets_resolver.SecretsResolver) -> trace.Tracer: trace_provider.add_span_processor(sp) tracer = trace_provider.get_tracer("truss_server") else: - logger.info("Using no-op tracing.") + if enable_tracing_data: + logger.info( + "Using no-op tracing (tracing is enabled, but no exporters confiugred)." + ) + else: + logger.info("Using no-op tracing (tracing was disabled).") + tracer = sdk_trace.NoOpTracer() _truss_tracer = tracer diff --git a/truss/templates/server/common/truss_server.py b/truss/templates/server/common/truss_server.py index 55cab7a17..7db57a99c 100644 --- a/truss/templates/server/common/truss_server.py +++ b/truss/templates/server/common/truss_server.py @@ -227,7 +227,7 @@ def __init__( setup_json_logger: bool = True, ): secrets = SecretsResolver.get_secrets(config) - tracer = tracing.get_truss_tracer(secrets) + tracer = tracing.get_truss_tracer(secrets, config) self.http_port = http_port self._config = config self._model = ModelWrapper(self._config, tracer) diff --git a/truss/templates/server/model_wrapper.py b/truss/templates/server/model_wrapper.py index e81d7ae1e..32bb9f6df 100644 --- a/truss/templates/server/model_wrapper.py +++ b/truss/templates/server/model_wrapper.py @@ -302,11 +302,7 @@ async def write_response_to_queue( finally: await queue.put(None) - async def _gather_generator(self, response: Any, span: trace.Span) -> str: - # In the case of gathering, it might make more sense to apply the postprocess - # to the gathered result, but that would be inconsistent with streaming. - # In general, it might even be better to strictly forbid postprocessing - # for generators. + async def _streaming_post_process(self, response: Any, span: trace.Span) -> Any: if hasattr(self._model, "postprocess"): logging.warning( "Predict returned a streaming response, while a postprocess is defined." @@ -317,6 +313,14 @@ async def _gather_generator(self, response: Any, span: trace.Span) -> str: ), tracing.detach_context(): response = await self.postprocess(response) + return response + + async def _gather_generator(self, response: Any, span: trace.Span) -> str: + # In the case of gathering, it might make more sense to apply the postprocess + # to the gathered result, but that would be inconsistent with streaming. + # In general, it might even be better to strictly forbid postprocessing + # for generators. + response = await self._streaming_post_process(response, span) return await _convert_streamed_response_to_string( _force_async_generator(response) ) @@ -333,6 +337,7 @@ async def _stream_with_background_task( streaming_read_timeout = self._config.get("runtime", {}).get( "streaming_read_timeout", STREAMING_RESPONSE_QUEUE_READ_TIMEOUT_SECS ) + response = await self._streaming_post_process(response, span) async_generator = _force_async_generator(response) # To ensure that a partial read from a client does not keep the semaphore # claimed, we write all the data from the stream to the queue as it is produced, @@ -349,13 +354,13 @@ async def _stream_with_background_task( gen_task.add_done_callback(lambda _: release_and_end()) # The gap between responses in a stream must be < streaming_read_timeout - async def _response_generator(): + async def _buffered_response_generator(): # `span` is tied to the "producer" `gen_task` which might complete before # "consume" part here finishes, therefore a dedicated span is required. # Because all of this code is inside a `detach_context` block, we # explicitly propagate the tracing context for this span. with self._tracer.start_as_current_span( - "response_generator", context=trace_ctx + "buffered-response-generator", context=trace_ctx ): while True: chunk = await asyncio.wait_for( @@ -366,7 +371,7 @@ async def _response_generator(): return yield chunk.value - return _response_generator() + return _buffered_response_generator() async def __call__( self, body: Any, headers: Optional[Mapping[str, str]] = None diff --git a/truss/tests/test_model_inference.py b/truss/tests/test_model_inference.py index b9fb6a464..c3283627d 100644 --- a/truss/tests/test_model_inference.py +++ b/truss/tests/test_model_inference.py @@ -1,4 +1,5 @@ import concurrent +import dataclasses import inspect import json import logging @@ -878,12 +879,21 @@ def _make_otel_headers() -> Mapping[str, str]: @pytest.mark.integration -def test_streaming_truss_with_user_tracing(): +@pytest.mark.parametrize("enable_tracing_data", [True, False]) +def test_streaming_truss_with_user_tracing(enable_tracing_data): with ensure_kill_all(): truss_root = Path(__file__).parent.parent.parent.resolve() / "truss" truss_dir = truss_root / "test_data" / "test_streaming_truss_with_tracing" tr = TrussHandle(truss_dir) + def enable_gpu_fn(conf): + new_runtime = dataclasses.replace( + conf.runtime, enable_tracing_data=enable_tracing_data + ) + return dataclasses.replace(conf, runtime=new_runtime) + + tr._update_config(enable_gpu_fn) + container = tr.docker_run( local_port=8090, detach=True, wait_for_server_ready=True ) @@ -930,11 +940,10 @@ def test_streaming_truss_with_user_tracing(): json.loads(s) for s in user_traces_file.read_text().splitlines() ] - # for x in truss_traces: - # print(x) - # print("#" * 30) - # for x in user_traces: - # print(x) + if not enable_tracing_data: + assert len(truss_traces) == 0 + assert len(user_traces) > 0 + return assert sum(1 for x in truss_traces if x["name"] == "predict-endpoint") == 3 assert sum(1 for x in user_traces if x["name"] == "load_model") == 1 diff --git a/truss/truss_config.py b/truss/truss_config.py index f62cb33bf..778d454ec 100644 --- a/truss/truss_config.py +++ b/truss/truss_config.py @@ -34,6 +34,7 @@ DEFAULT_PREDICT_CONCURRENCY = 1 DEFAULT_NUM_WORKERS = 1 DEFAULT_STREAMING_RESPONSE_READ_TIMEOUT = 60 +DEFAULT_ENABLE_TRACING_DATA = False # This should be in sync with tracing.py. DEFAULT_CPU = "1" DEFAULT_MEMORY = "2Gi" @@ -143,6 +144,7 @@ class Runtime: predict_concurrency: int = DEFAULT_PREDICT_CONCURRENCY num_workers: int = DEFAULT_NUM_WORKERS streaming_read_timeout: int = DEFAULT_STREAMING_RESPONSE_READ_TIMEOUT + enable_tracing_data: bool = DEFAULT_ENABLE_TRACING_DATA @staticmethod def from_dict(d): @@ -151,11 +153,13 @@ def from_dict(d): streaming_read_timeout = d.get( "streaming_read_timeout", DEFAULT_STREAMING_RESPONSE_READ_TIMEOUT ) + enable_tracing_data = d.get("enable_tracing_data", DEFAULT_ENABLE_TRACING_DATA) return Runtime( predict_concurrency=predict_concurrency, num_workers=num_workers, streaming_read_timeout=streaming_read_timeout, + enable_tracing_data=enable_tracing_data, ) def to_dict(self): @@ -163,6 +167,7 @@ def to_dict(self): "predict_concurrency": self.predict_concurrency, "num_workers": self.num_workers, "streaming_read_timeout": self.streaming_read_timeout, + "enable_tracing_data": self.enable_tracing_data, }