summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authordarikg <darikg@users.noreply.github.com>2015-06-16 19:40:57 -0400
committerdarikg <darikg@users.noreply.github.com>2015-06-16 19:40:57 -0400
commit0941ffc148ce65bd6c916f0ddc28f561cf2eec30 (patch)
treee1351c378c61156023ba7fa7f14f629337939301
parent1717186c640c40dea8997a849ad7b669c212562a (diff)
parentc576f47c0d9677a9fd8dcff25a2947eb40fd0bbc (diff)
Merge pull request #255 from dbcli/amjith/specials-decorator
Amjith/specials decorator
-rwxr-xr-xpgcli/main.py50
-rw-r--r--pgcli/packages/iospecial.py65
-rw-r--r--pgcli/packages/pgspecial/__init__.py10
-rw-r--r--pgcli/packages/pgspecial/dbcommands.py (renamed from pgcli/packages/pgspecial.py)185
-rw-r--r--pgcli/packages/pgspecial/iocommands.py149
-rw-r--r--pgcli/packages/pgspecial/main.py84
-rw-r--r--pgcli/packages/pgspecial/namedqueries.py (renamed from pgcli/packages/namedqueries.py)10
-rw-r--r--pgcli/packages/pgspecial/tests/conftest.py32
-rw-r--r--pgcli/packages/pgspecial/tests/dbutils.py68
-rw-r--r--pgcli/packages/pgspecial/tests/pytest.ini2
-rw-r--r--pgcli/packages/pgspecial/tests/test_specials.py60
-rw-r--r--pgcli/pgexecute.py56
12 files changed, 485 insertions, 286 deletions
diff --git a/pgcli/main.py b/pgcli/main.py
index 39e1b6fa..22bb1a3d 100755
--- a/pgcli/main.py
+++ b/pgcli/main.py
@@ -15,17 +15,16 @@ from prompt_toolkit.enums import DEFAULT_BUFFER
from prompt_toolkit.shortcuts import create_default_layout, create_eventloop
from prompt_toolkit.document import Document
from prompt_toolkit.filters import Always, HasFocus, IsDone
-from prompt_toolkit.layout.processors import HighlightMatchingBracketProcessor, ConditionalProcessor
+from prompt_toolkit.layout.processors import (ConditionalProcessor,
+ HighlightMatchingBracketProcessor)
from prompt_toolkit.history import FileHistory
from pygments.lexers.sql import PostgresLexer
from pygments.token import Token
from .packages.tabulate import tabulate
from .packages.expanded import expanded_table
-from .packages.pgspecial import (CASE_SENSITIVE_COMMANDS,
- NON_CASE_SENSITIVE_COMMANDS, is_expanded_output)
-import pgcli.packages.pgspecial as pgspecial
-import pgcli.packages.iospecial as iospecial
+from .packages.pgspecial.main import (COMMANDS)
+import pgcli.packages.pgspecial as special
from .pgcompleter import PGCompleter
from .pgtoolbar import create_toolbar_tokens_func
from .pgstyle import style_factory
@@ -67,7 +66,7 @@ class PGCli(object):
c = self.config = load_config('~/.pgclirc', default_config)
self.multi_line = c['main'].as_bool('multi_line')
self.vi_mode = c['main'].as_bool('vi')
- pgspecial.TIMING_ENABLED = c['main'].as_bool('timing')
+ special.set_timing(c['main'].as_bool('timing'))
self.table_format = c['main']['table_format']
self.syntax_style = c['main']['syntax_style']
@@ -79,9 +78,22 @@ class PGCli(object):
# Initialize completer
smart_completion = c['main'].as_bool('smart_completion')
completer = PGCompleter(smart_completion)
- completer.extend_special_commands(CASE_SENSITIVE_COMMANDS.keys())
- completer.extend_special_commands(NON_CASE_SENSITIVE_COMMANDS.keys())
self.completer = completer
+ self.register_special_commands()
+
+ def register_special_commands(self):
+ special.register_special_command(self.change_db, '\\c',
+ '\\c[onnect] database_name', 'Change to a new database.',
+ aliases=('use', '\\connect', 'USE'))
+
+ def change_db(self, pattern, **_):
+ if pattern is None:
+ self.pgexecute.connect()
+ else:
+ self.pgexecute.connect(database=pattern)
+
+ yield (None, None, None, 'You are now connected to database "%s" as '
+ 'user "%s"' % (self.pgexecute.dbname, self.pgexecute.user))
def initialize_logging(self):
@@ -119,12 +131,11 @@ class PGCli(object):
def connect(self, database='', host='', user='', port='', passwd=''):
# Connect to the database.
+ if not user:
+ user = getuser()
+
if not database:
- if user:
- database = user
- else:
- # default to current OS username just like psql
- database = user = getuser()
+ database = user
# Prompt for a password immediately if requested via the -W flag. This
# avoids wasting time trying to connect to the database and catching a
@@ -176,9 +187,9 @@ class PGCli(object):
:param document: Document
:return: Document
"""
- while iospecial.editor_command(document.text):
- filename = iospecial.get_filename(document.text)
- sql, message = iospecial.open_external_editor(filename,
+ while special.editor_command(document.text):
+ filename = special.get_filename(document.text)
+ sql, message = special.open_external_editor(filename,
sql=document.text)
if message:
# Something went wrong. Raise an exception and bail.
@@ -318,7 +329,7 @@ class PGCli(object):
click.secho(str(e), err=True, fg='red')
else:
click.echo_via_pager('\n'.join(output))
- if pgspecial.TIMING_ENABLED:
+ if special.get_timing():
print('Command Time: %0.03fs' % duration)
print('Format Time: %0.03fs' % total)
@@ -380,6 +391,9 @@ class PGCli(object):
# databases
completer.extend_database_names(pgexecute.databases())
+ # special commands
+ completer.extend_special_commands(COMMANDS.keys())
+
def get_completions(self, text, cursor_positition):
return self.completer.get_completions(
Document(text=text, cursor_position=cursor_positition), None)
@@ -433,7 +447,7 @@ def format_output(title, cur, headers, status, table_format):
output.append(title)
if cur:
headers = [utf8tounicode(x) for x in headers]
- if is_expanded_output():
+ if special.is_expanded_output():
output.append(expanded_table(cur, headers))
else:
output.append(tabulate(cur, headers, tablefmt=table_format,
diff --git a/pgcli/packages/iospecial.py b/pgcli/packages/iospecial.py
deleted file mode 100644
index e827c9d4..00000000
--- a/pgcli/packages/iospecial.py
+++ /dev/null
@@ -1,65 +0,0 @@
-from __future__ import print_function
-
-import re
-import logging
-from codecs import open
-import click
-
-_logger = logging.getLogger(__name__)
-
-
-def editor_command(command):
- """
- Is this an external editor command?
- :param command: string
- """
- # It is possible to have `\e filename` or `SELECT * FROM \e`. So we check
- # for both conditions.
- return command.strip().endswith('\\e') or command.strip().startswith('\\e')
-
-def get_filename(sql):
- if sql.strip().startswith('\\e'):
- command, _, filename = sql.partition(' ')
- return filename.strip() or None
-
-def open_external_editor(filename=None, sql=''):
- """
- Open external editor, wait for the user to type in his query,
- return the query.
- :return: list with one tuple, query as first element.
- """
-
- sql = sql.strip()
-
- # The reason we can't simply do .strip('\e') is that it strips characters,
- # not a substring. So it'll strip "e" in the end of the sql also!
- # Ex: "select * from style\e" -> "select * from styl".
- pattern = re.compile('(^\\\e|\\\e$)')
- while pattern.search(sql):
- sql = pattern.sub('', sql)
-
- message = None
- filename = filename.strip().split(' ', 1)[0] if filename else None
-
- MARKER = '# Type your query above this line.\n'
-
- # Populate the editor buffer with the partial sql (if available) and a
- # placeholder comment.
- query = click.edit(sql + '\n\n' + MARKER, filename=filename,
- extension='.sql')
-
- if filename:
- try:
- with open(filename, encoding='utf-8') as f:
- query = f.read()
- except IOError:
- message = 'Error reading file: %s.' % filename
-
- if query is not None:
- query = query.split(MARKER, 1)[0].rstrip('\n')
- else:
- # Don't return None for the caller to deal with.
- # Empty string is ok.
- query = sql
-
- return (query, message)
diff --git a/pgcli/packages/pgspecial/__init__.py b/pgcli/packages/pgspecial/__init__.py
new file mode 100644
index 00000000..92bcca6d
--- /dev/null
+++ b/pgcli/packages/pgspecial/__init__.py
@@ -0,0 +1,10 @@
+__all__ = []
+
+def export(defn):
+ """Decorator to explicitly mark functions that are exposed in a lib."""
+ globals()[defn.__name__] = defn
+ __all__.append(defn.__name__)
+ return defn
+
+from . import dbcommands
+from . import iocommands
diff --git a/pgcli/packages/pgspecial.py b/pgcli/packages/pgspecial/dbcommands.py
index 29163911..72e77e39 100644
--- a/pgcli/packages/pgspecial.py
+++ b/pgcli/packages/pgspecial/dbcommands.py
@@ -1,30 +1,24 @@
-from __future__ import print_function
-import sys
import logging
from collections import namedtuple
-from .tabulate import tabulate
-from .namedqueries import namedqueries
+from .main import special_command, RAW_QUERY
TableInfo = namedtuple("TableInfo", ['checks', 'relkind', 'hasindex',
'hasrules', 'hastriggers', 'hasoids', 'tablespace', 'reloptions', 'reloftype',
'relpersistence'])
-
log = logging.getLogger(__name__)
-use_expanded_output = False
-def is_expanded_output():
- return use_expanded_output
-
-TIMING_ENABLED = False
-
-def parse_special_command(sql):
- command, _, arg = sql.partition(' ')
- verbose = '+' in command
-
- command = command.strip().replace('+', '')
- return (command, verbose, arg.strip())
+@special_command('\\l', '\\l', 'List databases.', arg_type=RAW_QUERY)
+def list_databases(cur, **_):
+ query = 'SELECT datname FROM pg_database;'
+ cur.execute(query)
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ return [(None, cur, headers, cur.statusmessage)]
+ else:
+ return [(None, None, None, cur.statusmessage)]
+@special_command('\\du', '\\du[+] [pattern]', 'List roles.')
def list_roles(cur, pattern, verbose):
"""
Returns (title, rows, headers, status)
@@ -53,6 +47,7 @@ def list_roles(cur, pattern, verbose):
headers = [x[0] for x in cur.description]
return [(None, cur, headers, cur.statusmessage)]
+@special_command('\\dn', '\\dn[+] [pattern]', 'List schemas.')
def list_schemas(cur, pattern, verbose):
"""
Returns (title, rows, headers, status)
@@ -138,20 +133,24 @@ def list_objects(cur, pattern, verbose, relkinds):
return [(None, cur, headers, cur.statusmessage)]
+@special_command('\\dt', '\\dt[+] [pattern]', 'List tables.')
def list_tables(cur, pattern, verbose):
return list_objects(cur, pattern, verbose, ['r', ''])
+@special_command('\\dv', '\\dv[+] [pattern]', 'List views.')
def list_views(cur, pattern, verbose):
return list_objects(cur, pattern, verbose, ['v', 's', ''])
+@special_command('\\ds', '\\ds[+] [pattern]', 'List sequences.')
def list_sequences(cur, pattern, verbose):
return list_objects(cur, pattern, verbose, ['S', 's', ''])
+@special_command('\\di', '\\di[+] [pattern]', 'List indexes.')
def list_indexes(cur, pattern, verbose):
return list_objects(cur, pattern, verbose, ['i', 's', ''])
-
+@special_command('\\df', '\\df[+] [pattern]', 'List functions.')
def list_functions(cur, pattern, verbose):
if verbose:
@@ -216,7 +215,7 @@ def list_functions(cur, pattern, verbose):
headers = [x[0] for x in cur.description]
return [(None, cur, headers, cur.statusmessage)]
-
+@special_command('\\dT', '\\dT[S+] [pattern]', 'List data types')
def list_datatypes(cur, pattern, verbose):
assert True
sql = '''SELECT n.nspname as "Schema",
@@ -284,6 +283,8 @@ def list_datatypes(cur, pattern, verbose):
headers = [x[0] for x in cur.description]
return [(None, cur, headers, cur.statusmessage)]
+@special_command('describe', 'DESCRIBE [pattern]', '', hidden=True, case_sensitive=False)
+@special_command('\\d', '\\d [pattern]', 'List or describe tables, views and sequences.')
def describe_table_details(cur, pattern, verbose):
"""
Returns (title, rows, headers, status)
@@ -983,149 +984,3 @@ def sql_name_pattern(pattern):
schema = '^(' + schema + ')$'
return schema, relname
-
-def show_help(cur, arg, verbose): # All the parameters are ignored.
- headers = ['Command', 'Description']
- result = []
-
- for command, value in sorted(CASE_SENSITIVE_COMMANDS.items()):
- if value[1]:
- result.append(value[1])
- return [(None, result, headers, None)]
-
-
-def dummy_command(cur, arg, verbose):
- """
- Used by commands that are actually handled elsewhere.
- But we want to keep their docstrings in the same list
- as all the rest of commands.
- """
- raise NotImplementedError
-
-
-def in_progress(cur, arg, verbose):
- """
- Stub method to signal about commands being under development.
- """
- raise NotImplementedError
-
-def expanded_output(cur, arg, verbose):
- global use_expanded_output
- use_expanded_output = not use_expanded_output
- message = u"Expanded display is "
- message += u"on." if use_expanded_output else u"off."
- return [(None, None, None, message)]
-
-def toggle_timing(cur, arg, verbose):
- global TIMING_ENABLED
- TIMING_ENABLED = not TIMING_ENABLED
- message = "Timing is "
- message += "on." if TIMING_ENABLED else "off."
- return [(None, None, None, message)]
-
-def list_named_queries(cur, arg, verbose):
- """Returns (title, rows, headers, status)"""
- if not verbose:
- rows = [[r] for r in namedqueries.list()]
- headers = ["Name"]
- else:
- headers = ["Name", "Query"]
- rows = [[r, namedqueries.get(r)] for r in namedqueries.list()]
- return [('', rows, headers, "")]
-
-def save_named_query(cur, arg, verbose):
- """Returns (title, rows, headers, status)"""
- if ' ' not in arg:
- return [(None, None, None, "Invalid argument.")]
- name, query = arg.split(' ', 1)
- namedqueries.save(name, query)
- return [(None, None, None, "Saved.")]
-
-def delete_named_query(cur, arg, verbose):
- if len(arg) == 0:
- return [(None, None, None, "Invalid argument.")]
- namedqueries.delete(arg)
- return [(None, None, None, "Deleted.")]
-
-def execute_named_query(cur, arg, verbose):
- """Returns (title, rows, headers, status)"""
- if arg == '':
- return list_named_queries(cur, arg, verbose)
-
- query = namedqueries.get(arg)
- title = '> {}'.format(query)
- if query is None:
- message = "No named query: {}".format(arg)
- return [(None, None, None, message)]
- cur.execute(query)
- if cur.description:
- headers = [x[0] for x in cur.description]
- return [(title, cur, headers, cur.statusmessage)]
- else:
- return [(title, None, None, cur.statusmessage)]
-
-
-CASE_SENSITIVE_COMMANDS = {
- '\?': (show_help, ['\?', 'Help on pgcli commands.']),
- '\c': (dummy_command, ['\c database_name', 'Connect to a new database.']),
- '\l': ('''SELECT datname FROM pg_database;''', ['\l', 'List databases.']),
- '\d': (describe_table_details, ['\d [pattern]', 'List or describe tables, views and sequences.']),
- '\dn': (list_schemas, ['\dn[+] [pattern]', 'List schemas.']),
- '\du': (list_roles, ['\du[+] [pattern]', 'List roles.']),
- '\\x': (expanded_output, ['\\x', 'Toggle expanded output.']),
- '\\timing': (toggle_timing, ['\\timing', 'Toggle timing of commands.']),
- '\\dt': (list_tables, ['\\dt[+] [pattern]', 'List tables.']),
- '\\di': (list_indexes, ['\\di[+] [pattern]', 'List indexes.']),
- '\\dv': (list_views, ['\\dv[+] [pattern]', 'List views.']),
- '\\ds': (list_sequences, ['\\ds[+] [pattern]', 'List sequences.']),
- '\\df': (list_functions, ['\\df[+] [pattern]', 'List functions.']),
- '\\dT': (list_datatypes, ['\dT[S+] [pattern]', 'List data types']),
- '\e': (dummy_command, ['\e [file]', 'Edit the query buffer (or file) with external editor.']),
- '\ef': (in_progress, ['\ef [funcname [line]]', 'Not yet implemented.']),
- '\sf': (in_progress, ['\sf[+] funcname', 'Not yet implemented.']),
- '\z': (in_progress, ['\z [pattern]', 'Not yet implemented.']),
- '\do': (in_progress, ['\do[S] [pattern]', 'Not yet implemented.']),
- '\\n': (execute_named_query, ['\\n[+] [name]', 'List or execute named queries.']),
- '\\ns': (save_named_query, ['\\ns [name [query]]', 'Save a named query.']),
- '\\nd': (delete_named_query, ['\\nd [name]', 'Delete a named query.']),
- }
-
-NON_CASE_SENSITIVE_COMMANDS = {
- 'describe': (describe_table_details, ['DESCRIBE [pattern]', '']),
- }
-
-def execute(cur, sql):
- """Execute a special command and return the results. If the special command
- is not supported a KeyError will be raised.
- """
- command, verbose, arg = parse_special_command(sql)
-
- # Look up the command in the case-sensitive dict, if it's not there look in
- # non-case-sensitive dict. If not there either, throw a KeyError exception.
- global CASE_SENSITIVE_COMMANDS
- global NON_CASE_SENSITIVE_COMMANDS
- try:
- command_executor = CASE_SENSITIVE_COMMANDS[command][0]
- except KeyError:
- command_executor = NON_CASE_SENSITIVE_COMMANDS[command.lower()][0]
-
- # If the command executor is a function, then call the function with the
- # args. If it's a string, then assume it's an SQL command and run it.
- if callable(command_executor):
- return command_executor(cur, arg, verbose)
- elif isinstance(command_executor, str):
- cur.execute(command_executor)
- if cur.description:
- headers = [x[0] for x in cur.description]
- return [(None, cur, headers, cur.statusmessage)]
- else:
- return [(None, None, None, cur.statusmessage)]
-
-if __name__ == '__main__':
- import psycopg2
- con = psycopg2.connect(database='misago_testforum')
- cur = con.cursor()
- table = sys.argv[1]
- for rows, headers, status in describe_table_details(cur, table, False):
- print(tabulate(rows, headers, tablefmt='psql'))
- print(status)
diff --git a/pgcli/packages/pgspecial/iocommands.py b/pgcli/packages/pgspecial/iocommands.py
new file mode 100644
index 00000000..5b93e6ce
--- /dev/null
+++ b/pgcli/packages/pgspecial/iocommands.py
@@ -0,0 +1,149 @@
+import re
+import logging
+from codecs import open
+import click
+from .namedqueries import namedqueries
+from .main import special_command, NO_QUERY
+from . import export
+
+_logger = logging.getLogger(__name__)
+
+TIMING_ENABLED = True
+use_expanded_output = False
+
+@export
+def is_expanded_output():
+ return use_expanded_output
+
+@special_command('\\x', '\\x', 'Toggle expanded output.', arg_type=NO_QUERY)
+def toggle_expanded_output():
+ global use_expanded_output
+ use_expanded_output = not use_expanded_output
+ message = u"Expanded display is "
+ message += u"on." if use_expanded_output else u"off."
+ return [(None, None, None, message)]
+
+@special_command('\\timing', '\\timing', 'Toggle timing of commands.', arg_type=NO_QUERY)
+def toggle_timing():
+ global TIMING_ENABLED
+ TIMING_ENABLED = not TIMING_ENABLED
+ message = "Timing is "
+ message += "on." if TIMING_ENABLED else "off."
+ return [(None, None, None, message)]
+
+@export
+def set_timing(enable=True):
+ global TIMING_ENABLED
+ TIMING_ENABLED = enable
+
+@export
+def get_timing():
+ return TIMING_ENABLED
+
+
+@export
+def editor_command(command):
+ """
+ Is this an external editor command?
+ :param command: string
+ """
+ # It is possible to have `\e filename` or `SELECT * FROM \e`. So we check
+ # for both conditions.
+ return command.strip().endswith('\\e') or command.strip().startswith('\\e')
+
+@export
+def get_filename(sql):
+ if sql.strip().startswith('\\e'):
+ command, _, filename = sql.partition(' ')
+ return filename.strip() or None
+
+@export
+def open_external_editor(filename=None, sql=''):
+ """
+ Open external editor, wait for the user to type in his query,
+ return the query.
+ :return: list with one tuple, query as first element.
+ """
+
+ sql = sql.strip()
+
+ # The reason we can't simply do .strip('\e') is that it strips characters,
+ # not a substring. So it'll strip "e" in the end of the sql also!
+ # Ex: "select * from style\e" -> "select * from styl".
+ pattern = re.compile('(^\\\e|\\\e$)')
+ while pattern.search(sql):
+ sql = pattern.sub('', sql)
+
+ message = None
+ filename = filename.strip().split(' ', 1)[0] if filename else None
+
+ MARKER = '# Type your query above this line.\n'
+
+ # Populate the editor buffer with the partial sql (if available) and a
+ # placeholder comment.
+ query = click.edit(sql + '\n\n' + MARKER, filename=filename,
+ extension='.sql')
+
+ if filename:
+ try:
+ with open(filename, encoding='utf-8') as f:
+ query = f.read()
+ except IOError:
+ message = 'Error reading file: %s.' % filename
+
+ if query is not None:
+ query = query.split(MARKER, 1)[0].rstrip('\n')
+ else:
+ # Don't return None for the caller to deal with.
+ # Empty string is ok.
+ query = sql
+
+ return (query, message)
+
+@special_command('\\n', '\\n[+] [name]', 'List or execute named queries.')
+def execute_named_query(cur, pattern, verbose):
+ """Returns (title, rows, headers, status)"""
+ if pattern == '':
+ return list_named_queries(verbose)
+
+ query = namedqueries.get(pattern)
+ title = '> {}'.format(query)
+ if query is None:
+ message = "No named query: {}".format(pattern)
+ return [(None, None, None, message)]
+ cur.execute(query)
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ return [(title, cur, headers, cur.statusmessage)]
+ else:
+ return [(title, None, None, cur.statusmessage)]
+
+def list_named_queries(verbose):
+ """List of all named queries.
+ Returns (title, rows, headers, status)"""
+ if not verbose:
+ rows = [[r] for r in namedqueries.list()]
+ headers = ["Name"]
+ else:
+ headers = ["Name", "Query"]
+ rows = [[r, namedqueries.get(r)] for r in namedqueries.list()]
+ return [('', rows, headers, "")]
+
+@special_command('\\ns', '\\ns [name [query]]', 'Save a named query.')
+def save_named_query(pattern, **_):
+ """Save a new named query.
+ Returns (title, rows, headers, status)"""
+ if ' ' not in pattern:
+ return [(None, None, None, "Invalid argument.")]
+ name, query = pattern.split(' ', 1)
+ namedqueries.save(name, query)
+ return [(None, None, None, "Saved.")]
+
+@special_command('\\nd', '\\nd [name [query]]', 'Delete a named query.')
+def delete_named_query(pattern, **_):
+ """Delete an existing named query.
+ """
+ if len(pattern) == 0:
+ return [(None, None, None, "Invalid argument.")]
+ namedqueries.delete(pattern)
+ return [(None, None, None, "Deleted.")]
diff --git a/pgcli/packages/pgspecial/main.py b/pgcli/packages/pgspecial/main.py
new file mode 100644
index 00000000..3bee09f9
--- /dev/null
+++ b/pgcli/packages/pgspecial/main.py
@@ -0,0 +1,84 @@
+import logging
+from collections import namedtuple
+
+from . import export
+
+log = logging.getLogger(__name__)
+
+NO_QUERY = 0
+PARSED_QUERY = 1
+RAW_QUERY = 2
+
+SpecialCommand = namedtuple('SpecialCommand',
+ ['handler', 'syntax', 'description', 'arg_type', 'hidden', 'case_sensitive'])
+
+COMMANDS = {}
+
+@export
+def parse_special_command(sql):
+ command, _, arg = sql.partition(' ')
+ verbose = '+' in command
+
+ command = command.strip().replace('+', '')
+ return (command, verbose, arg.strip())
+
+@export
+def special_command(command, syntax, description, arg_type=PARSED_QUERY,
+ hidden=False, case_sensitive=True, aliases=()):
+ def wrapper(wrapped):
+ register_special_command(wrapped, command, syntax, description,
+ arg_type, hidden, case_sensitive, aliases)
+ return wrapped
+ return wrapper
+
+@export
+def register_special_command(handler, command, syntax, description,
+ arg_type=PARSED_QUERY, hidden=False, case_sensitive=True, aliases=()):
+ cmd = command.lower() if not case_sensitive else command
+ COMMANDS[cmd] = SpecialCommand(handler, syntax, description, arg_type,
+ hidden, case_sensitive)
+ for alias in aliases:
+ cmd = alias.lower() if not case_sensitive else alias
+ COMMANDS[cmd] = SpecialCommand(handler, syntax, description, arg_type,
+ case_sensitive=case_sensitive,
+ hidden=True)
+
+
+@export
+def execute(cur, sql):
+ command, verbose, pattern = parse_special_command(sql)
+ try:
+ special_cmd = COMMANDS[command]
+ except KeyError:
+ special_cmd = COMMANDS[command.lower()]
+ if special_cmd.case_sensitive:
+ raise KeyError('Command not found: %s' % command)
+
+ if special_cmd.arg_type == NO_QUERY:
+ return special_cmd.handler()
+ elif special_cmd.arg_type == PARSED_QUERY:
+ return special_cmd.handler(cur=cur, pattern=pattern, verbose=verbose)
+ elif special_cmd.arg_type == RAW_QUERY:
+ return special_cmd.handler(cur=cur, query=sql)
+
+@special_command('\\?', '\\?', 'Show Help.', arg_type=NO_QUERY)
+def show_help():
+ headers = ['Command', 'Description']
+ result = []
+
+ for _, value in sorted(COMMANDS.items()):
+ if not value.hidden:
+ result.append((value.syntax, value.description))
+ return [(None, result, headers, None)]
+
+@special_command('\\e', '\\e [file]', 'Edit the query with external editor.', arg_type=NO_QUERY)
+def doc_only():
+ raise RuntimeError
+
+@special_command('\\ef', '\\ef [funcname [line]]', 'Edit the contents of the query buffer.', arg_type=NO_QUERY, hidden=True)
+@special_command('\\sf', '\\sf[+] FUNCNAME', 'Show a function\'s definition.', arg_type=NO_QUERY, hidden=True)
+@special_command('\\do', '\\do[S] [pattern]', 'List operators.', arg_type=NO_QUERY, hidden=True)
+@special_command('\\dp', '\\dp [pattern]', 'List table, view, and sequence access privileges.', arg_type=NO_QUERY, hidden=True)
+@special_command('\\z', '\\z [pattern]', 'Same as \\dp.', arg_type=NO_QUERY, hidden=True)
+def place_holder():
+ raise NotImplementedError
diff --git a/pgcli/packages/namedqueries.py b/pgcli/packages/pgspecial/namedqueries.py
index e83519cf..0c5f2957 100644
--- a/pgcli/packages/namedqueries.py
+++ b/pgcli/packages/pgspecial/namedqueries.py
@@ -1,11 +1,9 @@
-from ..config import load_config
-
class NamedQueries(object):
section_name = 'named queries'
- def __init__(self, filename):
- self.config = load_config(filename)
+ def __init__(self, config):
+ self.config = config
def list(self):
return self.config.get(self.section_name, [])
@@ -24,5 +22,5 @@ class NamedQueries(object):
del self.config[self.section_name][name]
self.config.write()
-
-namedqueries = NamedQueries('~/.pgclirc')
+from ...config import load_config
+namedqueries = NamedQueries(load_config('~/.pgclirc'))
diff --git a/pgcli/packages/pgspecial/tests/conftest.py b/pgcli/packages/pgspecial/tests/conftest.py
new file mode 100644
index 00000000..11a3d6e0
--- /dev/null
+++ b/pgcli/packages/pgspecial/tests/conftest.py
@@ -0,0 +1,32 @@
+import pytest
+from dbutils import (create_db, db_connection, setup_db, teardown_db)
+from pgcli.packages.pgspecial.main import execute
+
+
+@pytest.yield_fixture(scope='module')
+def connection():
+ create_db('_test_db')
+ connection = db_connection('_test_db')
+ setup_db(connection)
+ yield connection
+
+ teardown_db(connection)
+ connection.close()
+
+
+@pytest.fixture
+def cursor(connection):
+ with connection.cursor() as cur:
+ return cur
+
+
+@pytest.fixture
+def executor(connection):
+ cur = connection.cursor()
+
+ def query_runner(sql):
+ results = []
+ for title, rows, headers, status in execute(cur=cur, sql=sql):
+ results.extend((title, list(rows), headers, status))
+ return results
+ return query_runner
diff --git a/pgcli/packages/pgspecial/tests/dbutils.py b/pgcli/packages/pgspecial/tests/dbutils.py
new file mode 100644
index 00000000..e417892d
--- /dev/null
+++ b/pgcli/packages/pgspecial/tests/dbutils.py
@@ -0,0 +1,68 @@
+import pytest
+import psycopg2
+import psycopg2.extras
+
+# TODO: should this be somehow be divined from environment?
+POSTGRES_USER, POSTGRES_HOST = 'postgres', 'localhost'
+
+
+def db_connection(dbname=None):
+ conn = psycopg2.connect(user=POSTGRES_USER, host=POSTGRES_HOST,
+ database=dbname)
+ conn.autocommit = True
+ return conn
+
+
+try:
+ conn = db_connection()
+ CAN_CONNECT_TO_DB = True
+ SERVER_VERSION = conn.server_version
+except:
+ CAN_CONNECT_TO_DB = False
+ SERVER_VERSION = 0
+
+
+dbtest = pytest.mark.skipif(
+ not CAN_CONNECT_TO_DB,
+ reason="Need a postgres instance at localhost accessible by user 'postgres'")
+
+def create_db(dbname):
+ with db_connection().cursor() as cur:
+ try:
+ cur.execute('''CREATE DATABASE _test_db''')
+ except:
+ pass
+
+def setup_db(conn):
+ with conn.cursor() as cur:
+ # schemas
+ cur.execute('create schema schema1')
+ cur.execute('create schema schema2')
+
+ # tables
+ cur.execute('create table tbl1(id1 integer, txt1 text)')
+ cur.execute('create table tbl2(id2 integer, txt2 text)')
+ cur.execute('create table schema1.s1_tbl1(id1 integer, txt1 text)')
+
+ # views
+ cur.execute('create view vw1 as select * from tbl1')
+ cur.execute('''create view schema1.s1_vw1 as select * from
+ schema1.s1_tbl1''')
+
+ # datatype
+ cur.execute('create type foo AS (a int, b text)')
+
+ # functions
+ cur.execute('''create function func1() returns int language sql as
+ $$select 1$$''')
+ cur.execute('''create function schema1.s1_func1() returns int language
+ sql as $$select 2$$''')
+
+
+def teardown_db(conn):
+ with conn.cursor() as cur:
+ cur.execute('''
+ DROP SCHEMA public CASCADE;
+ CREATE SCHEMA public;
+ DROP SCHEMA IF EXISTS schema1 CASCADE;
+ DROP SCHEMA IF EXISTS schema2 CASCADE''')
diff --git a/pgcli/packages/pgspecial/tests/pytest.ini b/pgcli/packages/pgspecial/tests/pytest.ini
new file mode 100644
index 00000000..f7877405
--- /dev/null
+++ b/pgcli/packages/pgspecial/tests/pytest.ini
@@ -0,0 +1,2 @@
+[pytest]
+addopts=--capture=sys --showlocals \ No newline at end of file
diff --git a/pgcli/packages/pgspecial/tests/test_specials.py b/pgcli/packages/pgspecial/tests/test_specials.py
new file mode 100644
index 00000000..889efad5
--- /dev/null
+++ b/pgcli/packages/pgspecial/tests/test_specials.py
@@ -0,0 +1,60 @@
+from dbutils import dbtest
+
+@dbtest
+def test_slash_d(executor):
+ results = executor('\d')
+ title = None
+ rows = [('public', 'tbl1', 'table', 'postgres'),
+ ('public', 'tbl2', 'table', 'postgres'),
+ ('public', 'vw1', 'view', 'postgres')]
+ headers = ['Schema', 'Name', 'Type', 'Owner']
+ status = 'SELECT 3'
+ expected = [title, rows, headers, status]
+ assert results == expected
+
+@dbtest
+def test_slash_dn(executor):
+ """List all schemas."""
+ results = executor('\dn')
+ title = None
+ rows = [('public', 'postgres'),
+ ('schema1', 'postgres'),
+ ('schema2', 'postgres')]
+ headers = ['Name', 'Owner']
+ status = 'SELECT 3'
+ expected = [title, rows, headers, status]
+ assert results == expected
+
+@dbtest
+def test_slash_dt(executor):
+ """List all tables in public schema."""
+ results = executor('\dt')
+ title = None
+ rows = [('public', 'tbl1', 'table', 'postgres'),
+ ('public', 'tbl2', 'table', 'postgres')]
+ headers = ['Schema', 'Name', 'Type', 'Owner']
+ status = 'SELECT 2'
+ expected = [title, rows, headers, status]
+ assert results == expected
+
+@dbtest
+def test_slash_dT(executor):
+ """List all datatypes."""
+ results = executor('\dT')
+ title = None
+ rows = [('public', 'foo', None)]
+ headers = ['Schema', 'Name', 'Description']
+ status = 'SELECT 1'
+ expected = [title, rows, headers, status]
+ assert results == expected
+
+@dbtest
+def test_slash_df(executor):
+ results = executor('\df')
+ title = None
+ rows = [('public', 'func1', 'integer', '', 'normal')]
+ headers = ['Schema', 'Name', 'Result data type', 'Argument data types',
+ 'Type']
+ status = 'SELECT 1'
+ expected = [title, rows, headers, status]
+ assert results == expected
diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py
index 14677e50..ed4290da 100644
--- a/pgcli/pgexecute.py
+++ b/pgcli/pgexecute.py
@@ -3,7 +3,7