package backend import ( "bytes" "context" "fmt" "io" "net/http" "net/http/httptest" "testing" "time" rtest "github.com/restic/restic/internal/test" ) func TestRead(t *testing.T) { data := []byte("abcdef") var ctr int kick := func() { ctr++ } var closed bool onClose := func() { closed = true } wd := newWatchdogReadCloser(io.NopCloser(bytes.NewReader(data)), 1, kick, onClose) out, err := io.ReadAll(wd) rtest.OK(t, err) rtest.Equals(t, data, out, "data mismatch") // the EOF read also triggers the kick function rtest.Equals(t, len(data)*2+2, ctr, "unexpected number of kick calls") rtest.Equals(t, false, closed, "close function called too early") rtest.OK(t, wd.Close()) rtest.Equals(t, true, closed, "close function not called") } func TestRoundtrip(t *testing.T) { t.Parallel() // at the higher delay values, it takes longer to transmit the request/response body // than the roundTripper timeout for _, delay := range []int{0, 1, 10, 20} { t.Run(fmt.Sprintf("%v", delay), func(t *testing.T) { msg := []byte("ping-pong-data") srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { data, err := io.ReadAll(r.Body) if err != nil { w.WriteHeader(500) return } w.WriteHeader(200) // slowly send the reply for len(data) >= 2 { _, _ = w.Write(data[:2]) w.(http.Flusher).Flush() data = data[2:] time.Sleep(time.Duration(delay) * time.Millisecond) } _, _ = w.Write(data) })) defer srv.Close() rt := newWatchdogRoundtripper(http.DefaultTransport, 50*time.Millisecond, 2) req, err := http.NewRequestWithContext(context.TODO(), "GET", srv.URL, io.NopCloser(newSlowReader(bytes.NewReader(msg), time.Duration(delay)*time.Millisecond))) rtest.OK(t, err) resp, err := rt.RoundTrip(req) rtest.OK(t, err) rtest.Equals(t, 200, resp.StatusCode, "unexpected status code") response, err := io.ReadAll(resp.Body) rtest.OK(t, err) rtest.Equals(t, msg, response, "unexpected response") rtest.OK(t, resp.Body.Close()) }) } } func TestCanceledRoundtrip(t *testing.T) { rt := newWatchdogRoundtripper(http.DefaultTransport, time.Second, 2) ctx, cancel := context.WithCancel(context.Background()) cancel() req, err := http.NewRequestWithContext(ctx, "GET", "http://some.random.url.dfdgsfg", nil) rtest.OK(t, err) resp, err := rt.RoundTrip(req) rtest.Equals(t, context.Canceled, err) // make linter happy if resp != nil { rtest.OK(t, resp.Body.Close()) } } type slowReader struct { data io.Reader delay time.Duration } func newSlowReader(data io.Reader, delay time.Duration) *slowReader { return &slowReader{ data: data, delay: delay, } } func (s *slowReader) Read(p []byte) (n int, err error) { time.Sleep(s.delay) return s.data.Read(p) } func TestUploadTimeout(t *testing.T) { t.Parallel() msg := []byte("ping") srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, err := io.ReadAll(r.Body) if err != nil { w.WriteHeader(500) return } t.Error("upload should have been canceled") })) defer srv.Close() rt := newWatchdogRoundtripper(http.DefaultTransport, 10*time.Millisecond, 1024) req, err := http.NewRequestWithContext(context.TODO(), "GET", srv.URL, io.NopCloser(newSlowReader(bytes.NewReader(msg), 100*time.Millisecond))) rtest.OK(t, err) resp, err := rt.RoundTrip(req) rtest.Equals(t, context.Canceled, err) // make linter happy if resp != nil { rtest.OK(t, resp.Body.Close()) } } func TestProcessingTimeout(t *testing.T) { t.Parallel() msg := []byte("ping") srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, err := io.ReadAll(r.Body) if err != nil { w.WriteHeader(500) return } time.Sleep(100 * time.Millisecond) w.WriteHeader(200) })) defer srv.Close() rt := newWatchdogRoundtripper(http.DefaultTransport, 10*time.Millisecond, 1024) req, err := http.NewRequestWithContext(context.TODO(), "GET", srv.URL, io.NopCloser(bytes.NewReader(msg))) rtest.OK(t, err) resp, err := rt.RoundTrip(req) rtest.Equals(t, context.Canceled, err) // make linter happy if resp != nil { rtest.OK(t, resp.Body.Close()) } } func TestDownloadTimeout(t *testing.T) { t.Parallel() msg := []byte("ping") srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { data, err := io.ReadAll(r.Body) if err != nil { w.WriteHeader(500) return } w.WriteHeader(200) _, _ = w.Write(data[:2]) w.(http.Flusher).Flush() data = data[2:] time.Sleep(100 * time.Millisecond) _, _ = w.Write(data) })) defer srv.Close() rt := newWatchdogRoundtripper(http.DefaultTransport, 10*time.Millisecond, 1024) req, err := http.NewRequestWithContext(context.TODO(), "GET", srv.URL, io.NopCloser(bytes.NewReader(msg))) rtest.OK(t, err) resp, err := rt.RoundTrip(req) rtest.OK(t, err) rtest.Equals(t, 200, resp.StatusCode, "unexpected status code") _, err = io.ReadAll(resp.Body) rtest.Equals(t, context.Canceled, err, "response download not canceled") rtest.OK(t, resp.Body.Close()) }