Skip to content

Commit

Permalink
Merge pull request #19 from ppittle/http-client-improvements
Browse files Browse the repository at this point in the history
Http client improvements
  • Loading branch information
ppittle authored Mar 4, 2018
2 parents 43a9c41 + 27fbe0e commit c696237
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 42 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Net;
using System;
using System.Net;
using System.Net.Http;
using System.Reflection;
using System.Threading.Tasks;
Expand Down Expand Up @@ -27,12 +28,20 @@ internal static class HttpClientHandlerExtensions
typeof(HttpClientHandler)
.GetMethod("SetContentHeaders", BindingFlags.NonPublic | BindingFlags.Static);

private static readonly FieldInfo _getRequestStreamCallback =
typeof(HttpClientHandler)
.GetField("getRequestStreamCallback", BindingFlags.NonPublic | BindingFlags.Instance);

private static readonly MethodInfo _startGettingResponse =
typeof(HttpClientHandler)
.GetMethod("StartGettingResponse", BindingFlags.NonPublic | BindingFlags.Instance);

/// <summary>
/// This is basically just <see cref="T:System.Net.Http.HttpClientHandler.CreateAndPrepareWebRequest"/>,
/// except it uses <paramref name="webRequest"/> rather than creating a <see cref="HttpWebRequest"/>
/// directly.
/// </summary>
public static async Task PrepareWebRequest(
public static void PrepareWebRequest(
this HttpClientHandler httpClientHandler,
HttpWebRequest webRequest,
HttpRequestMessage requestMessage)
Expand All @@ -50,16 +59,20 @@ public static async Task PrepareWebRequest(
_setRequestHeaders.Invoke(null, new object[] { webRequest, requestMessage });
// HttpClientHandler.SetContentHeaders(HttpWebRequest webRequest, HttpRequestMessage request);
_setContentHeaders.Invoke(null, new object[] { webRequest, requestMessage });
}

// copy request stream
if (null == requestMessage.Content)
{
webRequest.ContentLength = 0;
}
else
{
await requestMessage.Content.CopyToAsync(webRequest.GetRequestStream());
}
public static void SetGetRequestStreamCallback(
this HttpClientHandler httpClientHandler,
AsyncCallback getRequestStreamCallback)
{
_getRequestStreamCallback.SetValue(httpClientHandler, getRequestStreamCallback);
}

public static void StartGettingResponse(
this HttpClientHandler httpClientHandler,
object requestState)
{
_startGettingResponse.Invoke(httpClientHandler, new[] {requestState});
}
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.IO;
using System.Net;
using System.Net.Http;
using System.Reflection;
Expand Down Expand Up @@ -66,14 +67,51 @@ public void Visit(Task task)
var handler = (HttpClientHandler)taskAction.Target;

// copy the request message to the http web request we just built
handler.PrepareWebRequest(httpWebRequest, requestMessage).Wait();
handler.PrepareWebRequest(httpWebRequest, requestMessage);

// save the http web request back to the request state -
// now when the intercepted handler continues executing
// StartRequest, it will be using our custom HttpWebRequest!!
requestStateWrapper.SetHttpWebRequest(httpWebRequest);

// we'll also need to intercept how the HttpClientHandler
// acquires the request stream to make sure we can intercept
// the request body, so replace the GetRequestStreamCallback
// continuation
handler.SetGetRequestStreamCallback(asyncResult => CustomGetRequestStreamCallback(asyncResult, handler));
}

/// <summary>
/// <see cref="HttpClientHandler"/>'s GetRequestStreamCallback uses an
/// overload of EndGetRequestStream that <see cref="HttpWebRequestWrapperRecorder"/>
/// can't intercept (method isn't virtual). So we need to intercept the call
/// and force using <see cref="HttpWebRequestWrapperRecorder.EndGetRequestStream(System.IAsyncResult)"/>
/// (which is intercepted)
/// </summary>
private void CustomGetRequestStreamCallback(IAsyncResult ar, HttpClientHandler httpClientHandler)
{
// build a wrapper around the Request State stored in AsyncState
var requestStateWrapper = new HttpClientHandlerRequestStateWrapper(ar.AsyncState);

// get the HttpRequestMessage we've intercepted
var requestMessage = requestStateWrapper.GetHttpRequestMessage();

// load the HttpWebRequest we've already intercpeted and replaced
var httpWebRequest = requestStateWrapper.GetHttpWebRequest();

// get a copy of the request streams
var requestStream = httpWebRequest.GetRequestStream();

// copy the request message content to the request stream
requestMessage.Content.CopyToAsync(requestStream).Wait();

// save the request stream to the requet state
requestStateWrapper.SetRequestStream(requestStream);

// continue on with StartGettingResponse
httpClientHandler.StartGettingResponse(ar.AsyncState);
}

