Skip to content

Commit

Permalink
Rename tls_events fields for future compatibility with non handshake …
Browse files Browse the repository at this point in the history
…TLS records. Update JSON parsing to better match new structure

Signed-off-by: Dom Del Nano <[email protected]>
  • Loading branch information
ddelnano committed Feb 11, 2025
1 parent 37bcf97 commit f48ebb6
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ namespace stirling {
namespace protocols {
namespace tls {

using px::utils::JSONObjectBuilder;

constexpr size_t kTLSRecordHeaderLength = 5;
constexpr size_t kExtensionMinimumLength = 4;
constexpr size_t kSNIExtensionMinimumLength = 3;
Expand All @@ -39,11 +41,9 @@ constexpr size_t kSNIExtensionMinimumLength = 3;
// In TLS 1.2 and earlier, gmt_unix_time is 4 bytes and Random is 28 bytes.
constexpr size_t kRandomStructLength = 32;

StatusOr<ParseState> ExtractSNIExtension(std::map<std::string, std::string>* exts,
BinaryDecoder* decoder) {
StatusOr<ParseState> ExtractSNIExtension(ReqExtensions* exts, BinaryDecoder* decoder) {
PX_ASSIGN_OR(auto server_name_list_length, decoder->ExtractBEInt<uint16_t>(),
return ParseState::kInvalid);
std::vector<std::string> server_names;
while (server_name_list_length > 0) {
PX_ASSIGN_OR(auto server_name_type, decoder->ExtractBEInt<uint8_t>(),
return error::Internal("Failed to extract server name type"));
Expand All @@ -56,10 +56,9 @@ StatusOr<ParseState> ExtractSNIExtension(std::map<std::string, std::string>* ext
PX_ASSIGN_OR(auto server_name, decoder->ExtractString(server_name_length),
return error::Internal("Failed to extract server name"));

server_names.push_back(std::string(server_name));
exts->server_names.push_back(std::string(server_name));
server_name_list_length -= kSNIExtensionMinimumLength + server_name_length;
}
exts->insert({"server_name", ToJSONString(server_names)});
return ParseState::kSuccess;
}

Expand Down Expand Up @@ -162,6 +161,8 @@ ParseState ParseFullFrame(BinaryDecoder* decoder, Frame* frame) {
return ParseState::kSuccess;
}

ReqExtensions req_extensions;
RespExtensions resp_extensions;
while (extensions_length > 0) {
PX_ASSIGN_OR(auto extension_type, decoder->ExtractBEInt<uint16_t>(),
return ParseState::kInvalid);
Expand All @@ -170,7 +171,7 @@ ParseState ParseFullFrame(BinaryDecoder* decoder, Frame* frame) {

if (extension_length > 0) {
if (extension_type == 0x00) {
if (!ExtractSNIExtension(&frame->extensions, decoder).ok()) {
if (!ExtractSNIExtension(&req_extensions, decoder).ok()) {
return ParseState::kInvalid;
}
} else {
Expand All @@ -182,6 +183,13 @@ ParseState ParseFullFrame(BinaryDecoder* decoder, Frame* frame) {

extensions_length -= kExtensionMinimumLength + extension_length;
}
JSONObjectBuilder req_body_builder;
req_body_builder.WriteKVRecursive("extensions", req_extensions);
frame->req_body = req_body_builder.GetString();

JSONObjectBuilder resp_body_builder;
resp_body_builder.WriteKVRecursive("extensions", resp_extensions);
frame->resp_body = resp_body_builder.GetString();

return ParseState::kSuccess;
}
Expand Down
30 changes: 24 additions & 6 deletions src/stirling/source_connectors/socket_tracer/protocols/tls/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ namespace stirling {
namespace protocols {
namespace tls {

using ::px::utils::ToJSONString;

enum class ContentType : uint8_t {
kChangeCipherSpec = 0x14,
kAlert = 0x15,
Expand Down Expand Up @@ -186,6 +184,25 @@ enum class ExtensionType : uint16_t {
kRenegotiationInfo = 65281,
};

// Extensions that are common to both the client and server side
// of a TLS handshake
struct SharedExtensions {
void ToJSON(::px::utils::JSONObjectBuilder* /*builder*/) const {}
};

struct ReqExtensions : public SharedExtensions {
std::vector<std::string> server_names;

void ToJSON(::px::utils::JSONObjectBuilder* builder) const {
SharedExtensions::ToJSON(builder);
builder->WriteKV("server_name", server_names);
}
};

struct RespExtensions : public SharedExtensions {
void ToJSON(::px::utils::JSONObjectBuilder* builder) const { SharedExtensions::ToJSON(builder); }
};

struct Frame : public FrameBase {
ContentType content_type;

Expand All @@ -200,7 +217,8 @@ struct Frame : public FrameBase {
LegacyVersion handshake_version;

std::string session_id;
std::map<std::string, std::string> extensions;
std::string req_body;
std::string resp_body;

bool consumed = false;

Expand All @@ -209,9 +227,9 @@ struct Frame : public FrameBase {
std::string ToString() const override {
return absl::Substitute(
"TLS Frame [len=$0 content_type=$1 legacy_version=$2 handshake_version=$3 "
"handshake_type=$4 extensions=$5]",
length, content_type, legacy_version, handshake_version, handshake_type,
ToJSONString(extensions));
"handshake_type=$4 req_body=$5 resp_body=$6]",
length, content_type, legacy_version, handshake_version, handshake_type, req_body,
resp_body);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,12 @@ using px::utils::ToJSONString;
// Most HTTP servers support 8K headers, so we truncate after that.
// https://stackoverflow.com/questions/686217/maximum-on-http-header-values
constexpr size_t kMaxHTTPHeadersBytes = 8192;
// TLS records have a maximum size of 16KiB. While there isn't a size limit
// for the extensions, we limit it to 1 KiB to avoid excessive memory usage.
// A typical ClientHello from curl is around 500 bytes. This assumes that
// TLS records have a maximum size of 16KiB. The bulk of the body columns are extensions
// and while there isn't a size limit for them, we limit it to 1 KiB to avoid excessive
// memory usage. A typical ClientHello from curl is around 500 bytes. This assumes that
// all extensions are captured, but we won't support capturing all extensions and
// will avoid large extensions like the padding extension,
constexpr size_t kMaxTLSExtensionsBytes = 1024;
constexpr size_t kMaxTLSBodyBytes = 1024;

// Protobuf printer will limit strings to this length.
constexpr size_t kMaxPBStringLen = 64;
Expand Down Expand Up @@ -1721,9 +1721,10 @@ void SocketTraceConnector::AppendMessage(ConnectorContext* ctx, const ConnTracke
r.Append<r.ColIndex("local_addr")>(conn_tracker.local_endpoint().AddrStr());
r.Append<r.ColIndex("local_port")>(conn_tracker.local_endpoint().port());
r.Append<r.ColIndex("trace_role")>(conn_tracker.role());
r.Append<r.ColIndex("req_type")>(static_cast<uint64_t>(req_message.content_type));
r.Append<r.ColIndex("content_type")>(static_cast<uint64_t>(req_message.content_type));
r.Append<r.ColIndex("version")>(static_cast<uint64_t>(req_message.legacy_version));
r.Append<r.ColIndex("extensions")>(ToJSONString(req_message.extensions), kMaxTLSExtensionsBytes);
r.Append<r.ColIndex("req_body")>(req_message.req_body, kMaxTLSBodyBytes);
r.Append<r.ColIndex("resp_body")>(resp_message.resp_body, kMaxTLSBodyBytes);
r.Append<r.ColIndex("latency")>(
CalculateLatency(req_message.timestamp_ns, resp_message.timestamp_ns));
#ifndef NDEBUG
Expand Down
14 changes: 9 additions & 5 deletions src/stirling/source_connectors/socket_tracer/tls_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,22 @@ static constexpr DataElement kTLSElements[] = {
canonical_data_elements::kLocalAddr,
canonical_data_elements::kLocalPort,
canonical_data_elements::kTraceRole,
{"req_type", "The type of request from the TLS record (Client/ServerHello, etc.)",
{"content_type", "The content type of the TLS record (e.g. handshake, alert, heartbeat, etc)",
types::DataType::INT64,
types::SemanticType::ST_NONE,
types::PatternType::GENERAL_ENUM},
{"version", "Version of TLS record",
types::DataType::INT64,
types::SemanticType::ST_NONE,
types::PatternType::GENERAL_ENUM},
{"extensions", "Extensions in the TLS record",
{"req_body", "Request body in JSON format. Structure depends on content type (e.g. handshakes contain TLS extensions)",
types::DataType::STRING,
types::SemanticType::ST_NONE,
types::PatternType::GENERAL},
types::PatternType::STRUCTURED},
{"resp_body", "Response body in JSON format. Structure depends on content type (e.g. handshakes contain TLS extensions)",
types::DataType::STRING,
types::SemanticType::ST_NONE,
types::PatternType::STRUCTURED},
canonical_data_elements::kLatencyNS,
#ifndef NDEBUG
canonical_data_elements::kPXInfo,
Expand All @@ -61,9 +65,9 @@ static constexpr auto kTLSTable =
DEFINE_PRINT_TABLE(TLS)

constexpr int kTLSUPIDIdx = kTLSTable.ColIndex("upid");
constexpr int kTLSCmdIdx = kTLSTable.ColIndex("req_type");
constexpr int kTLSCmdIdx = kTLSTable.ColIndex("content_type");
constexpr int kTLSVersionIdx = kTLSTable.ColIndex("version");
constexpr int kTLSExtensionsIdx = kTLSTable.ColIndex("extensions");
constexpr int kTLSReqBodyIdx = kTLSTable.ColIndex("req_body");

} // namespace stirling
} // namespace px
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ using ::testing::UnorderedElementsAre;

struct TraceRecords {
std::vector<tls::Record> tls_records;
std::vector<std::string> tls_extensions;
std::vector<std::string> req_body;
};

class NginxOpenSSL_3_0_8_ContainerWrapper
Expand Down Expand Up @@ -80,11 +80,11 @@ tls::Record GetExpectedTLSRecord() {
return expected_record;
}

inline std::vector<std::string> GetExtensions(const types::ColumnWrapperRecordBatch& rb,
const std::vector<size_t>& indices) {
inline std::vector<std::string> GetRequestBody(const types::ColumnWrapperRecordBatch& rb,
const std::vector<size_t>& indices) {
std::vector<std::string> exts;
for (size_t idx : indices) {
exts.push_back(rb[kTLSExtensionsIdx]->Get<types::StringValue>(idx));
exts.push_back(rb[kTLSReqBodyIdx]->Get<types::StringValue>(idx));
}
return exts;
}
Expand Down Expand Up @@ -127,9 +127,9 @@ class TLSVersionParameterizedTest

TraceRecords records = this->GetTraceRecords(this->server_.PID());
EXPECT_THAT(records.tls_records, SizeIs(1));
EXPECT_THAT(records.tls_extensions, SizeIs(1));
auto sni_str = R"({"server_name":"[\"test-host\"]"})";
EXPECT_THAT(records.tls_extensions[0], StrEq(sni_str));
EXPECT_THAT(records.req_body, SizeIs(1));
auto sni_str = R"({"extensions":{"server_name":["test-host"]}})";
EXPECT_THAT(records.req_body[0], StrEq(sni_str));
}

// Returns the trace records of the process specified by the input pid.
Expand All @@ -144,7 +144,7 @@ class TLSVersionParameterizedTest
FindRecordIdxMatchesPID(record_batch, kTLSUPIDIdx, pid);
std::vector<tls::Record> tls_records =
ToRecordVector<tls::Record>(record_batch, server_record_indices);
std::vector<std::string> extensions = GetExtensions(record_batch, server_record_indices);
std::vector<std::string> extensions = GetRequestBody(record_batch, server_record_indices);

return {std::move(tls_records), std::move(extensions)};
}
Expand Down

0 comments on commit f48ebb6

Please sign in to comment.