diff options
43 files changed, 237 insertions, 154 deletions
diff --git a/pkg/gui/context/branches_context.go b/pkg/gui/context/branches_context.go index 5905168ea..d2647ef84 100644 --- a/pkg/gui/context/branches_context.go +++ b/pkg/gui/context/branches_context.go @@ -59,15 +59,6 @@ func NewBranchesContext(c *ContextCommon) *BranchesContext { return self } -func (self *BranchesContext) GetSelectedItemId() string { - item := self.GetSelected() - if item == nil { - return "" - } - - return item.ID() -} - func (self *BranchesContext) GetSelectedRef() types.Ref { branch := self.GetSelected() if branch == nil { diff --git a/pkg/gui/context/commit_files_context.go b/pkg/gui/context/commit_files_context.go index fbfff7144..7af968fb7 100644 --- a/pkg/gui/context/commit_files_context.go +++ b/pkg/gui/context/commit_files_context.go @@ -72,15 +72,6 @@ func NewCommitFilesContext(c *ContextCommon) *CommitFilesContext { return ctx } -func (self *CommitFilesContext) GetSelectedItemId() string { - item := self.GetSelected() - if item == nil { - return "" - } - - return item.ID() -} - func (self *CommitFilesContext) GetDiffTerminals() []string { return []string{self.GetRef().RefName()} } diff --git a/pkg/gui/context/filtered_list_view_model.go b/pkg/gui/context/filtered_list_view_model.go index c8abbe4a1..2c2841964 100644 --- a/pkg/gui/context/filtered_list_view_model.go +++ b/pkg/gui/context/filtered_list_view_model.go @@ -1,12 +1,12 @@ package context -type FilteredListViewModel[T any] struct { +type FilteredListViewModel[T HasID] struct { *FilteredList[T] *ListViewModel[T] *SearchHistory } -func NewFilteredListViewModel[T any](getList func() []T, getFilterFields func(T) []string) *FilteredListViewModel[T] { +func NewFilteredListViewModel[T HasID](getList func() []T, getFilterFields func(T) []string) *FilteredListViewModel[T] { filteredList := NewFilteredList(getList, getFilterFields) self := &FilteredListViewModel[T]{ diff --git a/pkg/gui/context/list_renderer_test.go b/pkg/gui/context/list_renderer_test.go index 98e3f60aa..99c476427 100644 --- a/pkg/gui/context/list_renderer_test.go +++ b/pkg/gui/context/list_renderer_test.go @@ -9,10 +9,17 @@ import ( "github.com/stretchr/testify/assert" ) +// wrapping string in my own type to give it an ID method which is required for list items +type mystring string + +func (self mystring) ID() string { + return string(self) +} + func TestListRenderer_renderLines(t *testing.T) { scenarios := []struct { name string - modelStrings []string + modelStrings []mystring nonModelIndices []int startIdx int endIdx int @@ -20,7 +27,7 @@ func TestListRenderer_renderLines(t *testing.T) { }{ { name: "Render whole list", - modelStrings: []string{"a", "b", "c"}, + modelStrings: []mystring{"a", "b", "c"}, startIdx: 0, endIdx: 3, expectedOutput: ` @@ -30,7 +37,7 @@ func TestListRenderer_renderLines(t *testing.T) { }, { name: "Partial list, beginning", - modelStrings: []string{"a", "b", "c"}, + modelStrings: []mystring{"a", "b", "c"}, startIdx: 0, endIdx: 2, expectedOutput: ` @@ -39,7 +46,7 @@ func TestListRenderer_renderLines(t *testing.T) { }, { name: "Partial list, end", - modelStrings: []string{"a", "b", "c"}, + modelStrings: []mystring{"a", "b", "c"}, startIdx: 1, endIdx: 3, expectedOutput: ` @@ -48,7 +55,7 @@ func TestListRenderer_renderLines(t *testing.T) { }, { name: "Pass an endIdx greater than the model length", - modelStrings: []string{"a", "b", "c"}, + modelStrings: []mystring{"a", "b", "c"}, startIdx: 2, endIdx: 5, expectedOutput: ` @@ -56,7 +63,7 @@ func TestListRenderer_renderLines(t *testing.T) { }, { name: "Whole list with section headers", - modelStrings: []string{"a", "b", "c"}, + modelStrings: []mystring{"a", "b", "c"}, nonModelIndices: []int{1, 3}, startIdx: 0, endIdx: 5, @@ -69,7 +76,7 @@ func TestListRenderer_renderLines(t *testing.T) { }, { name: "Multiple consecutive headers", - modelStrings: []string{"a", "b", "c"}, + modelStrings: []mystring{"a", "b", "c"}, nonModelIndices: []int{0, 0, 2, 2, 2}, startIdx: 0, endIdx: 8, @@ -85,7 +92,7 @@ func TestListRenderer_renderLines(t *testing.T) { }, { name: "Partial list with headers, beginning", - modelStrings: []string{"a", "b", "c"}, + modelStrings: []mystring{"a", "b", "c"}, nonModelIndices: []int{1, 3}, startIdx: 0, endIdx: 3, @@ -96,7 +103,7 @@ func TestListRenderer_renderLines(t *testing.T) { }, { name: "Partial list with headers, end (beyond end index)", - modelStrings: []string{"a", "b", "c"}, + modelStrings: []mystring{"a", "b", "c"}, nonModelIndices: []int{1, 3}, startIdx: 2, endIdx: 7, @@ -108,7 +115,7 @@ func TestListRenderer_renderLines(t *testing.T) { } for _, s := range scenarios { t.Run(s.name, func(t *testing.T) { - viewModel := NewListViewModel[string](func() []string { return s.modelStrings }) + viewModel := NewListViewModel[mystring](func() []mystring { return s.modelStrings }) var getNonModelItems func() []*NonModelItem if s.nonModelIndices != nil { getNonModelItems = func() []*NonModelItem { @@ -124,7 +131,7 @@ func TestListRenderer_renderLines(t *testing.T) { list: viewModel, getDisplayStrings: func(startIdx int, endIdx int) [][]string { return lo.Map(s.modelStrings[startIdx:endIdx], - func(s string, _ int) []string { return []string{s} }) + func(s mystring, _ int) []string { return []string{string(s)} }) }, getNonModelItems: getNonModelItems, } @@ -138,6 +145,12 @@ func TestListRenderer_renderLines(t *testing.T) { } } +type myint int + +func (self myint) ID() string { + return fmt.Sprint(int(self)) +} + func TestListRenderer_ModelIndexToViewIndex_and_back(t *testing.T) { scenarios := []struct { name string @@ -222,8 +235,8 @@ func TestListRenderer_ModelIndexToViewIndex_and_back(t *testing.T) { assert.Equal(t, len(s.modelIndices), len(s.expectedViewIndices)) assert.Equal(t, len(s.viewIndices), len(s.expectedModelIndices)) - modelInts := lo.Range(s.numModelItems) - viewModel := NewListViewModel[int](func() []int { return modelInts }) + modelInts := lo.Map(lo.Range(s.numModelItems), func(i int, _ int) myint { return myint(i) }) + viewModel := NewListViewModel[myint](func() []myint { return modelInts }) var getNonModelItems func() []*NonModelItem if s.nonModelIndices != nil { getNonModelItems = func() []*NonModelItem { @@ -236,7 +249,7 @@ func TestListRenderer_ModelIndexToViewIndex_and_back(t *testing.T) { list: viewModel, getDisplayStrings: func(startIdx int, endIdx int) [][]string { return lo.Map(modelInts[startIdx:endIdx], - func(i int, _ int) []string { return []string{fmt.Sprint(i)} }) + func(i myint, _ int) []string { return []string{fmt.Sprint(i)} }) }, getNonModelItems: getNonModelItems, } diff --git a/pkg/gui/context/list_view_model.go b/pkg/gui/context/list_view_model.go index 22416bff1..bf8c80e23 100644 --- a/pkg/gui/context/list_view_model.go +++ b/pkg/gui/context/list_view_model.go @@ -3,14 +3,19 @@ package context import ( "github.com/jesseduffield/lazygit/pkg/gui/context/traits" "github.com/jesseduffield/lazygit/pkg/gui/types" + "github.com/samber/lo" ) -type ListViewModel[T any] struct { +type HasID interface { + ID() string +} + +type ListViewModel[T HasID] struct { *traits.ListCursor getModel func() []T } -func NewListViewModel[T any](getModel func() []T) *ListViewModel[T] { +func NewListViewModel[T HasID](getModel func() []T) *ListViewModel[T] { self := &ListViewModel[T]{ getModel: getModel, } @@ -32,6 +37,34 @@ func (self *ListViewModel[T]) GetSelected() T { return self.getModel()[self.GetSelectedLineIdx()] } +func (self *ListViewModel[T]) GetSelectedItemId() string { + if self.Len() == 0 { + return "" + } + + return self.GetSelected().ID() +} + +func (self *ListViewModel[T]) GetSelectedItems() ([]T, int, int) { + if self.Len() == 0 { + return nil, -1, -1 + } + + startIdx, endIdx := self.GetSelectionRange() + + return self.getModel()[startIdx : endIdx+1], startIdx, endIdx +} + +func (self *ListViewModel[T]) GetSelectedItemIds() ([]string, int, int) { + selectedItems, startIdx, endIdx := self.GetSelectedItems() + + ids := lo.Map(selectedItems, func(item T, _ int) string { + return item.ID() + }) + + return ids, startIdx, endIdx +} + func (self *ListViewModel[T]) GetItems() []T { return self.getModel() } diff --git a/pkg/gui/context/local_commits_context.go b/pkg/gui/context/local_commits_context.go index 61a40b30b..5ff361e09 100644 --- a/pkg/gui/context/local_commits_context.go +++ b/pkg/gui/context/local_commits_context.go @@ -92,15 +92,6 @@ func NewLocalCommitsContext(c *ContextCommon) *LocalCommitsContext { return ctx } -func (self *LocalCommitsContext) GetSelectedItemId() string { - item := self.GetSelected() - if item == nil { - return "" - } - - return item.ID() -} - type LocalCommitsViewModel struct { *ListViewModel[*models.Commit] diff --git a/pkg/gui/context/menu_context.go b/pkg/gui/context/menu_context.go index 131aa8665..bb1060de6 100644 --- a/pkg/gui/context/menu_context.go +++ b/pkg/gui/context/menu_context.go @@ -45,16 +45,6 @@ func NewMenuContext( } } -// TODO: remove this thing. -func (self *MenuContext) GetSelectedItemId() string { - item := self.GetSelected() - if item == nil { - return "" - } - - return item.Label -} - type MenuViewModel struct { c *ContextCommon menuItems []*types.MenuItem diff --git a/pkg/gui/context/reflog_commits_context.go b/pkg/gui/context/reflog_commits_context.go index 8dc52cde7..65137d633 100644 --- a/pkg/gui/context/reflog_commits_context.go +++ b/pkg/gui/context/reflog_commits_context.go @@ -59,15 +59,6 @@ func NewReflogCommitsContext(c *ContextCommon) *ReflogCommitsContext { } } -func (self *ReflogCommitsContext) GetSelectedItemId() string { - item := self.GetSelected() - if item == nil { - return "" - } - - return item.ID() -} - func (self *ReflogCommitsContext) CanRebase() bool { return false } diff --git a/pkg/gui/context/remote_branches_context.go b/pkg/gui/context/remote_branches_context.go index 82d37b613..884d3debb 100644 --- a/pkg/gui/context/remote_branches_context.go +++ b/pkg/gui/context/remote_branches_context.go @@ -4,6 +4,7 @@ import ( "github.com/jesseduffield/lazygit/pkg/commands/models" "github.com/jesseduffield/lazygit/pkg/gui/presentation" "github.com/jesseduffield/lazygit/pkg/gui/types" + "github.com/samber/lo" ) type RemoteBranchesContext struct { @@ -53,15 +54,6 @@ func NewRemoteBranchesContext( } } -func (self *RemoteBranchesContext) GetSelectedItemId() string { - item := self.GetSelected() - if item == nil { - return "" - } - - return item.ID() -} - func (self *RemoteBranchesContext) GetSelectedRef() types.Ref { remoteBranch := self.GetSelected() if remoteBranch == nil { @@ -70,6 +62,16 @@ func (self *RemoteBranchesContext) GetSelectedRef() types.Ref { return remoteBranch } +func (self *RemoteBranchesContext) GetSelectedRefs() ([]types.Ref, int, int) { + items, startIdx, endIdx := self.GetSelectedItems() + + refs := lo.Map(items, func(item *models.RemoteBranch, _ int) types.Ref { + return item + }) + + return refs, startIdx, endIdx +} + func (self *RemoteBranchesContext) GetDiffTerminals() []string { itemId := self.GetSelectedItemId() diff --git a/pkg/gui/context/remotes_context.go b/pkg/gui/context/remotes_context.go index 035fb2321..ec59d5fd7 100644 --- a/pkg/gui/context/remotes_context.go +++ b/pkg/gui/context/remotes_context.go @@ -47,15 +47,6 @@ func NewRemotesContext(c *ContextCommon) *RemotesContext { } } -func (self *RemotesContext) GetSelectedItemId() string { - item := self.GetSelected() - if item == nil { - return "" - } - - return item.ID() -} - func (self *RemotesContext) GetDiffTerminals() []string { itemId := self.GetSelectedItemId() diff --git a/pkg/gui/context/stash_context.go b/pkg/gui/context/stash_context.go index 2b86d945f..c8d487688 100644 --- a/pkg/gui/context/stash_context.go +++ b/pkg/gui/context/stash_context.go @@ -49,15 +49,6 @@ func NewStashContext( } } -func (self *StashContext) GetSelectedItemId() string { - item := self.GetSelected() - if item == nil { - return "" - } - - return item.ID() -} - func (self *StashContext) CanRebase() bool { return false } diff --git a/pkg/gui/context/sub_commits_context.go b/pkg/gui/context/sub_commits_context.go index 1f795b44d..7a797e61d 100644 --- a/pkg/gui/context/sub_commits_context.go +++ b/pkg/gui/context/sub_commits_context.go @@ -175,15 +175,6 @@ func (self *SubCommitsViewModel) GetShowBranchHeads() bool { return self.showBranchHeads } -func (self *SubCommitsContext) GetSelectedItemId() string { - item := self.GetSelected() - if item == nil { - return "" - } - - return item.ID() -} - func (self *SubCommitsContext) CanRebase() bool { return false } diff --git a/pkg/gui/context/submodules_context.go b/pkg/gui/context/submodules_context.go index 2cffd82d6..82deb25af 100644 --- a/pkg/gui/context/submodules_context.go +++ b/pkg/gui/context/submodules_context.go @@ -43,12 +43,3 @@ func NewSubmodulesContext(c *ContextCommon) *SubmodulesContext { }, } } - -func (self *SubmodulesContext) GetSelectedItemId() string { - item := self.GetSelected() - if item == nil { - return "" - } - - return item.ID() -} diff --git a/pkg/gui/context/suggestions_context.go b/pkg/gui/context/suggestions_context.go index 30781fce1..59908fe5e 100644 --- a/pkg/gui/context/suggestions_context.go +++ b/pkg/gui/context/suggestions_context.go @@ -63,15 +63,6 @@ func NewSuggestionsContext( } } -func (self *SuggestionsContext) GetSelectedItemId() string { - item := self.GetSelected() - if item == nil { - return "" - } - - return item.Value -} - func (self *SuggestionsContext) SetSuggestions(suggestions []*types.Suggestion) { self.State.Suggestions = suggestions self.SetSelection(0) diff --git a/pkg/gui/context/tags_context.go b/pkg/gui/context/tags_context.go index 3da5a9576..d827564dd 100644 --- a/pkg/gui/context/tags_context.go +++ b/pkg/gui/context/tags_context.go @@ -52,15 +52,6 @@ func NewTagsContext( } } -func (self *TagsContext) GetSelectedItemId() string { - item := self.GetSelected() - if item == nil { - return "" - } - - return item.ID() -} - func (self *TagsContext) GetSelectedRef() types.Ref { tag := self.GetSelected() if tag == nil { diff --git a/pkg/gui/context/working_tree_context.go b/pkg/gui/context/working_tree_context.go index f3bc91929..6fa462cb1 100644 --- a/pkg/gui/context/working_tree_context.go +++ b/pkg/gui/context/working_tree_context.go @@ -58,12 +58,3 @@ func NewWorkingTreeContext(c *ContextCommon) *WorkingTreeContext { return ctx } - -func (self *WorkingTreeContext) GetSelectedItemId() string { - item := self.GetSelected() - if item == nil { - return "" - } - - return item.ID() -} diff --git a/pkg/gui/context/worktrees_context.go b/pkg/gui/context/worktrees_context.go index c616dd49e..3e45f2d45 100644 --- a/pkg/gui/context/worktrees_context.go +++ b/pkg/gui/context/worktrees_context.go @@ -46,12 +46,3 @@ func NewWorktreesContext(c *ContextCommon) *WorktreesContext { }, } } - -func (self *WorktreesContext) GetSelectedItemId() string { - item := self.GetSelected() - if item == nil { - return "" - } - - return item.ID() -} diff --git a/pkg/gui/controllers/basic_commits_controller.go b/pkg/gui/controllers/basic_commits_controller.go index 386877b4d..6c378ecf0 100644 --- a/pkg/gui/controllers/basic_commits_controller.go +++ b/pkg/gui/controllers/basic_commits_controller.go @@ -16,6 +16,7 @@ type ContainsCommits interface { types.Context types.IListContext GetSelected() *models.Commit + GetSelectedItems() ([]*models.Commit, int, int) GetCommits() []*models.Commit GetSelectedLineIdx() int } @@ -36,6 +37,7 @@ func NewBasicCommitsController(c *ControllerCommon, context ContainsCommits) *Ba c, context, context.GetSelected, + context.GetSelectedItems, ), } } diff --git a/pkg/gui/controllers/bisect_controller.go b/pkg/gui/controllers/bisect_controller.go index deb4f1b7a..2f9a7ec36 100644 --- a/pkg/gui/controllers/bisect_controller.go +++ b/pkg/gui/controllers/bisect_controller.go @@ -30,6 +30,7 @@ func NewBisectController( c, c.Contexts().LocalCommits, c.Contexts().LocalCommits.GetSelected, + c.Contexts().LocalCommits.GetSelectedItems, ), } } diff --git a/pkg/gui/controllers/branches_controller.go b/pkg/gui/controllers/branches_controller.go index 8cac9537d..dbd15ef93 100644 --- a/pkg/gui/controllers/branches_controller.go +++ b/pkg/gui/controllers/branches_controller.go @@ -33,6 +33,7 @@ func NewBranchesController( c, c.Contexts().Branches, c.Contexts().Branches.GetSelected, + c.Contexts().Branches.GetSelectedItems, |