summaryrefslogtreecommitdiffstats
path: root/pgcli/pgexecute.py
diff options
context:
space:
mode:
Diffstat (limited to 'pgcli/pgexecute.py')
-rw-r--r--pgcli/pgexecute.py151
1 files changed, 112 insertions, 39 deletions
diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py
index 5cba7845..a013b558 100644
--- a/pgcli/pgexecute.py
+++ b/pgcli/pgexecute.py
@@ -1,13 +1,15 @@
-import traceback
import logging
+import select
+import traceback
+
+import pgspecial as special
import psycopg2
-import psycopg2.extras
import psycopg2.errorcodes
import psycopg2.extensions as ext
+import psycopg2.extras
import sqlparse
-import pgspecial as special
-import select
from psycopg2.extensions import POLL_OK, POLL_READ, POLL_WRITE, make_dsn
+
from .packages.parseutils.meta import FunctionMetadata, ForeignKey
_logger = logging.getLogger(__name__)
@@ -27,6 +29,7 @@ ext.register_type(ext.new_type((17,), "BYTEA_TEXT", psycopg2.STRING))
# TODO: Get default timeout from pgclirc?
_WAIT_SELECT_TIMEOUT = 1
+_wait_callback_is_set = False
def _wait_select(conn):
@@ -34,31 +37,41 @@ def _wait_select(conn):
copy-pasted from psycopg2.extras.wait_select
the default implementation doesn't define a timeout in the select calls
"""
- while 1:
- try:
- state = conn.poll()
- if state == POLL_OK:
- break
- elif state == POLL_READ:
- select.select([conn.fileno()], [], [], _WAIT_SELECT_TIMEOUT)
- elif state == POLL_WRITE:
- select.select([], [conn.fileno()], [], _WAIT_SELECT_TIMEOUT)
- else:
- raise conn.OperationalError("bad state from poll: %s" % state)
- except KeyboardInterrupt:
- conn.cancel()
- # the loop will be broken by a server error
- continue
- except OSError as e:
- errno = e.args[0]
- if errno != 4:
- raise
+ try:
+ while 1:
+ try:
+ state = conn.poll()
+ if state == POLL_OK:
+ break
+ elif state == POLL_READ:
+ select.select([conn.fileno()], [], [], _WAIT_SELECT_TIMEOUT)
+ elif state == POLL_WRITE:
+ select.select([], [conn.fileno()], [], _WAIT_SELECT_TIMEOUT)
+ else:
+ raise conn.OperationalError("bad state from poll: %s" % state)
+ except KeyboardInterrupt:
+ conn.cancel()
+ # the loop will be broken by a server error
+ continue
+ except OSError as e:
+ errno = e.args[0]
+ if errno != 4:
+ raise
+ except psycopg2.OperationalError:
+ pass
-# When running a query, make pressing CTRL+C raise a KeyboardInterrupt
-# See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/
-# See also https://github.com/psycopg/psycopg2/issues/468
-ext.set_wait_callback(_wait_select)
+def _set_wait_callback(is_virtual_database):
+ global _wait_callback_is_set
+ if _wait_callback_is_set:
+ return
+ _wait_callback_is_set = True
+ if is_virtual_database:
+ return
+ # When running a query, make pressing CTRL+C raise a KeyboardInterrupt
+ # See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/
+ # See also https://github.com/psycopg/psycopg2/issues/468
+ ext.set_wait_callback(_wait_select)
def register_date_typecasters(connection):
@@ -72,6 +85,8 @@ def register_date_typecasters(connection):
cursor = connection.cursor()
cursor.execute("SELECT NULL::date")
+ if cursor.description is None:
+ return
date_oid = cursor.description[0][1]
cursor.execute("SELECT NULL::timestamp")
timestamp_oid = cursor.description[0][1]
@@ -103,7 +118,7 @@ def register_json_typecasters(conn, loads_fn):
try:
psycopg2.extras.register_json(conn, loads=loads_fn, name=name)
available.add(name)
- except psycopg2.ProgrammingError:
+ except (psycopg2.ProgrammingError, psycopg2.errors.ProtocolViolation):
pass
return available
@@ -127,6 +142,38 @@ def register_hstore_typecaster(conn):
pass
+class ProtocolSafeCursor(psycopg2.extensions.cursor):
+ def __init__(self, *args, **kwargs):
+ self.protocol_error = False
+ self.protocol_message = ""
+ super().__init__(*args, **kwargs)
+
+ def __iter__(self):
+ if self.protocol_error:
+ raise StopIteration
+ return super().__iter__()
+
+ def fetchall(self):
+ if self.protocol_error:
+ return [(self.protocol_message,)]
+ return super().fetchall()
+
+ def fetchone(self):
+ if self.protocol_error:
+ return (self.protocol_message,)
+ return super().fetchone()
+
+ def execute(self, sql, args=None):
+ try:
+ psycopg2.extensions.cursor.execute(self, sql, args)
+ self.protocol_error = False
+ self.protocol_message = ""
+ except psycopg2.errors.ProtocolViolation as ex:
+ self.protocol_error = True
+ self.protocol_message = ex.pgerror
+ _logger.debug("%s: %s" % (ex.__class__.__name__, ex))
+
+
class PGExecute:
# The boolean argument to the current_schemas function indicates whether
@@ -190,8 +237,6 @@ class PGExecute:
SELECT pg_catalog.pg_get_functiondef(f.f_oid)
FROM f"""
- version_query = "SELECT version();"
-
def __init__(
self,
database=None,
@@ -203,6 +248,7 @@ class PGExecute:
**kwargs,
):
self._conn_params = {}
+ self._is_virtual_database = None
self.conn = None
self.dbname = None
self.user = None
@@ -214,6 +260,11 @@ class PGExecute:
self.connect(database, user, password, host, port, dsn, **kwargs)
self.reset_expanded = None
+ def is_virtual_database(self):
+ if self._is_virtual_database is None:
+ self._is_virtual_database = self.is_protocol_error()
+ return self._is_virtual_database
+
def copy(self):
"""Returns a clone of the current executor."""
return self.__class__(**self._conn_params)
@@ -250,9 +301,9 @@ class PGExecute:
)
conn_params.update({k: v for k, v in new_params.items() if v})
+ conn_params["cursor_factory"] = ProtocolSafeCursor
conn = psycopg2.connect(**conn_params)
- cursor = conn.cursor()
conn.set_client_encoding("utf8")
self._conn_params = conn_params
@@ -293,16 +344,22 @@ class PGExecute:
self.extra_args = kwargs
if not self.host:
- self.host = self.get_socket_directory()
+ self.host = (
+ "pgbouncer"
+ if self.is_virtual_database()
+ else self.get_socket_directory()
+ )
- pid = self._select_one(cursor, "select pg_backend_pid()")[0]
- self.pid = pid
+ self.pid = conn.get_backend_pid()
self.superuser = conn.get_parameter_status("is_superuser") in ("on", "1")
- self.server_version = conn.get_parameter_status("server_version")
+ self.server_version = conn.get_parameter_status("server_version") or ""
+
+ _set_wait_callback(self.is_virtual_database())
- register_date_typecasters(conn)
- register_json_typecasters(self.conn, self._json_typecaster)
- register_hstore_typecaster(self.conn)
+ if not self.is_virtual_database():
+ register_date_typecasters(conn)
+ register_json_typecasters(self.conn, self._json_typecaster)
+ register_hstore_typecaster(self.conn)
@property
def short_host(self):
@@ -395,7 +452,13 @@ class PGExecute:
# See https://github.com/dbcli/pgcli/issues/1014.
cur = None
try:
- for result in pgspecial.execute(cur, sql):
+ response = pgspecial.execute(cur, sql)
+ if cur and cur.protocol_error:
+ yield None, None, None, cur.protocol_message, statement, False, False
+ # this would close connection. We should reconnect.
+ self.connect()
+ continue
+ for result in response:
# e.g. execute_from_file already appends these
if len(result) < 7:
yield result + (sql, True, True)
@@ -453,6 +516,9 @@ class PGExecute:
if cur.description:
headers = [x[0] for x in cur.description]
return title, cur, headers, cur.statusmessage
+ elif cur.protocol_error:
+ _logger.debug("Protocol error, unsupported command.")
+ return title, None, None, cur.protocol_message
else:
_logger.debug("No rows in result.")
return title, None, None, cur.statusmessage
@@ -617,6 +683,13 @@ class PGExecute:
headers = [x[0] for x in cur.description]
return cur.fetchall(), headers, cur.statusmessage
+ def is_protocol_error(self):
+ query = "SELECT 1"
+ with self.conn.cursor() as cur:
+ _logger.debug("Simple Query. sql: %r", query)
+ cur.execute(query)
+ return bool(cur.protocol_error)
+
def get_socket_directory(self):
with self.conn.cursor() as cur:
_logger.debug(