From e0a4c18c4a074016beef46886b03e456102e2db1 Mon Sep 17 00:00:00 2001 From: Irina Truong Date: Fri, 21 May 2021 15:32:34 -0700 Subject: Another attempt to fix pgbouncer error (1093.) (#1097) * Another attempt to fix pgbouncer error (1093.) * Fixes for various pgbouncer problems. * different approach with custom cursor. * Fix rebase. * Missed this. * Fix completion refresher test. * Black. * Unused import. * Changelog. * Fix race condition in test. * Switch from is_pgbouncer to more generic is_virtual_database, and duck-type it. Add very dumb unit test for virtual cursor. * Remove debugger code. --- changelog.rst | 1 + pgcli/completion_refresher.py | 5 +- pgcli/main.py | 5 +- pgcli/pgexecute.py | 151 +++++++++++++++++++++++++++---------- tests/features/steps/specials.py | 7 +- tests/features/steps/wrappers.py | 8 +- tests/test_completion_refresher.py | 18 ++--- tests/test_pgexecute.py | 24 ++++++ 8 files changed, 163 insertions(+), 56 deletions(-) diff --git a/changelog.rst b/changelog.rst index 21941499..54951023 100644 --- a/changelog.rst +++ b/changelog.rst @@ -19,6 +19,7 @@ Bug fixes: * Fix pager not being used when output format is set to csv. (#1238) * Add function literals random, generate_series, generate_subscripts * Fix ANSI escape codes in first line make the cli choose expanded output incorrectly +* Fix pgcli crashing with virtual `pgbouncer` database. (#1093) 3.1.0 ===== diff --git a/pgcli/completion_refresher.py b/pgcli/completion_refresher.py index 3e847b09..1039d515 100644 --- a/pgcli/completion_refresher.py +++ b/pgcli/completion_refresher.py @@ -3,7 +3,6 @@ import os from collections import OrderedDict from .pgcompleter import PGCompleter -from .pgexecute import PGExecute class CompletionRefresher: @@ -27,6 +26,10 @@ class CompletionRefresher: has completed the refresh. The newly created completion object will be passed in as an argument to each callback. """ + if executor.is_virtual_database(): + # do nothing + return [(None, None, None, "Auto-completion refresh can't be started.")] + if self.is_refreshing(): self._restart_refresh.set() return [(None, None, None, "Auto-completion refresh restarted.")] diff --git a/pgcli/main.py b/pgcli/main.py index 2202c1a7..5135f6fd 100644 --- a/pgcli/main.py +++ b/pgcli/main.py @@ -988,16 +988,13 @@ class PGCli: callback = functools.partial( self._on_completions_refreshed, persist_priorities=persist_priorities ) - self.completion_refresher.refresh( + return self.completion_refresher.refresh( self.pgexecute, self.pgspecial, callback, history=history, settings=self.settings, ) - return [ - (None, None, None, "Auto-completion refresh started in the background.") - ] def _on_completions_refreshed(self, new_completer, persist_priorities): self._swap_completer_objects(new_completer, persist_priorities) diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py index 5cba7845..a013b558 100644 --- a/pgcli/pgexecute.py +++ b/pgcli/pgexecute.py @@ -1,13 +1,15 @@ -import traceback import logging +import select +import traceback + +import pgspecial as special import psycopg2 -import psycopg2.extras import psycopg2.errorcodes import psycopg2.extensions as ext +import psycopg2.extras import sqlparse -import pgspecial as special -import select from psycopg2.extensions import POLL_OK, POLL_READ, POLL_WRITE, make_dsn + from .packages.parseutils.meta import FunctionMetadata, ForeignKey _logger = logging.getLogger(__name__) @@ -27,6 +29,7 @@ ext.register_type(ext.new_type((17,), "BYTEA_TEXT", psycopg2.STRING)) # TODO: Get default timeout from pgclirc? _WAIT_SELECT_TIMEOUT = 1 +_wait_callback_is_set = False def _wait_select(conn): @@ -34,31 +37,41 @@ def _wait_select(conn): copy-pasted from psycopg2.extras.wait_select the default implementation doesn't define a timeout in the select calls """ - while 1: - try: - state = conn.poll() - if state == POLL_OK: - break - elif state == POLL_READ: - select.select([conn.fileno()], [], [], _WAIT_SELECT_TIMEOUT) - elif state == POLL_WRITE: - select.select([], [conn.fileno()], [], _WAIT_SELECT_TIMEOUT) - else: - raise conn.OperationalError("bad state from poll: %s" % state) - except KeyboardInterrupt: - conn.cancel() - # the loop will be broken by a server error - continue - except OSError as e: - errno = e.args[0] - if errno != 4: - raise + try: + while 1: + try: + state = conn.poll() + if state == POLL_OK: + break + elif state == POLL_READ: + select.select([conn.fileno()], [], [], _WAIT_SELECT_TIMEOUT) + elif state == POLL_WRITE: + select.select([], [conn.fileno()], [], _WAIT_SELECT_TIMEOUT) + else: + raise conn.OperationalError("bad state from poll: %s" % state) + except KeyboardInterrupt: + conn.cancel() + # the loop will be broken by a server error + continue + except OSError as e: + errno = e.args[0] + if errno != 4: + raise + except psycopg2.OperationalError: + pass -# When running a query, make pressing CTRL+C raise a KeyboardInterrupt -# See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/ -# See also https://github.com/psycopg/psycopg2/issues/468 -ext.set_wait_callback(_wait_select) +def _set_wait_callback(is_virtual_database): + global _wait_callback_is_set + if _wait_callback_is_set: + return + _wait_callback_is_set = True + if is_virtual_database: + return + # When running a query, make pressing CTRL+C raise a KeyboardInterrupt + # See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/ + # See also https://github.com/psycopg/psycopg2/issues/468 + ext.set_wait_callback(_wait_select) def register_date_typecasters(connection): @@ -72,6 +85,8 @@ def register_date_typecasters(connection): cursor = connection.cursor() cursor.execute("SELECT NULL::date") + if cursor.description is None: + return date_oid = cursor.description[0][1] cursor.execute("SELECT NULL::timestamp") timestamp_oid = cursor.description[0][1] @@ -103,7 +118,7 @@ def register_json_typecasters(conn, loads_fn): try: psycopg2.extras.register_json(conn, loads=loads_fn, name=name) available.add(name) - except psycopg2.ProgrammingError: + except (psycopg2.ProgrammingError, psycopg2.errors.ProtocolViolation): pass return available @@ -127,6 +142,38 @@ def register_hstore_typecaster(conn): pass +class ProtocolSafeCursor(psycopg2.extensions.cursor): + def __init__(self, *args, **kwargs): + self.protocol_error = False + self.protocol_message = "" + super().__init__(*args, **kwargs) + + def __iter__(self): + if self.protocol_error: + raise StopIteration + return super().__iter__() + + def fetchall(self): + if self.protocol_error: + return [(self.protocol_message,)] + return super().fetchall() + + def fetchone(self): + if self.protocol_error: + return (self.protocol_message,) + return super().fetchone() + + def execute(self, sql, args=None): + try: + psycopg2.extensions.cursor.execute(self, sql, args) + self.protocol_error = False + self.protocol_message = "" + except psycopg2.errors.ProtocolViolation as ex: + self.protocol_error = True + self.protocol_message = ex.pgerror + _logger.debug("%s: %s" % (ex.__class__.__name__, ex)) + + class PGExecute: # The boolean argument to the current_schemas function indicates whether @@ -190,8 +237,6 @@ class PGExecute: SELECT pg_catalog.pg_get_functiondef(f.f_oid) FROM f""" - version_query = "SELECT version();" - def __init__( self, database=None, @@ -203,6 +248,7 @@ class PGExecute: **kwargs, ): self._conn_params = {} + self._is_virtual_database = None self.conn = None self.dbname = None self.user = None @@ -214,6 +260,11 @@ class PGExecute: self.connect(database, user, password, host, port, dsn, **kwargs) self.reset_expanded = None + def is_virtual_database(self): + if self._is_virtual_database is None: + self._is_virtual_database = self.is_protocol_error() + return self._is_virtual_database + def copy(self): """Returns a clone of the current executor.""" return self.__class__(**self._conn_params) @@ -250,9 +301,9 @@ class PGExecute: ) conn_params.update({k: v for k, v in new_params.items() if v}) + conn_params["cursor_factory"] = ProtocolSafeCursor conn = psycopg2.connect(**conn_params) - cursor = conn.cursor() conn.set_client_encoding("utf8") self._conn_params = conn_params @@ -293,16 +344,22 @@ class PGExecute: self.extra_args = kwargs if not self.host: - self.host = self.get_socket_directory() + self.host = ( + "pgbouncer" + if self.is_virtual_database() + else self.get_socket_directory() + ) - pid = self._select_one(cursor, "select pg_backend_pid()")[0] - self.pid = pid + self.pid = conn.get_backend_pid() self.superuser = conn.get_parameter_status("is_superuser") in ("on", "1") - self.server_version = conn.get_parameter_status("server_version") + self.server_version = conn.get_parameter_status("server_version") or "" + + _set_wait_callback(self.is_virtual_database()) - register_date_typecasters(conn) - register_json_typecasters(self.conn, self._json_typecaster) - register_hstore_typecaster(self.conn) + if not self.is_virtual_database(): + register_date_typecasters(conn) + register_json_typecasters(self.conn, self._json_typecaster) + register_hstore_typecaster(self.conn) @property def short_host(self): @@ -395,7 +452,13 @@ class PGExecute: # See https://github.com/dbcli/pgcli/issues/1014. cur = None try: - for result in pgspecial.execute(cur, sql): + response = pgspecial.execute(cur, sql) + if cur and cur.protocol_error: + yield None, None, None, cur.protocol_message, statement, False, False + # this would close connection. We should reconnect. + self.connect() + continue + for result in response: # e.g. execute_from_file already appends these if len(result) < 7: yield result + (sql, True, True) @@ -453,6 +516,9 @@ class PGExecute: if cur.description: headers = [x[0] for x in cur.description] return title, cur, headers, cur.statusmessage + elif cur.protocol_error: + _logger.debug("Protocol error, unsupported command.") + return title, None, None, cur.protocol_message else: _logger.debug("No rows in result.") return title, None, None, cur.statusmessage @@ -617,6 +683,13 @@ class PGExecute: headers = [x[0] for x in cur.description] return cur.fetchall(), headers, cur.statusmessage + def is_protocol_error(self): + query = "SELECT 1" + with self.conn.cursor() as cur: + _logger.debug("Simple Query. sql: %r", query) + cur.execute(query) + return bool(cur.protocol_error) + def get_socket_directory(self): with self.conn.cursor() as cur: _logger.debug( diff --git a/tests/features/steps/specials.py b/tests/features/steps/specials.py index 813292c4..a85f3710 100644 --- a/tests/features/steps/specials.py +++ b/tests/features/steps/specials.py @@ -22,5 +22,10 @@ def step_see_refresh_started(context): Wait to see refresh output. """ wrappers.expect_pager( - context, "Auto-completion refresh started in the background.\r\n", timeout=2 + context, + [ + "Auto-completion refresh started in the background.\r\n", + "Auto-completion refresh restarted.\r\n", + ], + timeout=2, ) diff --git a/tests/features/steps/wrappers.py b/tests/features/steps/wrappers.py index 78d76881..0ca83669 100644 --- a/tests/features/steps/wrappers.py +++ b/tests/features/steps/wrappers.py @@ -39,9 +39,15 @@ def expect_exact(context, expected, timeout): def expect_pager(context, expected, timeout): + formatted = expected if isinstance(expected, list) else [expected] + formatted = [ + f"{context.conf['pager_boundary']}\r\n{t}{context.conf['pager_boundary']}\r\n" + for t in formatted + ] + expect_exact( context, - "{0}\r\n{1}{0}\r\n".format(context.conf["pager_boundary"], expected), + formatted, timeout=timeout, ) diff --git a/tests/test_completion_refresher.py b/tests/test_completion_refresher.py index 34cf5700..a5529d6b 100644 --- a/tests/test_completion_refresher.py +++ b/tests/test_completion_refresher.py @@ -37,7 +37,7 @@ def test_refresh_called_once(refresher): :return: """ callbacks = Mock() - pgexecute = Mock() + pgexecute = Mock(**{"is_virtual_database.return_value": False}) special = Mock() with patch.object(refresher, "_bg_refresh") as bg_refresh: @@ -57,7 +57,7 @@ def test_refresh_called_twice(refresher): """ callbacks = Mock() - pgexecute = Mock() + pgexecute = Mock(**{"is_virtual_database.return_value": False}) special = Mock() def dummy_bg_refresh(*args): @@ -84,14 +84,12 @@ def test_refresh_with_callbacks(refresher): :param refresher: """ callbacks = [Mock()] - pgexecute_class = Mock() - pgexecute = Mock() + pgexecute = Mock(**{"is_virtual_database.return_value": False}) pgexecute.extra_args = {} special = Mock() - with patch("pgcli.completion_refresher.PGExecute", pgexecute_class): - # Set refreshers to 0: we're not testing refresh logic here - refresher.refreshers = {} - refresher.refresh(pgexecute, special, callbacks) - time.sleep(1) # Wait for the thread to work. - assert callbacks[0].call_count == 1 + # Set refreshers to 0: we're not testing refresh logic here + refresher.refreshers = {} + refresher.refresh(pgexecute, special, callbacks) + time.sleep(1) # Wait for the thread to work. + assert callbacks[0].call_count == 1 diff --git a/tests/test_pgexecute.py b/tests/test_pgexecute.py index 513e6192..109674cb 100644 --- a/tests/test_pgexecute.py +++ b/tests/test_pgexecute.py @@ -520,6 +520,21 @@ class BrokenConnection: raise psycopg2.InterfaceError("I'm broken!") +class VirtualCursor: + """Mock a cursor to virtual database like pgbouncer.""" + + def __init__(self): + self.protocol_error = False + self.protocol_message = "" + self.description = None + self.status = None + self.statusmessage = "Error" + + def execute(self, *args, **kwargs): + self.protocol_error = True + self.protocol_message = "Command not supported" + + @dbtest def test_exit_without_active_connection(executor): quit_handler = MagicMock() @@ -542,3 +557,12 @@ def test_exit_without_active_connection(executor): # an exception should be raised when running a query without active connection with pytest.raises(psycopg2.InterfaceError): run(executor, "select 1", pgspecial=pgspecial) + + +@dbtest +def test_virtual_database(executor): + virtual_connection = MagicMock() + virtual_connection.cursor.return_value = VirtualCursor() + with patch.object(executor, "conn", virtual_connection): + result = run(executor, "select 1") + assert "Command not supported" in result -- cgit v1.2.3