diff options
author | Amjith Ramanujam <amjith.r@gmail.com> | 2019-03-16 14:06:24 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-03-16 14:06:24 -0700 |
commit | cbd944ffab7206d03a675dc5384553874c2889a7 (patch) | |
tree | dcdc516de86c1a0740efb05e42712cff59daedfc | |
parent | 945461bfa2f901dd66c216a9709b00e6e140be0e (diff) | |
parent | 7d6523e18b2c752e0a404746018df77cbad5bcdb (diff) |
Merge branch 'master' into tab-on-line-start
-rw-r--r-- | AUTHORS | 5 | ||||
-rw-r--r-- | changelog.rst | 48 | ||||
-rw-r--r-- | pgcli/__init__.py | 2 | ||||
-rw-r--r-- | pgcli/completion_refresher.py | 8 | ||||
-rw-r--r-- | pgcli/key_bindings.py | 13 | ||||
-rw-r--r-- | pgcli/main.py | 83 | ||||
-rw-r--r-- | pgcli/packages/parseutils/meta.py | 16 | ||||
-rw-r--r-- | pgcli/packages/parseutils/tables.py | 54 | ||||
-rw-r--r-- | pgcli/pgcompleter.py | 18 | ||||
-rw-r--r-- | pgcli/pgexecute.py | 156 | ||||
-rw-r--r-- | release.py | 2 | ||||
-rw-r--r-- | setup.py | 2 | ||||
-rw-r--r-- | tests/conftest.py | 6 | ||||
-rw-r--r-- | tests/features/environment.py | 1 | ||||
-rw-r--r-- | tests/features/steps/wrappers.py | 4 | ||||
-rw-r--r-- | tests/metadata.py | 8 | ||||
-rw-r--r-- | tests/parseutils/test_function_metadata.py | 9 | ||||
-rw-r--r-- | tests/test_main.py | 14 | ||||
-rw-r--r-- | tests/test_naive_completion.py | 48 | ||||
-rw-r--r-- | tests/test_pgexecute.py | 29 | ||||
-rw-r--r-- | tests/test_smart_completion_multiple_schemata.py | 154 | ||||
-rw-r--r-- | tests/test_smart_completion_public_schema_only.py | 282 | ||||
-rw-r--r-- | tests/utils.py | 11 | ||||
-rw-r--r-- | tox.ini | 2 |
24 files changed, 585 insertions, 390 deletions
@@ -86,6 +86,11 @@ Contributors: * Kenny Do * Max Rothman * Daniel Egger + * Ignacio Campabadal + * Mikhail Elovskikh (wronglink) + * Marcin Cieślak (saper) + * easteregg (verfriemelt-dot-org) + * Scott Brenstuhl (808sAndBR) Creator: -------- diff --git a/changelog.rst b/changelog.rst index 32023f38..2a3e446f 100644 --- a/changelog.rst +++ b/changelog.rst @@ -1,6 +1,44 @@ Upcoming: ========= +Features: +--------- + +* Keybindings for closing the autocomplete list. (Thanks: `easteregg`_) +* Reconnect automatically when server closes connection. (Thanks: `Scott Brenstuhl`_) + +Bug fixes: +---------- +* Avoid error message on the server side if hstore extension is not installed in the current database (#991). (Thanks: `Marcin Cieślak`_) +* All pexpect submodules have been moved into the pexpect package as of version 3.0. Use pexpect.TIMEOUT (Thanks: `Marcin Cieślak`_) +* Fix crash retrieving server version with ``--single-connection``. (Thanks: `Irina Truong`_) + +Internal: +--------- + +* (Fixup) Clean up and add behave logging. (Thanks: `Marcin Cieślak`_, `Dick Marinus`_) +* Override VISUAL environment variable for behave tests. (Thanks: `Marcin Cieślak`_) +* Remove build dir before running sdist, remove stray files from wheel distribution. (Thanks: `Dick Marinus`_) +* Fix unit tests, unhashable formatted text since new python prompttoolkit version. (Thanks: `Dick Marinus`_) + +2.0.2: +====== + +Features: +--------- + +* Allows passing the ``-u`` flag to specify a username. (Thanks: `Ignacio Campabadal`_) +* Fix for lag in v2 (#979). (Thanks: `Irina Truong`_) +* Support for multihost connection string that is convenient if you have postgres cluster. (Thanks: `Mikhail Elovskikh`_) + +Internal: +--------- + +* Added tests for special command completion. (Thanks: `Amjith Ramanujam`_) + +2.0.1: +====== + Bug fixes: ---------- @@ -9,13 +47,15 @@ Bug fixes: * Fix for loading/saving named queries from provided config file (#938). (Thanks: `Daniel Egger`_) * Set default port in `connect_uri` when none is given. (Thanks: `Daniel Egger`_) * Fix for error listing databases (#951). (Thanks: `Irina Truong`_) +* Enable Ctrl-Z to suspend the app (Thanks: `Amjith Ramanujam`_). +* Fix StopIteration exception raised at runtime for Python 3.7 (Thanks: `Amjith Ramanujam`_). Internal: --------- * Clean up and add behave logging. (Thanks: `Dick Marinus`_) * Require prompt_toolkit>=2.0.6. (Thanks: `Dick Marinus`_) -* Improve development guide (Thanks: `Ignacio Campabadal`_) +* Improve development guide. (Thanks: `Ignacio Campabadal`_) 2.0.0: ====== @@ -831,7 +871,7 @@ Improvements: * Faster test runs on TravisCI. (Thanks: https://github.com/macobo) * Integration tests with Postgres!! (Thanks: https://github.com/macobo) -.. _`Amjith Ramanujam`: https://github.com/amjith +.. _`Amjith Ramanujam`: https://blog.amjith.com .. _`Andrew Kuchling`: https://github.com/akuchling .. _`Darik Gamble`: https://github.com/darikg .. _`Daniel Rocco`: https://github.com/drocco007 @@ -903,3 +943,7 @@ Improvements: .. _`Max Rothman`: https://github.com/maxrothman .. _`Daniel Egger`: https://github.com/DanEEStar .. _`Ignacio Campabadal`: https://github.com/igncampa +.. _`Mikhail Elovskikh`: https://github.com/wronglink +.. _`Marcin Cieślak`: https://github.com/saper +.. _`Scott Brenstuhl`: https://github.com/808sAndBR +.. _`easteregg`: https://github.com/verfriemelt-dot-org diff --git a/pgcli/__init__.py b/pgcli/__init__.py index afced147..668c3446 100644 --- a/pgcli/__init__.py +++ b/pgcli/__init__.py @@ -1 +1 @@ -__version__ = '2.0.0' +__version__ = '2.0.2' diff --git a/pgcli/completion_refresher.py b/pgcli/completion_refresher.py index 2f6908e1..388bb295 100644 --- a/pgcli/completion_refresher.py +++ b/pgcli/completion_refresher.py @@ -53,12 +53,8 @@ class CompletionRefresher(object): if settings.get('single_connection'): executor = pgexecute else: - # Create a new pgexecute method to popoulate the completions. - e = pgexecute - executor = PGExecute( - e.dbname, e.user, e.password, e.host, e.port, e.dsn, - **e.extra_args) - + # Create a new pgexecute method to populate the completions. + executor = pgexecute.copy() # If callbacks is a single function then push it into a list. if callable(callbacks): callbacks = [callbacks] diff --git a/pgcli/key_bindings.py b/pgcli/key_bindings.py index 6f856197..cda4fb6c 100644 --- a/pgcli/key_bindings.py +++ b/pgcli/key_bindings.py @@ -51,6 +51,16 @@ def pgcli_bindings(pgcli): else: buff.insert_text(tab_insert_text, fire_event=False) + @kb.add('escape') + def _(event): + """Force closing of autocompletion.""" + _logger.debug('Detected <Esc> key.') + + event.current_buffer.complete_state = None + event.app.current_buffer.complete_state = None + + + @kb.add('c-space') def _(event): """ @@ -81,7 +91,6 @@ def pgcli_bindings(pgcli): _logger.debug('Detected enter key.') event.current_buffer.complete_state = None - b = event.app.current_buffer - b.complete_state = None + event.app.current_buffer.complete_state = None return kb diff --git a/pgcli/main.py b/pgcli/main.py index 86a86b1e..8078c5ce 100644 --- a/pgcli/main.py +++ b/pgcli/main.py @@ -30,7 +30,7 @@ try: import setproctitle except ImportError: setproctitle = None -from prompt_toolkit.completion import DynamicCompleter +from prompt_toolkit.completion import DynamicCompleter, ThreadedCompleter from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode from prompt_toolkit.shortcuts import PromptSession, CompleteStyle from prompt_toolkit.document import Document @@ -87,14 +87,15 @@ MetaQuery = namedtuple( [ 'query', # The entire text of the command 'successful', # True If all subqueries were successful - 'total_time', # Time elapsed executing the query + 'total_time', # Time elapsed executing the query and formatting results + 'execution_time', # Time elapsed executing the query 'meta_changed', # True if any subquery executed create/alter/drop 'db_changed', # True if any subquery changed the database 'path_changed', # True if any subquery changed the search path 'mutated', # True if any subquery executed insert/update/delete 'is_special', # True if the query is a special command ]) -MetaQuery.__new__.__defaults__ = ('', False, 0, False, False, False, False) +MetaQuery.__new__.__defaults__ = ('', False, 0, 0, False, False, False, False) OutputSettings = namedtuple( 'OutputSettings', @@ -385,25 +386,13 @@ class PGCli(object): self.connect(dsn=dsn) def connect_uri(self, uri): - uri = urlparse(uri) - database = uri.path[1:] # ignore the leading fwd slash - - def fixup_possible_percent_encoding(s): - return unquote(str(s)) if s else s - - arguments = dict(database=fixup_possible_percent_encoding(database), - host=fixup_possible_percent_encoding(uri.hostname), - user=fixup_possible_percent_encoding(uri.username), - port=fixup_possible_percent_encoding(uri.port), - passwd=fixup_possible_percent_encoding(uri.password)) - # Deal with extra params e.g. ?sslmode=verify-ca&sslrootcert=/myrootcert - if uri.query: - arguments = dict( - {k: v for k, (v,) in parse_qs(uri.query).items()}, - **arguments) - - # unquote str(each URI part (they may be percent encoded) - self.connect(**arguments) + kwargs = psycopg2.extensions.parse_dsn(uri) + remap = { + 'dbname': 'database', + 'password': 'passwd', + } + kwargs = {remap.get(k, k): v for k, v in kwargs.items()} + self.connect(**kwargs) def connect(self, database='', host='', user='', port='', passwd='', dsn='', **kwargs): @@ -566,8 +555,8 @@ class PGCli(object): except OperationalError as e: logger.error("sql: %r, error: %r", text, e) logger.error("traceback: %r", traceback.format_exc()) - self._handle_server_closed_connection() - except PgCliQuitError as e: + self._handle_server_closed_connection(text) + except (PgCliQuitError, EOFError) as e: raise except Exception as e: logger.error("sql: %r, error: %r", text, e) @@ -591,8 +580,11 @@ class PGCli(object): if self.pgspecial.timing_enabled: # Only add humanized time display if > 1 second if query.total_time > 1: - print('Time: %0.03fs (%s)' % (query.total_time, - humanize.time.naturaldelta(query.total_time))) + print('Time: %0.03fs (%s), executed in: %0.03fs (%s)' % (query.total_time, + humanize.time.naturaldelta( + query.total_time), + query.execution_time, + humanize.time.naturaldelta(query.execution_time))) else: print('Time: %0.03fs' % query.total_time) @@ -626,7 +618,7 @@ class PGCli(object): self.prompt_app = self._build_cli(history) if not self.less_chatty: - print('Server: PostgreSQL', self.pgexecute.get_server_version()) + print('Server: PostgreSQL', self.pgexecute.server_version) print('Version:', __version__) print('Chat: https://gitter.im/dbcli/pgcli') print('Mail: https://groups.google.com/forum/#!forum/pgcli') @@ -722,13 +714,15 @@ class PGCli(object): tempfile_suffix='.sql', multiline=pg_is_multiline(self), history=history, - completer=DynamicCompleter(lambda: self.completer), + completer=ThreadedCompleter( + DynamicCompleter(lambda: self.completer)), complete_while_typing=True, style=style_factory(self.syntax_style, self.cli_style), include_default_pygments_style=False, key_bindings=key_bindings, enable_open_in_editor=True, enable_system_prompt=True, + enable_suspend=True, editing_mode=EditingMode.VI if self.vi_mode else EditingMode.EMACS, search_ignore_case=True) @@ -756,6 +750,7 @@ class PGCli(object): path_changed = False output = [] total = 0 + execution = 0 # Run the query. start = time() @@ -794,6 +789,7 @@ class PGCli(object): ), style_output=self.style_output ) + execution = time() - start formatted = format_output(title, cur, headers, status, settings) output.extend(formatted) @@ -809,22 +805,21 @@ class PGCli(object): else: all_success = False - meta_query = MetaQuery(text, all_success, total, meta_changed, + meta_query = MetaQuery(text, all_success, total, execution, meta_changed, db_changed, path_changed, mutated, is_special) return output, meta_query - def _handle_server_closed_connection(self): - """Used during CLI execution""" - reconnect = click.prompt( - 'Connection reset. Reconnect (Y/n)', - show_default=False, type=bool, default=True) - if reconnect: - try: - self.pgexecute.connect() - click.secho('Reconnected!\nTry the command again.', fg='green') - except OperationalError as e: - click.secho(str(e), err=True, fg='red') + def _handle_server_closed_connection(self, text): + """Used during CLI execution.""" + try: + click.secho('Reconnecting...', fg='green') + self.pgexecute.connect() + click.secho('Reconnected!', fg='green') + self.execute_command(text) + except OperationalError as e: + click.secho('Reconnect Failed', fg='red') + click.secho(str(e), err=True, fg='red') def refresh_completions(self, history=None, persist_priorities='all'): """ Refresh outdated completions @@ -892,10 +887,8 @@ class PGCli(object): string = string.replace('\\dsn_alias', self.dsn_alias or '') string = string.replace('\\t', self.now.strftime('%x %X')) string = string.replace('\\u', self.pgexecute.user or '(none)') - host = self.pgexecute.host or '(none)' - string = string.replace('\\H', host) - short_host, _, _ = host.partition('.') - string = string.replace('\\h', short_host) + string = string.replace('\\H', self.pgexecute.host or '(none)') + string = string.replace('\\h', self.pgexecute.short_host or '(none)') string = string.replace('\\d', self.pgexecute.dbname or '(none)') string = string.replace('\\p', str( self.pgexecute.port) if self.pgexecute.port is not None else '5432') @@ -941,7 +934,7 @@ class PGCli(object): @click.option('-p', '--port', default=5432, help='Port number at which the ' 'postgres instance is listening.', envvar='PGPORT', type=click.INT) @click.option('-U', '--username', 'username_opt', help='Username to connect to the postgres database.') -@click.option('--user', 'username_opt', help='Username to connect to the postgres database.') +@click.option('-u', '--user', 'username_opt', help='Username to connect to the postgres database.') @click.option('-W', '--password', 'prompt_passwd', is_flag=True, default=False, help='Force password prompt.') @click.option('-w', '--no-password', 'never_prompt', is_flag=True, diff --git a/pgcli/packages/parseutils/meta.py b/pgcli/packages/parseutils/meta.py index 16a49676..32dd0aff 100644 --- a/pgcli/packages/parseutils/meta.py +++ b/pgcli/packages/parseutils/meta.py @@ -51,7 +51,8 @@ class FunctionMetadata(object): def __init__( self, schema_name, func_name, arg_names, arg_types, arg_modes, - return_type, is_aggregate, is_window, is_set_returning, arg_defaults + return_type, is_aggregate, is_window, is_set_returning, is_extension, + arg_defaults ): """Class for describing a postgresql function""" @@ -79,6 +80,8 @@ class FunctionMetadata(object): self.is_aggregate = is_aggregate self.is_window = is_window self.is_set_returning = is_set_returning + self.is_extension = bool(is_extension) + self.is_public = (self.schema_name and self.schema_name == 'public') def __eq__(self, other): return (isinstance(other, self.__class__) @@ -89,9 +92,9 @@ class FunctionMetadata(object): def _signature(self): return ( - self.schema_name, self.func_name, self.arg_names, self.arg_types, - self.arg_modes, self.return_type, self.is_aggregate, - self.is_window, self.is_set_returning, self.arg_defaults + self.schema_name, self.func_name, self.arg_names, + self.arg_types, self.arg_modes, self.return_type, self.is_aggregate, + self.is_window, self.is_set_returning, self.is_extension, self.arg_defaults ) def __hash__(self): @@ -102,8 +105,8 @@ class FunctionMetadata(object): ( '%s(schema_name=%r, func_name=%r, arg_names=%r, ' 'arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, ' - 'is_window=%r, is_set_returning=%r, arg_defaults=%r)' - ) % (self.__class__.__name__,) + self._signature() + 'is_window=%r, is_set_returning=%r, is_extension=%r, arg_defaults=%r)' + ) % ((self.__class__.__name__,) + self._signature()) ) def has_variadic(self): @@ -132,7 +135,6 @@ class FunctionMetadata(object): return [arg(name, typ, num) for num, (name, typ) in enumerate(args)] - def fields(self): """Returns a list of output-field ColumnMetadata namedtuples""" diff --git a/pgcli/packages/parseutils/tables.py b/pgcli/packages/parseutils/tables.py index dc80a2ce..607b5009 100644 --- a/pgcli/packages/parseutils/tables.py +++ b/pgcli/packages/parseutils/tables.py @@ -36,11 +36,11 @@ def extract_from_part(parsed, stop_at_punctuation=True): for x in extract_from_part(item, stop_at_punctuation): yield x elif stop_at_punctuation and item.ttype is Punctuation: - raise StopIteration + return # An incomplete nested select won't be recognized correctly as a # sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes # the second FROM to trigger this elif condition resulting in a - # StopIteration. So we need to ignore the keyword if the keyword + # `return`. So we need to ignore the keyword if the keyword # FROM. # Also 'SELECT * FROM abc JOIN def' will trigger this elif # condition. So we need to ignore the keyword JOIN and its variants @@ -93,30 +93,32 @@ def extract_table_identifiers(token_stream, allow_functions=True): name = name.lower() return schema_name, name, alias - - for item in token_stream: - if isinstance(item, IdentifierList): - for identifier in item.get_identifiers(): - # Sometimes Keywords (such as FROM ) are classified as - # identifiers which don't have the get_real_name() method. - try: - schema_name = identifier.get_parent_name() - real_name = identifier.get_real_name() - is_function = (allow_functions and - _identifier_is_function(identifier)) - except AttributeError: - continue - if real_name: - yield TableReference(schema_name, real_name, - identifier.get_alias(), is_function) - elif isinstance(item, Identifier): - schema_name, real_name, alias = parse_identifier(item) - is_function = allow_functions and _identifier_is_function(item) - - yield TableReference(schema_name, real_name, alias, is_function) - elif isinstance(item, Function): - schema_name, real_name, alias = parse_identifier(item) - yield TableReference(None, real_name, alias, allow_functions) + try: + for item in token_stream: + if isinstance(item, IdentifierList): + for identifier in item.get_identifiers(): + # Sometimes Keywords (such as FROM ) are classified as + # identifiers which don't have the get_real_name() method. + try: + schema_name = identifier.get_parent_name() + real_name = identifier.get_real_name() + is_function = (allow_functions and + _identifier_is_function(identifier)) + except AttributeError: + continue + if real_name: + yield TableReference(schema_name, real_name, + identifier.get_alias(), is_function) + elif isinstance(item, Identifier): + schema_name, real_name, alias = parse_identifier(item) + is_function = allow_functions and _identifier_is_function(item) + + yield TableReference(schema_name, real_name, alias, is_function) + elif isinstance(item, Function): + schema_name, real_name, alias = parse_identifier(item) + yield TableReference(None, real_name, alias, allow_functions) + except StopIteration: + return # extract_tables is inspired from examples in the sqlparse lib. diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py index 2ec0ac5a..a4c3724d 100644 --- a/pgcli/pgcompleter.py +++ b/pgcli/pgcompleter.py @@ -50,6 +50,7 @@ arg_default_type_strip_regex = re.compile(r'::[\w\.]+(\[\])?$') normalize_ref = lambda ref: ref if ref[0] == '"' else '"' + ref.lower() + '"' + def generate_alias(tbl): """ Generate a table alias, consisting of all upper-case letters in the table name, or, if there are no upper-case letters, the first letter + @@ -636,22 +637,33 @@ class PGCompleter(Completer): return self.find_matches(word_before_cursor, conds, meta='join') def get_function_matches(self, suggestion, word_before_cursor, alias=False): + if suggestion.usage == 'from': # Only suggest functions allowed in FROM clause - def filt(f): return not f.is_aggregate and not f.is_window + + def filt(f): + return (not f.is_aggregate and + not f.is_window and + not f.is_extension and + (f.is_public or f.schema_name == suggestion.schema)) else: alias = False - def filt(f): return True + def filt(f): + return (not f.is_extension and + (f.is_public or f.schema_name == suggestion.schema)) + arg_mode = { 'signature': 'signature', 'special': None, }.get(suggestion.usage, 'call') + # Function overloading means we way have multiple functions of the same # name at this point, so keep unique names only + all_functions = self.populate_functions(suggestion.schema, filt) funcs = set( self._make_cand(f, alias, suggestion, arg_mode) - for f in self.populate_functions(suggestion.schema, filt) + for f in all_functions ) matches = self.find_matches(word_before_cursor, funcs, meta='function') diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py index 8bcc5c63..c7d204f0 100644 --- a/pgcli/pgexecute.py +++ b/pgcli/pgexecute.py @@ -113,7 +113,8 @@ def register_hstore_typecaster(conn): """ with conn.cursor() as cur: try: - cur.execute("SELECT 'hstore'::regtype::oid") + cur.execute( + "select t.oid FROM pg_type t WHERE t.typname = 'hstore' and t.typisdefined") oid = cur.fetchone()[0] ext.register_type(ext.new_type((oid,), "HSTORE", ext.UNICODE)) except Exception: @@ -185,82 +186,80 @@ class PGExecute(object): version_query = "SELECT version();" - def __init__(self, database, user, password, host, port, dsn, **kwargs): - self.dbname = database - self.user = user - self.password = password - self.host = host - self.port = port - self.dsn = dsn - self.extra_args = {k: unicode2utf8(v) for k, v in kwargs.items()} + def __init__(self, database=None, user=None, password=None, host=None, + port=None, dsn=None, **kwargs): + self._conn_params = {} + self.conn = None + self.dbname = None + self.user = None + self.password = None + self.host = None + self.port = None self.server_version = None - self.connect() - - def get_server_version(self): - if self.server_version: - return self.server_version - with self.conn.cursor() as cur: - _logger.debug('Version Query. sql: %r', self.version_query) - cur.execute(self.version_query) - result = cur.fetchone() - if result: - # full version string looks like this: - # PostgreSQL 10.3 on x86_64-apple-darwin17.3.0, compiled by Apple LLVM version 9.0.0 (clang-900.0.39.2), 64-bit # noqa - # let's only retrieve version number - version_parts = result[0].split() - self.server_version = version_parts[1] - else: - self.server_version = '' - return self.server_version + self.connect(database, user, password, host, port, dsn, **kwargs) + + def copy(self): + """Returns a clone of the current executor.""" + return self.__class__(**self._conn_params) + + def get_server_version(self, cursor): + _logger.debug('Version Query. sql: %r', self.version_query) + cursor.execute(self.version_query) + result = cursor.fetchone() + server_version = '' + if result: + # full version string looks like this: + # PostgreSQL 10.3 on x86_64-apple-darwin17.3.0, compiled by Apple LLVM version 9.0.0 (clang-900.0.39.2), 64-bit # noqa + # let's only retrieve version number + version_parts = result[0].split() + server_version = version_parts[1] + return server_version def connect(self, database=None, user=None, password=None, host=None, port=None, dsn=None, **kwargs): - db = (database or self.dbname) - user = (user or self.user) - password = (password or self.password) - host = (host or self.host) - port = (port or self.port) - dsn = (dsn or self.dsn) - kwargs = (kwargs or self.extra_args) - pid = -1 - if dsn: - if password: - dsn = "{0} password={1}".format(dsn, password) - conn = psycopg2.connect(dsn=unicode2utf8(dsn)) - cursor = conn.cursor() - else: - conn = psycopg2.connect( - database=unicode2utf8(db), - user=unicode2utf8(user), - password=unicode2utf8(password), - host=unicode2utf8(host), - port=unicode2utf8(port), - **kwargs) - - cursor = conn.cursor() + conn_params = self._conn_params.copy() + + new_params = { + 'database': database, + 'user': user, + 'password': password, + 'host': host, + 'port': port, + 'dsn': dsn, + } + new_params.update(kwargs) + conn_params.update({ + k: unicode2utf8(v) for k, v in new_params.items() if v is not None + }) + + if 'password' in conn_params and 'dsn' in conn_params: + conn_params['dsn'] = "{0} password={1}".format( + conn_params['dsn'], conn_params.pop('password') + ) + conn = psycopg2.connect(**conn_params) + cursor = conn.cursor() conn.set_client_encoding('utf8') - if hasattr(self, 'conn'): + + self._conn_params = conn_params + if self.conn: self.conn.close() self.conn = conn self.conn.autocommit = True - if dsn: - # When we connect using a DSN, we don't really know what db, - # user, etc. we connected to. Let's read it. - # Note: moved this after setting autocommit because of #664. - dsn_parameters = conn.get_dsn_parameters() - db = dsn_parameters['dbname'] - user = dsn_parameters['user'] - host = dsn_parameters['host'] - port = dsn_parameters['port'] - - self.dbname = db - self.user = user + # When we connect using a DSN, we don't really know what db, + # user, etc. we connected to. Let's read it. + # Note: moved this after setting autocommit because of #664. + # TODO: use actual connection info from psycopg2.extensions.Connection.info as psycopg>2.8 is available and required dependency # noqa + dsn_parameters = conn.get_dsn_parameters() + + self.dbname = dsn_parameters['dbname'] + self.user = dsn_parameters['user'] self.password = password - self.host = host - self.port = port + self.host = dsn_parameters['host'] + self.port = dsn_parameters['port'] + self.extra_args = kwargs if not self.host: self.host = self.get_socket_directory() @@ -272,10 +271,21 @@ class PGExecute(object): self.pid = pid self.superuser = db_parameters.get('is_superuser') == '1' + self.server_version = self.get_server_version(cursor) + register_date_typecasters(conn) register_json_typecasters(self.conn, self._json_typecaster) register_hstore_typecaster(self.conn) + @property + def short_host(self): + if ',' in self.host: + host, _, _ = self.host.partition(',') + else: + host = self.host + short_host, _, _ = host.partition('.') + return short_host + def _select_one(self, cur, sql): """ Helper method to run a select and retrieve a single field value @@ -629,10 +639,12 @@ class PGExecute(object): p.prokind = 'a' is_aggregate, p.prokind = 'w' is_window, p.proretset is_set_returning, + d.deptype = 'e' is_extension, pg_get_expr(proargdefaults, 0) AS arg_defaults FROM pg_catalog.pg_proc p INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace + LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e' WHERE p.prorettype::regtype != 'trigger'::regtype ORDER BY 1, 2 ''' @@ -647,10 +659,12 @@ class PGExecute(object): p.proisagg is_aggregate, p.proiswindow is_window, p.proretset is_set_returning, + d.deptype = 'e' is_extension, pg_get_expr(proargdefaults, 0) AS arg_defaults |