summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJesse Duffield <jessedduffield@gmail.com>2021-11-07 13:25:06 +1100
committerJesse Duffield <jessedduffield@gmail.com>2021-11-07 13:32:36 +1100
commit3fb478a30e5c88af10c6e1e90eb22c99113f2967 (patch)
treef54ec583dffcc5b3a747a06aeb0f7daf6ac268a5
parent5d12a6bf9905f90fad6f32332f034bed7706c2f2 (diff)
add testsv0.31.2
-rw-r--r--pkg/tasks/tasks.go120
-rw-r--r--pkg/tasks/tasks_test.go136
2 files changed, 194 insertions, 62 deletions
diff --git a/pkg/tasks/tasks.go b/pkg/tasks/tasks.go
index 4d9fe67e3..4a987039c 100644
--- a/pkg/tasks/tasks.go
+++ b/pkg/tasks/tasks.go
@@ -19,18 +19,14 @@ const THROTTLE_TIME = time.Millisecond * 30
// we use this to check if the system is under stress right now. Hopefully this makes sense on other machines
const COMMAND_START_THRESHOLD = time.Millisecond * 10
-type Task struct {
- stop chan struct{}
- stopped bool
- stopMutex sync.Mutex
- notifyStopped chan struct{}
- Log *logrus.Entry
- f func(chan struct{}) error
-}
-
type ViewBufferManager struct {
- writer io.Writer
- currentTask *Task
+ // this blocks until the task has been properly stopped
+ stopCurrentTask func()
+
+ // this is what we write the output of the task to. It's typically a view
+ writer io.Writer
+
+ // this is for when we wait to get
waitingMutex sync.Mutex
taskIDMutex sync.Mutex
Log *logrus.Entry
@@ -80,8 +76,15 @@ func (m *ViewBufferManager) ReadLines(n int) {
})
}
+// note: onDone may be called twice
func (m *ViewBufferManager) NewCmdTask(start func() (*exec.Cmd, io.Reader), prefix string, linesToRead int, onDone func()) func(chan struct{}) error {
return func(stop chan struct{}) error {
+ var once sync.Once
+ var onDoneWrapper func()
+ if onDone != nil {
+ onDoneWrapper = func() { once.Do(onDone) }
+ }
+
if m.throttle {
m.Log.Info("throttling task")
time.Sleep(THROTTLE_TIME)
@@ -110,8 +113,9 @@ func (m *ViewBufferManager) NewCmdTask(start func() (*exec.Cmd, io.Reader), pref
}
}
- if onDone != nil {
- onDone()
+ // for pty's we need to call onDone here so that cmd.Wait() doesn't block forever
+ if onDoneWrapper != nil {
+ onDoneWrapper()
}
})
@@ -122,34 +126,42 @@ func (m *ViewBufferManager) NewCmdTask(start func() (*exec.Cmd, io.Reader), pref
done := make(chan struct{})
- go utils.Safe(func() {
- scanner := bufio.NewScanner(r)
- scanner.Split(bufio.ScanLines)
+ scanner := bufio.NewScanner(r)
+ scanner.Split(bufio.ScanLines)
- loaded := false
+ loaded := false
- go utils.Safe(func() {
- ticker := time.NewTicker(time.Millisecond * 200)
- defer ticker.Stop()
- select {
- case <-ticker.C:
- loadingMutex.Lock()
- if !loaded {
- m.beforeStart()
- _, _ = m.writer.Write([]byte("loading..."))
- m.refreshView()
- }
- loadingMutex.Unlock()
- case <-stop:
- return
+ go utils.Safe(func() {
+ ticker := time.NewTicker(time.Millisecond * 200)
+ defer ticker.Stop()
+ select {
+ case <-stop:
+ return
+ case <-ticker.C:
+ loadingMutex.Lock()
+ if !loaded {
+ m.beforeStart()
+ _, _ = m.writer.Write([]byte("loading..."))
+ m.refreshView()
}
- })
+ loadingMutex.Unlock()
+ }
+ })
+ go utils.Safe(func() {
outer:
for {
select {
+ case <-stop:
+ break outer
case linesToRead := <-m.readLines:
for i := 0; i < linesToRead; i++ {
+ select {
+ case <-stop:
+ break outer
+ default:
+ }
+
ok := scanner.Scan()
loadingMutex.Lock()
if !loaded {
@@ -161,11 +173,6 @@ func (m *ViewBufferManager) NewCmdTask(start func() (*exec.Cmd, io.Reader), pref
}
loadingMutex.Unlock()
- select {
- case <-stop:
- break outer
- default:
- }
if !ok {
// if we're here then there's nothing left to scan from the source
// so we're at the EOF and can flush the stale content
@@ -175,8 +182,6 @@ func (m *ViewBufferManager) NewCmdTask(start func() (*exec.Cmd, io.Reader), pref
_, _ = m.writer.Write(append(scanner.Bytes(), '\n'))
}
m.refreshView()
- case <-stop:
- break outer
}
}
@@ -189,8 +194,9 @@ func (m *ViewBufferManager) NewCmdTask(start func() (*exec.Cmd, io.Reader), pref
}
}
- if onDone != nil {
- onDone()
+ // calling onDoneWrapper here again in case the program ended on its own accord
+ if onDoneWrapper != nil {
+ onDoneWrapper()
}
close(done)
@@ -206,14 +212,14 @@ func (m *ViewBufferManager) NewCmdTask(start func() (*exec.Cmd, io.Reader), pref
// Close closes the task manager, killing whatever task may currently be running
func (t *ViewBufferManager) Close() {
- if t.currentTask == nil {
+ if t.stopCurrentTask == nil {
return
}
c := make(chan struct{})
go utils.Safe(func() {
- t.currentTask.Stop()
+ t.stopCurrentTask()
c <- struct{}{}
})
@@ -249,19 +255,20 @@ func (m *ViewBufferManager) NewTask(f func(stop chan struct{}) error, key string
return
}
+ if m.stopCurrentTask != nil {
+ m.stopCurrentTask()
+ }
+
stop := make(chan struct{})
notifyStopped := make(chan struct{})
- if m.currentTask != nil {
- m.currentTask.Stop()
+ var once sync.Once
+ onStop := func() {
+ close(stop)
+ <-notifyStopped
}
- m.currentTask = &Task{
- stop: stop,
- notifyStopped: notifyStopped,
- Log: m.Log,
- f: f,
- }
+ m.stopCurrentTask = func() { once.Do(onStop) }
go utils.Safe(func() {
if err := f(stop); err != nil {
@@ -274,14 +281,3 @@ func (m *ViewBufferManager) NewTask(f func(stop chan struct{}) error, key string
return nil
}
-
-func (t *Task) Stop() {
- t.stopMutex.Lock()
- defer t.stopMutex.Unlock()
- if t.stopped {
- return
- }
- close(t.stop)
- <-t.notifyStopped
- t.stopped = true
-}
diff --git a/pkg/tasks/tasks_test.go b/pkg/tasks/tasks_test.go
new file mode 100644
index 000000000..d580c95f5
--- /dev/null
+++ b/pkg/tasks/tasks_test.go
@@ -0,0 +1,136 @@
+package tasks
+
+import (
+ "bytes"
+ "io"
+ "os/exec"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/jesseduffield/lazygit/pkg/secureexec"
+ "github.com/jesseduffield/lazygit/pkg/utils"
+)
+
+func getCounter() (func(), func() int) {
+ counter := 0
+ return func() { counter++ }, func() int { return counter }
+}
+
+func TestNewCmdTaskInstantStop(t *testing.T) {
+ writer := bytes.NewBuffer(nil)
+ beforeStart, getBeforeStartCallCount := getCounter()
+ refreshView, getRefreshViewCallCount := getCounter()
+ onEndOfInput, getOnEndOfInputCallCount := getCounter()
+ onNewKey, getOnNewKeyCallCount := getCounter()
+ onDone, getOnDoneCallCount := getCounter()
+
+ manager := NewViewBufferManager(
+ utils.NewDummyLog(),
+ writer,
+ beforeStart,
+ refreshView,
+ onEndOfInput,
+ onNewKey,
+ )
+
+ stop := make(chan struct{})
+ reader := bytes.NewBufferString("test")
+ start := func() (*exec.Cmd, io.Reader) {
+ // not actually starting this because it's not necessary
+ cmd := secureexec.Command("blah blah")
+
+ close(stop)
+
+ return cmd, reader
+ }
+
+ fn := manager.NewCmdTask(start, "prefix\n", 20, onDone)
+
+ _ = fn(stop)
+
+ callCountExpectations := []struct {
+ expected int
+ actual int
+ name string
+ }{
+ {0, getBeforeStartCallCount(), "beforeStart"},
+ {1, getRefreshViewCallCount(), "refreshView"},
+ {0, getOnEndOfInputCallCount(), "onEndOfInput"},
+ {0, getOnNewKeyCallCount(), "onNewKey"},
+ {1, getOnDoneCallCount(), "onDone"},
+ }
+ for _, expectation := range callCountExpectations {
+ if expectation.actual != expectation.expected {
+ t.Errorf("expected %s to be called %d times, got %d", expectation.name, expectation.expected, expectation.actual)
+ }
+ }
+
+ expectedContent := ""
+ actualContent := writer.String()
+ if actualContent != expectedContent {
+ t.Errorf("expected writer to receive the following content: \n%s\n. But instead it recevied: %s", expectedContent, actualContent)
+ }
+}
+
+func TestNewCmdTask(t *testing.T) {
+ writer := bytes.NewBuffer(nil)
+ beforeStart, getBeforeStartCallCount := getCounter()
+ refreshView, getRefreshViewCallCount := getCounter()
+ onEndOfInput, getOnEndOfInputCallCount := getCounter()
+ onNewKey, getOnNewKeyCallCount := getCounter()
+ onDone, getOnDoneCallCount := getCounter()
+
+ manager := NewViewBufferManager(
+ utils.NewDummyLog(),
+ writer,
+ beforeStart,
+ refreshView,
+ onEndOfInput,
+ onNewKey,
+ )
+
+ stop := make(chan struct{})
+ reader := bytes.NewBufferString("test")
+ start := func() (*exec.Cmd, io.Reader) {
+ // not actually starting this because it's not necessary
+ cmd := secureexec.Command("blah blah")
+
+ return cmd, reader
+ }
+
+ fn := manager.NewCmdTask(start, "prefix\n", 20, onDone)
+ wg := sync.WaitGroup{}
+ wg.Add(1)
+ go func() {
+ time.Sleep(100 * time.Millisecond)
+ close(stop)
+ wg.Done()
+ }()
+ _ = fn(stop)
+
+ wg.Wait()
+
+ callCountExpectations := []struct {
+ expected int
+ actual int
+ name string
+ }{
+ {1, getBeforeStartCallCount(), "beforeStart"},
+ {1, getRefreshViewCallCount(), "refreshView"},
+ {1, getOnEndOfInputCallCount(), "onEndOfInput"},
+ {0, getOnNewKeyCallCount(), "onNewKey"},
+ {1, getOnDoneCallCount(), "onDone"},
+ }
+ for _, expectation := range callCountExpectations {
+ if expectation.actual != expectation.expected {
+ t.Errorf("expected %s to be called %d times, got %d", expectation.name, expectation.expected, expectation.actual)
+ }
+ }
+
+ expectedContent := "prefix\ntest\n"
+ actualContent := writer.String()
+ if actualContent != expectedContent {
+ t.Errorf("expected writer to receive the following content: \n%s\n. But instead it recevied: %s", expectedContent, actualContent)
+ }
+}