diff --git a/changelog/unreleased/issue-4768 b/changelog/unreleased/issue-4768 new file mode 100644 index 000000000..dd52089e1 --- /dev/null +++ b/changelog/unreleased/issue-4768 @@ -0,0 +1,8 @@ +Enhancement: Allow custom User-Agent to be specified for outgoing requests + +Restic now permits setting a custom `User-Agent` for outgoing HTTP requests +using the global flag `--http-user-agent` or the `RESTIC_HTTP_USER_AGENT` +environment variable. + +https://github.com/restic/restic/issues/4768 +https://github.com/restic/restic/pull/4810 \ No newline at end of file diff --git a/cmd/restic/global.go b/cmd/restic/global.go index 144445cc0..a5250ca38 100644 --- a/cmd/restic/global.go +++ b/cmd/restic/global.go @@ -135,6 +135,7 @@ func init() { f.IntVar(&globalOptions.Limits.DownloadKb, "limit-download", 0, "limits downloads to a maximum `rate` in KiB/s. (default: unlimited)") f.UintVar(&globalOptions.PackSize, "pack-size", 0, "set target pack `size` in MiB, created pack files may be larger (default: $RESTIC_PACK_SIZE)") f.StringSliceVarP(&globalOptions.Options, "option", "o", []string{}, "set extended option (`key=value`, can be specified multiple times)") + f.StringVar(&globalOptions.HTTPUserAgent, "http-user-agent", "", "set a http user agent for outgoing http requests") // Use our "generate" command instead of the cobra provided "completion" command cmdRoot.CompletionOptions.DisableDefaultCmd = true @@ -155,6 +156,10 @@ func init() { // parse target pack size from env, on error the default value will be used targetPackSize, _ := strconv.ParseUint(os.Getenv("RESTIC_PACK_SIZE"), 10, 32) globalOptions.PackSize = uint(targetPackSize) + + if os.Getenv("RESTIC_HTTP_USER_AGENT") != "" { + globalOptions.HTTPUserAgent = os.Getenv("RESTIC_HTTP_USER_AGENT") + } } func stdinIsTerminal() bool { diff --git a/doc/manual_rest.rst b/doc/manual_rest.rst index 3f8b3a2c7..9d9d6e141 100644 --- a/doc/manual_rest.rst +++ b/doc/manual_rest.rst @@ -54,6 +54,7 @@ Usage help is available: --cleanup-cache auto remove old cache directories --compression mode compression mode (only available for repository format version 2), one of (auto|off|max) (default: $RESTIC_COMPRESSION) (default auto) -h, --help help for restic + --http-user-agent value set a custom user agent for outgoing http requests --insecure-no-password use an empty password for the repository, must be passed to every restic command (insecure) --insecure-tls skip TLS certificate verification when connecting to the repository (insecure) --json set output mode to JSON for commands that support it @@ -134,6 +135,7 @@ command: --cache-dir directory set the cache directory. (default: use system default cache directory) --cleanup-cache auto remove old cache directories --compression mode compression mode (only available for repository format version 2), one of (auto|off|max) (default: $RESTIC_COMPRESSION) (default auto) + --http-user-agent value set a custom user agent for outgoing http requests --insecure-no-password use an empty password for the repository, must be passed to every restic command (insecure) --insecure-tls skip TLS certificate verification when connecting to the repository (insecure) --json set output mode to JSON for commands that support it diff --git a/internal/backend/http_transport.go b/internal/backend/http_transport.go index 97fd521e3..19613e810 100644 --- a/internal/backend/http_transport.go +++ b/internal/backend/http_transport.go @@ -28,6 +28,9 @@ type TransportOptions struct { // Skip TLS certificate verification InsecureTLS bool + + // Specify Custom User-Agent for the http Client + HTTPUserAgent string } // readPEMCertKey reads a file and returns the PEM encoded certificate and key @@ -132,6 +135,13 @@ func Transport(opts TransportOptions) (http.RoundTripper, error) { } rt := http.RoundTripper(tr) + + // if the userAgent is set in the Transport Options, wrap the + // http.RoundTripper + if opts.HTTPUserAgent != "" { + rt = newCustomUserAgentRoundTripper(rt, opts.HTTPUserAgent) + } + if feature.Flag.Enabled(feature.BackendErrorRedesign) { rt = newWatchdogRoundtripper(rt, 120*time.Second, 128*1024) } diff --git a/internal/backend/httpuseragent_roundtripper.go b/internal/backend/httpuseragent_roundtripper.go new file mode 100644 index 000000000..6272aa41a --- /dev/null +++ b/internal/backend/httpuseragent_roundtripper.go @@ -0,0 +1,25 @@ +package backend + +import "net/http" + +// httpUserAgentRoundTripper is a custom http.RoundTripper that modifies the User-Agent header +// of outgoing HTTP requests. +type httpUserAgentRoundTripper struct { + userAgent string + rt http.RoundTripper +} + +func newCustomUserAgentRoundTripper(rt http.RoundTripper, userAgent string) *httpUserAgentRoundTripper { + return &httpUserAgentRoundTripper{ + rt: rt, + userAgent: userAgent, + } +} + +// RoundTrip modifies the User-Agent header of the request and then delegates the request +// to the underlying RoundTripper. +func (c *httpUserAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + req = req.Clone(req.Context()) + req.Header.Set("User-Agent", c.userAgent) + return c.rt.RoundTrip(req) +} diff --git a/internal/backend/httpuseragent_roundtripper_test.go b/internal/backend/httpuseragent_roundtripper_test.go new file mode 100644 index 000000000..0a81c418a --- /dev/null +++ b/internal/backend/httpuseragent_roundtripper_test.go @@ -0,0 +1,50 @@ +package backend + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestCustomUserAgentTransport(t *testing.T) { + // Create a mock HTTP handler that checks the User-Agent header + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + userAgent := r.Header.Get("User-Agent") + if userAgent != "TestUserAgent" { + t.Errorf("Expected User-Agent: TestUserAgent, got: %s", userAgent) + } + w.WriteHeader(http.StatusOK) + }) + + // Create a test server with the mock handler + server := httptest.NewServer(handler) + defer server.Close() + + // Create a custom user agent transport + customUserAgent := "TestUserAgent" + transport := &httpUserAgentRoundTripper{ + userAgent: customUserAgent, + rt: http.DefaultTransport, + } + + // Create an HTTP client with the custom transport + client := &http.Client{ + Transport: transport, + } + + // Make a request to the test server + resp, err := client.Get(server.URL) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer func() { + if err := resp.Body.Close(); err != nil { + t.Log("failed to close response body") + } + }() + + // Check the response status code + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status code: %d, got: %d", http.StatusOK, resp.StatusCode) + } +}