Skip to content

Commit

Permalink
Fix tests, add truss config option.
Browse files Browse the repository at this point in the history
  • Loading branch information
marius-baseten committed Aug 29, 2024
1 parent f181538 commit a0677b9
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 34 deletions.
49 changes: 30 additions & 19 deletions truss/templates/server/common/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
HONEYCOMB_DATASET = "HONEYCOMB_DATASET"
HONEYCOMB_API_KEY = "HONEYCOMB_API_KEY"

DEFAULT_ENABLE_TRACING_DATA = False # This should be in sync with truss_config.py.


class JSONFileExporter(trace_export.SpanExporter):
"""Writes spans to newline-delimited JSON file for debugging / testing."""
Expand All @@ -45,14 +47,18 @@ def shutdown(self) -> None:
_truss_tracer: Optional[trace.Tracer] = None


def get_truss_tracer(secrets: secrets_resolver.SecretsResolver) -> trace.Tracer:
def get_truss_tracer(secrets: secrets_resolver.SecretsResolver, config) -> trace.Tracer:
"""Creates a cached tracer (i.e. runtime-singleton) to be used for truss
internal tracing.
The goal is to separate truss-internal tracing instrumentation
completely from potential user-defined tracing - see also `detach_context` below.
"""
enable_tracing_data = config.get("runtime", {}).get(
"enable_tracing_data", DEFAULT_ENABLE_TRACING_DATA
)

global _truss_tracer
if _truss_tracer:
return _truss_tracer
Expand All @@ -65,36 +71,41 @@ def get_truss_tracer(secrets: secrets_resolver.SecretsResolver) -> trace.Tracer:
span_processors.append(otlp_processor)

if tracing_log_file := os.getenv(OTEL_TRACING_NDJSON_FILE):
logger.info("Exporting trace data to `tracing_log_file`.")
logger.info(f"Exporting trace data to file `{tracing_log_file}`.")
json_file_exporter = JSONFileExporter(pathlib.Path(tracing_log_file))
file_processor = sdk_trace.export.SimpleSpanProcessor(json_file_exporter)
span_processors.append(file_processor)

if honeycomb_dataset := os.getenv(HONEYCOMB_DATASET):
if HONEYCOMB_API_KEY in secrets:
honeycomb_api_key = secrets[HONEYCOMB_API_KEY]
logger.info("Exporting trace data to honeycomb.")
honeycomb_exporter = oltp_exporter.OTLPSpanExporter(
endpoint="https://api.honeycomb.io/v1/traces",
headers={
"x-honeycomb-team": honeycomb_api_key,
"x-honeycomb-dataset": honeycomb_dataset,
},
)
honeycomb_processor = sdk_trace.export.BatchSpanProcessor(
honeycomb_exporter
)
span_processors.append(honeycomb_processor)
if (
honeycomb_dataset := os.getenv(HONEYCOMB_DATASET)
) and HONEYCOMB_API_KEY in secrets:
honeycomb_api_key = secrets[HONEYCOMB_API_KEY]
logger.info("Exporting trace data to honeycomb.")
honeycomb_exporter = oltp_exporter.OTLPSpanExporter(
endpoint="https://api.honeycomb.io/v1/traces",
headers={
"x-honeycomb-team": honeycomb_api_key,
"x-honeycomb-dataset": honeycomb_dataset,
},
)
honeycomb_processor = sdk_trace.export.BatchSpanProcessor(honeycomb_exporter)
span_processors.append(honeycomb_processor)

if span_processors:
if span_processors and enable_tracing_data:
logger.info("Instantiating truss tracer.")
resource = resources.Resource.create({resources.SERVICE_NAME: "TrussServer"})
trace_provider = sdk_trace.TracerProvider(resource=resource)
for sp in span_processors:
trace_provider.add_span_processor(sp)
tracer = trace_provider.get_tracer("truss_server")
else:
logger.info("Using no-op tracing.")
if enable_tracing_data:
logger.info(
"Using no-op tracing (tracing is enabled, but no exporters confiugred)."
)
else:
logger.info("Using no-op tracing (tracing was disabled).")

tracer = sdk_trace.NoOpTracer()

