diff options
author | Irina Truong <i.chernyavska@gmail.com> | 2019-05-25 13:08:56 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-05-25 13:08:56 -0700 |
commit | 8cb7009bcd0f0062942932c853706a36178f566c (patch) | |
tree | cb5bb42674ccce90b173e7a2f09bf72157ae4a86 /pgcli/pgexecute.py | |
parent | a5e607b6fc889afd3f8960ca3903ae16b641c304 (diff) |
black all the things. (#1049)
* added black to develop guide
* no need for pep8radius.
* changelog.
* Add pre-commit checkbox.
* Add pre-commit to dev reqs.
* Add pyproject.toml for black.
* Pre-commit config.
* Add black to travis and dev reqs.
* Install and run black in travis.
* Remove black from dev reqs.
* Lower black target version.
* Re-format with black.
Diffstat (limited to 'pgcli/pgexecute.py')
-rw-r--r-- | pgcli/pgexecute.py | 256 |
1 files changed, 136 insertions, 120 deletions
diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py index 337be145..ad5ed4a3 100644 --- a/pgcli/pgexecute.py +++ b/pgcli/pgexecute.py @@ -24,7 +24,7 @@ ext.register_type(ext.new_type((2249,), "RECORD", ext.UNICODE)) # Cast bytea fields to text. By default, this will render as hex strings with # Postgres 9+ and as escaped binary in earlier versions. -ext.register_type(ext.new_type((17,), 'BYTEA_TEXT', psycopg2.STRING)) +ext.register_type(ext.new_type((17,), "BYTEA_TEXT", psycopg2.STRING)) # TODO: Get default timeout from pgclirc? _WAIT_SELECT_TIMEOUT = 1 @@ -55,6 +55,7 @@ def _wait_select(conn): if errno != 4: raise + # When running a query, make pressing CTRL+C raise a KeyboardInterrupt # See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/ # See also https://github.com/psycopg/psycopg2/issues/468 @@ -66,17 +67,19 @@ def register_date_typecasters(connection): Casts date and timestamp values to string, resolves issues with out of range dates (e.g. BC) which psycopg2 can't handle """ + def cast_date(value, cursor): return value + cursor = connection.cursor() - cursor.execute('SELECT NULL::date') + cursor.execute("SELECT NULL::date") date_oid = cursor.description[0][1] - cursor.execute('SELECT NULL::timestamp') + cursor.execute("SELECT NULL::timestamp") timestamp_oid = cursor.description[0][1] - cursor.execute('SELECT NULL::timestamp with time zone') + cursor.execute("SELECT NULL::timestamp with time zone") timestamptz_oid = cursor.description[0][1] oids = (date_oid, timestamp_oid, timestamptz_oid) - new_type = psycopg2.extensions.new_type(oids, 'DATE', cast_date) + new_type = psycopg2.extensions.new_type(oids, "DATE", cast_date) psycopg2.extensions.register_type(new_type) @@ -97,7 +100,7 @@ def register_json_typecasters(conn, loads_fn): """ available = set() - for name in ['json', 'jsonb']: + for name in ["json", "jsonb"]: try: psycopg2.extras.register_json(conn, loads=loads_fn, name=name) available.add(name) @@ -117,7 +120,8 @@ def register_hstore_typecaster(conn): with conn.cursor() as cur: try: cur.execute( - "select t.oid FROM pg_type t WHERE t.typname = 'hstore' and t.typisdefined") + "select t.oid FROM pg_type t WHERE t.typname = 'hstore' and t.typisdefined" + ) oid = cur.fetchone()[0] ext.register_type(ext.new_type((oid,), "HSTORE", ext.UNICODE)) except Exception: @@ -128,29 +132,29 @@ class PGExecute(object): # The boolean argument to the current_schemas function indicates whether # implicit schemas, e.g. pg_catalog - search_path_query = ''' - SELECT * FROM unnest(current_schemas(true))''' + search_path_query = """ + SELECT * FROM unnest(current_schemas(true))""" - schemata_query = ''' + schemata_query = """ SELECT nspname FROM pg_catalog.pg_namespace - ORDER BY 1 ''' + ORDER BY 1 """ - tables_query = ''' + tables_query = """ SELECT n.nspname schema_name, c.relname table_name FROM pg_catalog.pg_class c LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace WHERE c.relkind = ANY(%s) - ORDER BY 1,2;''' + ORDER BY 1,2;""" - databases_query = ''' + databases_query = """ SELECT d.datname FROM pg_catalog.pg_database d - ORDER BY 1''' + ORDER BY 1""" - full_databases_query = ''' + full_databases_query = """ SELECT d.datname as "Name", pg_catalog.pg_get_userbyid(d.datdba) as "Owner", pg_catalog.pg_encoding_to_char(d.encoding) as "Encoding", @@ -158,15 +162,15 @@ class PGExecute(object): d.datctype as "Ctype", pg_catalog.array_to_string(d.datacl, E'\n') AS "Access privileges" FROM pg_catalog.pg_database d - ORDER BY 1''' + ORDER BY 1""" - socket_directory_query = ''' + socket_directory_query = """ SELECT setting FROM pg_settings WHERE name = 'unix_socket_directories' - ''' + """ - view_definition_query = ''' + view_definition_query = """ WITH v AS (SELECT %s::pg_catalog.regclass::pg_catalog.oid AS v_oid) SELECT nspname, relname, relkind, pg_catalog.pg_get_viewdef(c.oid, true), @@ -179,18 +183,26 @@ class PGExecute(object): END AS checkoption FROM pg_catalog.pg_class c LEFT JOIN pg_catalog.pg_namespace n ON (c.relnamespace = n.oid) - JOIN v ON (c.oid = v.v_oid)''' + JOIN v ON (c.oid = v.v_oid)""" - function_definition_query = ''' + function_definition_query = """ WITH f AS (SELECT %s::pg_catalog.regproc::pg_catalog.oid AS f_oid) SELECT pg_catalog.pg_get_functiondef(f.f_oid) - FROM f''' + FROM f""" version_query = "SELECT version();" - def __init__(self, database=None, user=None, password=None, host=None, - port=None, dsn=None, **kwargs): + def __init__( + self, + database=None, + user=None, + password=None, + host=None, + port=None, + dsn=None, + **kwargs + ): self._conn_params = {} self.conn = None self.dbname = None @@ -207,10 +219,10 @@ class PGExecute(object): return self.__class__(**self._conn_params) def get_server_version(self, cursor): - _logger.debug('Version Query. sql: %r', self.version_query) + _logger.debug("Version Query. sql: %r", self.version_query) cursor.execute(self.version_query) result = cursor.fetchone() - server_version = '' + server_version = "" if result: # full version string looks like this: # PostgreSQL 10.3 on x86_64-apple-darwin17.3.0, compiled by Apple LLVM version 9.0.0 (clang-900.0.39.2), 64-bit # noqa @@ -219,38 +231,42 @@ class PGExecute(object): server_version = version_parts[1] return server_version - def connect(self, database=None, user=None, password=None, host=None, - port=None, dsn=None, **kwargs): + def connect( + self, + database=None, + user=None, + password=None, + host=None, + port=None, + dsn=None, + **kwargs + ): conn_params = self._conn_params.copy() new_params = { - 'database': database, - 'user': user, - 'password': password, - 'host': host, - 'port': port, - 'dsn': dsn, + "database": database, + "user": user, + "password": password, + "host": host, + "port": port, + "dsn": dsn, } new_params.update(kwargs) - if new_params['dsn']: - new_params = { - 'dsn': new_params['dsn'], - 'password': new_params['password'] - } + if new_params["dsn"]: + new_params = {"dsn": new_params["dsn"], "password": new_params["password"]} - if new_params['password']: - new_params['dsn'] = make_dsn( - new_params['dsn'], password=new_params.pop('password')) + if new_params["password"]: + new_params["dsn"] = make_dsn( + new_params["dsn"], password=new_params.pop("password") + ) - conn_params.update({ - k: unicode2utf8(v) for k, v in new_params.items() if v - }) + conn_params.update({k: unicode2utf8(v) for k, v in new_params.items() if v}) conn = psycopg2.connect(**conn_params) cursor = conn.cursor() - conn.set_client_encoding('utf8') + conn.set_client_encoding("utf8") self._conn_params = conn_params if self.conn: @@ -264,11 +280,11 @@ class PGExecute(object): # TODO: use actual connection info from psycopg2.extensions.Connection.info as psycopg>2.8 is available and required dependency # noqa dsn_parameters = conn.get_dsn_parameters() - self.dbname = dsn_parameters.get('dbname') - self.user = dsn_parameters.get('user') + self.dbname = dsn_parameters.get("dbname") + self.user = dsn_parameters.get("user") self.password = password - self.host = dsn_parameters.get('host') - self.port = dsn_parameters.get('port') + self.host = dsn_parameters.get("host") + self.port = dsn_parameters.get("port") self.extra_args = kwargs if not self.host: @@ -277,9 +293,9 @@ class PGExecute(object): cursor.execute("SHOW ALL") db_parameters = dict(name_val_desc[:2] for name_val_desc in cursor.fetchall()) - pid = self._select_one(cursor, 'select pg_backend_pid()')[0] + pid = self._select_one(cursor, "select pg_backend_pid()")[0] self.pid = pid - self.superuser = db_parameters.get('is_superuser') == '1' + self.superuser = db_parameters.get("is_superuser") == "1" self.server_version = self.get_server_version(cursor) @@ -289,11 +305,11 @@ class PGExecute(object): @property def short_host(self): - if ',' in self.host: - host, _, _ = self.host.partition(',') + if "," in self.host: + host, _, _ = self.host.partition(",") else: host = self.host - short_host, _, _ = host.partition('.') + short_host, _, _ = host.partition(".") return short_host def _select_one(self, cur, sql): @@ -326,11 +342,14 @@ class PGExecute(object): def valid_transaction(self): status = self.conn.get_transaction_status() - return (status == ext.TRANSACTION_STATUS_ACTIVE or - status == ext.TRANSACTION_STATUS_INTRANS) - - def run(self, statement, pgspecial=None, exception_formatter=None, - on_error_resume=False): + return ( + status == ext.TRANSACTION_STATUS_ACTIVE + or status == ext.TRANSACTION_STATUS_INTRANS + ) + + def run( + self, statement, pgspecial=None, exception_formatter=None, on_error_resume=False + ): """Execute the sql in the database and return the results. :param statement: A string containing one or more sql statements @@ -355,12 +374,12 @@ class PGExecute(object): # Split the sql into separate queries and run each one. for sql in sqlparse.split(statement): # Remove spaces, eol and semi-colons. - sql = sql.rstrip(';') + sql = sql.rstrip(";") try: if pgspecial: # First try to run each query as special - _logger.debug('Trying a pgspecial command. sql: %r', sql) + _logger.debug("Trying a pgspecial command. sql: %r", sql) try: cur = self.conn.cursor() except psycopg2.InterfaceError: @@ -385,8 +404,7 @@ class PGExecute(object): _logger.error("sql: %r, error: %r", sql, e) _logger.error("traceback: %r", traceback.format_exc()) - if (self._must_raise(e) - or not exception_formatter): + if self._must_raise(e) or not exception_formatter: raise yield None, None, None, exception_formatter(e), sql, False, False @@ -410,12 +428,12 @@ class PGExecute(object): def execute_normal_sql(self, split_sql): """Returns tuple (title, rows, headers, status)""" - _logger.debug('Regular sql statement. sql: %r', split_sql) + _logger.debug("Regular sql statement. sql: %r", split_sql) cur = self.conn.cursor() cur.execute(split_sql) # conn.notices persist between queies, we use pop to clear out the list - title = '' + title = "" while len(self.conn.notices) > 0: title = utf8tounicode(self.conn.notices.pop()) + title @@ -425,7 +443,7 @@ class PGExecute(object): headers = [x[0] for x in cur.description] return title, cur, headers, cur.statusmessage else: - _logger.debug('No rows in result.') + _logger.debug("No rows in result.") return title, None, None, cur.statusmessage def search_path(self): @@ -433,33 +451,32 @@ class PGExecute(object): try: with self.conn.cursor() as cur: - _logger.debug('Search path query. sql: %r', self.search_path_query) + _logger.debug("Search path query. sql: %r", self.search_path_query) cur.execute(self.search_path_query) return [x[0] for x in cur.fetchall()] except psycopg2.ProgrammingError: - fallback = 'SELECT * FROM current_schemas(true)' + fallback = "SELECT * FROM current_schemas(true)" with self.conn.cursor() as cur: - _logger.debug('Search path query. sql: %r', fallback) + _logger.debug("Search path query. sql: %r", fallback) cur.execute(fallback) return cur.fetchone()[0] def view_definition(self, spec): """Returns the SQL defining views described by `spec`""" - template = 'CREATE OR REPLACE {6} VIEW {0}.{1} AS \n{3}' + template = "CREATE OR REPLACE {6} VIEW {0}.{1} AS \n{3}" # 2: relkind, v or m (materialized) # 4: reloptions, null # 5: checkoption: local or cascaded with self.conn.cursor() as cur: sql = self.view_definition_query - _logger.debug('View Definition Query. sql: %r\nspec: %r', - sql, spec) + _logger.debug("View Definition Query. sql: %r\nspec: %r", sql, spec) try: - cur.execute(sql, (spec, )) + cur.execute(sql, (spec,)) except psycopg2.ProgrammingError: - raise RuntimeError('View {} does not exist.'.format(spec)) + raise RuntimeError("View {} does not exist.".format(spec)) result = cur.fetchone() - view_type = 'MATERIALIZED' if result[2] == 'm' else '' + view_type = "MATERIALIZED" if result[2] == "m" else "" return template.format(*result + (view_type,)) def function_definition(self, spec): @@ -467,24 +484,23 @@ class PGExecute(object): with self.conn.cursor() as cur: sql = self.function_definition_query - _logger.debug('Function Definition Query. sql: %r\nspec: %r', - sql, spec) + _logger.debug("Function Definition Query. sql: %r\nspec: %r", sql, spec) try: cur.execute(sql, (spec,)) result = cur.fetchone() return result[0] except psycopg2.ProgrammingError: - raise RuntimeError('Function {} does not exist.'.format(spec)) + raise RuntimeError("Function {} does not exist.".format(spec)) def schemata(self): """Returns a list of schema names in the database""" with self.conn.cursor() as cur: - _logger.debug('Schemata Query. sql: %r', self.schemata_query) + _logger.debug("Schemata Query. sql: %r", self.schemata_query) cur.execute(self.schemata_query) return [x[0] for x in cur.fetchall()] - def _relations(self, kinds=('r', 'v', 'm')): + def _relations(self, kinds=("r", "v", "m")): """Get table or view name metadata :param kinds: list of postgres relkind filters: @@ -496,14 +512,14 @@ class PGExecute(object): with self.conn.cursor() as cur: sql = cur.mogrify(self.tables_query, [kinds]) - _logger.debug('Tables Query. sql: %r', sql) + _logger.debug("Tables Query. sql: %r", sql) cur.execute(sql) for row in cur: yield row def tables(self): """Yields (schema_name, table_name) tuples""" - for row in self._relations(kinds=['r']): + for row in self._relations(kinds=["r"]): yield row def views(self): @@ -511,10 +527,10 @@ class PGExecute(object): Includes both views and and materialized views """ - for row in self._relations(kinds=['v', 'm']): + for row in self._relations(kinds=["v", "m"]): yield row - def _columns(self, kinds=('r', 'v', 'm')): + def _columns(self, kinds=("r", "v", "m")): """Get column metadata for tables and views :param kinds: kinds: list of postgres relkind filters: @@ -525,7 +541,7 @@ class PGExecute(object): """ if self.conn.server_version >= 80400: - columns_query = ''' + columns_query = """ SELECT nsp.nspname schema_name, cls.relname table_name, att.attname column_name, @@ -543,9 +559,9 @@ class PGExecute(object): WHERE cls.relkind = ANY(%s) AND NOT att.attisdropped AND att.attnum > 0 - ORDER BY 1, 2, att.attnum''' + ORDER BY 1, 2, att.attnum""" else: - columns_query = ''' + columns_query = """ SELECT nsp.nspname schema_name, cls.relname table_name, att.attname column_name, @@ -562,44 +578,44 @@ class PGExecute(object): WHERE cls.relkind = ANY(%s) AND NOT att.attisdropped AND att.attnum > 0 - ORDER BY 1, 2, att.attnum''' + ORDER BY 1, 2, att.attnum""" with self.conn.cursor() as cur: sql = cur.mogrify(columns_query, [kinds]) - _logger.debug('Columns Query. sql: %r', sql) + _logger.debug("Columns Query. sql: %r", sql) cur.execute(sql) for row in cur: yield row def table_columns(self): - for row in self._columns(kinds=['r']): + for row in self._columns(kinds=["r"]): yield row def view_columns(self): - for row in self._columns(kinds=['v', 'm']): + for row in self._columns(kinds=["v", "m"]): yield row def databases(self): with self.conn.cursor() as cur: - _logger.debug('Databases Query. sql: %r', self.databases_query) + _logger.debug("Databases Query. sql: %r", self.databases_query) cur.execute(self.databases_query) return [x[0] for x in cur.fetchall()] def full_databases(self): with self.conn.cursor() as cur: - _logger.debug('Databases Query. sql: %r', - self.full_databases_query) + _logger.debug("Databases Query. sql: %r", self.full_databases_query) cur.execute(self.full_databases_query) headers = [x[0] for x in cur.description] return cur.fetchall(), headers, cur.statusmessage def get_socket_directory(self): with self.conn.cursor() as cur: - _logger.debug('Socket directory Query. sql: %r', - self.socket_directory_query) + _logger.debug( + "Socket directory Query. sql: %r", self.socket_directory_query + ) cur.execute(self.socket_directory_query) result = cur.fetchone() - return result[0] if result else '' + return result[0] if result else "" def foreignkeys(self): """Yields ForeignKey named tuples""" @@ -608,7 +624,7 @@ class PGExecute(object): return with self.conn.cursor() as cur: - query = ''' + query = """ SELECT s_p.nspname AS parentschema, t_p.relname AS parenttable, unnest(( @@ -635,8 +651,8 @@ class PGExecute(object): JOIN pg_catalog.pg_class t_c ON t_c.oid = fk.conrelid JOIN pg_catalog.pg_namespace s_c ON s_c.oid = t_c.relnamespace WHERE fk.contype = 'f'; - ''' - _logger.debug('Functions Query. sql: %r', query) + """ + _logger.debug("Functions Query. sql: %r", query) cur.execute(query) for row in cur: yield ForeignKey(*row) @@ -645,7 +661,7 @@ class PGExecute(object): """Yields FunctionMetadata named tuples""" if self.conn.server_version >= 110000: - query = ''' + query = """ SELECT n.nspname schema_name, p.proname func_name, p.proargnames, @@ -663,9 +679,9 @@ class PGExecute(object): LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e' WHERE p.prorettype::regtype != 'trigger'::regtype ORDER BY 1, 2 - ''' + """ elif self.conn.server_version > 90000: - query = ''' + query = """ SELECT n.nspname schema_name, p.proname func_name, p.proargnames, @@ -683,9 +699,9 @@ class PGExecute(object): LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e' WHERE p.prorettype::regtype != 'trigger'::regtype ORDER BY 1, 2 - ''' + """ elif self.conn.server_version >= 80400: - query = ''' + query = """ SELECT n.nspname schema_name, p.proname func_name, p.proargnames, @@ -703,9 +719,9 @@ class PGExecute(object): LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e' WHERE p.prorettype::regtype != 'trigger'::regtype ORDER BY 1, 2 - ''' + """ else: - query = ''' + query = """ SELECT n.nspname schema_name, p.proname func_name, p.proargnames, @@ -723,20 +739,20 @@ class PGExecute(object): LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e' WHERE p.prorettype::regtype != 'trigger'::regtype ORDER BY 1, 2 - ''' + """ with self.conn.cursor() as cur: - _logger.debug('Functions Query. sql: %r', query) + _logger.debug("Functions Query. sql: %r", query) cur.execute(query) for row in cur: - yield FunctionMetadata(*row) + yield FunctionMetadata(*row) def datatypes(self): """Yields tuples of (schema_name, type_name)""" with self.conn.cursor() as cur: if self.conn.server_version > 90000: - query = ''' + query = """ SELECT n.nspname schema_name, t.typname type_name FROM pg_catalog.pg_type t @@ -757,9 +773,9 @@ class PGExecute(object): AND n.nspname <> 'pg_catalog' AND n.nspname <> 'information_schema' ORDER BY 1, 2; - ''' + """ else: - query = ''' + query = """ SELECT n.nspname schema_name, pg_catalog.format_type(t.oid, NULL) type_name FROM pg_catalog.pg_type t @@ -770,8 +786,8 @@ class PGExecute(object): AND n.nspname <> 'information_schema' AND pg_catalog.pg_type_is_visible(t.oid) ORDER BY 1, 2; - ''' - _logger.debug('Datatypes Query. sql: %r', query) + """ + _logger.debug("Datatypes Query. sql: %r", query) cur.execute(query) for row in cur: yield row @@ -779,7 +795,7 @@ class PGExecute(object): def casing(self): """Yields the most common casing for names used in db functions""" with self.conn.cursor() as cur: - query = r''' + query = r""" WITH Words AS ( SELECT regexp_split_to_table(prosrc, '\W+') AS Word, COUNT(1) FROM pg_catalog.pg_proc P @@ -819,8 +835,8 @@ class PGExecute(object): FROM OrderWords WHERE LOWER(Word) IN (SELECT Name FROM Names) AND Row_Number = 1; - ''' - _logger.debug('Casing Query. sql: %r', query) + """ + _logger.debug("Casing Query. sql: %r", query) cur.execute(query) for row in cur: yield row[0] |