diff --git a/src/NzbDrone.Common.Test/Http/HttpClientFixture.cs b/src/NzbDrone.Common.Test/Http/HttpClientFixture.cs index 04ad3b08b..39d72fa82 100644 --- a/src/NzbDrone.Common.Test/Http/HttpClientFixture.cs +++ b/src/NzbDrone.Common.Test/Http/HttpClientFixture.cs @@ -226,6 +226,29 @@ namespace NzbDrone.Common.Test.Http ExceptionVerification.IgnoreErrors(); } + + [Test] + public void should_overwrite_response_cookie() + { + var requestSet = new HttpRequest("http://eu.httpbin.org/cookies/set?my=cookie"); + requestSet.AllowAutoRedirect = false; + requestSet.StoreResponseCookie = true; + requestSet.AddCookie("my", "oldcookie"); + + var responseSet = Subject.Get(requestSet); + + var request = new HttpRequest("http://eu.httpbin.org/get"); + + var response = Subject.Get(request); + + response.Resource.Headers.Should().ContainKey("Cookie"); + + var cookie = response.Resource.Headers["Cookie"].ToString(); + + cookie.Should().Contain("my=cookie"); + + ExceptionVerification.IgnoreErrors(); + } } public class HttpBinResource diff --git a/src/NzbDrone.Common/Http/HttpClient.cs b/src/NzbDrone.Common/Http/HttpClient.cs index adeed3533..df6485488 100644 --- a/src/NzbDrone.Common/Http/HttpClient.cs +++ b/src/NzbDrone.Common/Http/HttpClient.cs @@ -73,31 +73,12 @@ namespace NzbDrone.Common.Http AddRequestHeaders(webRequest, request.Headers); } - var cookieContainer = _cookieContainerCache.Get("container", () => new CookieContainer()); - - if (request.Cookies.Count != 0) - { - foreach (var pair in request.Cookies) - { - cookieContainer.Add(new Cookie(pair.Key, pair.Value, "/", request.Url.Host) - { - Expires = DateTime.UtcNow.AddHours(1) - }); - } - } - - if (request.StoreResponseCookie) - { - webRequest.CookieContainer = cookieContainer; - } - else - { - webRequest.CookieContainer = new CookieContainer(); - webRequest.CookieContainer.Add(cookieContainer.GetCookies(request.Url)); - } + PrepareRequestCookies(request, webRequest); var response = ExecuteRequest(request, webRequest); + HandleResponseCookies(request, webRequest); + stopWatch.Stop(); _logger.Trace("{0} ({1:n0} ms)", response, stopWatch.ElapsedMilliseconds); @@ -119,6 +100,56 @@ namespace NzbDrone.Common.Http return response; } + private void PrepareRequestCookies(HttpRequest request, HttpWebRequest webRequest) + { + lock (_cookieContainerCache) + { + var persistentCookieContainer = _cookieContainerCache.Get("container", () => new CookieContainer()); + + if (request.Cookies.Count != 0) + { + foreach (var pair in request.Cookies) + { + persistentCookieContainer.Add(new Cookie(pair.Key, pair.Value, "/", request.Url.Host) + { + Expires = DateTime.UtcNow.AddHours(1) + }); + } + } + + var requestCookies = persistentCookieContainer.GetCookies(request.Url); + + if (requestCookies.Count == 0 && !request.StoreResponseCookie) + { + return; + } + + if (webRequest.CookieContainer == null) + { + webRequest.CookieContainer = new CookieContainer(); + } + + webRequest.CookieContainer.Add(requestCookies); + } + } + + private void HandleResponseCookies(HttpRequest request, HttpWebRequest webRequest) + { + if (!request.StoreResponseCookie) + { + return; + } + + lock (_cookieContainerCache) + { + var persistentCookieContainer = _cookieContainerCache.Get("container", () => new CookieContainer()); + + var cookies = webRequest.CookieContainer.GetCookies(request.Url); + + persistentCookieContainer.Add(cookies); + } + } + private HttpResponse ExecuteRequest(HttpRequest request, HttpWebRequest webRequest) { if (OsInfo.IsMonoRuntime && webRequest.RequestUri.Scheme == "https")