diff options
Diffstat (limited to 'pgcli/pgexecute.py')
-rw-r--r-- | pgcli/pgexecute.py | 156 |
1 files changed, 87 insertions, 69 deletions
diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py index 8bcc5c63..c7d204f0 100644 --- a/pgcli/pgexecute.py +++ b/pgcli/pgexecute.py @@ -113,7 +113,8 @@ def register_hstore_typecaster(conn): """ with conn.cursor() as cur: try: - cur.execute("SELECT 'hstore'::regtype::oid") + cur.execute( + "select t.oid FROM pg_type t WHERE t.typname = 'hstore' and t.typisdefined") oid = cur.fetchone()[0] ext.register_type(ext.new_type((oid,), "HSTORE", ext.UNICODE)) except Exception: @@ -185,82 +186,80 @@ class PGExecute(object): version_query = "SELECT version();" - def __init__(self, database, user, password, host, port, dsn, **kwargs): - self.dbname = database - self.user = user - self.password = password - self.host = host - self.port = port - self.dsn = dsn - self.extra_args = {k: unicode2utf8(v) for k, v in kwargs.items()} + def __init__(self, database=None, user=None, password=None, host=None, + port=None, dsn=None, **kwargs): + self._conn_params = {} + self.conn = None + self.dbname = None + self.user = None + self.password = None + self.host = None + self.port = None self.server_version = None - self.connect() - - def get_server_version(self): - if self.server_version: - return self.server_version - with self.conn.cursor() as cur: - _logger.debug('Version Query. sql: %r', self.version_query) - cur.execute(self.version_query) - result = cur.fetchone() - if result: - # full version string looks like this: - # PostgreSQL 10.3 on x86_64-apple-darwin17.3.0, compiled by Apple LLVM version 9.0.0 (clang-900.0.39.2), 64-bit # noqa - # let's only retrieve version number - version_parts = result[0].split() - self.server_version = version_parts[1] - else: - self.server_version = '' - return self.server_version + self.connect(database, user, password, host, port, dsn, **kwargs) + + def copy(self): + """Returns a clone of the current executor.""" + return self.__class__(**self._conn_params) + + def get_server_version(self, cursor): + _logger.debug('Version Query. sql: %r', self.version_query) + cursor.execute(self.version_query) + result = cursor.fetchone() + server_version = '' + if result: + # full version string looks like this: + # PostgreSQL 10.3 on x86_64-apple-darwin17.3.0, compiled by Apple LLVM version 9.0.0 (clang-900.0.39.2), 64-bit # noqa + # let's only retrieve version number + version_parts = result[0].split() + server_version = version_parts[1] + return server_version def connect(self, database=None, user=None, password=None, host=None, port=None, dsn=None, **kwargs): - db = (database or self.dbname) - user = (user or self.user) - password = (password or self.password) - host = (host or self.host) - port = (port or self.port) - dsn = (dsn or self.dsn) - kwargs = (kwargs or self.extra_args) - pid = -1 - if dsn: - if password: - dsn = "{0} password={1}".format(dsn, password) - conn = psycopg2.connect(dsn=unicode2utf8(dsn)) - cursor = conn.cursor() - else: - conn = psycopg2.connect( - database=unicode2utf8(db), - user=unicode2utf8(user), - password=unicode2utf8(password), - host=unicode2utf8(host), - port=unicode2utf8(port), - **kwargs) - - cursor = conn.cursor() + conn_params = self._conn_params.copy() + + new_params = { + 'database': database, + 'user': user, + 'password': password, + 'host': host, + 'port': port, + 'dsn': dsn, + } + new_params.update(kwargs) + conn_params.update({ + k: unicode2utf8(v) for k, v in new_params.items() if v is not None + }) + + if 'password' in conn_params and 'dsn' in conn_params: + conn_params['dsn'] = "{0} password={1}".format( + conn_params['dsn'], conn_params.pop('password') + ) + conn = psycopg2.connect(**conn_params) + cursor = conn.cursor() conn.set_client_encoding('utf8') - if hasattr(self, 'conn'): + + self._conn_params = conn_params + if self.conn: self.conn.close() self.conn = conn self.conn.autocommit = True - if dsn: - # When we connect using a DSN, we don't really know what db, - # user, etc. we connected to. Let's read it. - # Note: moved this after setting autocommit because of #664. - dsn_parameters = conn.get_dsn_parameters() - db = dsn_parameters['dbname'] - user = dsn_parameters['user'] - host = dsn_parameters['host'] - port = dsn_parameters['port'] - - self.dbname = db - self.user = user + # When we connect using a DSN, we don't really know what db, + # user, etc. we connected to. Let's read it. + # Note: moved this after setting autocommit because of #664. + # TODO: use actual connection info from psycopg2.extensions.Connection.info as psycopg>2.8 is available and required dependency # noqa + dsn_parameters = conn.get_dsn_parameters() + + self.dbname = dsn_parameters['dbname'] + self.user = dsn_parameters['user'] self.password = password - self.host = host - self.port = port + self.host = dsn_parameters['host'] + self.port = dsn_parameters['port'] + self.extra_args = kwargs if not self.host: self.host = self.get_socket_directory() @@ -272,10 +271,21 @@ class PGExecute(object): self.pid = pid self.superuser = db_parameters.get('is_superuser') == '1' + self.server_version = self.get_server_version(cursor) + register_date_typecasters(conn) register_json_typecasters(self.conn, self._json_typecaster) register_hstore_typecaster(self.conn) + @property + def short_host(self): + if ',' in self.host: + host, _, _ = self.host.partition(',') + else: + host = self.host + short_host, _, _ = host.partition('.') + return short_host + def _select_one(self, cur, sql): """ Helper method to run a select and retrieve a single field value @@ -629,10 +639,12 @@ class PGExecute(object): p.prokind = 'a' is_aggregate, p.prokind = 'w' is_window, p.proretset is_set_returning, + d.deptype = 'e' is_extension, pg_get_expr(proargdefaults, 0) AS arg_defaults FROM pg_catalog.pg_proc p INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace + LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e' WHERE p.prorettype::regtype != 'trigger'::regtype ORDER BY 1, 2 ''' @@ -647,10 +659,12 @@ class PGExecute(object): p.proisagg is_aggregate, p.proiswindow is_window, p.proretset is_set_returning, + d.deptype = 'e' is_extension, pg_get_expr(proargdefaults, 0) AS arg_defaults FROM pg_catalog.pg_proc p INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace + LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e' WHERE p.prorettype::regtype != 'trigger'::regtype ORDER BY 1, 2 ''' @@ -665,10 +679,12 @@ class PGExecute(object): p.proisagg is_aggregate, false is_window, p.proretset is_set_returning, + d.deptype = 'e' is_extension, NULL AS arg_defaults FROM pg_catalog.pg_proc p - INNER JOIN pg_catalog.pg_namespace n - ON n.oid = p.pronamespace + INNER JOIN pg_catalog.pg_namespace n + ON n.oid = p.pronamespace + LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e' WHERE p.prorettype::regtype != 'trigger'::regtype ORDER BY 1, 2 ''' @@ -683,10 +699,12 @@ class PGExecute(object): p.proisagg is_aggregate, false is_window, p.proretset is_set_returning, + d.deptype = 'e' is_extension, NULL AS arg_defaults FROM pg_catalog.pg_proc p - INNER JOIN pg_catalog.pg_namespace n - ON n.oid = p.pronamespace + INNER JOIN pg_catalog.pg_namespace n + ON n.oid = p.pronamespace + LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e' WHERE p.prorettype::regtype != 'trigger'::regtype ORDER BY 1, 2 ''' |