diff options
-rw-r--r-- | pkg/gui/files_panel.go | 22 | ||||
-rw-r--r-- | pkg/gui/files_panel_test.go | 34 |
2 files changed, 54 insertions, 2 deletions
diff --git a/pkg/gui/files_panel.go b/pkg/gui/files_panel.go index 3bc600420..a0e71bc20 100644 --- a/pkg/gui/files_panel.go +++ b/pkg/gui/files_panel.go @@ -660,9 +660,11 @@ func (gui *Gui) handlePullFiles() error { } } + suggestedRemote := getSuggestedRemote(gui.State.Remotes) + return gui.prompt(promptOpts{ title: gui.Tr.EnterUpstream, - initialContent: "origin/" + currentBranch.Name, + initialContent: suggestedRemote + "/" + currentBranch.Name, findSuggestionsFunc: gui.getRemoteBranchesSuggestionsFunc("/"), handleConfirm: func(upstream string) error { if err := gui.GitCommand.SetUpstreamBranch(upstream); err != nil { @@ -797,12 +799,14 @@ func (gui *Gui) pushFiles() error { ) } + suggestedRemote := getSuggestedRemote(gui.State.Remotes) + if gui.GitCommand.PushToCurrent { return gui.push(pushOpts{setUpstream: true}) } else { return gui.prompt(promptOpts{ title: gui.Tr.EnterUpstream, - initialContent: "origin " + currentBranch.Name, + initialContent: suggestedRemote + " " + currentBranch.Name, findSuggestionsFunc: gui.getRemoteBranchesSuggestionsFunc(" "), handleConfirm: func(upstream string) error { var upstreamBranch, upstreamRemote string @@ -827,6 +831,20 @@ func (gui *Gui) pushFiles() error { } } +func getSuggestedRemote(remotes []*models.Remote) string { + if len(remotes) == 0 { + return "origin" + } + + for _, remote := range remotes { + if remote.Name == "origin" { + return remote.Name + } + } + + return remotes[0].Name +} + func (gui *Gui) requestToForcePush() error { forcePushDisabled := gui.Config.GetUserConfig().Git.DisableForcePushing if forcePushDisabled { diff --git a/pkg/gui/files_panel_test.go b/pkg/gui/files_panel_test.go new file mode 100644 index 000000000..fcf8fe66b --- /dev/null +++ b/pkg/gui/files_panel_test.go @@ -0,0 +1,34 @@ +package gui + +import ( + "testing" + + "github.com/jesseduffield/lazygit/pkg/commands/models" + "github.com/stretchr/testify/assert" +) + +func TestGetSuggestedRemote(t *testing.T) { + cases := []struct { + remotes []*models.Remote + expected string + }{ + {mkRemoteList(), "origin"}, + {mkRemoteList("upstream", "origin", "foo"), "origin"}, + {mkRemoteList("upstream", "foo", "bar"), "upstream"}, + } + + for _, c := range cases { + result := getSuggestedRemote(c.remotes) + assert.EqualValues(t, c.expected, result) + } +} + +func mkRemoteList(names ...string) []*models.Remote { + var result []*models.Remote + + for _, name := range names { + result = append(result, &models.Remote{Name: name}) + } + + return result +} |