From e499bbe3ae747541b56572271912df17e910afd1 Mon Sep 17 00:00:00 2001 From: Michael Eischer Date: Thu, 29 Dec 2022 12:29:46 +0100 Subject: [PATCH 1/2] progress: extract progress updating into Updater struct This allows reusing the code to create periodic progress updates. --- cmd/restic/progress.go | 2 +- internal/restic/find_test.go | 4 +- internal/ui/progress/counter.go | 71 ++++------------------- internal/ui/progress/counter_test.go | 22 +------- internal/ui/progress/updater.go | 84 ++++++++++++++++++++++++++++ internal/ui/progress/updater_test.go | 52 +++++++++++++++++ 6 files changed, 152 insertions(+), 83 deletions(-) create mode 100644 internal/ui/progress/updater.go create mode 100644 internal/ui/progress/updater_test.go diff --git a/cmd/restic/progress.go b/cmd/restic/progress.go index 3caa34a26..4b6025a54 100644 --- a/cmd/restic/progress.go +++ b/cmd/restic/progress.go @@ -37,7 +37,7 @@ func newProgressMax(show bool, max uint64, description string) *progress.Counter interval := calculateProgressInterval(show, false) canUpdateStatus := stdoutCanUpdateStatus() - return progress.New(interval, max, func(v uint64, max uint64, d time.Duration, final bool) { + return progress.NewCounter(interval, max, func(v uint64, max uint64, d time.Duration, final bool) { var status string if max == 0 { status = fmt.Sprintf("[%s] %d %s", diff --git a/internal/restic/find_test.go b/internal/restic/find_test.go index 80f616513..f5e288b9d 100644 --- a/internal/restic/find_test.go +++ b/internal/restic/find_test.go @@ -93,7 +93,7 @@ func TestFindUsedBlobs(t *testing.T) { snapshots = append(snapshots, sn) } - p := progress.New(time.Second, findTestSnapshots, func(value uint64, total uint64, runtime time.Duration, final bool) {}) + p := progress.NewCounter(time.Second, findTestSnapshots, func(value uint64, total uint64, runtime time.Duration, final bool) {}) defer p.Done() for i, sn := range snapshots { @@ -142,7 +142,7 @@ func TestMultiFindUsedBlobs(t *testing.T) { want.Merge(loadIDSet(t, goldenFilename)) } - p := progress.New(time.Second, findTestSnapshots, func(value uint64, total uint64, runtime time.Duration, final bool) {}) + p := progress.NewCounter(time.Second, findTestSnapshots, func(value uint64, total uint64, runtime time.Duration, final bool) {}) defer p.Done() // run twice to check progress bar handling of duplicate tree roots diff --git a/internal/ui/progress/counter.go b/internal/ui/progress/counter.go index 90a09d0d8..c1275d2f2 100644 --- a/internal/ui/progress/counter.go +++ b/internal/ui/progress/counter.go @@ -3,9 +3,6 @@ package progress import ( "sync" "time" - - "github.com/restic/restic/internal/debug" - "github.com/restic/restic/internal/ui/signals" ) // A Func is a callback for a Counter. @@ -19,32 +16,22 @@ type Func func(value uint64, total uint64, runtime time.Duration, final bool) // // The Func is also called when SIGUSR1 (or SIGINFO, on BSD) is received. type Counter struct { - report Func - start time.Time - stopped chan struct{} // Closed by run. - stop chan struct{} // Close to stop run. - tick *time.Ticker + Updater valueMutex sync.Mutex value uint64 max uint64 } -// New starts a new Counter. -func New(interval time.Duration, total uint64, report Func) *Counter { +// NewCounter starts a new Counter. +func NewCounter(interval time.Duration, total uint64, report Func) *Counter { c := &Counter{ - report: report, - start: time.Now(), - stopped: make(chan struct{}), - stop: make(chan struct{}), - max: total, + max: total, } - - if interval > 0 { - c.tick = time.NewTicker(interval) - } - - go c.run() + c.Updater = *NewUpdater(interval, func(runtime time.Duration, final bool) { + v, max := c.Get() + report(v, max, runtime, final) + }) return c } @@ -69,18 +56,6 @@ func (c *Counter) SetMax(max uint64) { c.valueMutex.Unlock() } -// Done tells a Counter to stop and waits for it to report its final value. -func (c *Counter) Done() { - if c == nil { - return - } - if c.tick != nil { - c.tick.Stop() - } - close(c.stop) - <-c.stopped // Wait for last progress report. -} - // Get returns the current value and the maximum of c. // This method is concurrency-safe. func (c *Counter) Get() (v, max uint64) { @@ -91,32 +66,8 @@ func (c *Counter) Get() (v, max uint64) { return v, max } -func (c *Counter) run() { - defer close(c.stopped) - defer func() { - // Must be a func so that time.Since isn't called at defer time. - v, max := c.Get() - c.report(v, max, time.Since(c.start), true) - }() - - var tick <-chan time.Time - if c.tick != nil { - tick = c.tick.C - } - signalsCh := signals.GetProgressChannel() - for { - var now time.Time - - select { - case now = <-tick: - case sig := <-signalsCh: - debug.Log("Signal received: %v\n", sig) - now = time.Now() - case <-c.stop: - return - } - - v, max := c.Get() - c.report(v, max, now.Sub(c.start), false) +func (c *Counter) Done() { + if c != nil { + c.Updater.Done() } } diff --git a/internal/ui/progress/counter_test.go b/internal/ui/progress/counter_test.go index 85695d209..49c694e06 100644 --- a/internal/ui/progress/counter_test.go +++ b/internal/ui/progress/counter_test.go @@ -35,7 +35,7 @@ func TestCounter(t *testing.T) { lastTotal = total ncalls++ } - c := progress.New(10*time.Millisecond, startTotal, report) + c := progress.NewCounter(10*time.Millisecond, startTotal, report) done := make(chan struct{}) go func() { @@ -63,24 +63,6 @@ func TestCounterNil(t *testing.T) { // Shouldn't panic. var c *progress.Counter c.Add(1) + c.SetMax(42) c.Done() } - -func TestCounterNoTick(t *testing.T) { - finalSeen := false - otherSeen := false - - report := func(value, total uint64, d time.Duration, final bool) { - if final { - finalSeen = true - } else { - otherSeen = true - } - } - c := progress.New(0, 1, report) - time.Sleep(time.Millisecond) - c.Done() - - test.Assert(t, finalSeen, "final call did not happen") - test.Assert(t, !otherSeen, "unexpected status update") -} diff --git a/internal/ui/progress/updater.go b/internal/ui/progress/updater.go new file mode 100644 index 000000000..7fb6c8836 --- /dev/null +++ b/internal/ui/progress/updater.go @@ -0,0 +1,84 @@ +package progress + +import ( + "time" + + "github.com/restic/restic/internal/debug" + "github.com/restic/restic/internal/ui/signals" +) + +// An UpdateFunc is a callback for a (progress) Updater. +// +// The final argument is true if Updater.Done has been called, +// which means that the current call will be the last. +type UpdateFunc func(runtime time.Duration, final bool) + +// An Updater controls a goroutine that periodically calls an UpdateFunc. +// +// The UpdateFunc is also called when SIGUSR1 (or SIGINFO, on BSD) is received. +type Updater struct { + report UpdateFunc + start time.Time + stopped chan struct{} // Closed by run. + stop chan struct{} // Close to stop run. + tick *time.Ticker +} + +// NewUpdater starts a new Updater. +func NewUpdater(interval time.Duration, report UpdateFunc) *Updater { + c := &Updater{ + report: report, + start: time.Now(), + stopped: make(chan struct{}), + stop: make(chan struct{}), + } + + if interval > 0 { + c.tick = time.NewTicker(interval) + } + + go c.run() + return c +} + +// Done tells an Updater to stop and waits for it to report its final value. +// Later calls do nothing. +func (c *Updater) Done() { + if c == nil || c.stop == nil { + return + } + if c.tick != nil { + c.tick.Stop() + } + close(c.stop) + <-c.stopped // Wait for last progress report. + c.stop = nil +} + +func (c *Updater) run() { + defer close(c.stopped) + defer func() { + // Must be a func so that time.Since isn't called at defer time. + c.report(time.Since(c.start), true) + }() + + var tick <-chan time.Time + if c.tick != nil { + tick = c.tick.C + } + signalsCh := signals.GetProgressChannel() + for { + var now time.Time + + select { + case now = <-tick: + case sig := <-signalsCh: + debug.Log("Signal received: %v\n", sig) + now = time.Now() + case <-c.stop: + return + } + + c.report(now.Sub(c.start), false) + } +} diff --git a/internal/ui/progress/updater_test.go b/internal/ui/progress/updater_test.go new file mode 100644 index 000000000..5b5207dd5 --- /dev/null +++ b/internal/ui/progress/updater_test.go @@ -0,0 +1,52 @@ +package progress_test + +import ( + "testing" + "time" + + "github.com/restic/restic/internal/test" + "github.com/restic/restic/internal/ui/progress" +) + +func TestUpdater(t *testing.T) { + finalSeen := false + var ncalls int + + report := func(d time.Duration, final bool) { + if final { + finalSeen = true + } + ncalls++ + } + c := progress.NewUpdater(10*time.Millisecond, report) + time.Sleep(100 * time.Millisecond) + c.Done() + + test.Assert(t, finalSeen, "final call did not happen") + test.Assert(t, ncalls > 0, "no progress was reported") +} + +func TestUpdaterStopTwice(t *testing.T) { + c := progress.NewUpdater(0, func(runtime time.Duration, final bool) {}) + c.Done() + c.Done() +} + +func TestUpdaterNoTick(t *testing.T) { + finalSeen := false + otherSeen := false + + report := func(d time.Duration, final bool) { + if final { + finalSeen = true + } else { + otherSeen = true + } + } + c := progress.NewUpdater(0, report) + time.Sleep(time.Millisecond) + c.Done() + + test.Assert(t, finalSeen, "final call did not happen") + test.Assert(t, !otherSeen, "unexpected status update") +} From 4a7a6b06afbb776328353740959765c50a0d84ee Mon Sep 17 00:00:00 2001 From: Michael Eischer Date: Thu, 29 Dec 2022 12:31:20 +0100 Subject: [PATCH 2/2] ui/backup: Use progress.Updater for progress updates --- cmd/restic/cmd_backup.go | 10 ++-- internal/ui/backup/progress.go | 80 +++++++++-------------------- internal/ui/backup/progress_test.go | 5 -- 3 files changed, 30 insertions(+), 65 deletions(-) diff --git a/cmd/restic/cmd_backup.go b/cmd/restic/cmd_backup.go index 683ce9268..e59f503db 100644 --- a/cmd/restic/cmd_backup.go +++ b/cmd/restic/cmd_backup.go @@ -483,16 +483,12 @@ func runBackup(ctx context.Context, opts BackupOptions, gopts GlobalOptions, ter } progressReporter := backup.NewProgress(progressPrinter, calculateProgressInterval(!gopts.Quiet, gopts.JSON)) + defer progressReporter.Done() if opts.DryRun { repo.SetDryRun() } - wg, wgCtx := errgroup.WithContext(ctx) - cancelCtx, cancel := context.WithCancel(wgCtx) - defer cancel() - wg.Go(func() error { progressReporter.Run(cancelCtx); return nil }) - if !gopts.JSON { progressPrinter.V("lock repository") } @@ -590,6 +586,10 @@ func runBackup(ctx context.Context, opts BackupOptions, gopts GlobalOptions, ter targets = []string{filename} } + wg, wgCtx := errgroup.WithContext(ctx) + cancelCtx, cancel := context.WithCancel(wgCtx) + defer cancel() + if !opts.NoScan { sc := archiver.NewScanner(targetFS) sc.SelectByName = selectByNameFilter diff --git a/internal/ui/backup/progress.go b/internal/ui/backup/progress.go index 746fe7a49..71facda4c 100644 --- a/internal/ui/backup/progress.go +++ b/internal/ui/backup/progress.go @@ -1,13 +1,12 @@ package backup import ( - "context" "sync" "time" "github.com/restic/restic/internal/archiver" "github.com/restic/restic/internal/restic" - "github.com/restic/restic/internal/ui/signals" + "github.com/restic/restic/internal/ui/progress" ) // A ProgressPrinter can print various progress messages. @@ -41,10 +40,10 @@ type Summary struct { // Progress reports progress for the `backup` command. type Progress struct { + progress.Updater mu sync.Mutex - interval time.Duration - start time.Time + start time.Time scanStarted, scanFinished bool @@ -52,66 +51,37 @@ type Progress struct { processed, total Counter errors uint - closed chan struct{} - summary Summary printer ProgressPrinter } func NewProgress(printer ProgressPrinter, interval time.Duration) *Progress { - return &Progress{ - interval: interval, - start: time.Now(), - + p := &Progress{ + start: time.Now(), currentFiles: make(map[string]struct{}), - closed: make(chan struct{}), - - printer: printer, + printer: printer, } -} + p.Updater = *progress.NewUpdater(interval, func(runtime time.Duration, final bool) { + if final { + p.printer.Reset() + } else { + p.mu.Lock() + defer p.mu.Unlock() + if !p.scanStarted { + return + } -// Run regularly updates the status lines. It should be called in a separate -// goroutine. -func (p *Progress) Run(ctx context.Context) { - defer close(p.closed) - // Reset status when finished - defer p.printer.Reset() + var secondsRemaining uint64 + if p.scanFinished { + secs := float64(runtime / time.Second) + todo := float64(p.total.Bytes - p.processed.Bytes) + secondsRemaining = uint64(secs / float64(p.processed.Bytes) * todo) + } - var tick <-chan time.Time - if p.interval != 0 { - t := time.NewTicker(p.interval) - defer t.Stop() - tick = t.C - } - - signalsCh := signals.GetProgressChannel() - - for { - var now time.Time - select { - case <-ctx.Done(): - return - case now = <-tick: - case <-signalsCh: - now = time.Now() + p.printer.Update(p.total, p.processed, p.errors, p.currentFiles, p.start, secondsRemaining) } - - p.mu.Lock() - if !p.scanStarted { - p.mu.Unlock() - continue - } - - var secondsRemaining uint64 - if p.scanFinished { - secs := float64(now.Sub(p.start) / time.Second) - todo := float64(p.total.Bytes - p.processed.Bytes) - secondsRemaining = uint64(secs / float64(p.processed.Bytes) * todo) - } - - p.printer.Update(p.total, p.processed, p.errors, p.currentFiles, p.start, secondsRemaining) - p.mu.Unlock() - } + }) + return p } // Error is the error callback function for the archiver, it prints the error and returns nil. @@ -236,6 +206,6 @@ func (p *Progress) ReportTotal(item string, s archiver.ScanStats) { // Finish prints the finishing messages. func (p *Progress) Finish(snapshotID restic.ID, dryrun bool) { // wait for the status update goroutine to shut down - <-p.closed + p.Updater.Done() p.printer.Finish(snapshotID, p.start, &p.summary, dryrun) } diff --git a/internal/ui/backup/progress_test.go b/internal/ui/backup/progress_test.go index e0dc093d2..a7282c7da 100644 --- a/internal/ui/backup/progress_test.go +++ b/internal/ui/backup/progress_test.go @@ -1,7 +1,6 @@ package backup import ( - "context" "sync" "testing" "time" @@ -53,9 +52,6 @@ func TestProgress(t *testing.T) { prnt := &mockPrinter{} prog := NewProgress(prnt, time.Millisecond) - ctx, cancel := context.WithCancel(context.Background()) - go prog.Run(ctx) - prog.StartFile("foo") prog.CompleteBlob(1024) @@ -67,7 +63,6 @@ func TestProgress(t *testing.T) { prog.CompleteItem("foo", nil, &node, archiver.ItemStats{}, 0) time.Sleep(10 * time.Millisecond) - cancel() id := restic.NewRandomID() prog.Finish(id, false)