package azure import ( "bytes" "context" "crypto/md5" "encoding/base64" "fmt" "hash" "io" "net/http" "path" "strings" "github.com/restic/restic/internal/backend" "github.com/restic/restic/internal/backend/layout" "github.com/restic/restic/internal/backend/location" "github.com/restic/restic/internal/backend/util" "github.com/restic/restic/internal/debug" "github.com/restic/restic/internal/errors" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blockblob" azContainer "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container" ) // Backend stores data on an azure endpoint. type Backend struct { cfg Config container *azContainer.Client connections uint prefix string listMaxItems int layout.Layout } const saveLargeSize = 256 * 1024 * 1024 const defaultListMaxItems = 5000 // make sure that *Backend implements backend.Backend var _ backend.Backend = &Backend{} func NewFactory() location.Factory { return location.NewHTTPBackendFactory("azure", ParseConfig, location.NoPassword, Create, Open) } func open(cfg Config, rt http.RoundTripper) (*Backend, error) { debug.Log("open, config %#v", cfg) var client *azContainer.Client var err error var endpointSuffix string if cfg.EndpointSuffix != "" { endpointSuffix = cfg.EndpointSuffix } else { endpointSuffix = "core.windows.net" } url := fmt.Sprintf("https://%s.blob.%s/%s", cfg.AccountName, endpointSuffix, cfg.Container) opts := &azContainer.ClientOptions{ ClientOptions: azcore.ClientOptions{ Transport: &http.Client{Transport: rt}, }, } if cfg.AccountKey.String() != "" { // We have an account key value, find the BlobServiceClient // from with a BasicClient debug.Log(" - using account key") cred, err := azblob.NewSharedKeyCredential(cfg.AccountName, cfg.AccountKey.Unwrap()) if err != nil { return nil, errors.Wrap(err, "NewSharedKeyCredential") } client, err = azContainer.NewClientWithSharedKeyCredential(url, cred, opts) if err != nil { return nil, errors.Wrap(err, "NewClientWithSharedKeyCredential") } } else if cfg.AccountSAS.String() != "" { // Get the client using the SAS Token as authentication, this // is longer winded than above because the SDK wants a URL for the Account // if your using a SAS token, and not just the account name // we (as per the SDK ) assume the default Azure portal. // https://github.com/Azure/azure-storage-blob-go/issues/130 debug.Log(" - using sas token") sas := cfg.AccountSAS.Unwrap() // strip query sign prefix if sas[0] == '?' { sas = sas[1:] } urlWithSAS := fmt.Sprintf("%s?%s", url, sas) client, err = azContainer.NewClientWithNoCredential(urlWithSAS, opts) if err != nil { return nil, errors.Wrap(err, "NewAccountSASClientFromEndpointToken") } } else { var cred azcore.TokenCredential if cfg.ForceCliCredential { debug.Log(" - using AzureCLICredential") cred, err = azidentity.NewAzureCLICredential(nil) if err != nil { return nil, errors.Wrap(err, "NewAzureCLICredential") } } else { debug.Log(" - using DefaultAzureCredential") cred, err = azidentity.NewDefaultAzureCredential(nil) if err != nil { return nil, errors.Wrap(err, "NewDefaultAzureCredential") } } client, err = azContainer.NewClient(url, cred, opts) if err != nil { return nil, errors.Wrap(err, "NewClient") } } be := &Backend{ container: client, cfg: cfg, connections: cfg.Connections, Layout: &layout.DefaultLayout{ Path: cfg.Prefix, Join: path.Join, }, listMaxItems: defaultListMaxItems, } return be, nil } // Open opens the Azure backend at specified container. func Open(_ context.Context, cfg Config, rt http.RoundTripper) (*Backend, error) { return open(cfg, rt) } // Create opens the Azure backend at specified container and creates the container if // it does not exist yet. func Create(ctx context.Context, cfg Config, rt http.RoundTripper) (*Backend, error) { be, err := open(cfg, rt) if err != nil { return nil, errors.Wrap(err, "open") } _, err = be.container.GetProperties(ctx, &azContainer.GetPropertiesOptions{}) if err != nil && bloberror.HasCode(err, bloberror.ContainerNotFound) { _, err = be.container.Create(ctx, &azContainer.CreateOptions{}) if err != nil { return nil, errors.Wrap(err, "container.Create") } } else if err != nil { return be, errors.Wrap(err, "container.GetProperties") } return be, nil } // SetListMaxItems sets the number of list items to load per request. func (be *Backend) SetListMaxItems(i int) { be.listMaxItems = i } // IsNotExist returns true if the error is caused by a not existing file. func (be *Backend) IsNotExist(err error) bool { return bloberror.HasCode(err, bloberror.BlobNotFound) } func (be *Backend) IsPermanentError(err error) bool { if be.IsNotExist(err) { return true } var aerr *azcore.ResponseError if errors.As(err, &aerr) { if aerr.StatusCode == http.StatusRequestedRangeNotSatisfiable || aerr.StatusCode == http.StatusUnauthorized || aerr.StatusCode == http.StatusForbidden { return true } } return false } // Join combines path components with slashes. func (be *Backend) Join(p ...string) string { return path.Join(p...) } func (be *Backend) Connections() uint { return be.connections } // Hasher may return a hash function for calculating a content hash for the backend func (be *Backend) Hasher() hash.Hash { return md5.New() } // HasAtomicReplace returns whether Save() can atomically replace files func (be *Backend) HasAtomicReplace() bool { return true } // Path returns the path in the bucket that is used for this backend. func (be *Backend) Path() string { return be.prefix } // Save stores data in the backend at the handle. func (be *Backend) Save(ctx context.Context, h backend.Handle, rd backend.RewindReader) error { objName := be.Filename(h) debug.Log("InsertObject(%v, %v)", be.cfg.AccountName, objName) var err error if rd.Length() < saveLargeSize { // if it's smaller than 256miB, then just create the file directly from the reader err = be.saveSmall(ctx, objName, rd) } else { // otherwise use the more complicated method err = be.saveLarge(ctx, objName, rd) } return err } func (be *Backend) saveSmall(ctx context.Context, objName string, rd backend.RewindReader) error { blockBlobClient := be.container.NewBlockBlobClient(objName) // upload it as a new "block", use the base64 hash for the ID id := base64.StdEncoding.EncodeToString(rd.Hash()) buf := make([]byte, rd.Length()) _, err := io.ReadFull(rd, buf) if err != nil { return errors.Wrap(err, "ReadFull") } reader := bytes.NewReader(buf) _, err = blockBlobClient.StageBlock(ctx, id, streaming.NopCloser(reader), &blockblob.StageBlockOptions{ TransactionalValidation: blob.TransferValidationTypeMD5(rd.Hash()), }) if err != nil { return errors.Wrap(err, "StageBlock") } blocks := []string{id} _, err = blockBlobClient.CommitBlockList(ctx, blocks, &blockblob.CommitBlockListOptions{}) return errors.Wrap(err, "CommitBlockList") } func (be *Backend) saveLarge(ctx context.Context, objName string, rd backend.RewindReader) error { blockBlobClient := be.container.NewBlockBlobClient(objName) buf := make([]byte, 100*1024*1024) blocks := []string{} uploadedBytes := 0 for { n, err := io.ReadFull(rd, buf) if err == io.ErrUnexpectedEOF { err = nil } if err == io.EOF { // end of file reached, no bytes have been read at all break } if err != nil { return errors.Wrap(err, "ReadFull") } buf = buf[:n] uploadedBytes += n // upload it as a new "block", use the base64 hash for the ID h := md5.Sum(buf) id := base64.StdEncoding.EncodeToString(h[:]) reader := bytes.NewReader(buf) debug.Log("StageBlock %v with %d bytes", id, len(buf)) _, err = blockBlobClient.StageBlock(ctx, id, streaming.NopCloser(reader), &blockblob.StageBlockOptions{ TransactionalValidation: blob.TransferValidationTypeMD5(h[:]), }) if err != nil { return errors.Wrap(err, "StageBlock") } blocks = append(blocks, id) } // sanity check if uploadedBytes != int(rd.Length()) { return errors.Errorf("wrote %d bytes instead of the expected %d bytes", uploadedBytes, rd.Length()) } _, err := blockBlobClient.CommitBlockList(ctx, blocks, &blockblob.CommitBlockListOptions{}) debug.Log("uploaded %d parts: %v", len(blocks), blocks) return errors.Wrap(err, "CommitBlockList") } // Load runs fn with a reader that yields the contents of the file at h at the // given offset. func (be *Backend) Load(ctx context.Context, h backend.Handle, length int, offset int64, fn func(rd io.Reader) error) error { return util.DefaultLoad(ctx, h, length, offset, be.openReader, fn) } func (be *Backend) openReader(ctx context.Context, h backend.Handle, length int, offset int64) (io.ReadCloser, error) { objName := be.Filename(h) blockBlobClient := be.container.NewBlobClient(objName) resp, err := blockBlobClient.DownloadStream(ctx, &blob.DownloadStreamOptions{ Range: azblob.HTTPRange{ Offset: offset, Count: int64(length), }, }) if err != nil { return nil, err } if length > 0 && (resp.ContentLength == nil || *resp.ContentLength != int64(length)) { _ = resp.Body.Close() return nil, &azcore.ResponseError{ErrorCode: "restic-file-too-short", StatusCode: http.StatusRequestedRangeNotSatisfiable} } return resp.Body, err } // Stat returns information about a blob. func (be *Backend) Stat(ctx context.Context, h backend.Handle) (backend.FileInfo, error) { objName := be.Filename(h) blobClient := be.container.NewBlobClient(objName) props, err := blobClient.GetProperties(ctx, nil) if err != nil { return backend.FileInfo{}, errors.Wrap(err, "blob.GetProperties") } fi := backend.FileInfo{ Size: *props.ContentLength, Name: h.Name, } return fi, nil } // Remove removes the blob with the given name and type. func (be *Backend) Remove(ctx context.Context, h backend.Handle) error { objName := be.Filename(h) blob := be.container.NewBlobClient(objName) _, err := blob.Delete(ctx, &azblob.DeleteBlobOptions{}) if be.IsNotExist(err) { return nil } return errors.Wrap(err, "client.RemoveObject") } // List runs fn for each file in the backend which has the type t. When an // error occurs (or fn returns an error), List stops and returns it. func (be *Backend) List(ctx context.Context, t backend.FileType, fn func(backend.FileInfo) error) error { prefix, _ := be.Basedir(t) // make sure prefix ends with a slash if !strings.HasSuffix(prefix, "/") { prefix += "/" } max := int32(be.listMaxItems) opts := &azContainer.ListBlobsFlatOptions{ MaxResults: &max, Prefix: &prefix, } lister := be.container.NewListBlobsFlatPager(opts) for lister.More() { resp, err := lister.NextPage(ctx) if err != nil { return err } debug.Log("got %v objects", len(resp.Segment.BlobItems)) for _, item := range resp.Segment.BlobItems { m := strings.TrimPrefix(*item.Name, prefix) if m == "" { continue } fi := backend.FileInfo{ Name: path.Base(m), Size: *item.Properties.ContentLength, } if ctx.Err() != nil { return ctx.Err() } err := fn(fi) if err != nil { return err } if ctx.Err() != nil { return ctx.Err() } } } return ctx.Err() } // Delete removes all restic keys in the bucket. It will not remove the bucket itself. func (be *Backend) Delete(ctx context.Context) error { return util.DefaultDelete(ctx, be) } // Close does nothing func (be *Backend) Close() error { return nil }