From 18ecd9df30d758a2d4b4a074b4f15791d50ea226 Mon Sep 17 00:00:00 2001 From: Alexander Neumann Date: Fri, 29 Dec 2017 12:43:49 +0100 Subject: [PATCH] Improve limiting HTTP based backends --- cmd/restic/global.go | 16 ++++++++------ internal/limiter/limiter.go | 4 ++++ internal/limiter/static_limiter.go | 34 ++++++++++++++++++++++++++++++ 3 files changed, 47 insertions(+), 7 deletions(-) diff --git a/cmd/restic/global.go b/cmd/restic/global.go index 4138c222a..37e3ebea1 100644 --- a/cmd/restic/global.go +++ b/cmd/restic/global.go @@ -323,16 +323,11 @@ func OpenRepository(opts GlobalOptions) (*repository.Repository, error) { return nil, errors.Fatal("Please specify repository location (-r)") } - be, err := open(opts.Repo, opts.extended) + be, err := open(opts.Repo, opts, opts.extended) if err != nil { return nil, err } - if opts.LimitUploadKb > 0 || opts.LimitDownloadKb > 0 { - debug.Log("rate limiting backend to %d KiB/s upload and %d KiB/s download", opts.LimitUploadKb, opts.LimitDownloadKb) - be = limiter.LimitBackend(be, limiter.NewStaticLimiter(opts.LimitUploadKb, opts.LimitDownloadKb)) - } - be = backend.NewRetryBackend(be, 10, func(msg string, err error, d time.Duration) { Warnf("%v returned error, retrying after %v: %v\n", msg, d, err) }) @@ -532,7 +527,7 @@ func parseConfig(loc location.Location, opts options.Options) (interface{}, erro } // Open the backend specified by a location config. -func open(s string, opts options.Options) (restic.Backend, error) { +func open(s string, gopts GlobalOptions, opts options.Options) (restic.Backend, error) { debug.Log("parsing location %v", s) loc, err := location.Parse(s) if err != nil { @@ -551,11 +546,18 @@ func open(s string, opts options.Options) (restic.Backend, error) { return nil, err } + // wrap the transport so that the throughput via HTTP is limited + rt = limiter.NewStaticLimiter(gopts.LimitUploadKb, gopts.LimitDownloadKb).Transport(rt) + switch loc.Scheme { case "local": be, err = local.Open(cfg.(local.Config)) + // wrap the backend in a LimitBackend so that the throughput is limited + be = limiter.LimitBackend(be, limiter.NewStaticLimiter(gopts.LimitUploadKb, gopts.LimitDownloadKb)) case "sftp": be, err = sftp.Open(cfg.(sftp.Config), SuspendSignalHandler, InstallSignalHandler) + // wrap the backend in a LimitBackend so that the throughput is limited + be = limiter.LimitBackend(be, limiter.NewStaticLimiter(gopts.LimitUploadKb, gopts.LimitDownloadKb)) case "s3": be, err = s3.Open(cfg.(s3.Config), rt) case "gs": diff --git a/internal/limiter/limiter.go b/internal/limiter/limiter.go index c73d2bff5..abdbeaf75 100644 --- a/internal/limiter/limiter.go +++ b/internal/limiter/limiter.go @@ -2,6 +2,7 @@ package limiter import ( "io" + "net/http" ) // Limiter defines an interface that implementors can use to rate limit I/O @@ -14,4 +15,7 @@ type Limiter interface { // Downstream returns a rate limited reader that is intended to be used // for downloads. Downstream(r io.Reader) io.Reader + + // Transport returns an http.RoundTripper limited with the limiter. + Transport(http.RoundTripper) http.RoundTripper } diff --git a/internal/limiter/static_limiter.go b/internal/limiter/static_limiter.go index 62205cb4f..c2ff96ce4 100644 --- a/internal/limiter/static_limiter.go +++ b/internal/limiter/static_limiter.go @@ -2,6 +2,7 @@ package limiter import ( "io" + "net/http" "github.com/juju/ratelimit" ) @@ -41,6 +42,39 @@ func (l staticLimiter) Downstream(r io.Reader) io.Reader { return l.limit(r, l.downstream) } +type roundTripper func(*http.Request) (*http.Response, error) + +func (rt roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return rt(req) +} + +func (l staticLimiter) roundTripper(rt http.RoundTripper, req *http.Request) (*http.Response, error) { + if req.Body != nil { + req.Body = limitedReadCloser{ + limited: l.Upstream(req.Body), + original: req.Body, + } + } + + res, err := rt.RoundTrip(req) + + if res != nil && res.Body != nil { + res.Body = limitedReadCloser{ + limited: l.Downstream(res.Body), + original: res.Body, + } + } + + return res, err +} + +// Transport returns an HTTP transport limited with the limiter l. +func (l staticLimiter) Transport(rt http.RoundTripper) http.RoundTripper { + return roundTripper(func(req *http.Request) (*http.Response, error) { + return l.roundTripper(rt, req) + }) +} + func (l staticLimiter) limit(r io.Reader, b *ratelimit.Bucket) io.Reader { if b == nil { return r