summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorBruno Inec <7051978+sweenu@users.noreply.github.com>2022-02-21 19:20:11 +0100
committerGitHub <noreply@github.com>2022-02-21 10:20:11 -0800
commit54f0cc9ddd060ac62269666073357b1eedadd545 (patch)
treed3c4ddb5837d7d18565c8544b7ca1be330e50681
parented9d123073eacbef4199e77113078cc6add02ad9 (diff)
ssh tunnels: allow configuring auto matches (#1302)
-rw-r--r--pgcli/main.py21
-rw-r--r--tests/test_ssh_tunnel.py57
2 files changed, 71 insertions, 7 deletions
diff --git a/pgcli/main.py b/pgcli/main.py
index 7feb74b7..a72f7089 100644
--- a/pgcli/main.py
+++ b/pgcli/main.py
@@ -286,6 +286,7 @@ class PGCli:
self.prompt_app = None
+ self.ssh_tunnel_config = c.get("ssh tunnels")
self.ssh_tunnel_url = ssh_tunnel_url
self.ssh_tunnel = None
@@ -599,18 +600,24 @@ class PGCli:
return True
return False
+ if dsn:
+ parsed_dsn = parse_dsn(dsn)
+ if "host" in parsed_dsn:
+ host = parsed_dsn["host"]
+ if "port" in parsed_dsn:
+ port = parsed_dsn["port"]
+
+ if self.ssh_tunnel_config and not self.ssh_tunnel_url:
+ for db_host_regex, tunnel_url in self.ssh_tunnel_config.items():
+ if re.search(db_host_regex, host):
+ self.ssh_tunnel_url = tunnel_url
+ break
+
if self.ssh_tunnel_url:
# We add the protocol as urlparse doesn't find it by itself
if "://" not in self.ssh_tunnel_url:
self.ssh_tunnel_url = f"ssh://{self.ssh_tunnel_url}"
- if dsn:
- parsed_dsn = parse_dsn(dsn)
- if "host" in parsed_dsn:
- host = parsed_dsn["host"]
- if "port" in parsed_dsn:
- port = parsed_dsn["port"]
-
tunnel_info = urlparse(self.ssh_tunnel_url)
params = {
"local_bind_address": ("127.0.0.1",),
diff --git a/tests/test_ssh_tunnel.py b/tests/test_ssh_tunnel.py
index 9163c08c..ae865f4a 100644
--- a/tests/test_ssh_tunnel.py
+++ b/tests/test_ssh_tunnel.py
@@ -1,6 +1,8 @@
+import os
from unittest.mock import patch, MagicMock, ANY
import pytest
+from configobj import ConfigObj
from click.testing import CliRunner
from sshtunnel import SSHTunnelForwarder
@@ -129,3 +131,58 @@ def test_cli_with_tunnel() -> None:
mock_pgcli.assert_called_once()
call_args, call_kwargs = mock_pgcli.call_args
assert call_kwargs["ssh_tunnel_url"] == tunnel_url
+
+
+def test_config(
+ tmpdir: os.PathLike, mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock
+) -> None:
+ pgclirc = str(tmpdir.join("rcfile"))
+
+ tunnel_user = "tunnel_user"
+ tunnel_passwd = "tunnel_pass"
+ tunnel_host = "tunnel.host"
+ tunnel_port = 1022
+ tunnel_url = f"{tunnel_user}:{tunnel_passwd}@{tunnel_host}:{tunnel_port}"
+
+ tunnel2_url = "tunnel2.host"
+
+ config = ConfigObj()
+ config.filename = pgclirc
+ config["ssh tunnels"] = {}
+ config["ssh tunnels"][r"\.com$"] = tunnel_url
+ config["ssh tunnels"][r"^hello-"] = tunnel2_url
+ config.write()
+
+ # Unmatched host
+ pgcli = PGCli(pgclirc_file=pgclirc)
+ pgcli.connect(host="unmatched.host")
+ mock_ssh_tunnel_forwarder.assert_not_called()
+
+ # Host matching first tunnel
+ pgcli = PGCli(pgclirc_file=pgclirc)
+ pgcli.connect(host="matched.host.com")
+ mock_ssh_tunnel_forwarder.assert_called_once()
+ call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args
+ assert call_kwargs["ssh_address_or_host"] == (tunnel_host, tunnel_port)
+ assert call_kwargs["ssh_username"] == tunnel_user
+ assert call_kwargs["ssh_password"] == tunnel_passwd
+ mock_ssh_tunnel_forwarder.reset_mock()
+
+ # Host matching second tunnel
+ pgcli = PGCli(pgclirc_file=pgclirc)
+ pgcli.connect(host="hello-i-am-matched")
+ mock_ssh_tunnel_forwarder.assert_called_once()
+
+ call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args
+ assert call_kwargs["ssh_address_or_host"] == (tunnel2_url, 22)
+ mock_ssh_tunnel_forwarder.reset_mock()
+
+ # Host matching both tunnels (will use the first one matched)
+ pgcli = PGCli(pgclirc_file=pgclirc)
+ pgcli.connect(host="hello-i-am-matched.com")
+ mock_ssh_tunnel_forwarder.assert_called_once()
+
+ call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args
+ assert call_kwargs["ssh_address_or_host"] == (tunnel_host, tunnel_port)
+ assert call_kwargs["ssh_username"] == tunnel_user
+ assert call_kwargs["ssh_password"] == tunnel_passwd