diff options
author | Bruno Inec <7051978+sweenu@users.noreply.github.com> | 2022-02-18 23:57:42 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-02-18 14:57:42 -0800 |
commit | ed9d123073eacbef4199e77113078cc6add02ad9 (patch) | |
tree | a49478909d6f3ec8c3b188d8f06187bc814fbcf1 | |
parent | 78843ac30f1406a6e472366bc7d0c7cffe9c4aee (diff) |
Add SSH tunnel support (#1301)
* Add initial sshtunnel support
* Force CI to rerun.
* Fix unit test for 3.6.
* Black.
Co-authored-by: Irina Truong <i.chernyavska@gmail.com>
-rw-r--r-- | .github/workflows/ci.yml | 2 | ||||
-rw-r--r-- | AUTHORS | 3 | ||||
-rw-r--r-- | changelog.rst | 1 | ||||
-rw-r--r-- | pgcli/main.py | 82 | ||||
-rw-r--r-- | pgcli/packages/parseutils/tables.py | 18 | ||||
-rw-r--r-- | pgcli/pgcompleter.py | 13 | ||||
-rw-r--r-- | setup.py | 5 | ||||
-rw-r--r-- | tests/features/steps/basic_commands.py | 4 | ||||
-rw-r--r-- | tests/test_ssh_tunnel.py | 131 |
9 files changed, 234 insertions, 25 deletions
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ce54d6f5..f0e6fd88 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -38,7 +38,7 @@ jobs: - name: Install requirements run: | pip install -U pip setuptools - pip install --no-cache-dir . + pip install --no-cache-dir ".[sshtunnel]" pip install -r requirements-dev.txt pip install keyrings.alt>=3.1 @@ -116,8 +116,9 @@ Contributors: * Kevin Marsh (kevinmarsh) * Eero Ruohola (ruohola) * Miroslav Šedivý (eumiro) - * Eric R Young (ERYoung11) + * Eric R Young (ERYoung11) * Paweł Sacawa (psacawa) + * Bruno Inec (sweenu) Creator: -------- diff --git a/changelog.rst b/changelog.rst index d4bbd39c..7fbd35a6 100644 --- a/changelog.rst +++ b/changelog.rst @@ -16,6 +16,7 @@ Features: * Add `max_field_width` setting to config, to enable more control over field truncation ([related issue](https://github.com/dbcli/pgcli/issues/1250)). * Re-run last query via bare `\watch`. (Thanks: `Saif Hakim`_) +* Add optional support for automatically creating an SSH tunnel to a machine with access to the remote database ([related issue](https://github.com/dbcli/pgcli/issues/459)). Bug fixes: ---------- diff --git a/pgcli/main.py b/pgcli/main.py index e4a2ee39..7feb74b7 100644 --- a/pgcli/main.py +++ b/pgcli/main.py @@ -1,6 +1,5 @@ import platform import warnings -from os.path import expanduser from configobj import ConfigObj, ParseError from pgspecial.namedqueries import NamedQueries @@ -8,6 +7,7 @@ from .config import skip_initial_comment warnings.filterwarnings("ignore", category=UserWarning, module="psycopg2") +import atexit import os import re import sys @@ -21,6 +21,8 @@ import datetime as dt import itertools import platform from time import time, sleep +from typing import Optional +from urllib.parse import urlparse keyring = None # keyring will be loaded later @@ -78,12 +80,21 @@ except ImportError: from getpass import getuser from psycopg2 import OperationalError, InterfaceError +from psycopg2.extensions import make_dsn, parse_dsn import psycopg2 from collections import namedtuple from textwrap import dedent +try: + import sshtunnel + + SSH_TUNNEL_SUPPORT = True +except ImportError: + SSH_TUNNEL_SUPPORT = False + + # Ref: https://stackoverflow.com/questions/30425105/filter-special-chars-such-as-color-codes-from-shell-output COLOR_CODE_REGEX = re.compile(r"\x1b(\[.*?[@-~]|\].*?(\x07|\x1b\\))") DEFAULT_MAX_FIELD_WIDTH = 500 @@ -168,8 +179,8 @@ class PGCli: prompt_dsn=None, auto_vertical_output=False, warn=None, + ssh_tunnel_url: Optional[str] = None, ): - self.force_passwd_prompt = force_passwd_prompt self.never_passwd_prompt = never_passwd_prompt self.pgexecute = pgexecute @@ -275,6 +286,9 @@ class PGCli: self.prompt_app = None + self.ssh_tunnel_url = ssh_tunnel_url + self.ssh_tunnel = None + def quit(self): raise PgCliQuitError @@ -585,6 +599,50 @@ class PGCli: return True return False + if self.ssh_tunnel_url: + # We add the protocol as urlparse doesn't find it by itself + if "://" not in self.ssh_tunnel_url: + self.ssh_tunnel_url = f"ssh://{self.ssh_tunnel_url}" + + if dsn: + parsed_dsn = parse_dsn(dsn) + if "host" in parsed_dsn: + host = parsed_dsn["host"] + if "port" in parsed_dsn: + port = parsed_dsn["port"] + + tunnel_info = urlparse(self.ssh_tunnel_url) + params = { + "local_bind_address": ("127.0.0.1",), + "remote_bind_address": (host, int(port or 5432)), + "ssh_address_or_host": (tunnel_info.hostname, tunnel_info.port or 22), + "logger": self.logger, + } + if tunnel_info.username: + params["ssh_username"] = tunnel_info.username + if tunnel_info.password: + params["ssh_password"] = tunnel_info.password + + # Hack: sshtunnel adds a console handler to the logger, so we revert handlers. + # We can remove this when https://github.com/pahaz/sshtunnel/pull/250 is merged. + logger_handlers = self.logger.handlers.copy() + try: + self.ssh_tunnel = sshtunnel.SSHTunnelForwarder(**params) + self.ssh_tunnel.start() + except Exception as e: + self.logger.handlers = logger_handlers + self.logger.error("traceback: %r", traceback.format_exc()) + click.secho(str(e), err=True, fg="red") + exit(1) + self.logger.handlers = logger_handlers + + atexit.register(self.ssh_tunnel.stop) + host = "127.0.0.1" + port = self.ssh_tunnel.local_bind_ports[0] + + if dsn: + dsn = make_dsn(dsn, host=host, port=port) + # Attempt to connect to the database. # Note that passwd may be empty on the first attempt. If connection # fails because of a missing or incorrect password, but we're allowed to @@ -1222,7 +1280,7 @@ class PGCli: "--list", "list_databases", is_flag=True, - help="list " "available databases, then exit.", + help="list available databases, then exit.", ) @click.option( "--auto-vertical-output", @@ -1235,6 +1293,11 @@ class PGCli: type=click.Choice(["all", "moderate", "off"]), help="Warn before running a destructive query.", ) +@click.option( + "--ssh-tunnel", + default=None, + help="Open an SSH tunnel to the given address and connect to the database from it.", +) @click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1) @click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1) def cli( @@ -1258,6 +1321,7 @@ def cli( auto_vertical_output, list_dsn, warn, + ssh_tunnel: str, ): if version: print("Version:", __version__) @@ -1294,6 +1358,15 @@ def cli( ) exit(1) + if ssh_tunnel and not SSH_TUNNEL_SUPPORT: + click.secho( + 'Cannot open SSH tunnel, "sshtunnel" package was not found. ' + "Please install pgcli with `pip install pgcli[sshtunnel]` if you want SSH tunnel support.", + err=True, + fg="red", + ) + exit(1) + pgcli = PGCli( prompt_passwd, never_prompt, @@ -1305,6 +1378,7 @@ def cli( prompt_dsn=prompt_dsn, auto_vertical_output=auto_vertical_output, warn=warn, + ssh_tunnel_url=ssh_tunnel, ) # Choose which ever one has a valid value. @@ -1548,7 +1622,7 @@ def parse_service_info(service): elif os.getenv("PGSYSCONFDIR"): service_file = os.path.join(os.getenv("PGSYSCONFDIR"), ".pg_service.conf") else: - service_file = expanduser("~/.pg_service.conf") + service_file = os.path.expanduser("~/.pg_service.conf") if not service or not os.path.exists(service_file): # nothing to do return None, service_file diff --git a/pgcli/packages/parseutils/tables.py b/pgcli/packages/parseutils/tables.py index aaa676cc..f2e1e42b 100644 --- a/pgcli/packages/parseutils/tables.py +++ b/pgcli/packages/parseutils/tables.py @@ -63,17 +63,13 @@ def extract_from_part(parsed, stop_at_punctuation=True): yield item elif item.ttype is Keyword or item.ttype is Keyword.DML: item_val = item.value.upper() - if ( - item_val - in ( - "COPY", - "FROM", - "INTO", - "UPDATE", - "TABLE", - ) - or item_val.endswith("JOIN") - ): + if item_val in ( + "COPY", + "FROM", + "INTO", + "UPDATE", + "TABLE", + ) or item_val.endswith("JOIN"): tbl_prefix_seen = True # 'SELECT a, FROM abc' will detect FROM as part of the column list. # So this check here is necessary. diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py index 227e25c6..e66c3dc2 100644 --- a/pgcli/pgcompleter.py +++ b/pgcli/pgcompleter.py @@ -491,11 +491,14 @@ class PGCompleter(Completer): def get_column_matches(self, suggestion, word_before_cursor): tables = suggestion.table_refs - do_qualify = suggestion.qualifiable and { - "always": True, - "never": False, - "if_more_than_one_table": len(tables) > 1, - }[self.qualify_columns] + do_qualify = ( + suggestion.qualifiable + and { + "always": True, + "never": False, + "if_more_than_one_table": len(tables) > 1, + }[self.qualify_columns] + ) qualify = lambda col, tbl: ( (tbl + "." + self.case(col)) if do_qualify else self.case(col) ) @@ -39,7 +39,10 @@ setup( description=description, long_description=open("README.rst").read(), install_requires=install_requirements, - extras_require={"keyring": ["keyring >= 12.2.0"]}, + extras_require={ + "keyring": ["keyring >= 12.2.0"], + "sshtunnel": ["sshtunnel >= 0.4.0"], + }, python_requires=">=3.6", entry_points=""" [console_scripts] diff --git a/tests/features/steps/basic_commands.py b/tests/features/steps/basic_commands.py index 7ca20f06..a7c99eea 100644 --- a/tests/features/steps/basic_commands.py +++ b/tests/features/steps/basic_commands.py @@ -97,9 +97,9 @@ def step_see_error_message(context): @when("we send source command") def step_send_source_command(context): context.tmpfile_sql_help = tempfile.NamedTemporaryFile(prefix="pgcli_") - context.tmpfile_sql_help.write(br"\?") + context.tmpfile_sql_help.write(rb"\?") context.tmpfile_sql_help.flush() - context.cli.sendline(fr"\i {context.tmpfile_sql_help.name}") + context.cli.sendline(rf"\i {context.tmpfile_sql_help.name}") wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5) diff --git a/tests/test_ssh_tunnel.py b/tests/test_ssh_tunnel.py new file mode 100644 index 00000000..9163c08c --- /dev/null +++ b/tests/test_ssh_tunnel.py @@ -0,0 +1,131 @@ +from unittest.mock import patch, MagicMock, ANY + +import pytest +from click.testing import CliRunner +from sshtunnel import SSHTunnelForwarder + +from pgcli.main import cli, PGCli +from pgcli.pgexecute import PGExecute + + +@pytest.fixture +def mock_ssh_tunnel_forwarder() -> MagicMock: + mock_ssh_tunnel_forwarder = MagicMock( + SSHTunnelForwarder, local_bind_ports=[1111], autospec=True + ) + with patch( + "pgcli.main.sshtunnel.SSHTunnelForwarder", + return_value=mock_ssh_tunnel_forwarder, + ) as mock: + yield mock + + +@pytest.fixture +def mock_pgexecute() -> MagicMock: + with patch.object(PGExecute, "__init__", return_value=None) as mock_pgexecute: + yield mock_pgexecute + + +def test_ssh_tunnel( + mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock +) -> None: + # Test with just a host + tunnel_url = "some.host" + db_params = { + "database": "dbname", + "host": "db.host", + "user": "db_user", + "passwd": "db_passwd", + } + expected_tunnel_params = { + "local_bind_address": ("127.0.0.1",), + "remote_bind_address": (db_params["host"], 5432), + "ssh_address_or_host": (tunnel_url, 22), + "logger": ANY, + } + + pgcli = PGCli(ssh_tunnel_url=tunnel_url) + pgcli.connect(**db_params) + + mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params) + mock_ssh_tunnel_forwarder.return_value.start.assert_called_once() + mock_pgexecute.assert_called_once() + + call_args, call_kwargs = mock_pgexecute.call_args + assert call_args == ( + db_params["database"], + db_params["user"], + db_params["passwd"], + "127.0.0.1", + pgcli.ssh_tunnel.local_bind_ports[0], + "", + ) + mock_ssh_tunnel_forwarder.reset_mock() + mock_pgexecute.reset_mock() + + # Test with a full url and with a specific db port + tunnel_user = "tunnel_user" + tunnel_passwd = "tunnel_pass" + tunnel_host = "some.other.host" + tunnel_port = 1022 + tunnel_url = f"ssh://{tunnel_user}:{tunnel_passwd}@{tunnel_host}:{tunnel_port}" + db_params["port"] = 1234 + + expected_tunnel_params["remote_bind_address"] = ( + db_params["host"], + db_params["port"], + ) + expected_tunnel_params["ssh_address_or_host"] = (tunnel_host, tunnel_port) + expected_tunnel_params["ssh_username"] = tunnel_user + expected_tunnel_params["ssh_password"] = tunnel_passwd + + pgcli = PGCli(ssh_tunnel_url=tunnel_url) + pgcli.connect(**db_params) + + mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params) + mock_ssh_tunnel_forwarder.return_value.start.assert_called_once() + mock_pgexecute.assert_called_once() + + call_args, call_kwargs = mock_pgexecute.call_args + assert call_args == ( + db_params["database"], + db_params["user"], + db_params["passwd"], + "127.0.0.1", + pgcli.ssh_tunnel.local_bind_ports[0], + "", + ) + mock_ssh_tunnel_forwarder.reset_mock() + mock_pgexecute.reset_mock() + + # Test with DSN + dsn = ( + f"user={db_params['user']} password={db_params['passwd']} " + f"host={db_params['host']} port={db_params['port']}" + ) + + pgcli = PGCli(ssh_tunnel_url=tunnel_url) + pgcli.connect(dsn=dsn) + + expected_dsn = ( + f"user={db_params['user']} password={db_params['passwd']} " + f"host=127.0.0.1 port={pgcli.ssh_tunnel.local_bind_ports[0]}" + ) + + mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params) + mock_pgexecute.assert_called_once() + + call_args, call_kwargs = mock_pgexecute.call_args + assert expected_dsn in call_args + + +def test_cli_with_tunnel() -> None: + runner = CliRunner() + tunnel_url = "mytunnel" + with patch.object( + PGCli, "__init__", autospec=True, return_value=None + ) as mock_pgcli: + runner.invoke(cli, ["--ssh-tunnel", tunnel_url]) + mock_pgcli.assert_called_once() + call_args, call_kwargs = mock_pgcli.call_args + assert call_kwargs["ssh_tunnel_url"] == tunnel_url |