diff --git a/internal/fs/fs_local_vss.go b/internal/fs/fs_local_vss.go index 5f55dcfd1..de30bcedb 100644 --- a/internal/fs/fs_local_vss.go +++ b/internal/fs/fs_local_vss.go @@ -17,6 +17,7 @@ type VSSConfig struct { ExcludeAllMountPoints bool `option:"excludeallmountpoints" help:"exclude mountpoints from snapshotting on all volumes"` ExcludeVolumes string `option:"excludevolumes" help:"semicolon separated list of volumes to exclude from snapshotting (ex. 'c:\\;e:\\mnt;\\\\?\\Volume{...}')"` Timeout time.Duration `option:"timeout" help:"time that the VSS can spend creating snapshot before timing out"` + Provider string `option:"provider" help:"VSS provider identifier which will be used for snapshotting"` } func init() { @@ -64,6 +65,7 @@ type LocalVss struct { excludeAllMountPoints bool excludeVolumes map[string]struct{} timeout time.Duration + provider string } // statically ensure that LocalVss implements FS. @@ -102,6 +104,7 @@ func NewLocalVss(msgError ErrorHandler, msgMessage MessageHandler, cfg VSSConfig excludeAllMountPoints: cfg.ExcludeAllMountPoints, excludeVolumes: parseMountPoints(cfg.ExcludeVolumes, msgError), timeout: cfg.Timeout, + provider: cfg.Provider, } } @@ -209,7 +212,7 @@ func (fs *LocalVss) snapshotPath(path string) string { } } - if snapshot, err := NewVssSnapshot(vssVolume, fs.timeout, filter, fs.msgError); err != nil { + if snapshot, err := NewVssSnapshot(fs.provider, vssVolume, fs.timeout, filter, fs.msgError); err != nil { fs.msgError(vssVolume, errors.Errorf("failed to create snapshot for [%s]: %s", vssVolume, err)) fs.failedSnapshots[volumeNameLower] = struct{}{} diff --git a/internal/fs/vss.go b/internal/fs/vss.go index 838bdf79b..a54475480 100644 --- a/internal/fs/vss.go +++ b/internal/fs/vss.go @@ -41,7 +41,7 @@ func GetVolumeNameForVolumeMountPoint(mountPoint string) (string, error) { // NewVssSnapshot creates a new vss snapshot. If creating the snapshots doesn't // finish within the timeout an error is returned. -func NewVssSnapshot( +func NewVssSnapshot(_ string, _ string, _ time.Duration, _ VolumeFilter, _ ErrorHandler) (VssSnapshot, error) { return VssSnapshot{}, errors.New("VSS snapshots are only supported on windows") } diff --git a/internal/fs/vss_windows.go b/internal/fs/vss_windows.go index 424548a74..18aea419d 100644 --- a/internal/fs/vss_windows.go +++ b/internal/fs/vss_windows.go @@ -367,7 +367,7 @@ func (vss *IVssBackupComponents) convertToVSSAsync( } // IsVolumeSupported calls the equivalent VSS api. -func (vss *IVssBackupComponents) IsVolumeSupported(volumeName string) (bool, error) { +func (vss *IVssBackupComponents) IsVolumeSupported(providerID *ole.GUID, volumeName string) (bool, error) { volumeNamePointer, err := syscall.UTF16PtrFromString(volumeName) if err != nil { panic(err) @@ -377,7 +377,7 @@ func (vss *IVssBackupComponents) IsVolumeSupported(volumeName string) (bool, err var result uintptr if runtime.GOARCH == "386" { - id := (*[4]uintptr)(unsafe.Pointer(ole.IID_NULL)) + id := (*[4]uintptr)(unsafe.Pointer(providerID)) result, _, _ = syscall.Syscall9(vss.getVTable().isVolumeSupported, 7, uintptr(unsafe.Pointer(vss)), id[0], id[1], id[2], id[3], @@ -385,7 +385,7 @@ func (vss *IVssBackupComponents) IsVolumeSupported(volumeName string) (bool, err 0) } else { result, _, _ = syscall.Syscall6(vss.getVTable().isVolumeSupported, 4, - uintptr(unsafe.Pointer(vss)), uintptr(unsafe.Pointer(ole.IID_NULL)), + uintptr(unsafe.Pointer(vss)), uintptr(unsafe.Pointer(providerID)), uintptr(unsafe.Pointer(volumeNamePointer)), uintptr(unsafe.Pointer(&isSupportedRaw)), 0, 0) } @@ -411,7 +411,7 @@ func (vss *IVssBackupComponents) StartSnapshotSet() (ole.GUID, error) { } // AddToSnapshotSet calls the equivalent VSS api. -func (vss *IVssBackupComponents) AddToSnapshotSet(volumeName string, idSnapshot *ole.GUID) error { +func (vss *IVssBackupComponents) AddToSnapshotSet(volumeName string, providerID *ole.GUID, idSnapshot *ole.GUID) error { volumeNamePointer, err := syscall.UTF16PtrFromString(volumeName) if err != nil { panic(err) @@ -420,15 +420,15 @@ func (vss *IVssBackupComponents) AddToSnapshotSet(volumeName string, idSnapshot var result uintptr if runtime.GOARCH == "386" { - id := (*[4]uintptr)(unsafe.Pointer(ole.IID_NULL)) + id := (*[4]uintptr)(unsafe.Pointer(providerID)) result, _, _ = syscall.Syscall9(vss.getVTable().addToSnapshotSet, 7, - uintptr(unsafe.Pointer(vss)), uintptr(unsafe.Pointer(volumeNamePointer)), id[0], id[1], - id[2], id[3], uintptr(unsafe.Pointer(idSnapshot)), 0, 0) + uintptr(unsafe.Pointer(vss)), uintptr(unsafe.Pointer(volumeNamePointer)), + id[0], id[1], id[2], id[3], uintptr(unsafe.Pointer(idSnapshot)), 0, 0) } else { result, _, _ = syscall.Syscall6(vss.getVTable().addToSnapshotSet, 4, uintptr(unsafe.Pointer(vss)), uintptr(unsafe.Pointer(volumeNamePointer)), - uintptr(unsafe.Pointer(ole.IID_NULL)), uintptr(unsafe.Pointer(idSnapshot)), 0, 0) + uintptr(unsafe.Pointer(providerID)), uintptr(unsafe.Pointer(idSnapshot)), 0, 0) } return newVssErrorIfResultNotOK("AddToSnapshotSet() failed", HRESULT(result)) @@ -535,6 +535,13 @@ func vssFreeSnapshotProperties(properties *VssSnapshotProperties) error { return nil } +func vssFreeProviderProperties(p *VssProviderProperties) { + ole.CoTaskMemFree(uintptr(unsafe.Pointer(p.providerName))) + p.providerName = nil + ole.CoTaskMemFree(uintptr(unsafe.Pointer(p.providerVersion))) + p.providerName = nil +} + // BackupComplete calls the equivalent VSS api. func (vss *IVssBackupComponents) BackupComplete() (*IVSSAsync, error) { var oleIUnknown *ole.IUnknown @@ -563,6 +570,17 @@ type VssSnapshotProperties struct { status uint } +// VssProviderProperties defines the properties of a VSS provider as part of the VSS api. +// nolint:structcheck +type VssProviderProperties struct { + providerID ole.GUID + providerName *uint16 + providerType uint32 + providerVersion *uint16 + providerVersionID ole.GUID + classID ole.GUID +} + // GetSnapshotDeviceObject returns root path to access the snapshot files // and folders. func (p *VssSnapshotProperties) GetSnapshotDeviceObject() string { @@ -660,6 +678,75 @@ func (vssAsync *IVSSAsync) WaitUntilAsyncFinished(timeout time.Duration) error { return nil } +// UIID_IVSS_ADMIN defines the GUID of IVSSAdmin. +var ( + UIID_IVSS_ADMIN = ole.NewGUID("{77ED5996-2F63-11d3-8A39-00C04F72D8E3}") + CLSID_VSS_COORDINATOR = ole.NewGUID("{E579AB5F-1CC4-44b4-BED9-DE0991FF0623}") +) + +// IVSSAdmin VSS api interface. +type IVSSAdmin struct { + ole.IUnknown +} + +// IVSSAdminVTable is the vtable for IVSSAdmin. +// nolint:structcheck +type IVSSAdminVTable struct { + ole.IUnknownVtbl + registerProvider uintptr + unregisterProvider uintptr + queryProviders uintptr + abortAllSnapshotsInProgress uintptr +} + +// getVTable returns the vtable for IVSSAdmin. +func (vssAdmin *IVSSAdmin) getVTable() *IVSSAdminVTable { + return (*IVSSAdminVTable)(unsafe.Pointer(vssAdmin.RawVTable)) +} + +// QueryProviders calls the equivalent VSS api. +func (vssAdmin *IVSSAdmin) QueryProviders() (*IVssEnumObject, error) { + var enum *IVssEnumObject + + result, _, _ := syscall.Syscall(vssAdmin.getVTable().queryProviders, 2, + uintptr(unsafe.Pointer(vssAdmin)), uintptr(unsafe.Pointer(&enum)), 0) + + return enum, newVssErrorIfResultNotOK("QueryProviders() failed", HRESULT(result)) +} + +// IVssEnumObject VSS api interface. +type IVssEnumObject struct { + ole.IUnknown +} + +// IVssEnumObjectVTable is the vtable for IVssEnumObject. +// nolint:structcheck +type IVssEnumObjectVTable struct { + ole.IUnknownVtbl + next uintptr + skip uintptr + reset uintptr + clone uintptr +} + +// getVTable returns the vtable for IVssEnumObject. +func (vssEnum *IVssEnumObject) getVTable() *IVssEnumObjectVTable { + return (*IVssEnumObjectVTable)(unsafe.Pointer(vssEnum.RawVTable)) +} + +// Next calls the equivalent VSS api. +func (vssEnum *IVssEnumObject) Next(count uint, props unsafe.Pointer) (uint, error) { + var fetched uint32 + result, _, _ := syscall.Syscall6(vssEnum.getVTable().next, 4, + uintptr(unsafe.Pointer(vssEnum)), uintptr(count), uintptr(props), + uintptr(unsafe.Pointer(&fetched)), 0, 0) + if result == 1 { + return uint(fetched), nil + } + + return uint(fetched), newVssErrorIfResultNotOK("Next() failed", HRESULT(result)) +} + // MountPoint wraps all information of a snapshot of a mountpoint on a volume. type MountPoint struct { isSnapshotted bool @@ -766,7 +853,7 @@ func GetVolumeNameForVolumeMountPoint(mountPoint string) (string, error) { // NewVssSnapshot creates a new vss snapshot. If creating the snapshots doesn't // finish within the timeout an error is returned. -func NewVssSnapshot( +func NewVssSnapshot(provider string, volume string, timeout time.Duration, filter VolumeFilter, msgError ErrorHandler) (VssSnapshot, error) { is64Bit, err := isRunningOn64BitWindows() if err != nil { @@ -814,6 +901,12 @@ func NewVssSnapshot( iVssBackupComponents := (*IVssBackupComponents)(unsafe.Pointer(comInterface)) + providerID, err := getProviderID(provider) + if err != nil { + iVssBackupComponents.Release() + return VssSnapshot{}, err + } + if err := iVssBackupComponents.InitializeForBackup(); err != nil { iVssBackupComponents.Release() return VssSnapshot{}, err @@ -838,7 +931,7 @@ func NewVssSnapshot( return VssSnapshot{}, err } - if isSupported, err := iVssBackupComponents.IsVolumeSupported(volume); err != nil { + if isSupported, err := iVssBackupComponents.IsVolumeSupported(providerID, volume); err != nil { iVssBackupComponents.Release() return VssSnapshot{}, err } else if !isSupported { @@ -853,7 +946,7 @@ func NewVssSnapshot( return VssSnapshot{}, err } - if err := iVssBackupComponents.AddToSnapshotSet(volume, &snapshotSetID); err != nil { + if err := iVssBackupComponents.AddToSnapshotSet(volume, providerID, &snapshotSetID); err != nil { iVssBackupComponents.Release() return VssSnapshot{}, err } @@ -877,14 +970,14 @@ func NewVssSnapshot( if !filter(mountPoint) { continue - } else if isSupported, err := iVssBackupComponents.IsVolumeSupported(mountPoint); err != nil { + } else if isSupported, err := iVssBackupComponents.IsVolumeSupported(providerID, mountPoint); err != nil { continue } else if !isSupported { continue } var mountPointSnapshotSetID ole.GUID - err := iVssBackupComponents.AddToSnapshotSet(mountPoint, &mountPointSnapshotSetID) + err := iVssBackupComponents.AddToSnapshotSet(mountPoint, providerID, &mountPointSnapshotSetID) if err != nil { iVssBackupComponents.Release() @@ -988,6 +1081,55 @@ func (p *VssSnapshot) Delete() error { return nil } +func getProviderID(provider string) (*ole.GUID, error) { + comInterface, err := ole.CreateInstance(CLSID_VSS_COORDINATOR, UIID_IVSS_ADMIN) + if err != nil { + return nil, err + } + defer comInterface.Release() + + vssAdmin := (*IVSSAdmin)(unsafe.Pointer(comInterface)) + + providerLower := strings.ToLower(provider) + switch providerLower { + case "": + return ole.IID_NULL, nil + case "ms": + return ole.NewGUID("{b5946137-7b9f-4925-af80-51abd60b20d5}"), nil + } + + enum, err := vssAdmin.QueryProviders() + if err != nil { + return nil, err + } + defer enum.Release() + + id := ole.NewGUID(provider) + + var props struct { + objectType uint32 + provider VssProviderProperties + } + for { + count, err := enum.Next(1, unsafe.Pointer(&props)) + if err != nil { + return nil, err + } + + if count < 1 { + return nil, errors.Errorf(`invalid VSS provider "%s"`, provider) + } + + name := ole.UTF16PtrToString(props.provider.providerName) + vssFreeProviderProperties(&props.provider) + + if id != nil && *id == props.provider.providerID || + id == nil && providerLower == strings.ToLower(name) { + return &props.provider.providerID, nil + } + } +} + // asyncCallFunc is the callback type for callAsyncFunctionAndWait. type asyncCallFunc func() (*IVSSAsync, error)