_truss_tracer = tracer
Expand Down
2 changes: 1 addition & 1 deletion truss/templates/server/common/truss_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def __init__(
setup_json_logger: bool = True,
):
secrets = SecretsResolver.get_secrets(config)
tracer = tracing.get_truss_tracer(secrets)
tracer = tracing.get_truss_tracer(secrets, config)
self.http_port = http_port
self._config = config
self._model = ModelWrapper(self._config, tracer)
Expand Down
21 changes: 13 additions & 8 deletions truss/templates/server/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,11 +302,7 @@ async def write_response_to_queue(
finally:
await queue.put(None)

async def _gather_generator(self, response: Any, span: trace.Span) -> str:
# In the case of gathering, it might make more sense to apply the postprocess
# to the gathered result, but that would be inconsistent with streaming.
# In general, it might even be better to strictly forbid postprocessing
# for generators.
async def _streaming_post_process(self, response: Any, span: trace.Span) -> Any:
if hasattr(self._model, "postprocess"):
logging.warning(
"Predict returned a streaming response, while a postprocess is defined."
Expand All @@ -317,6 +313,14 @@ async def _gather_generator(self, response: Any, span: trace.Span) -> str:
), tracing.detach_context():
response = await self.postprocess(response)

return response

async def _gather_generator(self, response: Any, span: trace.Span) -> str:
# In the case of gathering, it might make more sense to apply the postprocess
# to the gathered result, but that would be inconsistent with streaming.
# In general, it might even be better to strictly forbid postprocessing
# for generators.
response = await self._streaming_post_process(response, span)
return await _convert_streamed_response_to_string(
_force_async_generator(response)
)
Expand All @@ -333,6 +337,7 @@ async def _stream_with_background_task(
streaming_read_timeout = self._config.get("runtime", {}).get(
"streaming_read_timeout", STREAMING_RESPONSE_QUEUE_READ_TIMEOUT_SECS
)
response = await self._streaming_post_process(response, span)
async_generator = _force_async_generator(response)
# To ensure that a partial read from a client does not keep the semaphore
# claimed, we write all the data from the stream to the queue as it is produced,
Expand All @@ -349,13 +354,13 @@ async def _stream_with_background_task(
gen_task.add_done_callback(lambda _: release_and_end())

# The gap between responses in a stream must be < streaming_read_timeout
async def _response_generator():
async def _buffered_response_generator():
# `span` is tied to the "producer" `gen_task` which might complete before
# "consume" part here finishes, therefore a dedicated span is required.
# Because all of this code is inside a `detach_context` block, we
# explicitly propagate the tracing context for this span.
with self._tracer.start_as_current_span(
"response_generator", context=trace_ctx
"buffered-response-generator", context=trace_ctx
):
while True:
chunk = await asyncio.wait_for(
Expand All @@ -366,7 +371,7 @@ async def _response_generator():
return
yield chunk.value

return _response_generator()
return _buffered_response_generator()

async def __call__(
self, body: Any, headers: Optional[Mapping[str, str]] = None
Expand Down
21 changes: 15 additions & 6 deletions truss/tests/test_model_inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import concurrent
import dataclasses
import inspect
import json
import logging
Expand Down Expand Up @@ -878,12 +879,21 @@ def _make_otel_headers() -> Mapping[str, str]:


@pytest.mark.integration
def test_streaming_truss_with_user_tracing():
@pytest.mark.parametrize("enable_tracing_data", [True, False])
def test_streaming_truss_with_user_tracing(enable_tracing_data):
with ensure_kill_all():
truss_root = Path(__file__).parent.parent.parent.resolve() / "truss"
truss_dir = truss_root / "test_data" / "test_streaming_truss_with_tracing"
tr = TrussHandle(truss_dir)

def enable_gpu_fn(conf):
new_runtime = dataclasses.replace(
conf.runtime, enable_tracing_data=enable_tracing_data
)
return dataclasses.replace(conf, runtime=new_runtime)

tr._update_config(enable_gpu_fn)

container = tr.docker_run(
local_port=8090, detach=True, wait_for_server_ready=True
)
Expand Down Expand Up @@ -930,11 +940,10 @@ def test_streaming_truss_with_user_tracing():
json.loads(s) for s in user_traces_file.read_text().splitlines()
]

# for x in truss_traces:
# print(x)
# print("#" * 30)
# for x in user_traces:
# print(x)
if not enable_tracing_data:
assert len(truss_traces) == 0
assert len(user_traces) > 0
return

assert sum(1 for x in truss_traces if x["name"] == "predict-endpoint") == 3
assert sum(1 for x in user_traces if x["name"] == "load_model") == 1
Expand Down
5 changes: 5 additions & 0 deletions truss/truss_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
DEFAULT_PREDICT_CONCURRENCY = 1
DEFAULT_NUM_WORKERS = 1
DEFAULT_STREAMING_RESPONSE_READ_TIMEOUT = 60
DEFAULT_ENABLE_TRACING_DATA = False # This should be in sync with tracing.py.

DEFAULT_CPU = "1"
DEFAULT_MEMORY = "2Gi"
Expand Down Expand Up @@ -143,6 +144,7 @@ class Runtime:
predict_concurrency: int = DEFAULT_PREDICT_CONCURRENCY
num_workers: int = DEFAULT_NUM_WORKERS
streaming_read_timeout: int = DEFAULT_STREAMING_RESPONSE_READ_TIMEOUT
enable_tracing_data: bool = DEFAULT_ENABLE_TRACING_DATA

@staticmethod
def from_dict(d):
Expand All @@ -151,18 +153,21 @@ def from_dict(d):
streaming_read_timeout = d.get(
"streaming_read_timeout", DEFAULT_STREAMING_RESPONSE_READ_TIMEOUT
)
enable_tracing_data = d.get("enable_tracing_data", DEFAULT_ENABLE_TRACING_DATA)

return Runtime(
predict_concurrency=predict_concurrency,
num_workers=num_workers,
streaming_read_timeout=streaming_read_timeout,
enable_tracing_data=enable_tracing_data,
)

def to_dict(self):
return {
"predict_concurrency": self.predict_concurrency,
"num_workers": self.num_workers,
"streaming_read_timeout": self.streaming_read_timeout,
"enable_tracing_data": self.enable_tracing_data,
}


Expand Down

0 comments on commit a0677b9

Please sign in to comment.