/// <summary>
/// Reflection helper for working with the nested private class
/// <see cref="T:System.Net.Http.HttpClientHandler.RequestState"/>
Expand All @@ -82,6 +120,7 @@ private class HttpClientHandlerRequestStateWrapper
{
private static readonly FieldInfo _httpRequestMessageField;
private static readonly FieldInfo _httpWebRequestField;
private static readonly FieldInfo _requestStreamField;

static HttpClientHandlerRequestStateWrapper()
{
Expand All @@ -102,6 +141,12 @@ static HttpClientHandlerRequestStateWrapper()
.GetField(
"webRequest",
BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic);

_requestStreamField =
httpClientHandlerRequestStateType
.GetField(
"requestStream",
BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic);
}

private readonly object _requestState;
Expand All @@ -117,10 +162,20 @@ public HttpRequestMessage GetHttpRequestMessage()
_httpRequestMessageField.GetValue(_requestState);
}

public HttpWebRequest GetHttpWebRequest()
{
return (HttpWebRequest) _httpWebRequestField.GetValue(_requestState);
}

public void SetHttpWebRequest(HttpWebRequest webRequest)
{
_httpWebRequestField.SetValue(_requestState, webRequest);
}

public void SetRequestStream(Stream stream)
{
_requestStreamField.SetValue(_requestState, stream);
}
}
}
}
109 changes: 106 additions & 3 deletions src/HttpWebRequestWrapper.Tests/HttpClientTests.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Net;
using System.Net.Http;
using System.Text;
using System.Threading.Tasks;
using HttpWebRequestWrapper.HttpClient;
using Should;
Expand All @@ -20,13 +21,14 @@ static HttpClientTests()
}

// WARNING!! Makes live request
[Fact]
[Fact(Timeout = 10000)]
public async Task CanRecord()
{
// ARRANGE
var url = "http://www.github.com/";
var url = "https://www.github.com/";

var recordingSession = new RecordingSession();

HttpResponseMessage response;

// ACT
Expand All @@ -48,7 +50,7 @@ public async Task CanRecord()
}

// WARNING!! Makes live request
[Fact]
[Fact(Timeout = 10000)]
public async Task CanRecordWebRequestException()
{
// ARRANGE
Expand Down Expand Up @@ -101,6 +103,107 @@ public async Task CanInterceptAndSpoofResponse()
}

[Fact]
public async Task CanInterceptPost()
{
// ARRANGE
var requestBody = @"<xml>test</xml>";
var requestUrl = new Uri("http://stackoverflow.com");
var responseBody = "Test Response";

var responseCreator = new Func<InterceptedRequest, HttpWebResponse>(req =>
{
if (req.HttpWebRequest.RequestUri == requestUrl &&
req.HttpWebRequest.Method == "POST" &&
req.RequestPayload == requestBody)
{
return req.HttpWebResponseCreator.Create(responseBody);
}

throw new Exception("Coulnd't match request");
});

HttpResponseMessage response;

// ACT
using (new HttpClientAndRequestWrapperSession(
new HttpWebRequestWrapperInterceptorCreator(responseCreator)))
{
var httpClient = new System.Net.Http.HttpClient();

response = await httpClient.PostAsync(requestUrl, new StringContent(requestBody));
}

// ASSERT
response.ShouldNotBeNull();

(await response.Content.ReadAsStringAsync()).ShouldEqual(responseBody);
}

[Fact]
public void CanInterceptWhenHttpClientUsesWebRequestHandler()
{

}

