diff options
-rwxr-xr-x | pgcli/main.py | 319 | ||||
-rw-r--r-- | tests/test_main.py | 9 |
2 files changed, 188 insertions, 140 deletions
diff --git a/pgcli/main.py b/pgcli/main.py index 673a42c8..4cdd3169 100755 --- a/pgcli/main.py +++ b/pgcli/main.py @@ -57,8 +57,18 @@ from psycopg2 import OperationalError from collections import namedtuple # Query tuples are used for maintaining history -Query = namedtuple('Query', ['query', 'successful', 'mutating']) - +MetaQuery = namedtuple( + 'Query', + [ + 'query', # The entire text of the command + 'successful', # True If all subqueries were successful + 'total_time', # Time elapsed executing the query + 'meta_changed', # True if any subquery executed create/alter/drop + 'db_changed', # True if any subquery changed the database + 'path_changed', # True if any subquery changed the search path + 'mutated', # True if any subquery executed insert/update/delete + ]) +MetaQuery.__new__.__defaults__ = ('', False, 0, False, False, False, False) class PGCli(object): @@ -256,56 +266,18 @@ class PGCli(object): return document def run_cli(self): - pgexecute = self.pgexecute logger = self.logger original_less_opts = self.adjust_less_opts() self.refresh_completions() - def set_vi_mode(value): - self.vi_mode = value - - key_binding_manager = pgcli_bindings( - get_vi_mode_enabled=lambda: self.vi_mode, - set_vi_mode_enabled=set_vi_mode) + self.cli = self._build_cli() print('Version:', __version__) print('Chat: https://gitter.im/dbcli/pgcli') print('Mail: https://groups.google.com/forum/#!forum/pgcli') print('Home: http://pgcli.com') - def prompt_tokens(cli): - return [(Token.Prompt, '%s> ' % pgexecute.dbname)] - - get_toolbar_tokens = create_toolbar_tokens_func(lambda: self.vi_mode, - self.completion_refresher.is_refreshing) - - layout = create_default_layout(lexer=PostgresLexer, - reserve_space_for_menu=True, - get_prompt_tokens=prompt_tokens, - get_bottom_toolbar_tokens=get_toolbar_tokens, - display_completions_in_columns=self.wider_completion_menu, - multiline=True, - extra_input_processors=[ - # Highlight matching brackets while editing. - ConditionalProcessor( - processor=HighlightMatchingBracketProcessor(chars='[](){}'), - filter=HasFocus(DEFAULT_BUFFER) & ~IsDone()), - ]) - history_file = self.config['main']['history_file'] - with self._completer_lock: - buf = PGBuffer(always_multiline=self.multi_line, completer=self.completer, - history=FileHistory(os.path.expanduser(history_file)), - complete_while_typing=Always()) - - application = Application(style=style_factory(self.syntax_style, self.cli_style), - layout=layout, buffer=buf, - key_bindings_registry=key_binding_manager.registry, - on_exit=AbortAction.RAISE_EXCEPTION, - ignore_case=True) - self.cli = CommandLineInterface(application=application, - eventloop=create_eventloop()) - try: while True: document = self.cli.run() @@ -325,70 +297,22 @@ class PGCli(object): click.secho(str(e), err=True, fg='red') continue - # Keep track of whether or not the query is mutating. In case - # of a multi-statement query, the overall query is considered - # mutating if any one of the component statements is mutating - mutating = False + # Initialize default metaquery in case execution fails + query = MetaQuery(query=document.text, successful=False) try: - logger.debug('sql: %r', document.text) - successful = False - # Initialized to [] because res might never get initialized - # if an exception occurs in pgexecute.run(). Which causes - # finally clause to fail. - res = [] - # Run the query. - start = time() - on_error_resume = self.on_error == 'RESUME' - res = pgexecute.run(document.text, self.pgspecial, - exception_formatter, on_error_resume) - output = [] - total = 0 - for title, cur, headers, status in res: - logger.debug("headers: %r", headers) - logger.debug("rows: %r", cur) - logger.debug("status: %r", status) - threshold = 1000 - if (is_select(status) and - cur and cur.rowcount > threshold): - click.secho('The result set has more than %s rows.' - % threshold, fg='red') - if not click.confirm('Do you want to continue?'): - click.secho("Aborted!", err=True, fg='red') - break - - if self.pgspecial.auto_expand: - max_width = self.cli.output.get_size().columns - else: - max_width = None - - formatted = format_output(title, cur, headers, status, - self.table_format, - self.pgspecial.expanded_output, - max_width) - output.extend(formatted) - end = time() - total += end - start - mutating = mutating or is_mutating(status) - + output, query = self._evaluate_command(document.text) except KeyboardInterrupt: # Restart connection to the database - pgexecute.connect() + self.pgexecute.connect() logger.debug("cancelled query, sql: %r", document.text) click.secho("cancelled query", err=True, fg='red') except NotImplementedError: click.secho('Not Yet Implemented.', fg="yellow") except OperationalError as e: - reconnect = True - if ('server closed the connection' in utf8tounicode(e.args[0])): - reconnect = click.prompt('Connection reset. Reconnect (Y/n)', - show_default=False, type=bool, default=True) - if reconnect: - try: - pgexecute.connect() - click.secho('Reconnected!\nTry the command again.', fg='green') - except OperationalError as e: - click.secho(str(e), err=True, fg='red') + if ('server closed the connection' + in utf8tounicode(e.args[0])): + self._handle_server_closed_connection() else: logger.error("sql: %r, error: %r", document.text, e) logger.error("traceback: %r", traceback.format_exc()) @@ -398,26 +322,28 @@ class PGCli(object): logger.error("traceback: %r", traceback.format_exc()) click.secho(str(e), err=True, fg='red') else: - successful = True try: click.echo_via_pager('\n'.join(output)) except KeyboardInterrupt: pass - if self.pgspecial.timing_enabled: - print('Time: %0.03fs' % total) - # Refresh the table names and column names if necessary. - if need_completion_refresh(document.text): - self.refresh_completions(need_completion_reset(document.text)) - - # Refresh search_path to set default schema. - if need_search_path_refresh(document.text): + if self.pgspecial.timing_enabled: + print('Time: %0.03fs' % query.total_time) + + # Check if we need to update completions, in order of most + # to least drastic changes + if query.db_changed: + self.refresh_completions(reset=True) + elif query.meta_changed: + self.refresh_completions(reset=False) + elif query.path_changed: logger.debug('Refreshing search path') with self._completer_lock: - self.completer.set_search_path(pgexecute.search_path()) - logger.debug('Search path: %r', self.completer.search_path) + self.completer.set_search_path( + self.pgexecute.search_path()) + logger.debug('Search path: %r', + self.completer.search_path) - query = Query(document.text, successful, mutating) self.query_history.append(query) except EOFError: @@ -426,6 +352,133 @@ class PGCli(object): logger.debug('Restoring env var LESS to %r.', original_less_opts) os.environ['LESS'] = original_less_opts + def _build_cli(self): + + def set_vi_mode(value): + self.vi_mode = value + + key_binding_manager = pgcli_bindings( + get_vi_mode_enabled=lambda: self.vi_mode, + set_vi_mode_enabled=set_vi_mode) + + def prompt_tokens(_): + return [(Token.Prompt, '%s> ' % self.pgexecute.dbname)] + + get_toolbar_tokens = create_toolbar_tokens_func( + lambda: self.vi_mode, self.completion_refresher.is_refreshing) + + layout = create_default_layout( + lexer=PostgresLexer, + reserve_space_for_menu=True, + get_prompt_tokens=prompt_tokens, + get_bottom_toolbar_tokens=get_toolbar_tokens, + display_completions_in_columns=self.wider_completion_menu, + multiline=True, + extra_input_processors=[ + # Highlight matching brackets while editing. + ConditionalProcessor( + processor=HighlightMatchingBracketProcessor(chars='[](){}'), + filter=HasFocus(DEFAULT_BUFFER) & ~IsDone()), + ]) + + history_file = self.config['main']['history_file'] + with self._completer_lock: + buf = PGBuffer( + always_multiline=self.multi_line, + completer=self.completer, + history=FileHistory(os.path.expanduser(history_file)), + complete_while_typing=Always()) + + application = Application( + style=style_factory(self.syntax_style, self.cli_style), + layout=layout, + buffer=buf, + key_bindings_registry=key_binding_manager.registry, + on_exit=AbortAction.RAISE_EXCEPTION, + ignore_case=True) + + cli = CommandLineInterface( + application=application, + eventloop=create_eventloop()) + + return cli + + def _evaluate_command(self, text): + """Used to run a command entered by the user during CLI operation + (Puts the E in REPL) + + returns (results, MetaQuery) + """ + logger = self.logger + logger.debug('sql: %r', text) + + all_success = True + meta_changed = False # CREATE, ALTER, DROP, etc + mutated = False # INSERT, DELETE, etc + db_changed = False + path_changed = False + output = [] + total = 0 + + # Run the query. + start = time() + on_error_resume = self.on_error == 'RESUME' + res = self.pgexecute.run(text, self.pgspecial, + exception_formatter, on_error_resume) + + for title, cur, headers, status, sql, success in res: + logger.debug("headers: %r", headers) + logger.debug("rows: %r", cur) + logger.debug("status: %r", status) + threshold = 1000 + if (is_select(status) and + cur and cur.rowcount > threshold): + click.secho('The result set has more than %s rows.' + % threshold, fg='red') + if not click.confirm('Do you want to continue?'): + click.secho("Aborted!", err=True, fg='red') + break + + if self.pgspecial.auto_expand: + max_width = self.cli.output.get_size().columns + else: + max_width = None + + formatted = format_output( + title, cur, headers, status, self.table_format, + self.pgspecial.expanded_output, max_width) + + output.extend(formatted) + end = time() + total += end - start + + # Keep track of whether any of the queries are mutating or changing + # the database + if success: + mutated = mutated or is_mutating(status) + db_changed = db_changed or has_change_db_cmd(sql) + meta_changed = meta_changed or has_meta_cmd(sql) + path_changed = path_changed or has_change_path_cmd(sql) + else: + all_success = False + + meta_query = MetaQuery(text, all_success, total, meta_changed, + db_changed, path_changed, mutated) + + return output, meta_query + + def _handle_server_closed_connection(self): + """Used during CLI execution""" + reconnect = click.prompt( + 'Connection reset. Reconnect (Y/n)', + show_default=False, type=bool, default=True) + if reconnect: + try: + self.pgexecute.connect() + click.secho('Reconnected!\nTry the command again.', fg='green') + except OperationalError as e: + click.secho(str(e), err=True, fg='red') + def adjust_less_opts(self): less_opts = os.environ.get('LESS', '') self.logger.debug('Original value for LESS env var: %r', less_opts) @@ -534,6 +587,7 @@ def cli(database, user, host, port, prompt_passwd, never_prompt, dbname, pgcli.run_cli() + def obfuscate_process_password(): process_title = setproctitle.getproctitle() if '://' in process_title: @@ -564,53 +618,54 @@ def format_output(title, cur, headers, status, table_format, expanded=False, max output.append(status) return output -def need_completion_refresh(queries): + +def has_meta_cmd(query): """Determines if the completion needs a refresh by checking if the sql - statement is an alter, create, drop or change db.""" - for query in sqlparse.split(queries): - try: - first_token = query.split()[0] - if first_token.lower() in ('alter', 'create', 'use', '\\c', - '\\connect', 'drop'): - return True - except Exception: - return False + statement is an alter, create, or drop""" + try: + first_token = query.split()[0] + if first_token.lower() in ('alter', 'create', 'drop'): + return True + except Exception: + return False return False -def need_completion_reset(queries): - """Determines if the statement is a database switch such as 'use' or '\\c'. - When a database is changed the existing completions must be reset before we - start the completion refresh for the new database. - """ - for query in sqlparse.split(queries): - try: - first_token = query.split()[0] - if first_token.lower() in ('use', '\\c', '\\connect'): - return True - except Exception: - return False +def has_change_db_cmd(query): + """Determines if the statement is a database switch such as 'use' or '\\c'""" + try: + first_token = query.split()[0] + if first_token.lower() in ('use', '\\c', '\\connect'): + return True + except Exception: + return False -def need_search_path_refresh(sql): + return False + + +def has_change_path_cmd(sql): """Determines if the search_path should be refreshed by checking if the sql has 'set search_path'.""" return 'set search_path' in sql.lower() + def is_mutating(status): """Determines if the statement is mutating based on the status.""" if not status: return False - mutating = set(['insert', 'update', 'delete', 'alter', 'create', 'drop']) + mutating = set(['insert', 'update', 'delete']) return status.split(None, 1)[0].lower() in mutating + def is_select(status): """Returns true if the first word in status is 'select'.""" if not status: return False return status.split(None, 1)[0].lower() == 'select' + def quit_command(sql): return (sql.strip().lower() == 'exit' or sql.strip().lower() == 'quit' diff --git a/tests/test_main.py b/tests/test_main.py index f98aae5e..0d026994 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,19 +1,12 @@ import pytest import platform -from pgcli.main import need_completion_refresh, obfuscate_process_password try: import setproctitle except ImportError: setproctitle = None +from pgcli.main import obfuscate_process_password -@pytest.mark.parametrize('sql', [ - 'DROP TABLE foo', - 'SELECT * FROM foo; DROP TABLE foo', -]) -def test_need_completion_refresh(sql): - assert need_completion_refresh(sql) - @pytest.mark.skipif(platform.system() == 'Windows', reason='Not applicable in windows') @pytest.mark.skipif(not setproctitle, |