summaryrefslogtreecommitdiffstats
path: root/pgcli/pgexecute.py
diff options
context:
space:
mode:
authorIrina Truong <i.chernyavska@gmail.com>2019-05-25 13:08:56 -0700
committerGitHub <noreply@github.com>2019-05-25 13:08:56 -0700
commit8cb7009bcd0f0062942932c853706a36178f566c (patch)
treecb5bb42674ccce90b173e7a2f09bf72157ae4a86 /pgcli/pgexecute.py
parenta5e607b6fc889afd3f8960ca3903ae16b641c304 (diff)
black all the things. (#1049)
* added black to develop guide * no need for pep8radius. * changelog. * Add pre-commit checkbox. * Add pre-commit to dev reqs. * Add pyproject.toml for black. * Pre-commit config. * Add black to travis and dev reqs. * Install and run black in travis. * Remove black from dev reqs. * Lower black target version. * Re-format with black.
Diffstat (limited to 'pgcli/pgexecute.py')
-rw-r--r--pgcli/pgexecute.py256
1 files changed, 136 insertions, 120 deletions
diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py
index 337be145..ad5ed4a3 100644
--- a/pgcli/pgexecute.py
+++ b/pgcli/pgexecute.py
@@ -24,7 +24,7 @@ ext.register_type(ext.new_type((2249,), "RECORD", ext.UNICODE))
# Cast bytea fields to text. By default, this will render as hex strings with
# Postgres 9+ and as escaped binary in earlier versions.
-ext.register_type(ext.new_type((17,), 'BYTEA_TEXT', psycopg2.STRING))
+ext.register_type(ext.new_type((17,), "BYTEA_TEXT", psycopg2.STRING))
# TODO: Get default timeout from pgclirc?
_WAIT_SELECT_TIMEOUT = 1
@@ -55,6 +55,7 @@ def _wait_select(conn):
if errno != 4:
raise
+
# When running a query, make pressing CTRL+C raise a KeyboardInterrupt
# See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/
# See also https://github.com/psycopg/psycopg2/issues/468
@@ -66,17 +67,19 @@ def register_date_typecasters(connection):
Casts date and timestamp values to string, resolves issues with out of
range dates (e.g. BC) which psycopg2 can't handle
"""
+
def cast_date(value, cursor):
return value
+
cursor = connection.cursor()
- cursor.execute('SELECT NULL::date')
+ cursor.execute("SELECT NULL::date")
date_oid = cursor.description[0][1]
- cursor.execute('SELECT NULL::timestamp')
+ cursor.execute("SELECT NULL::timestamp")
timestamp_oid = cursor.description[0][1]
- cursor.execute('SELECT NULL::timestamp with time zone')
+ cursor.execute("SELECT NULL::timestamp with time zone")
timestamptz_oid = cursor.description[0][1]
oids = (date_oid, timestamp_oid, timestamptz_oid)
- new_type = psycopg2.extensions.new_type(oids, 'DATE', cast_date)
+ new_type = psycopg2.extensions.new_type(oids, "DATE", cast_date)
psycopg2.extensions.register_type(new_type)
@@ -97,7 +100,7 @@ def register_json_typecasters(conn, loads_fn):
"""
available = set()
- for name in ['json', 'jsonb']:
+ for name in ["json", "jsonb"]:
try:
psycopg2.extras.register_json(conn, loads=loads_fn, name=name)
available.add(name)
@@ -117,7 +120,8 @@ def register_hstore_typecaster(conn):
with conn.cursor() as cur:
try:
cur.execute(
- "select t.oid FROM pg_type t WHERE t.typname = 'hstore' and t.typisdefined")
+ "select t.oid FROM pg_type t WHERE t.typname = 'hstore' and t.typisdefined"
+ )
oid = cur.fetchone()[0]
ext.register_type(ext.new_type((oid,), "HSTORE", ext.UNICODE))
except Exception:
@@ -128,29 +132,29 @@ class PGExecute(object):
# The boolean argument to the current_schemas function indicates whether
# implicit schemas, e.g. pg_catalog
- search_path_query = '''
- SELECT * FROM unnest(current_schemas(true))'''
+ search_path_query = """
+ SELECT * FROM unnest(current_schemas(true))"""
- schemata_query = '''
+ schemata_query = """
SELECT nspname
FROM pg_catalog.pg_namespace
- ORDER BY 1 '''
+ ORDER BY 1 """
- tables_query = '''
+ tables_query = """
SELECT n.nspname schema_name,
c.relname table_name
FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_namespace n
ON n.oid = c.relnamespace
WHERE c.relkind = ANY(%s)
- ORDER BY 1,2;'''
+ ORDER BY 1,2;"""
- databases_query = '''
+ databases_query = """
SELECT d.datname
FROM pg_catalog.pg_database d
- ORDER BY 1'''
+ ORDER BY 1"""
- full_databases_query = '''
+ full_databases_query = """
SELECT d.datname as "Name",
pg_catalog.pg_get_userbyid(d.datdba) as "Owner",
pg_catalog.pg_encoding_to_char(d.encoding) as "Encoding",
@@ -158,15 +162,15 @@ class PGExecute(object):
d.datctype as "Ctype",
pg_catalog.array_to_string(d.datacl, E'\n') AS "Access privileges"
FROM pg_catalog.pg_database d
- ORDER BY 1'''
+ ORDER BY 1"""
- socket_directory_query = '''
+ socket_directory_query = """
SELECT setting
FROM pg_settings
WHERE name = 'unix_socket_directories'
- '''
+ """
- view_definition_query = '''
+ view_definition_query = """
WITH v AS (SELECT %s::pg_catalog.regclass::pg_catalog.oid AS v_oid)
SELECT nspname, relname, relkind,
pg_catalog.pg_get_viewdef(c.oid, true),
@@ -179,18 +183,26 @@ class PGExecute(object):
END AS checkoption
FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_namespace n ON (c.relnamespace = n.oid)
- JOIN v ON (c.oid = v.v_oid)'''
+ JOIN v ON (c.oid = v.v_oid)"""
- function_definition_query = '''
+ function_definition_query = """
WITH f AS
(SELECT %s::pg_catalog.regproc::pg_catalog.oid AS f_oid)
SELECT pg_catalog.pg_get_functiondef(f.f_oid)
- FROM f'''
+ FROM f"""
version_query = "SELECT version();"
- def __init__(self, database=None, user=None, password=None, host=None,
- port=None, dsn=None, **kwargs):
+ def __init__(
+ self,
+ database=None,
+ user=None,
+ password=None,
+ host=None,
+ port=None,
+ dsn=None,
+ **kwargs
+ ):
self._conn_params = {}
self.conn = None
self.dbname = None
@@ -207,10 +219,10 @@ class PGExecute(object):
return self.__class__(**self._conn_params)
def get_server_version(self, cursor):
- _logger.debug('Version Query. sql: %r', self.version_query)
+ _logger.debug("Version Query. sql: %r", self.version_query)
cursor.execute(self.version_query)
result = cursor.fetchone()
- server_version = ''
+ server_version = ""
if result:
# full version string looks like this:
# PostgreSQL 10.3 on x86_64-apple-darwin17.3.0, compiled by Apple LLVM version 9.0.0 (clang-900.0.39.2), 64-bit # noqa
@@ -219,38 +231,42 @@ class PGExecute(object):
server_version = version_parts[1]
return server_version
- def connect(self, database=None, user=None, password=None, host=None,
- port=None, dsn=None, **kwargs):
+ def connect(
+ self,
+ database=None,
+ user=None,
+ password=None,
+ host=None,
+ port=None,
+ dsn=None,
+ **kwargs
+ ):
conn_params = self._conn_params.copy()
new_params = {
- 'database': database,
- 'user': user,
- 'password': password,
- 'host': host,
- 'port': port,
- 'dsn': dsn,
+ "database": database,
+ "user": user,
+ "password": password,
+ "host": host,
+ "port": port,
+ "dsn": dsn,
}
new_params.update(kwargs)
- if new_params['dsn']:
- new_params = {
- 'dsn': new_params['dsn'],
- 'password': new_params['password']
- }
+ if new_params["dsn"]:
+ new_params = {"dsn": new_params["dsn"], "password": new_params["password"]}
- if new_params['password']:
- new_params['dsn'] = make_dsn(
- new_params['dsn'], password=new_params.pop('password'))
+ if new_params["password"]:
+ new_params["dsn"] = make_dsn(
+ new_params["dsn"], password=new_params.pop("password")
+ )
- conn_params.update({
- k: unicode2utf8(v) for k, v in new_params.items() if v
- })
+ conn_params.update({k: unicode2utf8(v) for k, v in new_params.items() if v})
conn = psycopg2.connect(**conn_params)
cursor = conn.cursor()
- conn.set_client_encoding('utf8')
+ conn.set_client_encoding("utf8")
self._conn_params = conn_params
if self.conn:
@@ -264,11 +280,11 @@ class PGExecute(object):
# TODO: use actual connection info from psycopg2.extensions.Connection.info as psycopg>2.8 is available and required dependency # noqa
dsn_parameters = conn.get_dsn_parameters()
- self.dbname = dsn_parameters.get('dbname')
- self.user = dsn_parameters.get('user')
+ self.dbname = dsn_parameters.get("dbname")
+ self.user = dsn_parameters.get("user")
self.password = password
- self.host = dsn_parameters.get('host')
- self.port = dsn_parameters.get('port')
+ self.host = dsn_parameters.get("host")
+ self.port = dsn_parameters.get("port")
self.extra_args = kwargs
if not self.host:
@@ -277,9 +293,9 @@ class PGExecute(object):
cursor.execute("SHOW ALL")
db_parameters = dict(name_val_desc[:2] for name_val_desc in cursor.fetchall())
- pid = self._select_one(cursor, 'select pg_backend_pid()')[0]
+ pid = self._select_one(cursor, "select pg_backend_pid()")[0]
self.pid = pid
- self.superuser = db_parameters.get('is_superuser') == '1'
+ self.superuser = db_parameters.get("is_superuser") == "1"
self.server_version = self.get_server_version(cursor)
@@ -289,11 +305,11 @@ class PGExecute(object):
@property
def short_host(self):
- if ',' in self.host:
- host, _, _ = self.host.partition(',')
+ if "," in self.host:
+ host, _, _ = self.host.partition(",")
else:
host = self.host
- short_host, _, _ = host.partition('.')
+ short_host, _, _ = host.partition(".")
return short_host
def _select_one(self, cur, sql):
@@ -326,11 +342,14 @@ class PGExecute(object):
def valid_transaction(self):
status = self.conn.get_transaction_status()
- return (status == ext.TRANSACTION_STATUS_ACTIVE or
- status == ext.TRANSACTION_STATUS_INTRANS)
-
- def run(self, statement, pgspecial=None, exception_formatter=None,
- on_error_resume=False):
+ return (
+ status == ext.TRANSACTION_STATUS_ACTIVE
+ or status == ext.TRANSACTION_STATUS_INTRANS
+ )
+
+ def run(
+ self, statement, pgspecial=None, exception_formatter=None, on_error_resume=False
+ ):
"""Execute the sql in the database and return the results.
:param statement: A string containing one or more sql statements
@@ -355,12 +374,12 @@ class PGExecute(object):
# Split the sql into separate queries and run each one.
for sql in sqlparse.split(statement):
# Remove spaces, eol and semi-colons.
- sql = sql.rstrip(';')
+ sql = sql.rstrip(";")
try:
if pgspecial:
# First try to run each query as special
- _logger.debug('Trying a pgspecial command. sql: %r', sql)
+ _logger.debug("Trying a pgspecial command. sql: %r", sql)
try:
cur = self.conn.cursor()
except psycopg2.InterfaceError:
@@ -385,8 +404,7 @@ class PGExecute(object):
_logger.error("sql: %r, error: %r", sql, e)
_logger.error("traceback: %r", traceback.format_exc())
- if (self._must_raise(e)
- or not exception_formatter):
+ if self._must_raise(e) or not exception_formatter:
raise
yield None, None, None, exception_formatter(e), sql, False, False
@@ -410,12 +428,12 @@ class PGExecute(object):
def execute_normal_sql(self, split_sql):
"""Returns tuple (title, rows, headers, status)"""
- _logger.debug('Regular sql statement. sql: %r', split_sql)
+ _logger.debug("Regular sql statement. sql: %r", split_sql)
cur = self.conn.cursor()
cur.execute(split_sql)
# conn.notices persist between queies, we use pop to clear out the list
- title = ''
+ title = ""
while len(self.conn.notices) > 0:
title = utf8tounicode(self.conn.notices.pop()) + title
@@ -425,7 +443,7 @@ class PGExecute(object):
headers = [x[0] for x in cur.description]
return title, cur, headers, cur.statusmessage
else:
- _logger.debug('No rows in result.')
+ _logger.debug("No rows in result.")
return title, None, None, cur.statusmessage
def search_path(self):
@@ -433,33 +451,32 @@ class PGExecute(object):
try:
with self.conn.cursor() as cur:
- _logger.debug('Search path query. sql: %r', self.search_path_query)
+ _logger.debug("Search path query. sql: %r", self.search_path_query)
cur.execute(self.search_path_query)
return [x[0] for x in cur.fetchall()]
except psycopg2.ProgrammingError:
- fallback = 'SELECT * FROM current_schemas(true)'
+ fallback = "SELECT * FROM current_schemas(true)"
with self.conn.cursor() as cur:
- _logger.debug('Search path query. sql: %r', fallback)
+ _logger.debug("Search path query. sql: %r", fallback)
cur.execute(fallback)
return cur.fetchone()[0]
def view_definition(self, spec):
"""Returns the SQL defining views described by `spec`"""
- template = 'CREATE OR REPLACE {6} VIEW {0}.{1} AS \n{3}'
+ template = "CREATE OR REPLACE {6} VIEW {0}.{1} AS \n{3}"
# 2: relkind, v or m (materialized)
# 4: reloptions, null
# 5: checkoption: local or cascaded
with self.conn.cursor() as cur:
sql = self.view_definition_query
- _logger.debug('View Definition Query. sql: %r\nspec: %r',
- sql, spec)
+ _logger.debug("View Definition Query. sql: %r\nspec: %r", sql, spec)
try:
- cur.execute(sql, (spec, ))
+ cur.execute(sql, (spec,))
except psycopg2.ProgrammingError:
- raise RuntimeError('View {} does not exist.'.format(spec))
+ raise RuntimeError("View {} does not exist.".format(spec))
result = cur.fetchone()
- view_type = 'MATERIALIZED' if result[2] == 'm' else ''
+ view_type = "MATERIALIZED" if result[2] == "m" else ""
return template.format(*result + (view_type,))
def function_definition(self, spec):
@@ -467,24 +484,23 @@ class PGExecute(object):
with self.conn.cursor() as cur:
sql = self.function_definition_query
- _logger.debug('Function Definition Query. sql: %r\nspec: %r',
- sql, spec)
+ _logger.debug("Function Definition Query. sql: %r\nspec: %r", sql, spec)
try:
cur.execute(sql, (spec,))
result = cur.fetchone()
return result[0]
except psycopg2.ProgrammingError:
- raise RuntimeError('Function {} does not exist.'.format(spec))
+ raise RuntimeError("Function {} does not exist.".format(spec))
def schemata(self):
"""Returns a list of schema names in the database"""
with self.conn.cursor() as cur:
- _logger.debug('Schemata Query. sql: %r', self.schemata_query)
+ _logger.debug("Schemata Query. sql: %r", self.schemata_query)
cur.execute(self.schemata_query)
return [x[0] for x in cur.fetchall()]
- def _relations(self, kinds=('r', 'v', 'm')):
+ def _relations(self, kinds=("r", "v", "m")):
"""Get table or view name metadata
:param kinds: list of postgres relkind filters:
@@ -496,14 +512,14 @@ class PGExecute(object):
with self.conn.cursor() as cur:
sql = cur.mogrify(self.tables_query, [kinds])
- _logger.debug('Tables Query. sql: %r', sql)
+ _logger.debug("Tables Query. sql: %r", sql)
cur.execute(sql)
for row in cur:
yield row
def tables(self):
"""Yields (schema_name, table_name) tuples"""
- for row in self._relations(kinds=['r']):
+ for row in self._relations(kinds=["r"]):
yield row
def views(self):
@@ -511,10 +527,10 @@ class PGExecute(object):
Includes both views and and materialized views
"""
- for row in self._relations(kinds=['v', 'm']):
+ for row in self._relations(kinds=["v", "m"]):
yield row
- def _columns(self, kinds=('r', 'v', 'm')):
+ def _columns(self, kinds=("r", "v", "m")):
"""Get column metadata for tables and views
:param kinds: kinds: list of postgres relkind filters:
@@ -525,7 +541,7 @@ class PGExecute(object):
"""
if self.conn.server_version >= 80400:
- columns_query = '''
+ columns_query = """
SELECT nsp.nspname schema_name,
cls.relname table_name,
att.attname column_name,
@@ -543,9 +559,9 @@ class PGExecute(object):
WHERE cls.relkind = ANY(%s)
AND NOT att.attisdropped
AND att.attnum > 0
- ORDER BY 1, 2, att.attnum'''
+ ORDER BY 1, 2, att.attnum"""
else:
- columns_query = '''
+ columns_query = """
SELECT nsp.nspname schema_name,
cls.relname table_name,
att.attname column_name,
@@ -562,44 +578,44 @@ class PGExecute(object):
WHERE cls.relkind = ANY(%s)
AND NOT att.attisdropped
AND att.attnum > 0
- ORDER BY 1, 2, att.attnum'''
+ ORDER BY 1, 2, att.attnum"""
with self.conn.cursor() as cur:
sql = cur.mogrify(columns_query, [kinds])
- _logger.debug('Columns Query. sql: %r', sql)
+ _logger.debug("Columns Query. sql: %r", sql)
cur.execute(sql)
for row in cur:
yield row
def table_columns(self):
- for row in self._columns(kinds=['r']):
+ for row in self._columns(kinds=["r"]):
yield row
def view_columns(self):
- for row in self._columns(kinds=['v', 'm']):
+ for row in self._columns(kinds=["v", "m"]):
yield row
def databases(self):
with self.conn.cursor() as cur:
- _logger.debug('Databases Query. sql: %r', self.databases_query)
+ _logger.debug("Databases Query. sql: %r", self.databases_query)
cur.execute(self.databases_query)
return [x[0] for x in cur.fetchall()]
def full_databases(self):
with self.conn.cursor() as cur:
- _logger.debug('Databases Query. sql: %r',
- self.full_databases_query)
+ _logger.debug("Databases Query. sql: %r", self.full_databases_query)
cur.execute(self.full_databases_query)
headers = [x[0] for x in cur.description]
return cur.fetchall(), headers, cur.statusmessage
def get_socket_directory(self):
with self.conn.cursor() as cur:
- _logger.debug('Socket directory Query. sql: %r',
- self.socket_directory_query)
+ _logger.debug(
+ "Socket directory Query. sql: %r", self.socket_directory_query
+ )
cur.execute(self.socket_directory_query)
result = cur.fetchone()
- return result[0] if result else ''
+ return result[0] if result else ""
def foreignkeys(self):
"""Yields ForeignKey named tuples"""
@@ -608,7 +624,7 @@ class PGExecute(object):
return
with self.conn.cursor() as cur:
- query = '''
+ query = """
SELECT s_p.nspname AS parentschema,
t_p.relname AS parenttable,
unnest((
@@ -635,8 +651,8 @@ class PGExecute(object):
JOIN pg_catalog.pg_class t_c ON t_c.oid = fk.conrelid
JOIN pg_catalog.pg_namespace s_c ON s_c.oid = t_c.relnamespace
WHERE fk.contype = 'f';
- '''
- _logger.debug('Functions Query. sql: %r', query)
+ """
+ _logger.debug("Functions Query. sql: %r", query)
cur.execute(query)
for row in cur:
yield ForeignKey(*row)
@@ -645,7 +661,7 @@ class PGExecute(object):
"""Yields FunctionMetadata named tuples"""
if self.conn.server_version >= 110000:
- query = '''
+ query = """
SELECT n.nspname schema_name,
p.proname func_name,
p.proargnames,
@@ -663,9 +679,9 @@ class PGExecute(object):
LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e'
WHERE p.prorettype::regtype != 'trigger'::regtype
ORDER BY 1, 2
- '''
+ """
elif self.conn.server_version > 90000:
- query = '''
+ query = """
SELECT n.nspname schema_name,
p.proname func_name,
p.proargnames,
@@ -683,9 +699,9 @@ class PGExecute(object):
LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e'
WHERE p.prorettype::regtype != 'trigger'::regtype
ORDER BY 1, 2
- '''
+ """
elif self.conn.server_version >= 80400:
- query = '''
+ query = """
SELECT n.nspname schema_name,
p.proname func_name,
p.proargnames,
@@ -703,9 +719,9 @@ class PGExecute(object):
LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e'
WHERE p.prorettype::regtype != 'trigger'::regtype
ORDER BY 1, 2
- '''
+ """
else:
- query = '''
+ query = """
SELECT n.nspname schema_name,
p.proname func_name,
p.proargnames,
@@ -723,20 +739,20 @@ class PGExecute(object):
LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e'
WHERE p.prorettype::regtype != 'trigger'::regtype
ORDER BY 1, 2
- '''
+ """
with self.conn.cursor() as cur:
- _logger.debug('Functions Query. sql: %r', query)
+ _logger.debug("Functions Query. sql: %r", query)
cur.execute(query)
for row in cur:
- yield FunctionMetadata(*row)
+ yield FunctionMetadata(*row)
def datatypes(self):
"""Yields tuples of (schema_name, type_name)"""
with self.conn.cursor() as cur:
if self.conn.server_version > 90000:
- query = '''
+ query = """
SELECT n.nspname schema_name,
t.typname type_name
FROM pg_catalog.pg_type t
@@ -757,9 +773,9 @@ class PGExecute(object):
AND n.nspname <> 'pg_catalog'
AND n.nspname <> 'information_schema'
ORDER BY 1, 2;
- '''
+ """
else:
- query = '''
+ query = """
SELECT n.nspname schema_name,
pg_catalog.format_type(t.oid, NULL) type_name
FROM pg_catalog.pg_type t
@@ -770,8 +786,8 @@ class PGExecute(object):
AND n.nspname <> 'information_schema'
AND pg_catalog.pg_type_is_visible(t.oid)
ORDER BY 1, 2;
- '''
- _logger.debug('Datatypes Query. sql: %r', query)
+ """
+ _logger.debug("Datatypes Query. sql: %r", query)
cur.execute(query)
for row in cur:
yield row
@@ -779,7 +795,7 @@ class PGExecute(object):
def casing(self):
"""Yields the most common casing for names used in db functions"""
with self.conn.cursor() as cur:
- query = r'''
+ query = r"""
WITH Words AS (
SELECT regexp_split_to_table(prosrc, '\W+') AS Word, COUNT(1)
FROM pg_catalog.pg_proc P
@@ -819,8 +835,8 @@ class PGExecute(object):
FROM OrderWords
WHERE LOWER(Word) IN (SELECT Name FROM Names)
AND Row_Number = 1;
- '''
- _logger.debug('Casing Query. sql: %r', query)
+ """
+ _logger.debug("Casing Query. sql: %r", query)
cur.execute(query)
for row in cur:
yield row[0]