diff options
author | Amjith Ramanujam <amjith.r@gmail.com> | 2015-07-01 17:14:09 -0700 |
---|---|---|
committer | Amjith Ramanujam <amjith.r@gmail.com> | 2015-07-01 17:14:09 -0700 |
commit | b245e8e0d378344fa24a4ea94e9098fa842387cc (patch) | |
tree | 5700642e4c7eeed9c386c70f09b518b1393dac07 | |
parent | adf9a455f65c92d4e7651dcd52a2a074ffdfaaf4 (diff) | |
parent | 0f9f034970f7c42545b49f25da8be74eb350d462 (diff) |
Merge pull request #273 from dbcli/darikg/specials-refactor-more
More specials refactoring
-rwxr-xr-x | pgcli/main.py | 40 | ||||
-rw-r--r-- | pgcli/packages/pgspecial/iocommands.py | 32 | ||||
-rw-r--r-- | pgcli/packages/pgspecial/main.py | 121 | ||||
-rw-r--r-- | pgcli/packages/pgspecial/tests/conftest.py | 5 | ||||
-rw-r--r-- | pgcli/pgexecute.py | 28 | ||||
-rw-r--r-- | tests/test_pgexecute.py | 37 | ||||
-rw-r--r-- | tests/utils.py | 7 |
7 files changed, 154 insertions, 116 deletions
diff --git a/pgcli/main.py b/pgcli/main.py index 6a6899d0..0c12319b 100755 --- a/pgcli/main.py +++ b/pgcli/main.py @@ -23,7 +23,7 @@ from pygments.token import Token from .packages.tabulate import tabulate from .packages.expanded import expanded_table -from .packages.pgspecial.main import (COMMANDS, NO_QUERY) +from .packages.pgspecial.main import (PGSpecial, NO_QUERY) import pgcli.packages.pgspecial as special from .pgcompleter import PGCompleter from .pgtoolbar import create_toolbar_tokens_func @@ -62,11 +62,13 @@ class PGCli(object): default_config = os.path.join(package_root, 'pgclirc') write_default_config(default_config, '~/.pgclirc') + self.pgspecial = PGSpecial() + # Load config. c = self.config = load_config('~/.pgclirc', default_config) self.multi_line = c['main'].as_bool('multi_line') self.vi_mode = c['main'].as_bool('vi') - special.set_timing(c['main'].as_bool('timing')) + self.pgspecial.timing_enabled = c['main'].as_bool('timing') self.table_format = c['main']['table_format'] self.syntax_style = c['main']['syntax_style'] @@ -82,13 +84,15 @@ class PGCli(object): self.register_special_commands() def register_special_commands(self): - special.register_special_command(self.change_db, '\\c', - '\\c[onnect] database_name', 'Change to a new database.', - aliases=('use', '\\connect', 'USE')) - special.register_special_command(self.refresh_completions, '\\#', - '\\#', 'Refresh auto-completions.', arg_type=NO_QUERY) - special.register_special_command(self.refresh_completions, '\\refresh', - '\\refresh', 'Refresh auto-completions.', arg_type=NO_QUERY) + + self.pgspecial.register(self.change_db, '\\c', + '\\c[onnect] database_name', + 'Change to a new database.', + aliases=('use', '\\connect', 'USE')) + self.pgspecial.register(self.refresh_completions, '\\#', '\\#', + 'Refresh auto-completions.', arg_type=NO_QUERY) + self.pgspecial.register(self.refresh_completions, '\\refresh', '\\refresh', + 'Refresh auto-completions.', arg_type=NO_QUERY) def change_db(self, pattern, **_): if pattern: @@ -287,7 +291,7 @@ class PGCli(object): res = [] start = time() # Run the query. - res = pgexecute.run(document.text) + res = pgexecute.run(document.text, self.pgspecial) duration = time() - start successful = True output = [] @@ -305,8 +309,11 @@ class PGCli(object): if not click.confirm('Do you want to continue?'): click.secho("Aborted!", err=True, fg='red') break - output.extend(format_output(title, cur, headers, - status, self.table_format)) + + formatted = format_output(title, cur, headers, status, + self.table_format, + self.pgspecial.expanded_output) + output.extend(formatted) end = time() total += end - start mutating = mutating or is_mutating(status) @@ -339,7 +346,7 @@ class PGCli(object): click.secho(str(e), err=True, fg='red') else: click.echo_via_pager('\n'.join(output)) - if special.get_timing(): + if self.pgspecial.timing_enabled: print('Command Time: %0.03fs' % duration) print('Format Time: %0.03fs' % total) @@ -397,7 +404,7 @@ class PGCli(object): completer.extend_database_names(pgexecute.databases()) # special commands - completer.extend_special_commands(COMMANDS.keys()) + completer.extend_special_commands(self.pgspecial.commands.keys()) return [(None, None, None, 'Auto-completions refreshed.')] @@ -405,6 +412,7 @@ class PGCli(object): return self.completer.get_completions( Document(text=text, cursor_position=cursor_positition), None) + @click.command() # Default host is '' so psycopg2 can default to either localhost or unix socket @click.option('-h', '--host', default='', envvar='PGHOST', @@ -448,13 +456,13 @@ def cli(database, user, host, port, prompt_passwd, never_prompt, dbname, pgcli.run_cli() -def format_output(title, cur, headers, status, table_format): +def format_output(title, cur, headers, status, table_format, expanded=False): output = [] if title: # Only print the title if it's not None. output.append(title) if cur: headers = [utf8tounicode(x) for x in headers] - if special.is_expanded_output(): + if expanded: output.append(expanded_table(cur, headers)) else: output.append(tabulate(cur, headers, tablefmt=table_format, diff --git a/pgcli/packages/pgspecial/iocommands.py b/pgcli/packages/pgspecial/iocommands.py index 49b5caae..7ef71e14 100644 --- a/pgcli/packages/pgspecial/iocommands.py +++ b/pgcli/packages/pgspecial/iocommands.py @@ -8,38 +8,6 @@ from . import export _logger = logging.getLogger(__name__) -TIMING_ENABLED = True -use_expanded_output = False - -@export -def is_expanded_output(): - return use_expanded_output - -@special_command('\\x', '\\x', 'Toggle expanded output.', arg_type=NO_QUERY) -def toggle_expanded_output(): - global use_expanded_output - use_expanded_output = not use_expanded_output - message = u"Expanded display is " - message += u"on." if use_expanded_output else u"off." - return [(None, None, None, message)] - -@special_command('\\timing', '\\timing', 'Toggle timing of commands.', arg_type=NO_QUERY) -def toggle_timing(): - global TIMING_ENABLED - TIMING_ENABLED = not TIMING_ENABLED - message = "Timing is " - message += "on." if TIMING_ENABLED else "off." - return [(None, None, None, message)] - -@export -def set_timing(enable=True): - global TIMING_ENABLED - TIMING_ENABLED = enable - -@export -def get_timing(): - return TIMING_ENABLED - @export def editor_command(command): diff --git a/pgcli/packages/pgspecial/main.py b/pgcli/packages/pgspecial/main.py index 3fd5161a..31df8872 100644 --- a/pgcli/packages/pgspecial/main.py +++ b/pgcli/packages/pgspecial/main.py @@ -12,12 +12,81 @@ RAW_QUERY = 2 SpecialCommand = namedtuple('SpecialCommand', ['handler', 'syntax', 'description', 'arg_type', 'hidden', 'case_sensitive']) -COMMANDS = {} - @export class CommandNotFound(Exception): pass + +@export +class PGSpecial(object): + + # Default static commands that don't rely on PGSpecial state are registered + # via the special_command decorator and stored in default_commands + default_commands = {} + + def __init__(self): + self.timing_enabled = True + + self.commands = self.default_commands.copy() + + self.timing_enabled = False + self.expanded_output = False + + self.register(self.show_help, '\\?', '\\?', 'Show Help.', + arg_type=NO_QUERY) + + self.register(self.toggle_expanded_output, '\\x', '\\x', + 'Toggle expanded output.', arg_type=NO_QUERY) + + self.register(self.toggle_timing, '\\timing', '\\timing', + 'Toggle timing of commands.', arg_type=NO_QUERY) + + def register(self, *args, **kwargs): + register_special_command(*args, command_dict=self.commands, **kwargs) + + def execute(self, cur, sql): + commands = self.commands + command, verbose, pattern = parse_special_command(sql) + + if (command not in commands) and (command.lower() not in commands): + raise CommandNotFound + + try: + special_cmd = commands[command] + except KeyError: + special_cmd = commands[command.lower()] + if special_cmd.case_sensitive: + raise CommandNotFound('Command not found: %s' % command) + + if special_cmd.arg_type == NO_QUERY: + return special_cmd.handler() + elif special_cmd.arg_type == PARSED_QUERY: + return special_cmd.handler(cur=cur, pattern=pattern, verbose=verbose) + elif special_cmd.arg_type == RAW_QUERY: + return special_cmd.handler(cur=cur, query=sql) + + def show_help(self): + headers = ['Command', 'Description'] + result = [] + + for _, value in sorted(self.commands.items()): + if not value.hidden: + result.append((value.syntax, value.description)) + return [(None, result, headers, None)] + + def toggle_expanded_output(self): + self.expanded_output = not self.expanded_output + message = u"Expanded display is " + message += u"on." if self.expanded_output else u"off." + return [(None, None, None, message)] + + def toggle_timing(self): + self.timing_enabled = not self.timing_enabled + message = "Timing is " + message += "on." if self.timing_enabled else "off." + return [(None, None, None, message)] + + @export def parse_special_command(sql): command, _, arg = sql.partition(' ') @@ -26,59 +95,33 @@ def parse_special_command(sql): command = command.strip().replace('+', '') return (command, verbose, arg.strip()) -@export + def special_command(command, syntax, description, arg_type=PARSED_QUERY, hidden=False, case_sensitive=True, aliases=()): + """A decorator used internally for static special commands""" + def wrapper(wrapped): register_special_command(wrapped, command, syntax, description, - arg_type, hidden, case_sensitive, aliases) + arg_type, hidden, case_sensitive, aliases, + command_dict=PGSpecial.default_commands) return wrapped return wrapper -@export + def register_special_command(handler, command, syntax, description, - arg_type=PARSED_QUERY, hidden=False, case_sensitive=True, aliases=()): + arg_type=PARSED_QUERY, hidden=False, case_sensitive=True, aliases=(), + command_dict=None): + cmd = command.lower() if not case_sensitive else command - COMMANDS[cmd] = SpecialCommand(handler, syntax, description, arg_type, + command_dict[cmd] = SpecialCommand(handler, syntax, description, arg_type, hidden, case_sensitive) for alias in aliases: cmd = alias.lower() if not case_sensitive else alias - COMMANDS[cmd] = SpecialCommand(handler, syntax, description, arg_type, + command_dict[cmd] = SpecialCommand(handler, syntax, description, arg_type, case_sensitive=case_sensitive, hidden=True) -@export -def execute(cur, sql): - command, verbose, pattern = parse_special_command(sql) - - if (command not in COMMANDS) and (command.lower() not in COMMANDS): - raise CommandNotFound - - try: - special_cmd = COMMANDS[command] - except KeyError: - special_cmd = COMMANDS[command.lower()] - if special_cmd.case_sensitive: - raise CommandNotFound('Command not found: %s' % command) - - if special_cmd.arg_type == NO_QUERY: - return special_cmd.handler() - elif special_cmd.arg_type == PARSED_QUERY: - return special_cmd.handler(cur=cur, pattern=pattern, verbose=verbose) - elif special_cmd.arg_type == RAW_QUERY: - return special_cmd.handler(cur=cur, query=sql) - -@special_command('\\?', '\\?', 'Show Help.', arg_type=NO_QUERY) -def show_help(): - headers = ['Command', 'Description'] - result = [] - - for _, value in sorted(COMMANDS.items()): - if not value.hidden: - result.append((value.syntax, value.description)) - return [(None, result, headers, None)] - @special_command('\\e', '\\e [file]', 'Edit the query with external editor.', arg_type=NO_QUERY) def doc_only(): raise RuntimeError diff --git a/pgcli/packages/pgspecial/tests/conftest.py b/pgcli/packages/pgspecial/tests/conftest.py index 11a3d6e0..e7f1f66c 100644 --- a/pgcli/packages/pgspecial/tests/conftest.py +++ b/pgcli/packages/pgspecial/tests/conftest.py @@ -1,6 +1,6 @@ import pytest from dbutils import (create_db, db_connection, setup_db, teardown_db) -from pgcli.packages.pgspecial.main import execute +from pgcli.packages.pgspecial import PGSpecial @pytest.yield_fixture(scope='module') @@ -23,10 +23,11 @@ def cursor(connection): @pytest.fixture def executor(connection): cur = connection.cursor() + pgspecial = PGSpecial() def query_runner(sql): results = [] - for title, rows, headers, status in execute(cur=cur, sql=sql): + for title, rows, headers, status in pgspecial.execute(cur=cur, sql=sql): results.extend((title, list(rows), headers, status)) return results return query_runner diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py index 535d2c5a..1b27bba3 100644 --- a/pgcli/pgexecute.py +++ b/pgcli/pgexecute.py @@ -188,9 +188,12 @@ class PGExecute(object): else: return json_data - def run(self, statement): - """Execute the sql in the database and return the results. The results - are a list of tuples. Each tuple has 4 values (title, rows, headers, status). + def run(self, statement, pgspecial=None): + """Execute the sql in the database and return the results. + + :param statement: A string containing one or more sql statements + :param pgspecial: PGSpecial object + :return: List of tuples containing (title, rows, headers, status) """ # Remove spaces and EOL @@ -203,13 +206,18 @@ class PGExecute(object): # Remove spaces, eol and semi-colons. sql = sql.rstrip(';') - try: # Special command - _logger.debug('Trying a pgspecial command. sql: %r', sql) - cur = self.conn.cursor() - for result in special.execute(cur, sql): - yield result - except special.CommandNotFound: # Regular SQL - yield self.execute_normal_sql(sql) + if pgspecial: + # First try to run each query as special + try: + _logger.debug('Trying a pgspecial command. sql: %r', sql) + cur = self.conn.cursor() + for result in pgspecial.execute(cur, sql): + yield result + return + except special.CommandNotFound: + pass + + yield self.execute_normal_sql(sql) def execute_normal_sql(self, split_sql): _logger.debug('Regular sql statement. sql: %r', split_sql) diff --git a/tests/test_pgexecute.py b/tests/test_pgexecute.py index c8dee913..355096b4 100644 --- a/tests/test_pgexecute.py +++ b/tests/test_pgexecute.py @@ -2,6 +2,7 @@ import pytest import psycopg2 +from pgcli.packages.pgspecial import PGSpecial from textwrap import dedent from utils import run, dbtest, requires_json, requires_jsonb @@ -96,13 +97,11 @@ def test_invalid_column_name(executor): run(executor, 'select invalid command') assert 'column "invalid" does not exist' in str(excinfo.value) -@pytest.yield_fixture(params=[True, False]) -def expanded(request, executor): - if request.param: - run(executor, '\\x') - yield request.param - if request.param: - run(executor, '\\x') + +@pytest.fixture(params=[True, False]) +def expanded(request): + return request.param + @dbtest def test_unicode_support_in_output(executor, expanded): @@ -110,7 +109,9 @@ def test_unicode_support_in_output(executor, expanded): run(executor, "insert into unicodechars (t) values ('é')") # See issue #24, this raises an exception without proper handling - assert u'é' in run(executor, "select * from unicodechars", join=True) + assert u'é' in run(executor, "select * from unicodechars", + join=True, expanded=expanded) + @dbtest def test_multiple_queries_same_line(executor): @@ -120,8 +121,8 @@ def test_multiple_queries_same_line(executor): assert "bar" in result[2] @dbtest -def test_multiple_queries_with_special_command_same_line(executor): - result = run(executor, "select 'foo'; \d") +def test_multiple_queries_with_special_command_same_line(executor, pgspecial): + result = run(executor, "select 'foo'; \d", pgspecial=pgspecial) assert len(result) == 4 # 2 * (output+status) assert "foo" in result[0] # This is a lame check. :( @@ -133,9 +134,15 @@ def test_multiple_queries_same_line_syntaxerror(executor): run(executor, "select 'foo'; invalid syntax") assert 'syntax error at or near "invalid"' in str(excinfo.value) + +@pytest.fixture +def pgspecial(): + return PGSpecial() + + @dbtest -def test_special_command_help(executor): - result = run(executor, '\\?')[0].split('|') +def test_special_command_help(executor, pgspecial): + result = run(executor, '\\?', pgspecial=pgspecial)[0].split('|') assert(result[1].find(u'Command') != -1) assert(result[2].find(u'Description') != -1) @@ -158,7 +165,8 @@ def test_unicode_support_in_unknown_type(executor): def test_json_renders_without_u_prefix(executor, expanded): run(executor, "create table jsontest(d json)") run(executor, """insert into jsontest (d) values ('{"name": "Éowyn"}')""") - result = run(executor, "SELECT d FROM jsontest LIMIT 1", join=True) + result = run(executor, "SELECT d FROM jsontest LIMIT 1", + join=True, expanded=expanded) assert u'{"name": "Éowyn"}' in result @@ -167,7 +175,8 @@ def test_json_renders_without_u_prefix(executor, expanded): def test_jsonb_renders_without_u_prefix(executor, expanded): run(executor, "create table jsonbtest(d jsonb)") run(executor, """insert into jsonbtest (d) values ('{"name": "Éowyn"}')""") - result = run(executor, "SELECT d FROM jsonbtest LIMIT 1", join=True) + result = run(executor, "SELECT d FROM jsonbtest LIMIT 1", + join=True, expanded=expanded) assert u'{"name": "Éowyn"}' in result diff --git a/tests/utils.py b/tests/utils.py index 5cdc8849..93e9f258 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -58,11 +58,12 @@ def drop_tables(conn): DROP SCHEMA IF EXISTS schema2 CASCADE''') -def run(executor, sql, join=False): +def run(executor, sql, join=False, expanded=False, pgspecial=None): " Return string output for the sql to be run " result = [] - for title, rows, headers, status in executor.run(sql): - result.extend(format_output(title, rows, headers, status, 'psql')) + for title, rows, headers, status in executor.run(sql, pgspecial): + result.extend(format_output(title, rows, headers, status, 'psql', + expanded=expanded)) if join: result = '\n'.join(result) return result |