mirror of https://github.com/restic/restic.git
backend: add tests for watchdogRoundTripper
This commit is contained in:
parent
8778670232
commit
ebd01a4675
|
@ -0,0 +1,201 @@
|
|||
package backend
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
rtest "github.com/restic/restic/internal/test"
|
||||
)
|
||||
|
||||
func TestRead(t *testing.T) {
|
||||
data := []byte("abcdef")
|
||||
var ctr int
|
||||
kick := func() {
|
||||
ctr++
|
||||
}
|
||||
var closed bool
|
||||
onClose := func() {
|
||||
closed = true
|
||||
}
|
||||
|
||||
wd := newWatchdogReadCloser(io.NopCloser(bytes.NewReader(data)), 1, kick, onClose)
|
||||
|
||||
out, err := io.ReadAll(wd)
|
||||
rtest.OK(t, err)
|
||||
rtest.Equals(t, data, out, "data mismatch")
|
||||
// the EOF read also triggers the kick function
|
||||
rtest.Equals(t, len(data)*2+2, ctr, "unexpected number of kick calls")
|
||||
|
||||
rtest.Equals(t, false, closed, "close function called too early")
|
||||
rtest.OK(t, wd.Close())
|
||||
rtest.Equals(t, true, closed, "close function not called")
|
||||
}
|
||||
|
||||
func TestRoundtrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// at the higher delay values, it takes longer to transmit the request/response body
|
||||
// than the roundTripper timeout
|
||||
for _, delay := range []int{0, 1, 10, 20} {
|
||||
t.Run(fmt.Sprintf("%v", delay), func(t *testing.T) {
|
||||
msg := []byte("ping-pong-data")
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
data, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
w.WriteHeader(500)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(200)
|
||||
|
||||
// slowly send the reply
|
||||
for len(data) >= 2 {
|
||||
_, _ = w.Write(data[:2])
|
||||
w.(http.Flusher).Flush()
|
||||
data = data[2:]
|
||||
time.Sleep(time.Duration(delay) * time.Millisecond)
|
||||
}
|
||||
_, _ = w.Write(data)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
rt := newWatchdogRoundtripper(http.DefaultTransport, 50*time.Millisecond, 2)
|
||||
req, err := http.NewRequestWithContext(context.TODO(), "GET", srv.URL, io.NopCloser(newSlowReader(bytes.NewReader(msg), time.Duration(delay)*time.Millisecond)))
|
||||
rtest.OK(t, err)
|
||||
|
||||
resp, err := rt.RoundTrip(req)
|
||||
rtest.OK(t, err)
|
||||
rtest.Equals(t, 200, resp.StatusCode, "unexpected status code")
|
||||
|
||||
response, err := io.ReadAll(resp.Body)
|
||||
rtest.OK(t, err)
|
||||
rtest.Equals(t, msg, response, "unexpected response")
|
||||
|
||||
rtest.OK(t, resp.Body.Close())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanceledRoundtrip(t *testing.T) {
|
||||
rt := newWatchdogRoundtripper(http.DefaultTransport, time.Second, 2)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "http://some.random.url.dfdgsfg", nil)
|
||||
rtest.OK(t, err)
|
||||
|
||||
resp, err := rt.RoundTrip(req)
|
||||
rtest.Equals(t, context.Canceled, err)
|
||||
// make linter happy
|
||||
if resp != nil {
|
||||
rtest.OK(t, resp.Body.Close())
|
||||
}
|
||||
}
|
||||
|
||||
type slowReader struct {
|
||||
data io.Reader
|
||||
delay time.Duration
|
||||
}
|
||||
|
||||
func newSlowReader(data io.Reader, delay time.Duration) *slowReader {
|
||||
return &slowReader{
|
||||
data: data,
|
||||
delay: delay,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *slowReader) Read(p []byte) (n int, err error) {
|
||||
time.Sleep(s.delay)
|
||||
return s.data.Read(p)
|
||||
}
|
||||
|
||||
func TestUploadTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
msg := []byte("ping")
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
w.WriteHeader(500)
|
||||
return
|
||||
}
|
||||
t.Error("upload should have been canceled")
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
rt := newWatchdogRoundtripper(http.DefaultTransport, 10*time.Millisecond, 1024)
|
||||
req, err := http.NewRequestWithContext(context.TODO(), "GET", srv.URL, io.NopCloser(newSlowReader(bytes.NewReader(msg), 100*time.Millisecond)))
|
||||
rtest.OK(t, err)
|
||||
|
||||
resp, err := rt.RoundTrip(req)
|
||||
rtest.Equals(t, context.Canceled, err)
|
||||
// make linter happy
|
||||
if resp != nil {
|
||||
rtest.OK(t, resp.Body.Close())
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessingTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
msg := []byte("ping")
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
w.WriteHeader(500)
|
||||
return
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
w.WriteHeader(200)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
rt := newWatchdogRoundtripper(http.DefaultTransport, 10*time.Millisecond, 1024)
|
||||
req, err := http.NewRequestWithContext(context.TODO(), "GET", srv.URL, io.NopCloser(bytes.NewReader(msg)))
|
||||
rtest.OK(t, err)
|
||||
|
||||
resp, err := rt.RoundTrip(req)
|
||||
rtest.Equals(t, context.Canceled, err)
|
||||
// make linter happy
|
||||
if resp != nil {
|
||||
rtest.OK(t, resp.Body.Close())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
msg := []byte("ping")
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
data, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
w.WriteHeader(500)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(200)
|
||||
_, _ = w.Write(data[:2])
|
||||
w.(http.Flusher).Flush()
|
||||
data = data[2:]
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
_, _ = w.Write(data)
|
||||
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
rt := newWatchdogRoundtripper(http.DefaultTransport, 10*time.Millisecond, 1024)
|
||||
req, err := http.NewRequestWithContext(context.TODO(), "GET", srv.URL, io.NopCloser(bytes.NewReader(msg)))
|
||||
rtest.OK(t, err)
|
||||
|
||||
resp, err := rt.RoundTrip(req)
|
||||
rtest.OK(t, err)
|
||||
rtest.Equals(t, 200, resp.StatusCode, "unexpected status code")
|
||||
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
rtest.Equals(t, context.Canceled, err, "response download not canceled")
|
||||
rtest.OK(t, resp.Body.Close())
|
||||
}
|
Loading…
Reference in New Issue