summaryrefslogtreecommitdiffstats
path: root/pgcli/pgexecute.py
diff options
context:
space:
mode:
Diffstat (limited to 'pgcli/pgexecute.py')
-rw-r--r--pgcli/pgexecute.py156
1 files changed, 87 insertions, 69 deletions
diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py
index 8bcc5c63..c7d204f0 100644
--- a/pgcli/pgexecute.py
+++ b/pgcli/pgexecute.py
@@ -113,7 +113,8 @@ def register_hstore_typecaster(conn):
"""
with conn.cursor() as cur:
try:
- cur.execute("SELECT 'hstore'::regtype::oid")
+ cur.execute(
+ "select t.oid FROM pg_type t WHERE t.typname = 'hstore' and t.typisdefined")
oid = cur.fetchone()[0]
ext.register_type(ext.new_type((oid,), "HSTORE", ext.UNICODE))
except Exception:
@@ -185,82 +186,80 @@ class PGExecute(object):
version_query = "SELECT version();"
- def __init__(self, database, user, password, host, port, dsn, **kwargs):
- self.dbname = database
- self.user = user
- self.password = password
- self.host = host
- self.port = port
- self.dsn = dsn
- self.extra_args = {k: unicode2utf8(v) for k, v in kwargs.items()}
+ def __init__(self, database=None, user=None, password=None, host=None,
+ port=None, dsn=None, **kwargs):
+ self._conn_params = {}
+ self.conn = None
+ self.dbname = None
+ self.user = None
+ self.password = None
+ self.host = None
+ self.port = None
self.server_version = None
- self.connect()
-
- def get_server_version(self):
- if self.server_version:
- return self.server_version
- with self.conn.cursor() as cur:
- _logger.debug('Version Query. sql: %r', self.version_query)
- cur.execute(self.version_query)
- result = cur.fetchone()
- if result:
- # full version string looks like this:
- # PostgreSQL 10.3 on x86_64-apple-darwin17.3.0, compiled by Apple LLVM version 9.0.0 (clang-900.0.39.2), 64-bit # noqa
- # let's only retrieve version number
- version_parts = result[0].split()
- self.server_version = version_parts[1]
- else:
- self.server_version = ''
- return self.server_version
+ self.connect(database, user, password, host, port, dsn, **kwargs)
+
+ def copy(self):
+ """Returns a clone of the current executor."""
+ return self.__class__(**self._conn_params)
+
+ def get_server_version(self, cursor):
+ _logger.debug('Version Query. sql: %r', self.version_query)
+ cursor.execute(self.version_query)
+ result = cursor.fetchone()
+ server_version = ''
+ if result:
+ # full version string looks like this:
+ # PostgreSQL 10.3 on x86_64-apple-darwin17.3.0, compiled by Apple LLVM version 9.0.0 (clang-900.0.39.2), 64-bit # noqa
+ # let's only retrieve version number
+ version_parts = result[0].split()
+ server_version = version_parts[1]
+ return server_version
def connect(self, database=None, user=None, password=None, host=None,
port=None, dsn=None, **kwargs):
- db = (database or self.dbname)
- user = (user or self.user)
- password = (password or self.password)
- host = (host or self.host)
- port = (port or self.port)
- dsn = (dsn or self.dsn)
- kwargs = (kwargs or self.extra_args)
- pid = -1
- if dsn:
- if password:
- dsn = "{0} password={1}".format(dsn, password)
- conn = psycopg2.connect(dsn=unicode2utf8(dsn))
- cursor = conn.cursor()
- else:
- conn = psycopg2.connect(
- database=unicode2utf8(db),
- user=unicode2utf8(user),
- password=unicode2utf8(password),
- host=unicode2utf8(host),
- port=unicode2utf8(port),
- **kwargs)
-
- cursor = conn.cursor()
+ conn_params = self._conn_params.copy()
+
+ new_params = {
+ 'database': database,
+ 'user': user,
+ 'password': password,
+ 'host': host,
+ 'port': port,
+ 'dsn': dsn,
+ }
+ new_params.update(kwargs)
+ conn_params.update({
+ k: unicode2utf8(v) for k, v in new_params.items() if v is not None
+ })
+
+ if 'password' in conn_params and 'dsn' in conn_params:
+ conn_params['dsn'] = "{0} password={1}".format(
+ conn_params['dsn'], conn_params.pop('password')
+ )
+ conn = psycopg2.connect(**conn_params)
+ cursor = conn.cursor()
conn.set_client_encoding('utf8')
- if hasattr(self, 'conn'):
+
+ self._conn_params = conn_params
+ if self.conn:
self.conn.close()
self.conn = conn
self.conn.autocommit = True
- if dsn:
- # When we connect using a DSN, we don't really know what db,
- # user, etc. we connected to. Let's read it.
- # Note: moved this after setting autocommit because of #664.
- dsn_parameters = conn.get_dsn_parameters()
- db = dsn_parameters['dbname']
- user = dsn_parameters['user']
- host = dsn_parameters['host']
- port = dsn_parameters['port']
-
- self.dbname = db
- self.user = user
+ # When we connect using a DSN, we don't really know what db,
+ # user, etc. we connected to. Let's read it.
+ # Note: moved this after setting autocommit because of #664.
+ # TODO: use actual connection info from psycopg2.extensions.Connection.info as psycopg>2.8 is available and required dependency # noqa
+ dsn_parameters = conn.get_dsn_parameters()
+
+ self.dbname = dsn_parameters['dbname']
+ self.user = dsn_parameters['user']
self.password = password
- self.host = host
- self.port = port
+ self.host = dsn_parameters['host']
+ self.port = dsn_parameters['port']
+ self.extra_args = kwargs
if not self.host:
self.host = self.get_socket_directory()
@@ -272,10 +271,21 @@ class PGExecute(object):
self.pid = pid
self.superuser = db_parameters.get('is_superuser') == '1'
+ self.server_version = self.get_server_version(cursor)
+
register_date_typecasters(conn)
register_json_typecasters(self.conn, self._json_typecaster)
register_hstore_typecaster(self.conn)
+ @property
+ def short_host(self):
+ if ',' in self.host:
+ host, _, _ = self.host.partition(',')
+ else:
+ host = self.host
+ short_host, _, _ = host.partition('.')
+ return short_host
+
def _select_one(self, cur, sql):
"""
Helper method to run a select and retrieve a single field value
@@ -629,10 +639,12 @@ class PGExecute(object):
p.prokind = 'a' is_aggregate,
p.prokind = 'w' is_window,
p.proretset is_set_returning,
+ d.deptype = 'e' is_extension,
pg_get_expr(proargdefaults, 0) AS arg_defaults
FROM pg_catalog.pg_proc p
INNER JOIN pg_catalog.pg_namespace n
ON n.oid = p.pronamespace
+ LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e'
WHERE p.prorettype::regtype != 'trigger'::regtype
ORDER BY 1, 2
'''
@@ -647,10 +659,12 @@ class PGExecute(object):
p.proisagg is_aggregate,
p.proiswindow is_window,
p.proretset is_set_returning,
+ d.deptype = 'e' is_extension,
pg_get_expr(proargdefaults, 0) AS arg_defaults
FROM pg_catalog.pg_proc p
INNER JOIN pg_catalog.pg_namespace n
ON n.oid = p.pronamespace
+ LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e'
WHERE p.prorettype::regtype != 'trigger'::regtype
ORDER BY 1, 2
'''
@@ -665,10 +679,12 @@ class PGExecute(object):
p.proisagg is_aggregate,
false is_window,
p.proretset is_set_returning,
+ d.deptype = 'e' is_extension,
NULL AS arg_defaults
FROM pg_catalog.pg_proc p
- INNER JOIN pg_catalog.pg_namespace n
- ON n.oid = p.pronamespace
+ INNER JOIN pg_catalog.pg_namespace n
+ ON n.oid = p.pronamespace
+ LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e'
WHERE p.prorettype::regtype != 'trigger'::regtype
ORDER BY 1, 2
'''
@@ -683,10 +699,12 @@ class PGExecute(object):
p.proisagg is_aggregate,
false is_window,
p.proretset is_set_returning,
+ d.deptype = 'e' is_extension,
NULL AS arg_defaults
FROM pg_catalog.pg_proc p
- INNER JOIN pg_catalog.pg_namespace n
- ON n.oid = p.pronamespace
+ INNER JOIN pg_catalog.pg_namespace n
+ ON n.oid = p.pronamespace
+ LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e'
WHERE p.prorettype::regtype != 'trigger'::regtype
ORDER BY 1, 2
'''