diff --git a/Octokit.Tests/Http/RedirectHandlerTests.cs b/Octokit.Tests/Http/RedirectHandlerTests.cs new file mode 100644 index 00000000..beb55b87 --- /dev/null +++ b/Octokit.Tests/Http/RedirectHandlerTests.cs @@ -0,0 +1,225 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Net.Http.Headers; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Octokit.Internal; +using Xunit; + +namespace Octokit.Tests.Http +{ + public class RedirectHandlerTests + { + + [Fact] + public async Task OkStatusShouldPassThrough() + { + var invoker = CreateInvoker(new HttpResponseMessage(HttpStatusCode.OK)); + var httpRequestMessage = CreateRequest(HttpMethod.Get); + + var response = await invoker.SendAsync(httpRequestMessage, new CancellationToken()); + + Assert.Equal(response.StatusCode, HttpStatusCode.OK); + Assert.Same(response.RequestMessage, httpRequestMessage); + } + + + [Theory] + [InlineData(HttpStatusCode.MovedPermanently)] // 301 + [InlineData(HttpStatusCode.Found)] // 302 + [InlineData(HttpStatusCode.TemporaryRedirect)] // 307 + public async Task ShouldRedirectSameMethod(HttpStatusCode statusCode) + { + var redirectResponse = new HttpResponseMessage(statusCode); + redirectResponse.Headers.Location = new Uri("http://example.org/bar"); + + var invoker = CreateInvoker(redirectResponse, + new HttpResponseMessage(HttpStatusCode.OK)); + var httpRequestMessage = CreateRequest(HttpMethod.Post); + + var response = await invoker.SendAsync(httpRequestMessage, new CancellationToken()); + + Assert.Equal(response.RequestMessage.Method, httpRequestMessage.Method); + Assert.NotSame(response.RequestMessage, httpRequestMessage); + } + + [Fact] + public async Task Status303ShouldRedirectChangeMethod() + { + var redirectResponse = new HttpResponseMessage(HttpStatusCode.SeeOther); + redirectResponse.Headers.Location = new Uri("http://example.org/bar"); + + var invoker = CreateInvoker(redirectResponse, + new HttpResponseMessage(HttpStatusCode.OK)); + var httpRequestMessage = CreateRequest(HttpMethod.Post); + + var response = await invoker.SendAsync(httpRequestMessage, new CancellationToken()); + + Assert.Equal(HttpMethod.Get, response.RequestMessage.Method); + Assert.NotSame(response.RequestMessage, httpRequestMessage); + } + + [Fact] + public async Task RedirectWithSameHostShouldKeepAuthHeader() + { + var redirectResponse = new HttpResponseMessage(HttpStatusCode.Redirect); + redirectResponse.Headers.Location = new Uri("http://example.org/bar"); + + var invoker = CreateInvoker(redirectResponse, + new HttpResponseMessage(HttpStatusCode.OK)); + var httpRequestMessage = CreateRequest(HttpMethod.Get); + httpRequestMessage.Headers.Authorization = new AuthenticationHeaderValue("fooAuth", "aparam"); + + var response = await invoker.SendAsync(httpRequestMessage, new CancellationToken()); + + Assert.NotSame(response.RequestMessage, httpRequestMessage); + Assert.Equal("fooAuth", response.RequestMessage.Headers.Authorization.Scheme); + } + + + [Theory] + [InlineData(HttpStatusCode.MovedPermanently)] // 301 + [InlineData(HttpStatusCode.Found)] // 302 + [InlineData(HttpStatusCode.SeeOther)] // 303 + [InlineData(HttpStatusCode.TemporaryRedirect)] // 307 + public async Task RedirectWithDifferentHostShouldLoseAuthHeader(HttpStatusCode statusCode) + { + var redirectResponse = new HttpResponseMessage(statusCode); + redirectResponse.Headers.Location = new Uri("http://example.net/bar"); + + var invoker = CreateInvoker(redirectResponse, + new HttpResponseMessage(HttpStatusCode.OK)); + var httpRequestMessage = CreateRequest(HttpMethod.Get); + httpRequestMessage.Headers.Authorization = new AuthenticationHeaderValue("fooAuth", "aparam"); + + var response = await invoker.SendAsync(httpRequestMessage, new CancellationToken()); + + Assert.NotSame(response.RequestMessage, httpRequestMessage); + Assert.Null(response.RequestMessage.Headers.Authorization); + + } + + [Fact] + public async Task DisabledRedirectShouldPassThrough() + { + var invoker = CreateInvoker(new HttpResponseMessage(HttpStatusCode.Found)); + var httpRequestMessage = CreateRequest(HttpMethod.Get); + httpRequestMessage.Properties[RedirectHandler.AllowAutoRedirectKey] = false; + var response = await invoker.SendAsync(httpRequestMessage, new CancellationToken()); + + Assert.Equal(response.StatusCode, HttpStatusCode.Redirect); + Assert.Same(response.RequestMessage, httpRequestMessage); + } + + [Theory] + [InlineData(HttpStatusCode.MovedPermanently)] // 301 + [InlineData(HttpStatusCode.Found)] // 302 + [InlineData(HttpStatusCode.TemporaryRedirect)] // 307 + public async Task Status301ShouldRedirectPOSTWithBody(HttpStatusCode statusCode) + { + var redirectResponse = new HttpResponseMessage(statusCode); + redirectResponse.Headers.Location = new Uri("http://example.org/bar"); + + var invoker = CreateInvoker(redirectResponse, + new HttpResponseMessage(HttpStatusCode.OK)); + var httpRequestMessage = CreateRequest(HttpMethod.Post); + httpRequestMessage.Content = new StringContent("Hello World"); + + var response = await invoker.SendAsync(httpRequestMessage, new CancellationToken()); + + Assert.Equal(response.RequestMessage.Method, httpRequestMessage.Method); + Assert.NotSame(response.RequestMessage, httpRequestMessage); + Assert.Equal("Hello World", await response.RequestMessage.Content.ReadAsStringAsync()); + } + + // POST see other with content + [Fact] + public async Task Status303ShouldRedirectToGETWithoutBody() + { + var redirectResponse = new HttpResponseMessage(HttpStatusCode.SeeOther); + redirectResponse.Headers.Location = new Uri("http://example.org/bar"); + + var invoker = CreateInvoker(redirectResponse, + new HttpResponseMessage(HttpStatusCode.OK)); + var httpRequestMessage = CreateRequest(HttpMethod.Post); + httpRequestMessage.Content = new StringContent("Hello World"); + + var response = await invoker.SendAsync(httpRequestMessage, new CancellationToken()); + + Assert.Equal(HttpMethod.Get, response.RequestMessage.Method); + Assert.NotSame(response.RequestMessage, httpRequestMessage); + Assert.Null(response.RequestMessage.Content); + } + + [Fact] + public async Task Exceed3RedirectsShouldReturn() + { + var redirectResponse = new HttpResponseMessage(HttpStatusCode.Found); + redirectResponse.Headers.Location = new Uri("http://example.org/bar"); + + var redirectResponse2 = new HttpResponseMessage(HttpStatusCode.Found); + redirectResponse2.Headers.Location = new Uri("http://example.org/foo"); + + + var invoker = CreateInvoker(redirectResponse, redirectResponse2); + + var httpRequestMessage = CreateRequest(HttpMethod.Get); + + var response = await invoker.SendAsync(httpRequestMessage, new CancellationToken()); + + Assert.NotSame(response.RequestMessage, httpRequestMessage); + Assert.Equal(4, (int)response.RequestMessage.Properties[RedirectHandler.RedirectCountKey]); + } + + static HttpRequestMessage CreateRequest(HttpMethod method) + { + var httpRequestMessage = new HttpRequestMessage(); + httpRequestMessage.RequestUri = new Uri("http://example.org/foo"); + httpRequestMessage.Properties[RedirectHandler.AllowAutoRedirectKey] = true; + httpRequestMessage.Method = method; + return httpRequestMessage; + } + + static HttpMessageInvoker CreateInvoker(HttpResponseMessage httpResponseMessage1, HttpResponseMessage httpResponseMessage2 = null) + { + + var redirectHandler = new RedirectHandler() + { + InnerHandler = new MockRedirectHandler(httpResponseMessage1, httpResponseMessage2) + }; + var invoker = new HttpMessageInvoker(redirectHandler); + return invoker; + } + } + + public class MockRedirectHandler : HttpMessageHandler + { + readonly HttpResponseMessage _response1; + readonly HttpResponseMessage _response2; + private bool _Response1Sent = false; + public MockRedirectHandler(HttpResponseMessage response1, HttpResponseMessage response2 = null) + { + _response1 = response1; + _response2 = response2; + } + + protected async override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + if (!_Response1Sent) + { + _Response1Sent = true; + _response1.RequestMessage = request; + return _response1; + } + else + { + _response2.RequestMessage = request; + return _response2; + } + } + } +} diff --git a/Octokit.Tests/Octokit.Tests.csproj b/Octokit.Tests/Octokit.Tests.csproj index e1cbb6f4..96aed6bd 100644 --- a/Octokit.Tests/Octokit.Tests.csproj +++ b/Octokit.Tests/Octokit.Tests.csproj @@ -141,6 +141,7 @@ + diff --git a/Octokit/Http/HttpClientAdapter.cs b/Octokit/Http/HttpClientAdapter.cs index 8f4f1a1c..c6794dab 100644 --- a/Octokit/Http/HttpClientAdapter.cs +++ b/Octokit/Http/HttpClientAdapter.cs @@ -207,44 +207,74 @@ namespace Octokit.Internal public class RedirectHandler : DelegatingHandler { + public const string AllowAutoRedirectKey = "AllowAutoRedirect"; + public const string RedirectCountKey = "RedirectCount"; public bool Enabled { get; set; } protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) { var response = await base.SendAsync(request, cancellationToken); - var allowAutoRedirect = (bool)request.Properties["AllowAutoRedirect"]; - if (!allowAutoRedirect) return response; + + // Can't redirect without somewhere to redirect too. Throw? + if (response.Headers.Location == null) return response; + + // Don't redirect if redirection has been disabled for this request + var allowAutoRedirect = (bool)request.Properties[AllowAutoRedirectKey]; + if (!allowAutoRedirect) return response; // Throw? + + // Don't redirect if we exceed max number of redirects + var redirectCount = 0; + if (request.Properties.Keys.Contains(RedirectCountKey)) + { + redirectCount = (int)request.Properties[RedirectCountKey]; + } + if (redirectCount > 3) return response; // Throw? + request.Properties[RedirectCountKey] = ++redirectCount; if (response.StatusCode == HttpStatusCode.MovedPermanently - || response.StatusCode == HttpStatusCode.Moved || response.StatusCode == HttpStatusCode.Redirect || response.StatusCode == HttpStatusCode.Found || response.StatusCode == HttpStatusCode.SeeOther - || response.StatusCode == HttpStatusCode.RedirectKeepVerb || response.StatusCode == HttpStatusCode.TemporaryRedirect || (int)response.StatusCode == 308) { var newRequest = CopyRequest(response.RequestMessage); - if (response.StatusCode == HttpStatusCode.Redirect - || response.StatusCode == HttpStatusCode.Found - || response.StatusCode == HttpStatusCode.SeeOther) + if (response.StatusCode == HttpStatusCode.SeeOther) { newRequest.Content = null; newRequest.Method = HttpMethod.Get; } + else + { + if (request.Content != null && request.Content.Headers.ContentLength != 0) { + var stream = await request.Content.ReadAsStreamAsync(); + if (stream.CanSeek) + { + stream.Position = 0; + } + else + { + throw new Exception("Cannot redirect a request with an unbuffered body"); + } + newRequest.Content = new StreamContent(stream); + } + } newRequest.RequestUri = response.Headers.Location; - - response = await base.SendAsync(newRequest, cancellationToken); + if (String.Compare(newRequest.RequestUri.Host,request.RequestUri.Host,StringComparison.OrdinalIgnoreCase) != 0) + { + newRequest.Headers.Authorization = null; + } + response = await this.SendAsync(newRequest, cancellationToken); } return response; } [System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Reliability", "CA2000:Dispose objects before losing scope")] - static HttpRequestMessage CopyRequest(HttpRequestMessage oldRequest) + private static HttpRequestMessage CopyRequest(HttpRequestMessage oldRequest) { var newrequest = new HttpRequestMessage(oldRequest.Method, oldRequest.RequestUri); @@ -256,7 +286,7 @@ namespace Octokit.Internal { newrequest.Properties.Add(property); } - if (oldRequest.Content != null) newrequest.Content = new StreamContent(oldRequest.Content.ReadAsStreamAsync().Result); + return newrequest; } }