summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorIrina Truong <i.chernyavska@gmail.com>2021-05-21 15:32:34 -0700
committerGitHub <noreply@github.com>2021-05-21 15:32:34 -0700
commite0a4c18c4a074016beef46886b03e456102e2db1 (patch)
tree8e0f57064035bd73faab5d4dea3446dc2c5c376c
parentd8532df22e1f309c8a05a452d752bd30f4869bbe (diff)
Another attempt to fix pgbouncer error (1093.) (#1097)HEADmaster
* Another attempt to fix pgbouncer error (1093.) * Fixes for various pgbouncer problems. * different approach with custom cursor. * Fix rebase. * Missed this. * Fix completion refresher test. * Black. * Unused import. * Changelog. * Fix race condition in test. * Switch from is_pgbouncer to more generic is_virtual_database, and duck-type it. Add very dumb unit test for virtual cursor. * Remove debugger code.
-rw-r--r--changelog.rst1
-rw-r--r--pgcli/completion_refresher.py5
-rw-r--r--pgcli/main.py5
-rw-r--r--pgcli/pgexecute.py151
-rw-r--r--tests/features/steps/specials.py7
-rw-r--r--tests/features/steps/wrappers.py8
-rw-r--r--tests/test_completion_refresher.py18
-rw-r--r--tests/test_pgexecute.py24
8 files changed, 163 insertions, 56 deletions
diff --git a/changelog.rst b/changelog.rst
index 21941499..54951023 100644
--- a/changelog.rst
+++ b/changelog.rst
@@ -19,6 +19,7 @@ Bug fixes:
* Fix pager not being used when output format is set to csv. (#1238)
* Add function literals random, generate_series, generate_subscripts
* Fix ANSI escape codes in first line make the cli choose expanded output incorrectly
+* Fix pgcli crashing with virtual `pgbouncer` database. (#1093)
3.1.0
=====
diff --git a/pgcli/completion_refresher.py b/pgcli/completion_refresher.py
index 3e847b09..1039d515 100644
--- a/pgcli/completion_refresher.py
+++ b/pgcli/completion_refresher.py
@@ -3,7 +3,6 @@ import os
from collections import OrderedDict
from .pgcompleter import PGCompleter
-from .pgexecute import PGExecute
class CompletionRefresher:
@@ -27,6 +26,10 @@ class CompletionRefresher:
has completed the refresh. The newly created completion
object will be passed in as an argument to each callback.
"""
+ if executor.is_virtual_database():
+ # do nothing
+ return [(None, None, None, "Auto-completion refresh can't be started.")]
+
if self.is_refreshing():
self._restart_refresh.set()
return [(None, None, None, "Auto-completion refresh restarted.")]
diff --git a/pgcli/main.py b/pgcli/main.py
index 2202c1a7..5135f6fd 100644
--- a/pgcli/main.py
+++ b/pgcli/main.py
@@ -988,16 +988,13 @@ class PGCli:
callback = functools.partial(
self._on_completions_refreshed, persist_priorities=persist_priorities
)
- self.completion_refresher.refresh(
+ return self.completion_refresher.refresh(
self.pgexecute,
self.pgspecial,
callback,
history=history,
settings=self.settings,
)
- return [
- (None, None, None, "Auto-completion refresh started in the background.")
- ]
def _on_completions_refreshed(self, new_completer, persist_priorities):
self._swap_completer_objects(new_completer, persist_priorities)
diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py
index 5cba7845..a013b558 100644
--- a/pgcli/pgexecute.py
+++ b/pgcli/pgexecute.py
@@ -1,13 +1,15 @@
-import traceback
import logging
+import select
+import traceback
+
+import pgspecial as special
import psycopg2
-import psycopg2.extras
import psycopg2.errorcodes
import psycopg2.extensions as ext
+import psycopg2.extras
import sqlparse
-import pgspecial as special
-import select
from psycopg2.extensions import POLL_OK, POLL_READ, POLL_WRITE, make_dsn
+
from .packages.parseutils.meta import FunctionMetadata, ForeignKey
_logger = logging.getLogger(__name__)
@@ -27,6 +29,7 @@ ext.register_type(ext.new_type((17,), "BYTEA_TEXT", psycopg2.STRING))
# TODO: Get default timeout from pgclirc?
_WAIT_SELECT_TIMEOUT = 1
+_wait_callback_is_set = False
def _wait_select(conn):
@@ -34,31 +37,41 @@ def _wait_select(conn):
copy-pasted from psycopg2.extras.wait_select
the default implementation doesn't define a timeout in the select calls
"""
- while 1:
- try:
- state = conn.poll()
- if state == POLL_OK:
- break
- elif state == POLL_READ:
- select.select([conn.fileno()], [], [], _WAIT_SELECT_TIMEOUT)
- elif state == POLL_WRITE:
- select.select([], [conn.fileno()], [], _WAIT_SELECT_TIMEOUT)
- else:
- raise conn.OperationalError("bad state from poll: %s" % state)
- except KeyboardInterrupt:
- conn.cancel()
- # the loop will be broken by a server error
- continue
- except OSError as e:
- errno = e.args[0]
- if errno != 4:
- raise
+ try:
+ while 1:
+ try:
+ state = conn.poll()
+ if state == POLL_OK:
+ break
+ elif state == POLL_READ:
+ select.select([conn.fileno()], [], [], _WAIT_SELECT_TIMEOUT)
+ elif state == POLL_WRITE:
+ select.select([], [conn.fileno()], [], _WAIT_SELECT_TIMEOUT)
+ else:
+ raise conn.OperationalError("bad state from poll: %s" % state)
+ except KeyboardInterrupt:
+ conn.cancel()
+ # the loop will be broken by a server error
+ continue
+ except OSError as e:
+ errno = e.args[0]
+ if errno != 4:
+ raise
+ except psycopg2.OperationalError:
+ pass
-# When running a query, make pressing CTRL+C raise a KeyboardInterrupt
-# See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/
-# See also https://github.com/psycopg/psycopg2/issues/468
-ext.set_wait_callback(_wait_select)
+def _set_wait_callback(is_virtual_database):
+ global _wait_callback_is_set
+ if _wait_callback_is_set:
+ return
+ _wait_callback_is_set = True
+ if is_virtual_database:
+ return
+ # When running a query, make pressing CTRL+C raise a KeyboardInterrupt
+ # See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/
+ # See also https://github.com/psycopg/psycopg2/issues/468
+ ext.set_wait_callback(_wait_select)
def register_date_typecasters(connection):
@@ -72,6 +85,8 @@ def register_date_typecasters(connection):
cursor = connection.cursor()
cursor.execute("SELECT NULL::date")
+ if cursor.description is None:
+ return
date_oid = cursor.description[0][1]
cursor.execute("SELECT NULL::timestamp")
timestamp_oid = cursor.description[0][1]
@@ -103,7 +118,7 @@ def register_json_typecasters(conn, loads_fn):
try:
psycopg2.extras.register_json(conn, loads=loads_fn, name=name)
available.add(name)
- except psycopg2.ProgrammingError:
+ except (psycopg2.ProgrammingError, psycopg2.errors.ProtocolViolation):
pass
return available
@@ -127,6 +142,38 @@ def register_hstore_typecaster(conn):
pass
+class ProtocolSafeCursor(psycopg2.extensions.cursor):
+ def __init__(self, *args, **kwargs):
+ self.protocol_error = False
+ self.protocol_message = ""
+ super().__init__(*args, **kwargs)
+
+ def __iter__(self):
+ if self.protocol_error:
+ raise StopIteration
+ return super().__iter__()
+
+ def fetchall(self):
+ if self.protocol_error:
+ return [(self.protocol_message,)]
+ return super().fetchall()
+
+ def fetchone(self):
+ if self.protocol_error:
+ return (self.protocol_message,)
+ return super().fetchone()
+
+ def execute(self, sql, args=None):
+ try:
+ psycopg2.extensions.cursor.execute(self, sql, args)
+ self.protocol_error = False
+ self.protocol_message = ""
+ except psycopg2.errors.ProtocolViolation as ex:
+ self.protocol_error = True
+ self.protocol_message = ex.pgerror
+ _logger.debug("%s: %s" % (ex.__class__.__name__, ex))
+
+
class PGExecute:
# The boolean argument to the current_schemas function indicates whether
@@ -190,8 +237,6 @@ class PGExecute:
SELECT pg_catalog.pg_get_functiondef(f.f_oid)
FROM f"""
- version_query = "SELECT version();"
-
def __init__(
self,
database=None,
@@ -203,6 +248,7 @@ class PGExecute:
**kwargs,
):
self._conn_params = {}
+ self._is_virtual_database = None
self.conn = None
self.dbname = None
self.user = None
@@ -214,6 +260,11 @@ class PGExecute:
self.connect(database, user, password, host, port, dsn, **kwargs)
self.reset_expanded = None
+ def is_virtual_database(self):
+ if self._is_virtual_database is None:
+ self._is_virtual_database = self.is_protocol_error()
+ return self._is_virtual_database
+
def copy(self):
"""Returns a clone of the current executor."""
return self.__class__(**self._conn_params)
@@ -250,9 +301,9 @@ class PGExecute:
)
conn_params.update({k: v for k, v in new_params.items() if v})
+ conn_params["cursor_factory"] = ProtocolSafeCursor
conn = psycopg2.connect(**conn_params)
- cursor = conn.cursor()
conn.set_client_encoding("utf8")
self._conn_params = conn_params
@@ -293,16 +344,22 @@ class PGExecute:
self.extra_args = kwargs
if not self.host:
- self.host = self.get_socket_directory()
+ self.host = (
+ "pgbouncer"
+ if self.is_virtual_database()
+ else self.get_socket_directory()
+ )
- pid = self._select_one(cursor, "select pg_backend_pid()")[0]
- self.pid = pid
+ self.pid = conn.get_backend_pid()
self.superuser = conn.get_parameter_status("is_superuser") in ("on", "1")
- self.server_version = conn.get_parameter_status("server_version")
+ self.server_version = conn.get_parameter_status("server_version") or ""
+
+ _set_wait_callback(self.is_virtual_database())
- register_date_typecasters(conn)
- register_json_typecasters(self.conn, self._json_typecaster)
- register_hstore_typecaster(self.conn)
+ if not self.is_virtual_database():
+ register_date_typecasters(conn)
+ register_json_typecasters(self.conn, self._json_typecaster)
+ register_hstore_typecaster(self.conn)
@property
def short_host(self):
@@ -395,7 +452,13 @@ class PGExecute:
# See https://github.com/dbcli/pgcli/issues/1014.
cur = None
try:
- for result in pgspecial.execute(cur, sql):
+ response = pgspecial.execute(cur, sql)
+ if cur and cur.protocol_error:
+ yield None, None, None, cur.protocol_message, statement, False, False
+ # this would close connection. We should reconnect.
+ self.connect()
+ continue
+ for result in response:
# e.g. execute_from_file already appends these
if len(result) < 7:
yield result + (sql, True, True)
@@ -453,6 +516,9 @@ class PGExecute:
if cur.description:
headers = [x[0] for x in cur.description]
return title, cur, headers, cur.statusmessage
+ elif cur.protocol_error:
+ _logger.debug("Protocol error, unsupported command.")
+ return title, None, None, cur.protocol_message
else:
_logger.debug("No rows in result.")
return title, None, None, cur.statusmessage
@@ -617,6 +683,13 @@ class PGExecute:
headers = [x[0] for x in cur.description]
return cur.fetchall(), headers, cur.statusmessage
+ def is_protocol_error(self):
+ query = "SELECT 1"
+ with self.conn.cursor() as cur:
+ _logger.debug("Simple Query. sql: %r", query)
+ cur.execute(query)
+ return bool(cur.protocol_error)
+
def get_socket_directory(self):
with self.conn.cursor() as cur:
_logger.debug(
diff --git a/tests/features/steps/specials.py b/tests/features/steps/specials.py
index 813292c4..a85f3710 100644
--- a/tests/features/steps/specials.py
+++ b/tests/features/steps/specials.py
@@ -22,5 +22,10 @@ def step_see_refresh_started(context):
Wait to see refresh output.
"""
wrappers.expect_pager(
- context, "Auto-completion refresh started in the background.\r\n", timeout=2
+ context,
+ [
+ "Auto-completion refresh started in the background.\r\n",
+ "Auto-completion refresh restarted.\r\n",
+ ],
+ timeout=2,
)
diff --git a/tests/features/steps/wrappers.py b/tests/features/steps/wrappers.py
index 78d76881..0ca83669 100644
--- a/tests/features/steps/wrappers.py
+++ b/tests/features/steps/wrappers.py
@@ -39,9 +39,15 @@ def expect_exact(context, expected, timeout):
def expect_pager(context, expected, timeout):
+ formatted = expected if isinstance(expected, list) else [expected]
+ formatted = [
+ f"{context.conf['pager_boundary']}\r\n{t}{context.conf['pager_boundary']}\r\n"
+ for t in formatted
+ ]
+
expect_exact(
context,
- "{0}\r\n{1}{0}\r\n".format(context.conf["pager_boundary"], expected),
+ formatted,
timeout=timeout,
)
diff --git a/tests/test_completion_refresher.py b/tests/test_completion_refresher.py
index 34cf5700..a5529d6b 100644
--- a/tests/test_completion_refresher.py
+++ b/tests/test_completion_refresher.py
@@ -37,7 +37,7 @@ def test_refresh_called_once(refresher):
:return:
"""
callbacks = Mock()
- pgexecute = Mock()
+ pgexecute = Mock(**{"is_virtual_database.return_value": False})
special = Mock()
with patch.object(refresher, "_bg_refresh") as bg_refresh:
@@ -57,7 +57,7 @@ def test_refresh_called_twice(refresher):
"""
callbacks = Mock()
- pgexecute = Mock()
+ pgexecute = Mock(**{"is_virtual_database.return_value": False})
special = Mock()
def dummy_bg_refresh(*args):
@@ -84,14 +84,12 @@ def test_refresh_with_callbacks(refresher):
:param refresher:
"""
callbacks = [Mock()]
- pgexecute_class = Mock()
- pgexecute = Mock()
+ pgexecute = Mock(**{"is_virtual_database.return_value": False})
pgexecute.extra_args = {}
special = Mock()
- with patch("pgcli.completion_refresher.PGExecute", pgexecute_class):
- # Set refreshers to 0: we're not testing refresh logic here
- refresher.refreshers = {}
- refresher.refresh(pgexecute, special, callbacks)
- time.sleep(1) # Wait for the thread to work.
- assert callbacks[0].call_count == 1
+ # Set refreshers to 0: we're not testing refresh logic here
+ refresher.refreshers = {}
+ refresher.refresh(pgexecute, special, callbacks)
+ time.sleep(1) # Wait for the thread to work.
+ assert callbacks[0].call_count == 1
diff --git a/tests/test_pgexecute.py b/tests/test_pgexecute.py
index 513e6192..109674cb 100644
--- a/tests/test_pgexecute.py
+++ b/tests/test_pgexecute.py
@@ -520,6 +520,21 @@ class BrokenConnection:
raise psycopg2.InterfaceError("I'm broken!")
+class VirtualCursor:
+ """Mock a cursor to virtual database like pgbouncer."""
+
+ def __init__(self):
+ self.protocol_error = False
+ self.protocol_message = ""
+ self.description = None
+ self.status = None
+ self.statusmessage = "Error"
+
+ def execute(self, *args, **kwargs):
+ self.protocol_error = True
+ self.protocol_message = "Command not supported"
+
+
@dbtest
def test_exit_without_active_connection(executor):
quit_handler = MagicMock()
@@ -542,3 +557,12 @@ def test_exit_without_active_connection(executor):
# an exception should be raised when running a query without active connection
with pytest.raises(psycopg2.InterfaceError):
run(executor, "select 1", pgspecial=pgspecial)
+
+
+@dbtest
+def test_virtual_database(executor):
+ virtual_connection = MagicMock()
+ virtual_connection.cursor.return_value = VirtualCursor()
+ with patch.object(executor, "conn", virtual_connection):
+ result = run(executor, "select 1")
+ assert "Command not supported" in result