Skip to content

Commit

Permalink
Fix curl callback data latency for clients with response stream (#3245)
Browse files Browse the repository at this point in the history
  • Loading branch information
sbera87 authored Jan 17, 2025
1 parent 8fd786a commit 573ecfe
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ namespace Model
// so we can not get operation's name from response.
inline virtual const char* GetServiceRequestName() const override { return "SubscribeToShard"; }

inline virtual bool HasEventStreamResponse() const override { return true; }
AWS_KINESIS_API Aws::String SerializePayload() const override;

AWS_KINESIS_API Aws::Http::HeaderValueCollection GetRequestSpecificHeaders() const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ namespace Aws
* Defaults to false, if this is set to true in derived class, it's an event stream request, which means the payload is consisted by multiple structured events.
*/
inline virtual bool IsEventStreamRequest() const { return false; }

/**
* Defaults to false, if this is set to true in derived class, the operation using this request will return an event stream response.
*/
inline virtual bool HasEventStreamResponse() const { return false; }

/**
* Defaults to true, if this is set to false, then signers, if they support body signing, will not do so
*/
Expand Down
4 changes: 4 additions & 0 deletions src/aws-cpp-sdk-core/include/aws/core/http/HttpRequest.h
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,9 @@ namespace Aws

bool IsEventStreamRequest() { return m_isEvenStreamRequest; }
void SetEventStreamRequest(bool eventStreamRequest) { m_isEvenStreamRequest = eventStreamRequest; }

bool HasEventStreamResponse() { return m_hasEvenStreamResponse; }
void SetHasEventStreamResponse(bool hasEventStreamResponse) { m_hasEvenStreamResponse = hasEventStreamResponse; }

virtual std::shared_ptr<Aws::Crt::Http::HttpRequest> ToCrtHttpRequest();

Expand All @@ -606,6 +609,7 @@ namespace Aws
URI m_uri;
HttpMethod m_method;
bool m_isEvenStreamRequest = false;
bool m_hasEvenStreamResponse{false};
HeadersReceivedEventHandler m_onHeadersReceived;
DataReceivedEventHandler m_onDataReceived;
DataSentEventHandler m_onDataSent;
Expand Down
1 change: 1 addition & 0 deletions src/aws-cpp-sdk-core/source/client/AWSClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ HttpResponseOutcome AWSClient::AttemptExhaustively(const Aws::Http::URI& uri,

};
httpRequest->SetEventStreamRequest(request.IsEventStreamRequest());
httpRequest->SetHasEventStreamResponse(request.HasEventStreamResponse());

outcome = AttemptOneRequest(httpRequest, request, signerName, signerRegion, signerServiceNameOverride);
outcome.SetRetryCount(retries);
Expand Down
3 changes: 2 additions & 1 deletion src/aws-cpp-sdk-core/source/http/curl/CurlHttpClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ static size_t WriteData(char* ptr, size_t size, size_t nmemb, void* userdata)
<< " at " << cur << " (eof: " << ref.eof() << ", bad: " << ref.bad() << ")");
return 0;
}
if (context->m_request->IsEventStreamRequest() && !response->HasHeader(Aws::Http::X_AMZN_ERROR_TYPE))
if ((context->m_request->IsEventStreamRequest() || context->m_request->HasEventStreamResponse() )
&& !response->HasHeader(Aws::Http::X_AMZN_ERROR_TYPE))
{
response->GetResponseBody().flush();
if (response->GetResponseBody().fail()) {
Expand Down
95 changes: 94 additions & 1 deletion tests/aws-cpp-sdk-kinesis-integration-tests/KinesisTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <aws/kinesis/model/DescribeStreamConsumerRequest.h>
#include <aws/kinesis/model/ListShardsRequest.h>
#include <aws/testing/TestingEnvironment.h>
#include <aws/kinesis/model/PutRecordRequest.h>

#include <thread>
#include <chrono>
Expand Down Expand Up @@ -49,7 +50,7 @@ class KinesisTest : public ::testing::Test
m_client.reset(Aws::New<KinesisClient>(ALLOC_TAG, config));

// Create stream
auto createStream = m_client->CreateStream(CreateStreamRequest().WithStreamName(streamName));
auto createStream = m_client->CreateStream(CreateStreamRequest().WithStreamName(streamName).WithShardCount(1));
AWS_ASSERT_SUCCESS(createStream);

// Wait 2 minutes for stream to be ready
Expand Down Expand Up @@ -171,4 +172,96 @@ TEST_F(KinesisTest, EnhancedFanOut)
AWS_ASSERT_SUCCESS(m_client->DeregisterStreamConsumer(deregisterRequest));
}


bool WriteDataToStream(Aws::Kinesis::KinesisClient &kinesis_client, const Aws::String &streamName, const Aws::String &data, const Aws::String &partitionKey)
{
Aws::Kinesis::Model::PutRecordRequest putRecordRequest;
putRecordRequest.SetStreamName(streamName);

putRecordRequest.SetPartitionKey(partitionKey);

Aws::Utils::ByteBuffer dataBuffer((unsigned char*)data.c_str(), data.size());
putRecordRequest.SetData(dataBuffer);

// Send the record to the stream
auto putRecordOutcome = kinesis_client.PutRecord(putRecordRequest);

return putRecordOutcome.IsSuccess();
}

TEST_F(KinesisTest, testSubscribe)
{
// Get the Stream ARN (different between accounts)
DescribeStreamRequest describeStreamRequest;
describeStreamRequest.SetStreamName(streamName);
auto describeStreamOutcome = m_client->DescribeStream(describeStreamRequest);
AWS_ASSERT_SUCCESS(describeStreamOutcome);
const auto streamARN = describeStreamOutcome.GetResult().GetStreamDescription().GetStreamARN();

// Register a consumer for enhanced fan-out
RegisterStreamConsumerRequest registerRequest;
const auto consumerName = BuildResourceName("sdktest");
registerRequest.WithConsumerName(consumerName).WithStreamARN(streamARN);
auto registerConsumerOutcome = m_client->RegisterStreamConsumer(registerRequest);
AWS_ASSERT_SUCCESS(registerConsumerOutcome);
const auto consumerARN = registerConsumerOutcome.GetResult().GetConsumer().GetConsumerARN();
WaitUntilConsumerIsActive(consumerARN);
// Get the shard id
ListShardsRequest listShardRequest;
listShardRequest.SetStreamName(streamName);
auto listShardsOutcome = m_client->ListShards(listShardRequest);
AWS_ASSERT_SUCCESS(listShardsOutcome);
const auto& shards = listShardsOutcome.GetResult().GetShards();
ASSERT_FALSE(shards.empty());
const auto shardId = shards[0].GetShardId();
Aws::String partitionKey = "shard0Key"; // Use a consistent partition key for Shard 0

const Aws::Vector<Aws::String> inputs = {
"Hello, this is the first test record for Shard 0!",
"Here's another test record for Shard 0!",
"Final record for Shard 0."
};

ASSERT_TRUE(WriteDataToStream(*m_client, streamName, inputs[0], partitionKey));
ASSERT_TRUE(WriteDataToStream(*m_client, streamName, inputs[1], partitionKey));
ASSERT_TRUE(WriteDataToStream(*m_client, streamName, inputs[2], partitionKey));

Aws::Kinesis::Model::StartingPosition start_position;
start_position.SetType(Aws::Kinesis::Model::ShardIteratorType::TRIM_HORIZON);

Aws::Kinesis::Model::SubscribeToShardRequest subscribe_request;
subscribe_request.SetConsumerARN(consumerARN);
subscribe_request.SetShardId(shardId);
subscribe_request.SetStartingPosition(start_position);

Aws::Kinesis::Model::SubscribeToShardHandler handler;
auto t_start = std::chrono::high_resolution_clock::now();
handler.SetSubscribeToShardEventCallback([&](const Aws::Kinesis::Model::SubscribeToShardEvent &event)
{
auto t_end = std::chrono::high_resolution_clock::now();
if(event.GetRecords().size())
{

double elapsed_time = std::chrono::duration<double, std::milli>(t_end-t_start).count();
SCOPED_TRACE("SetSubscribeToShardEventCallback called at time: " + std::to_string(elapsed_time) + " ms" );
EXPECT_EQ(event.GetRecords().size(), 3u);
}
t_start = t_end;
for (auto idx = 0u; idx < event.GetRecords().size(); ++idx)
{
const auto& record = event.GetRecords()[idx];
Aws::String record_str((char *) record.GetData().GetUnderlyingData(), record.GetData().GetLength());
SCOPED_TRACE("Record: " + record_str );
EXPECT_EQ(record_str, inputs[idx]);
}
});

subscribe_request.SetEventStreamHandler(handler);

SCOPED_TRACE("calling SubscribeToShard" );
auto subscribeOutcome = m_client->SubscribeToShard(subscribe_request);

EXPECT_TRUE(subscribeOutcome.IsSuccess());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,14 @@ protected SdkFileEntry generateModelHeaderFile(ServiceModel serviceModel, Map.En
for (Map.Entry<String, Operation> opEntry : serviceModel.getOperations().entrySet()) {
String key = opEntry.getKey();
Operation op = opEntry.getValue();
if (op.getRequest() != null && op.getRequest().getShape().getName() == shape.getName()) {
if (op.getRequest() != null && op.getRequest().getShape().getName() == shape.getName())
{
context.put("operation", op);
context.put("operationName", key);
if((op.getResult() != null) && op.getResult().getShape().hasEventStreamMembers())
{
context.put("hasEventStreamResponse", true);
}
break;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ namespace Model
#if($shape.hasEventStreamMembers())
inline virtual bool IsEventStreamRequest() const override { return true; }
#end
#if($hasEventStreamResponse)
inline virtual bool HasEventStreamResponse() const override { return true; }
#end
#if(!$shape.hasStreamMembers() && !$shape.hasEventStreamMembers())
${exportMacro} Aws::String SerializePayload() const override;

Expand Down

0 comments on commit 573ecfe

Please sign in to comment.