summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAmjith Ramanujam <amjith.r@gmail.com>2015-01-26 20:26:52 -0800
committerAmjith Ramanujam <amjith.r@gmail.com>2015-01-26 20:26:52 -0800
commit6944ef60f83e7e116aa687a972b5402f1a1bbe04 (patch)
tree593d5c212b78f3de496c16044de42a565fed60de
parent750206c779060ea8b1cf19e2058a87cd536363e2 (diff)
parent7d3f276e8330854e0153d25e02df1b103529e640 (diff)
Merge pull request #127 from darikg/schema_autocomplete
Make autocomplete schema-aware
-rwxr-xr-xpgcli/main.py40
-rw-r--r--pgcli/packages/parseutils.py26
-rw-r--r--pgcli/packages/sqlcompletion.py65
-rw-r--r--pgcli/pgcompleter.py187
-rw-r--r--pgcli/pgexecute.py76
-rw-r--r--setup.py2
-rw-r--r--tests/test_parseutils.py68
-rw-r--r--tests/test_pgexecute.py18
-rw-r--r--tests/test_pgspecial.py19
-rw-r--r--tests/test_smart_completion_multiple_schemata.py227
-rw-r--r--tests/test_smart_completion_public_schema_only.py (renamed from tests/test_smart_completion.py)46
-rw-r--r--tests/test_sqlcompletion.py170
-rw-r--r--tests/utils.py6
13 files changed, 739 insertions, 211 deletions
diff --git a/pgcli/main.py b/pgcli/main.py
index 7ed04c39..f47f7888 100755
--- a/pgcli/main.py
+++ b/pgcli/main.py
@@ -214,6 +214,12 @@ class PGCli(object):
end = time()
total += end - start
mutating = mutating or is_mutating(status)
+
+ if need_search_path_refresh(document.text, status):
+ logger.debug('Refreshing search path')
+ completer.set_search_path(pgexecute.search_path())
+ logger.debug('Search path: %r', completer.search_path)
+
except KeyboardInterrupt:
# Restart connection to the database
pgexecute.connect()
@@ -262,13 +268,16 @@ class PGCli(object):
return less_opts
def refresh_completions(self):
- self.completer.reset_completions()
- tables, columns = self.pgexecute.tables()
- self.completer.extend_table_names(tables)
- for table in tables:
- table = table[1:-1] if table[0] == '"' and table[-1] == '"' else table
- self.completer.extend_column_names(table, columns[table])
- self.completer.extend_database_names(self.pgexecute.databases())
+ completer = self.completer
+ completer.reset_completions()
+
+ pgexecute = self.pgexecute
+
+ completer.set_search_path(pgexecute.search_path())
+ completer.extend_schemata(pgexecute.schemata())
+ completer.extend_tables(pgexecute.tables())
+ completer.extend_columns(pgexecute.columns())
+ completer.extend_database_names(pgexecute.databases())
def get_completions(self, text, cursor_positition):
return self.completer.get_completions(
@@ -329,6 +338,22 @@ def need_completion_refresh(sql):
except Exception:
return False
+def need_search_path_refresh(sql, status):
+ # note that sql may be a multi-command query, but status belongs to an
+ # individual query, since pgexecute handles splitting up multi-commands
+ try:
+ status = status.split()[0]
+ if status.lower() == 'set':
+ # Since sql could be a multi-line query, it's hard to robustly
+ # pick out the variable name that's been set. Err on the side of
+ # false positives here, since the worst case is we refresh the
+ # search path when it's not necessary
+ return 'search_path' in sql.lower()
+ else:
+ return False
+ except Exception:
+ return False
+
def is_mutating(status):
"""Determines if the statement is mutating based on the status."""
if not status:
@@ -349,6 +374,5 @@ def quit_command(sql):
or sql.strip() == '\q'
or sql.strip() == ':q')
-
if __name__ == "__main__":
cli()
diff --git a/pgcli/packages/parseutils.py b/pgcli/packages/parseutils.py
index 370bb4db..4122a332 100644
--- a/pgcli/packages/parseutils.py
+++ b/pgcli/packages/parseutils.py
@@ -101,50 +101,50 @@ def extract_from_part(parsed, stop_at_punctuation=True):
break
def extract_table_identifiers(token_stream):
+ """yields tuples of (schema_name, table_name, table_alias)"""
+
for item in token_stream:
if isinstance(item, IdentifierList):
for identifier in item.get_identifiers():
# Sometimes Keywords (such as FROM ) are classified as
# identifiers which don't have the get_real_name() method.
try:
+ schema_name = identifier.get_parent_name()
real_name = identifier.get_real_name()
except AttributeError:
continue
if real_name:
- yield (real_name, identifier.get_alias() or real_name)
+ yield (schema_name, real_name, identifier.get_alias())
elif isinstance(item, Identifier):
real_name = item.get_real_name()
+ schema_name = item.get_parent_name()
if real_name:
- yield (real_name, item.get_alias() or real_name)
+ yield (schema_name, real_name, item.get_alias())
else:
name = item.get_name()
- yield (name, item.get_alias() or name)
+ yield (None, name, item.get_alias() or name)
elif isinstance(item, Function):
- yield (item.get_name(), item.get_name())
+ yield (None, item.get_name(), item.get_name())
# extract_tables is inspired from examples in the sqlparse lib.
-def extract_tables(sql, include_alias=False):
+def extract_tables(sql):
"""Extract the table names from an SQL statment.
- Returns a list of table names if include_alias=False (default).
- If include_alias=True, then a dictionary is returned where the keys are
- aliases and values are real table names.
+ Returns a list of (schema, table, alias) tuples
"""
parsed = sqlparse.parse(sql)
if not parsed:
return []
+
# INSERT statements must stop looking for tables at the sign of first
# Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
# abc is the table name, but if we don't stop at the first lparen, then
# we'll identify abc, col1 and col2 as table names.
insert_stmt = parsed[0].token_first().value.lower() == 'insert'
stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
- if include_alias:
- return dict((alias, t) for t, alias in extract_table_identifiers(stream))
- else:
- return [x[0] for x in extract_table_identifiers(stream)]
+ return list(extract_table_identifiers(stream))
def find_prev_keyword(sql):
if not sql.strip():
@@ -156,4 +156,4 @@ def find_prev_keyword(sql):
if __name__ == '__main__':
sql = 'select * from (select t. from tabl t'
- print (extract_tables(sql, True))
+ print (extract_tables(sql))
diff --git a/pgcli/packages/sqlcompletion.py b/pgcli/packages/sqlcompletion.py
index f409dae0..45e38fed 100644
--- a/pgcli/packages/sqlcompletion.py
+++ b/pgcli/packages/sqlcompletion.py
@@ -66,28 +66,65 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text):
# If the lparen is preceeded by a space chances are we're about to
# do a sub-select.
if last_word(text_before_cursor, 'all_punctuations').startswith('('):
- return 'keywords', []
- return 'columns', extract_tables(full_text)
+ return [{'type': 'keyword'}]
+
+ return [{'type': 'column', 'tables': extract_tables(full_text)}]
+
if token_v.lower() in ('set', 'by', 'distinct'):
- return 'columns', extract_tables(full_text)
+ return [{'type': 'column', 'tables': extract_tables(full_text)}]
elif token_v.lower() in ('select', 'where', 'having'):
- return 'columns-and-functions', extract_tables(full_text)
+ return [{'type': 'column', 'tables': extract_tables(full_text)},
+ {'type': 'function'}]
elif token_v.lower() in ('from', 'update', 'into', 'describe', 'join', 'table'):
- return 'tables', []
+ return [{'type': 'schema'}, {'type': 'table', 'schema': []}]
elif token_v.lower() == 'on':
- tables = extract_tables(full_text, include_alias=True)
- return 'tables-or-aliases', tables.keys()
+ tables = extract_tables(full_text) # [(schema, table, alias), ...]
+
+ # Use table alias if there is one, otherwise the table name
+ alias = [t[2] or t[1] for t in tables]
+
+ return [{'type': 'alias', 'aliases': alias}]
+
elif token_v in ('d',): # \d
- return 'tables', []
+ # Apparently "\d <other>" is parsed by sqlparse as
+ # Identifer('d', Whitespace, '<other>')
+ if len(token.tokens) > 2:
+ other = token.tokens[-1].value
+ identifiers = other.split('.')
+ if len(identifiers) == 1:
+ # "\d table" or "\d schema"
+ return [{'type': 'schema'}, {'type': 'table', 'schema': []}]
+ elif len(identifiers) == 2:
+ # \d schema.table
+ return [{'type': 'table', 'schema': identifiers[0]}]
+ else:
+ return [{'type': 'schema'}, {'type': 'table', 'schema': []}]
elif token_v.lower() in ('c', 'use'): # \c
- return 'databases', []
+ return [{'type': 'database'}]
elif token_v.endswith(',') or token_v == '=':
prev_keyword = find_prev_keyword(text_before_cursor)
if prev_keyword:
- return suggest_based_on_last_token(prev_keyword, text_before_cursor, full_text)
+ return suggest_based_on_last_token(
+ prev_keyword, text_before_cursor, full_text)
elif token_v.endswith('.'):
- current_alias = last_word(token_v[:-1])
- tables = extract_tables(full_text, include_alias=True)
- return 'columns', [tables.get(current_alias) or current_alias]
- return 'keywords', []
+ suggestions = []
+
+ identifier = last_word(token_v[:-1], 'all_punctuations')
+
+ # TABLE.<suggestion> or SCHEMA.TABLE.<suggestion>
+ tables = extract_tables(full_text)
+ tables = [t for t in tables if identifies(identifier, *t)]
+ suggestions.append({'type': 'column', 'tables': tables})
+
+ # SCHEMA.<suggestion>
+ suggestions.append({'type': 'table', 'schema': identifier})
+
+ return suggestions
+
+ return [{'type': 'keyword'}]
+
+
+def identifies(id, schema, table, alias):
+ return id == alias or id == table or (
+ schema and (id == schema + '.' + table))
diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py
index a8123a72..3585abeb 100644
--- a/pgcli/pgcompleter.py
+++ b/pgcli/pgcompleter.py
@@ -1,11 +1,11 @@
from __future__ import print_function
import logging
-from collections import defaultdict
from prompt_toolkit.completion import Completer, Completion
from .packages.sqlcompletion import suggest_type
from .packages.parseutils import last_word
from re import compile
+
_logger = logging.getLogger(__name__)
class PGCompleter(Completer):
@@ -21,7 +21,7 @@ class PGCompleter(Completer):
'MAXEXTENTS', 'MINUS', 'MLSLABEL', 'MODE', 'MODIFY', 'NOAUDIT',
'NOCOMPRESS', 'NOT', 'NOWAIT', 'NULL', 'NUMBER', 'OF', 'OFFLINE',
'ON', 'ONLINE', 'OPTION', 'OR', 'ORDER BY', 'OUTER', 'PCTFREE',
- 'PRIMARY', 'PRIOR', 'PRIVILEGES', 'PUBLIC', 'RAW', 'RENAME',
+ 'PRIMARY', 'PRIOR', 'PRIVILEGES', 'RAW', 'RENAME',
'RESOURCE', 'REVOKE', 'RIGHT', 'ROW', 'ROWID', 'ROWNUM', 'ROWS',
'SELECT', 'SESSION', 'SET', 'SHARE', 'SIZE', 'SMALLINT', 'START',
'SUCCESSFUL', 'SYNONYM', 'SYSDATE', 'TABLE', 'THEN', 'TO',
@@ -33,15 +33,6 @@ class PGCompleter(Completer):
'LCASE', 'LEN', 'MAX', 'MIN', 'MID', 'NOW', 'ROUND', 'SUM', 'TOP',
'UCASE']
- special_commands = []
-
- databases = []
- tables = []
- # This will create a defaultdict which is initialized with a list that has
- # a '*' by default.
- columns = defaultdict(lambda: ['*'])
- all_completions = set(keywords + functions)
-
def __init__(self, smart_completion=True):
super(self.__class__, self).__init__()
self.smart_completion = smart_completion
@@ -50,8 +41,15 @@ class PGCompleter(Completer):
self.reserved_words.update(x.split())
self.name_pattern = compile("^[_a-z][_a-z0-9\$]*$")
+ self.special_commands = []
+ self.databases = []
+ self.dbmetadata = {}
+ self.search_path = []
+
+ self.all_completions = set(self.keywords + self.functions)
+
def escape_name(self, name):
- if ((not self.name_pattern.match(name))
+ if name and ((not self.name_pattern.match(name))
or (name.upper() in self.reserved_words)
or (name.upper() in self.functions)):
name = '"%s"' % name
@@ -60,7 +58,7 @@ class PGCompleter(Completer):
def unescape_name(self, name):
""" Unquote a string."""
- if name[0] == '"' and name[-1] == '"':
+ if name and name[0] == '"' and name[-1] == '"':
name = name[1:-1]
return name
@@ -75,31 +73,50 @@ class PGCompleter(Completer):
def extend_database_names(self, databases):
databases = self.escaped_names(databases)
-
self.databases.extend(databases)
def extend_keywords(self, additional_keywords):
self.keywords.extend(additional_keywords)
self.all_completions.update(additional_keywords)
- def extend_table_names(self, tables):
- tables = self.escaped_names(tables)
+ def extend_schemata(self, schemata):
+
+ # data is a DataFrame with columns [schema]
+ schemata = self.escaped_names(schemata)
+ for schema in schemata:
+ self.dbmetadata[schema] = {}
+
+ self.all_completions.update(schemata)
- self.tables.extend(tables)
- self.all_completions.update(tables)
+ def extend_tables(self, table_data):
- def extend_column_names(self, table, columns):
- columns = self.escaped_names(columns)
+ # table_data is a list of (schema_name, table_name) tuples
+ table_data = [self.escaped_names(d) for d in table_data]
- unescaped_table_name = self.unescape_name(table)
+ # dbmetadata['schema_name']['table_name'] should be a list of column
+ # names. Default to an asterisk
+ for schema, table in table_data:
+ self.dbmetadata[schema][table] = ['*']
- self.columns[unescaped_table_name].extend(columns)
- self.all_completions.update(columns)
+ self.all_completions.update(t[1] for t in table_data)
+
+ def extend_columns(self, column_data):
+
+ # column_data is a list of (schema_name, table_name, column_name) tuples
+ column_data = [self.escaped_names(d) for d in column_data]
+
+ for schema, table, column in column_data:
+ self.dbmetadata[schema][table].append(column)
+
+ self.all_completions.update(t[2] for t in column_data)
+
+ def set_search_path(self, search_path):
+ self.search_path = self.escaped_names(search_path)
def reset_completions(self):
self.databases = []
- self.tables = []
- self.columns = defaultdict(lambda: ['*'])
+ self.search_path = []
+ self.dbmetadata = {}
self.all_completions = set(self.keywords)
@staticmethod
@@ -119,36 +136,90 @@ class PGCompleter(Completer):
if not smart_completion:
return self.find_matches(word_before_cursor, self.all_completions)
- category, scope = suggest_type(document.text,
- document.text_before_cursor)
-
- if category == 'columns':
- _logger.debug("Completion: 'columns' Scope: %r", scope)
- scoped_cols = self.populate_scoped_cols(scope)
- return self.find_matches(word_before_cursor, scoped_cols)
- elif category == 'columns-and-functions':
- _logger.debug("Completion: 'columns-and-functions' Scope: %r",
- scope)
- scoped_cols = self.populate_scoped_cols(scope)
- return self.find_matches(word_before_cursor, scoped_cols +
- self.functions)
- elif category == 'tables':
- _logger.debug("Completion: 'tables' Scope: %r", scope)
- return self.find_matches(word_before_cursor, self.tables)
- elif category == 'tables-or-aliases':
- _logger.debug("Completion: 'tables-or-aliases' Scope: %r", scope)
- return self.find_matches(word_before_cursor, scope)
- elif category == 'databases':
- _logger.debug("Completion: 'databases' Scope: %r", scope)
- return self.find_matches(word_before_cursor, self.databases)
- elif category == 'keywords':
- _logger.debug("Completion: 'keywords' Scope: %r", scope)
- return self.find_matches(word_before_cursor, self.keywords +
- self.special_commands)
-
- def populate_scoped_cols(self, tables):
- scoped_cols = []
- for table in tables:
- unescaped_table_name = self.unescape_name(table)
- scoped_cols.extend(self.columns[unescaped_table_name])
- return scoped_cols
+ completions = []
+ suggestions = suggest_type(document.text, document.text_before_cursor)
+
+ for suggestion in suggestions:
+
+ _logger.debug('Suggestion type: %r', suggestion['type'])
+
+ if suggestion['type'] == 'column':
+ tables = suggestion['tables']
+ _logger.debug("Completion column scope: %r", tables)
+ scoped_cols = self.populate_scoped_cols(tables)
+ cols = self.find_matches(word_before_cursor, scoped_cols)
+ completions.extend(cols)
+
+ elif suggestion['type'] == 'function':
+ funcs = self.find_matches(word_before_cursor, self.functions)
+ completions.extend(funcs)
+
+ elif suggestion['type'] == 'schema':
+ schema_names = self.dbmetadata.keys()
+ schema_names = self.find_matches(word_before_cursor, schema_names)
+ completions.extend(schema_names)
+
+ elif suggestion['type'] == 'table':
+
+ if suggestion['schema']:
+ try:
+ tables = self.dbmetadata[suggestion['schema']].keys()
+ except KeyError:
+ #schema doesn't exist
+ tables = []
+ else:
+ schemas = self.search_path
+ meta = self.dbmetadata
+ tables = [tbl for schema in schemas
+ for tbl in meta[schema].keys()]
+
+ tables = self.find_matches(word_before_cursor, tables)
+ completions.extend(tables)
+ elif suggestion['type'] == 'alias':
+ aliases = suggestion['aliases']
+ aliases = self.find_matches(word_before_cursor, aliases)
+ completions.extend(aliases)
+ elif suggestion['type'] == 'database':
+ dbs = self.find_matches(word_before_cursor, self.databases)
+ completions.extend(dbs)
+
+ elif suggestion['type'] == 'keyword':
+ keywords = self.keywords + self.special_commands
+ keywords = self.find_matches(word_before_cursor, keywords)
+ completions.extend(keywords)
+
+ return completions
+
+ def populate_scoped_cols(self, scoped_tbls):
+ """ Find all columns in a set of scoped_tables
+ :param scoped_tbls: list of (schema, table, alias) tuples
+ :return: list of column names
+ """
+
+ columns = []
+ meta = self.dbmetadata
+
+ for tbl in scoped_tbls:
+ if tbl[0]:
+ # A fully qualified schema.table reference
+ schema = self.escape_name(tbl[0])
+ table = self.escape_name(tbl[1])
+ try:
+ # Get columns from the corresponding schema.table
+ columns.extend(meta[schema][table])
+ except KeyError:
+ # Either the schema or table doesn't exist
+ pass
+ else:
+ for schema in self.search_path:
+ table = self.escape_name(tbl[1])
+ try:
+ columns.extend(meta[schema][table])
+ break
+ except KeyError:
+ pass
+
+ return columns
+
+
+
diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py
index 5ba0dcd3..e16fbd41 100644
--- a/pgcli/pgexecute.py
+++ b/pgcli/pgexecute.py
@@ -4,7 +4,6 @@ import psycopg2
import psycopg2.extras
import psycopg2.extensions as ext
import sqlparse
-from collections import defaultdict
from .packages import pgspecial
PY2 = sys.version_info[0] == 2
@@ -30,13 +29,43 @@ psycopg2.extensions.set_wait_callback(psycopg2.extras.wait_select)
class PGExecute(object):
- tables_query = '''SELECT c.relname as "Name" FROM pg_catalog.pg_class c
- LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace WHERE
- c.relkind IN ('r','') AND n.nspname <> 'pg_catalog' AND n.nspname <>
- 'information_schema' AND n.nspname !~ '^pg_toast' AND
- pg_catalog.pg_table_is_visible(c.oid) ORDER BY 1;'''
+ search_path_query = '''
+ SELECT * FROM unnest(current_schemas(false))'''
+
+ schemata_query = '''
+ SELECT nspname
+ FROM pg_catalog.pg_namespace
+ WHERE nspname !~ '^pg_'
+ AND nspname <> 'information_schema'
+ ORDER BY 1 '''
+
+ tables_query = '''
+ SELECT n.nspname schema_name,
+ c.relname table_name
+ FROM pg_catalog.pg_class c
+ LEFT JOIN pg_catalog.pg_namespace n
+ ON n.oid = c.relnamespace
+ WHERE c.relkind IN ('r','v', 'm') -- table, view, materialized view
+ AND n.nspname !~ '^pg_toast'
+ AND n.nspname NOT IN ('information_schema', 'pg_catalog')
+ ORDER BY 1,2;'''
+
+ columns_query = '''
+ SELECT nsp.nspname schema_name,
+ cls.relname table_name,
+ att.attname column_name
+ FROM pg_catalog.pg_attribute att
+ INNER JOIN pg_catalog.pg_class cls
+ ON att.attrelid = cls.oid
+ INNER JOIN pg_catalog.pg_namespace nsp
+ ON cls.relnamespace = nsp.oid
+ WHERE cls.relkind IN ('r', 'v', 'm')
+ AND nsp.nspname !~ '^pg_'
+ AND nsp.nspname <> 'information_schema'
+ AND NOT att.attisdropped
+ AND att.attnum > 0
+ ORDER BY 1, 2, 3'''
- columns_query = '''SELECT table_name, column_name FROM information_schema.columns'''
databases_query = """SELECT d.datname as "Name",
pg_catalog.pg_get_userbyid(d.datdba) as "Owner",
@@ -133,22 +162,37 @@ class PGExecute(object):
_logger.debug('No rows in result.')
return (None, None, cur.statusmessage)
+ def search_path(self):
+ """Returns the current search path as a list of schema names"""
+
+ with self.conn.cursor() as cur:
+ _logger.debug('Search path query. sql: %r', self.search_path_query)
+ cur.execute(self.search_path_query)
+ return [x[0] for x in cur.fetchall()]
+
+ def schemata(self):
+ """Returns a list of schema names in the database"""
+
+ with self.conn.cursor() as cur:
+ _logger.debug('Schemata Query. sql: %r', self.schemata_query)
+ cur.execute(self.schemata_query)
+ return [x[0] for x in cur.fetchall()]
+
def tables(self):
- """ Returns tuple (sorted_tables, columns). Columns is a dictionary of
- table name -> list of columns """
- columns = defaultdict(list)
+ """Returns a list of (schema_name, table_name) tuples """
+
with self.conn.cursor() as cur:
_logger.debug('Tables Query. sql: %r', self.tables_query)
cur.execute(self.tables_query)
- tables = [x[0] for x in cur.fetchall()]
+ return cur.fetchall()
- table_set = set(tables)
+ def columns(self):
+ """Returns a list of (schema_name, table_name, column_name) tuples"""
+
+ with self.conn.cursor() as cur:
_logger.debug('Columns Query. sql: %r', self.columns_query)
cur.execute(self.columns_query)
- for table, column in cur.fetchall():
- if table in table_set:
- columns[table].append(column)
- return tables, columns
+ return cur.fetchall()
def databases(self):
with self.conn.cursor() as cur:
diff --git a/setup.py b/setup.py
index c5a0feac..f52d2c51 100644
--- a/setup.py
+++ b/setup.py
@@ -28,7 +28,7 @@ setup(
'jedi == 0.8.1', # Temporary fix for installation woes.
'prompt_toolkit==0.26',
'psycopg2 >= 2.5.4',
- 'sqlparse >= 0.1.14',
+ 'sqlparse >= 0.1.14'
],
entry_points='''
[console_scripts]
diff --git a/tests/test_parseutils.py b/tests/test_parseutils.py
index e49aec68..b90f3855 100644
--- a/tests/test_parseutils.py
+++ b/tests/test_parseutils.py
@@ -1,3 +1,4 @@
+import pytest
from pgcli.packages.parseutils import extract_tables
@@ -7,48 +8,77 @@ def test_empty_string():
def test_simple_select_single_table():
tables = extract_tables('select * from abc')
- assert tables == ['abc']
+ assert tables == [(None, 'abc', None)]
+
+def test_simple_select_single_table_schema_qualified():
+ tables = extract_tables('select * from abc.def')
+ assert tables == [('abc', 'def', None)]
def test_simple_select_multiple_tables():
tables = extract_tables('select * from abc, def')
- assert tables == ['abc', 'def']
+ assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)]
+
+def test_simple_select_multiple_tables_schema_qualified():
+ tables = extract_tables('select * from abc.def, ghi.jkl')
+ assert sorted(tables) == [('abc', 'def', None), ('ghi', 'jkl', None)]
def test_simple_select_with_cols_single_table():
tables = extract_tables('select a,b from abc')
- assert tables == ['abc']
+ assert tables == [(None, 'abc', None)]
+
+def test_simple_select_with_cols_single_table_schema_qualified():
+ tables = extract_tables('select a,b from abc.def')
+ assert tables == [('abc', 'def', None)]
def test_simple_select_with_cols_multiple_tables():
tables = extract_tables('select a,b from abc, def')
- assert tables == ['abc', 'def']
+ assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)]
+
+def test_simple_select_with_cols_multiple_tables():
+ tables = extract_tables('select a,b from abc.def, def.ghi')
+ assert sorted(tables) == [('abc', 'def', None), ('def', 'ghi', None)]
def test_select_with_hanging_comma_single_table():
tables = extract_tables('select a, from abc')
- assert tables == ['abc']
+ assert tables == [(None, 'abc', None)]
def test_select_with_hanging_comma_multiple_tables():
tables = extract_tables('select a, from abc, def')
- assert tables == ['abc', 'def']
+ assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)]
+
+def test_select_with_hanging_period_multiple_tables():
+ tables = extract_tables('SELECT t1. FROM tabl1 t1, tabl2 t2')
+ assert sorted(tables) == [(None, 'tabl1', 't1'), (None, 'tabl2', 't2')]
def test_simple_insert_single_table():
tables = extract_tables('insert into abc (id, name) values (1, "def")')
- assert tables == ['abc']
+
+ # sqlparse mistakenly assigns an alias to the table
+ # assert tables == [(None, 'abc', None)]
+ assert tables == [(None, 'abc', 'abc')]
+
+@pytest.mark.xfail
+def test_simple_insert_single_table_schema_qualified():
+ tables = extract_tables('insert into abc.def (id, name) values (1, "def")')
+ assert tables == [('abc', 'def', None)]
def test_simple_update_table():
tables = extract_tables('update abc set id = 1')
- assert tables == ['abc']
+ assert tables == [(None, 'abc', None)]
+
+def test_simple_update_table():
+ tables = extract_tables('update abc.def set id = 1')
+ assert tables == [('abc', 'def', None)]
def test_join_table():
- expected = {'a': 'abc', 'd': 'def'}
tables = extract_tables('SELECT * FROM abc a JOIN def d ON a.id = d.num')
- tables_aliases = extract_tables(
- 'SELECT * FROM abc a JOIN def d ON a.id = d.num', True)
- assert tables == sorted(expected.values())
- assert tables_aliases == expected
+ assert sorted(tables) == [(None, 'abc', 'a'), (None, 'def', 'd')]
+
+def test_join_table_schema_qualified():
+ tables = extract_tables('SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num')
+ assert tables == [('abc', 'def', 'x'), ('ghi', 'jkl', 'y')]
def test_join_as_table():
- expected = {'m': 'my_table'}
- assert extract_tables(
- 'SELECT * FROM my_table AS m WHERE m.a > 5') == \
- sorted(expected.values())
- assert extract_tables(
- 'SELECT * FROM my_table AS m WHERE m.a > 5', True) == expected
+ tables = extract_tables('SELECT * FROM my_table AS m WHERE m.a > 5')
+ assert tables == [(None, 'my_table', 'm')]
+
diff --git a/tests/test_pgexecute.py b/tests/test_pgexecute.py
index 94cb6ef0..37815766 100644
--- a/tests/test_pgexecute.py
+++ b/tests/test_pgexecute.py
@@ -16,14 +16,22 @@ def test_conn(executor):
SELECT 1""")
@dbtest
-def test_table_and_columns_query(executor):
+def test_schemata_table_and_columns_query(executor):
run(executor, "create table a(x text, y text)")
run(executor, "create table b(z text)")
+ run(executor, "create schema schema1")
+ run(executor, "create table schema1.c (w text)")
+ run(executor, "create schema schema2")
- tables, columns = executor.tables()
- assert tables == ['a', 'b']
- assert columns['a'] == ['x', 'y']
- assert columns['b'] == ['z']
+ assert executor.schemata() == ['public', 'schema1', 'schema2']
+ assert executor.tables() == [
+ ('public', 'a'), ('public', 'b'), ('schema1', 'c')]
+
+ assert executor.columns() == [
+ ('public', 'a', 'x'), ('public', 'a', 'y'),
+ ('public', 'b', 'z'), ('schema1', 'c', 'w')]
+
+ assert executor.search_path() == ['public']
@dbtest
def test_database_list(executor):
diff --git a/tests/test_pgspecial.py b/tests/test_pgspecial.py
index e69de29b..f02efac9 100644
--- a/tests/test_pgspecial.py
+++ b/tests/test_pgspecial.py
@@ -0,0 +1,19 @@
+from pgcli.packages.sqlcompletion import suggest_type
+from test_sqlcompletion import sorted_dicts
+
+def test_d_suggests_tables_and_schemas():
+ suggestions = suggest_type('\d ', '\d ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'schema'}, {'type': 'table', 'schema': []}])
+
+ suggestions = suggest_type('\d xxx', '\d xxx')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'schema'}, {'type': 'table', 'schema': []}])
+
+def test_d_dot_suggests_schema_qualified_tables():
+ suggestions = suggest_type('\d myschema.', '\d myschema.')
+ assert suggestions == [{'type': 'table', 'schema': 'myschema'}]
+
+ suggestions = suggest_type('\d myschema.xxx', '\d myschema.xxx')
+ assert suggestions == [{'type': 'table', 'schema': 'myschema'}]
+
diff --git a/tests/test_smart_completion_multiple_schemata.py b/tests/test_smart_completion_multiple_schemata.py
new file mode 100644
index 00000000..d36f4c5a
--- /dev/null
+++ b/tests/test_smart_completion_multiple_schemata.py
@@ -0,0 +1,227 @@
+import pytest
+from prompt_toolkit.completion import Completion
+from prompt_toolkit.document import Document
+
+metadata = {
+ 'public': {
+ 'users': ['id', 'email', 'first_name', 'last_name'],
+ 'orders': ['id', 'ordered_date', 'status'],
+ 'select': ['id', 'insert', 'ABC']
+ },
+ 'custom': {
+ 'users': ['id', 'phone_number'],
+ 'products': ['id', 'product_name', 'price'],
+ 'shipments': ['id', 'address', 'user_id']
+ }
+ }
+
+@pytest.fixture
+def completer():
+
+ import pgcli.pgcompleter as pgcompleter
+ comp = pgcompleter.PGCompleter(smart_completion=True)
+
+ schemata, tables, columns = [], [], []
+
+ for schema, tbls in metadata.items():
+ schemata.append(schema)
+
+ for table, cols in tbls.items():
+ tables.append((schema, table))
+ columns.extend([(schema, table, col) for col in cols])
+
+ comp.extend_schemata(schemata)
+ comp.extend_tables(tables)
+ comp.extend_columns(columns)
+ comp.set_search_path(['public'])
+
+ return comp
+
+@pytest.fixture
+def complete_event():
+ from mock import Mock
+ return Mock()
+
+def test_schema_or_visible_table_completion(completer, complete_event):
+ text = 'SELECT * FROM '
+ position = len(text)
+ result = completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event)
+ assert set(result) == set([Completion(text='public', start_position=0),
+ Completion(text='custom', start_position=0),
+ Completion(text='users', start_position=0),
+ Completion(text='"select"', start_position=0),
+ Completion(text='orders', start_position=0)])
+
+def test_suggested_column_names_from_shadowed_visible_table(completer, complete_event):
+ """
+ Suggest column and function names when selecting from table
+ :param completer:
+ :param complete_event:
+ :return:
+ """
+ text = 'SELECT from users'
+ position = len('SELECT ')
+ result = set(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert set(result) == set([
+ Completi