[Fact]
public void CanInterceptWhenHttpClientSetsBaseAddress()
{

}
[Fact]
public async Task CanInterceptCustomRequestMessage()
{
// ARRANGE
var requestBody = @"<xml>test</xml>";
var requestUrl = new Uri("http://stackoverflow.com");
var responseBody = "Test Response";

var responseCreator = new Func<InterceptedRequest, HttpWebResponse>(req =>
{
if (req.HttpWebRequest.RequestUri == requestUrl &&
req.HttpWebRequest.Method == "POST" &&
req.RequestPayload == requestBody)
{
return req.HttpWebResponseCreator.Create(responseBody);
}

throw new Exception("Coulnd't match request");
});

var requestMessage = new HttpRequestMessage(HttpMethod.Post, requestUrl)
{
Content = new StringContent(requestBody, Encoding.UTF8, "text/xml")
};

HttpResponseMessage response;

// ACT
using (new HttpClientAndRequestWrapperSession(
new HttpWebRequestWrapperInterceptorCreator(responseCreator)))
{
var httpClient = new System.Net.Http.HttpClient();

response = await httpClient.SendAsync(requestMessage);
}

// ASSERT
response.ShouldNotBeNull();

(await response.Content.ReadAsStringAsync()).ShouldEqual(responseBody);
}


// TODO - can intercept WebRequestHandler (inherits from HttpClientHandler)
// TODO - cna intercept when HttpClient has BaseAddress set
// TODO - test when using custom request message
// TODO - test when using Send with HttpCompletionOption
// TODO - can record post
// TODO - can record binary response stream
// TODO - can record post request payload
// TODO - can record binary request payload
// TODO - can match on binary request payload

[Fact(Timeout = 3000)]
public async Task CanSupportMultipleConcurrentHttpClients()
{
// ARRANGE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
<Reference Include="System.Core" />
<Reference Include="System.Net" />
<Reference Include="System.Net.Http" />
<Reference Include="System.Net.Http.WebRequest" />
<Reference Include="System.Threading.Tasks.Extensions, Version=4.1.0.0, Culture=neutral, PublicKeyToken=cc7b13ffcd2ddd51, processorArchitecture=MSIL">
<HintPath>..\packages\System.Threading.Tasks.Extensions.4.3.0\lib\portable-net45+win8+wp8+wpa81\System.Threading.Tasks.Extensions.dll</HintPath>
</Reference>
Expand Down
23 changes: 6 additions & 17 deletions src/HttpWebRequestWrapper.Tests/InterceptorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Net;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using HttpWebRequestWrapper.Tests.Properties;
using Should;
using Xunit;
Expand Down Expand Up @@ -562,7 +563,7 @@ public void CanSpoofPutRequest()
}

[Fact]
public void CanSpoofAsyncRequest()
public async Task CanSpoofAsyncRequest()
{
var requestPayload = "Test Request";
var responseBody = "Test Response";
Expand All @@ -577,28 +578,16 @@ public void CanSpoofAsyncRequest()

IWebRequestCreate creator = new HttpWebRequestWrapperInterceptorCreator(responseCreator);

var request = creator.Create(new Uri("http://fakeSite.fake"));
var request = (HttpWebRequest)creator.Create(new Uri("http://fakeSite.fake"));
request.Method = "POST";

// ACT
var asyncResult = request.BeginGetRequestStream(
req =>
{
var requestStream = (req.AsyncState as HttpWebRequest).EndGetRequestStream(req);

using (var sw = new StreamWriter(requestStream))
sw.Write(requestPayload);
},
request);

if (!asyncResult.IsCompleted)
Thread.Sleep(TimeSpan.FromMilliseconds(250));
using (var sw = new StreamWriter(await request.GetRequestStreamAsync()))
await sw.WriteAsync(requestPayload);

if (!asyncResult.IsCompleted)
throw new Exception("Web Response didn't come back in reasonable time frame");
var response = await request.GetResponseAsync();

var response = request.GetResponse();

// ASSERT
response.ShouldNotBeNull();

Expand Down
10 changes: 0 additions & 10 deletions src/HttpWebRequestWrapper/HttpWebRequestWrapperInterceptor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,6 @@ public override Stream GetRequestStream()
return _requestStream;
}

/// <inheritdoc />
public override IAsyncResult BeginGetRequestStream(AsyncCallback callback, object state)
{
var asyncResult = new DummyAsyncResult(new ManualResetEvent(true), state);

callback(asyncResult);

return asyncResult;
}

/// <inheritdoc />
public override Stream EndGetRequestStream(IAsyncResult asyncResult)
{
Expand Down

0 comments on commit c696237

Please sign in to comment.