diff options
author | Amjith Ramanujam <amjith.r@gmail.com> | 2015-01-26 20:26:52 -0800 |
---|---|---|
committer | Amjith Ramanujam <amjith.r@gmail.com> | 2015-01-26 20:26:52 -0800 |
commit | 6944ef60f83e7e116aa687a972b5402f1a1bbe04 (patch) | |
tree | 593d5c212b78f3de496c16044de42a565fed60de | |
parent | 750206c779060ea8b1cf19e2058a87cd536363e2 (diff) | |
parent | 7d3f276e8330854e0153d25e02df1b103529e640 (diff) |
Merge pull request #127 from darikg/schema_autocomplete
Make autocomplete schema-aware
-rwxr-xr-x | pgcli/main.py | 40 | ||||
-rw-r--r-- | pgcli/packages/parseutils.py | 26 | ||||
-rw-r--r-- | pgcli/packages/sqlcompletion.py | 65 | ||||
-rw-r--r-- | pgcli/pgcompleter.py | 187 | ||||
-rw-r--r-- | pgcli/pgexecute.py | 76 | ||||
-rw-r--r-- | setup.py | 2 | ||||
-rw-r--r-- | tests/test_parseutils.py | 68 | ||||
-rw-r--r-- | tests/test_pgexecute.py | 18 | ||||
-rw-r--r-- | tests/test_pgspecial.py | 19 | ||||
-rw-r--r-- | tests/test_smart_completion_multiple_schemata.py | 227 | ||||
-rw-r--r-- | tests/test_smart_completion_public_schema_only.py (renamed from tests/test_smart_completion.py) | 46 | ||||
-rw-r--r-- | tests/test_sqlcompletion.py | 170 | ||||
-rw-r--r-- | tests/utils.py | 6 |
13 files changed, 739 insertions, 211 deletions
diff --git a/pgcli/main.py b/pgcli/main.py index 7ed04c39..f47f7888 100755 --- a/pgcli/main.py +++ b/pgcli/main.py @@ -214,6 +214,12 @@ class PGCli(object): end = time() total += end - start mutating = mutating or is_mutating(status) + + if need_search_path_refresh(document.text, status): + logger.debug('Refreshing search path') + completer.set_search_path(pgexecute.search_path()) + logger.debug('Search path: %r', completer.search_path) + except KeyboardInterrupt: # Restart connection to the database pgexecute.connect() @@ -262,13 +268,16 @@ class PGCli(object): return less_opts def refresh_completions(self): - self.completer.reset_completions() - tables, columns = self.pgexecute.tables() - self.completer.extend_table_names(tables) - for table in tables: - table = table[1:-1] if table[0] == '"' and table[-1] == '"' else table - self.completer.extend_column_names(table, columns[table]) - self.completer.extend_database_names(self.pgexecute.databases()) + completer = self.completer + completer.reset_completions() + + pgexecute = self.pgexecute + + completer.set_search_path(pgexecute.search_path()) + completer.extend_schemata(pgexecute.schemata()) + completer.extend_tables(pgexecute.tables()) + completer.extend_columns(pgexecute.columns()) + completer.extend_database_names(pgexecute.databases()) def get_completions(self, text, cursor_positition): return self.completer.get_completions( @@ -329,6 +338,22 @@ def need_completion_refresh(sql): except Exception: return False +def need_search_path_refresh(sql, status): + # note that sql may be a multi-command query, but status belongs to an + # individual query, since pgexecute handles splitting up multi-commands + try: + status = status.split()[0] + if status.lower() == 'set': + # Since sql could be a multi-line query, it's hard to robustly + # pick out the variable name that's been set. Err on the side of + # false positives here, since the worst case is we refresh the + # search path when it's not necessary + return 'search_path' in sql.lower() + else: + return False + except Exception: + return False + def is_mutating(status): """Determines if the statement is mutating based on the status.""" if not status: @@ -349,6 +374,5 @@ def quit_command(sql): or sql.strip() == '\q' or sql.strip() == ':q') - if __name__ == "__main__": cli() diff --git a/pgcli/packages/parseutils.py b/pgcli/packages/parseutils.py index 370bb4db..4122a332 100644 --- a/pgcli/packages/parseutils.py +++ b/pgcli/packages/parseutils.py @@ -101,50 +101,50 @@ def extract_from_part(parsed, stop_at_punctuation=True): break def extract_table_identifiers(token_stream): + """yields tuples of (schema_name, table_name, table_alias)""" + for item in token_stream: if isinstance(item, IdentifierList): for identifier in item.get_identifiers(): # Sometimes Keywords (such as FROM ) are classified as # identifiers which don't have the get_real_name() method. try: + schema_name = identifier.get_parent_name() real_name = identifier.get_real_name() except AttributeError: continue if real_name: - yield (real_name, identifier.get_alias() or real_name) + yield (schema_name, real_name, identifier.get_alias()) elif isinstance(item, Identifier): real_name = item.get_real_name() + schema_name = item.get_parent_name() if real_name: - yield (real_name, item.get_alias() or real_name) + yield (schema_name, real_name, item.get_alias()) else: name = item.get_name() - yield (name, item.get_alias() or name) + yield (None, name, item.get_alias() or name) elif isinstance(item, Function): - yield (item.get_name(), item.get_name()) + yield (None, item.get_name(), item.get_name()) # extract_tables is inspired from examples in the sqlparse lib. -def extract_tables(sql, include_alias=False): +def extract_tables(sql): """Extract the table names from an SQL statment. - Returns a list of table names if include_alias=False (default). - If include_alias=True, then a dictionary is returned where the keys are - aliases and values are real table names. + Returns a list of (schema, table, alias) tuples """ parsed = sqlparse.parse(sql) if not parsed: return [] + # INSERT statements must stop looking for tables at the sign of first # Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2) # abc is the table name, but if we don't stop at the first lparen, then # we'll identify abc, col1 and col2 as table names. insert_stmt = parsed[0].token_first().value.lower() == 'insert' stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt) - if include_alias: - return dict((alias, t) for t, alias in extract_table_identifiers(stream)) - else: - return [x[0] for x in extract_table_identifiers(stream)] + return list(extract_table_identifiers(stream)) def find_prev_keyword(sql): if not sql.strip(): @@ -156,4 +156,4 @@ def find_prev_keyword(sql): if __name__ == '__main__': sql = 'select * from (select t. from tabl t' - print (extract_tables(sql, True)) + print (extract_tables(sql)) diff --git a/pgcli/packages/sqlcompletion.py b/pgcli/packages/sqlcompletion.py index f409dae0..45e38fed 100644 --- a/pgcli/packages/sqlcompletion.py +++ b/pgcli/packages/sqlcompletion.py @@ -66,28 +66,65 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text): # If the lparen is preceeded by a space chances are we're about to # do a sub-select. if last_word(text_before_cursor, 'all_punctuations').startswith('('): - return 'keywords', [] - return 'columns', extract_tables(full_text) + return [{'type': 'keyword'}] + + return [{'type': 'column', 'tables': extract_tables(full_text)}] + if token_v.lower() in ('set', 'by', 'distinct'): - return 'columns', extract_tables(full_text) + return [{'type': 'column', 'tables': extract_tables(full_text)}] elif token_v.lower() in ('select', 'where', 'having'): - return 'columns-and-functions', extract_tables(full_text) + return [{'type': 'column', 'tables': extract_tables(full_text)}, + {'type': 'function'}] elif token_v.lower() in ('from', 'update', 'into', 'describe', 'join', 'table'): - return 'tables', [] + return [{'type': 'schema'}, {'type': 'table', 'schema': []}] elif token_v.lower() == 'on': - tables = extract_tables(full_text, include_alias=True) - return 'tables-or-aliases', tables.keys() + tables = extract_tables(full_text) # [(schema, table, alias), ...] + + # Use table alias if there is one, otherwise the table name + alias = [t[2] or t[1] for t in tables] + + return [{'type': 'alias', 'aliases': alias}] + elif token_v in ('d',): # \d - return 'tables', [] + # Apparently "\d <other>" is parsed by sqlparse as + # Identifer('d', Whitespace, '<other>') + if len(token.tokens) > 2: + other = token.tokens[-1].value + identifiers = other.split('.') + if len(identifiers) == 1: + # "\d table" or "\d schema" + return [{'type': 'schema'}, {'type': 'table', 'schema': []}] + elif len(identifiers) == 2: + # \d schema.table + return [{'type': 'table', 'schema': identifiers[0]}] + else: + return [{'type': 'schema'}, {'type': 'table', 'schema': []}] elif token_v.lower() in ('c', 'use'): # \c - return 'databases', [] + return [{'type': 'database'}] elif token_v.endswith(',') or token_v == '=': prev_keyword = find_prev_keyword(text_before_cursor) if prev_keyword: - return suggest_based_on_last_token(prev_keyword, text_before_cursor, full_text) + return suggest_based_on_last_token( + prev_keyword, text_before_cursor, full_text) elif token_v.endswith('.'): - current_alias = last_word(token_v[:-1]) - tables = extract_tables(full_text, include_alias=True) - return 'columns', [tables.get(current_alias) or current_alias] - return 'keywords', [] + suggestions = [] + + identifier = last_word(token_v[:-1], 'all_punctuations') + + # TABLE.<suggestion> or SCHEMA.TABLE.<suggestion> + tables = extract_tables(full_text) + tables = [t for t in tables if identifies(identifier, *t)] + suggestions.append({'type': 'column', 'tables': tables}) + + # SCHEMA.<suggestion> + suggestions.append({'type': 'table', 'schema': identifier}) + + return suggestions + + return [{'type': 'keyword'}] + + +def identifies(id, schema, table, alias): + return id == alias or id == table or ( + schema and (id == schema + '.' + table)) diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py index a8123a72..3585abeb 100644 --- a/pgcli/pgcompleter.py +++ b/pgcli/pgcompleter.py @@ -1,11 +1,11 @@ from __future__ import print_function import logging -from collections import defaultdict from prompt_toolkit.completion import Completer, Completion from .packages.sqlcompletion import suggest_type from .packages.parseutils import last_word from re import compile + _logger = logging.getLogger(__name__) class PGCompleter(Completer): @@ -21,7 +21,7 @@ class PGCompleter(Completer): 'MAXEXTENTS', 'MINUS', 'MLSLABEL', 'MODE', 'MODIFY', 'NOAUDIT', 'NOCOMPRESS', 'NOT', 'NOWAIT', 'NULL', 'NUMBER', 'OF', 'OFFLINE', 'ON', 'ONLINE', 'OPTION', 'OR', 'ORDER BY', 'OUTER', 'PCTFREE', - 'PRIMARY', 'PRIOR', 'PRIVILEGES', 'PUBLIC', 'RAW', 'RENAME', + 'PRIMARY', 'PRIOR', 'PRIVILEGES', 'RAW', 'RENAME', 'RESOURCE', 'REVOKE', 'RIGHT', 'ROW', 'ROWID', 'ROWNUM', 'ROWS', 'SELECT', 'SESSION', 'SET', 'SHARE', 'SIZE', 'SMALLINT', 'START', 'SUCCESSFUL', 'SYNONYM', 'SYSDATE', 'TABLE', 'THEN', 'TO', @@ -33,15 +33,6 @@ class PGCompleter(Completer): 'LCASE', 'LEN', 'MAX', 'MIN', 'MID', 'NOW', 'ROUND', 'SUM', 'TOP', 'UCASE'] - special_commands = [] - - databases = [] - tables = [] - # This will create a defaultdict which is initialized with a list that has - # a '*' by default. - columns = defaultdict(lambda: ['*']) - all_completions = set(keywords + functions) - def __init__(self, smart_completion=True): super(self.__class__, self).__init__() self.smart_completion = smart_completion @@ -50,8 +41,15 @@ class PGCompleter(Completer): self.reserved_words.update(x.split()) self.name_pattern = compile("^[_a-z][_a-z0-9\$]*$") + self.special_commands = [] + self.databases = [] + self.dbmetadata = {} + self.search_path = [] + + self.all_completions = set(self.keywords + self.functions) + def escape_name(self, name): - if ((not self.name_pattern.match(name)) + if name and ((not self.name_pattern.match(name)) or (name.upper() in self.reserved_words) or (name.upper() in self.functions)): name = '"%s"' % name @@ -60,7 +58,7 @@ class PGCompleter(Completer): def unescape_name(self, name): """ Unquote a string.""" - if name[0] == '"' and name[-1] == '"': + if name and name[0] == '"' and name[-1] == '"': name = name[1:-1] return name @@ -75,31 +73,50 @@ class PGCompleter(Completer): def extend_database_names(self, databases): databases = self.escaped_names(databases) - self.databases.extend(databases) def extend_keywords(self, additional_keywords): self.keywords.extend(additional_keywords) self.all_completions.update(additional_keywords) - def extend_table_names(self, tables): - tables = self.escaped_names(tables) + def extend_schemata(self, schemata): + + # data is a DataFrame with columns [schema] + schemata = self.escaped_names(schemata) + for schema in schemata: + self.dbmetadata[schema] = {} + + self.all_completions.update(schemata) - self.tables.extend(tables) - self.all_completions.update(tables) + def extend_tables(self, table_data): - def extend_column_names(self, table, columns): - columns = self.escaped_names(columns) + # table_data is a list of (schema_name, table_name) tuples + table_data = [self.escaped_names(d) for d in table_data] - unescaped_table_name = self.unescape_name(table) + # dbmetadata['schema_name']['table_name'] should be a list of column + # names. Default to an asterisk + for schema, table in table_data: + self.dbmetadata[schema][table] = ['*'] - self.columns[unescaped_table_name].extend(columns) - self.all_completions.update(columns) + self.all_completions.update(t[1] for t in table_data) + + def extend_columns(self, column_data): + + # column_data is a list of (schema_name, table_name, column_name) tuples + column_data = [self.escaped_names(d) for d in column_data] + + for schema, table, column in column_data: + self.dbmetadata[schema][table].append(column) + + self.all_completions.update(t[2] for t in column_data) + + def set_search_path(self, search_path): + self.search_path = self.escaped_names(search_path) def reset_completions(self): self.databases = [] - self.tables = [] - self.columns = defaultdict(lambda: ['*']) + self.search_path = [] + self.dbmetadata = {} self.all_completions = set(self.keywords) @staticmethod @@ -119,36 +136,90 @@ class PGCompleter(Completer): if not smart_completion: return self.find_matches(word_before_cursor, self.all_completions) - category, scope = suggest_type(document.text, - document.text_before_cursor) - - if category == 'columns': - _logger.debug("Completion: 'columns' Scope: %r", scope) - scoped_cols = self.populate_scoped_cols(scope) - return self.find_matches(word_before_cursor, scoped_cols) - elif category == 'columns-and-functions': - _logger.debug("Completion: 'columns-and-functions' Scope: %r", - scope) - scoped_cols = self.populate_scoped_cols(scope) - return self.find_matches(word_before_cursor, scoped_cols + - self.functions) - elif category == 'tables': - _logger.debug("Completion: 'tables' Scope: %r", scope) - return self.find_matches(word_before_cursor, self.tables) - elif category == 'tables-or-aliases': - _logger.debug("Completion: 'tables-or-aliases' Scope: %r", scope) - return self.find_matches(word_before_cursor, scope) - elif category == 'databases': - _logger.debug("Completion: 'databases' Scope: %r", scope) - return self.find_matches(word_before_cursor, self.databases) - elif category == 'keywords': - _logger.debug("Completion: 'keywords' Scope: %r", scope) - return self.find_matches(word_before_cursor, self.keywords + - self.special_commands) - - def populate_scoped_cols(self, tables): - scoped_cols = [] - for table in tables: - unescaped_table_name = self.unescape_name(table) - scoped_cols.extend(self.columns[unescaped_table_name]) - return scoped_cols + completions = [] + suggestions = suggest_type(document.text, document.text_before_cursor) + + for suggestion in suggestions: + + _logger.debug('Suggestion type: %r', suggestion['type']) + + if suggestion['type'] == 'column': + tables = suggestion['tables'] + _logger.debug("Completion column scope: %r", tables) + scoped_cols = self.populate_scoped_cols(tables) + cols = self.find_matches(word_before_cursor, scoped_cols) + completions.extend(cols) + + elif suggestion['type'] == 'function': + funcs = self.find_matches(word_before_cursor, self.functions) + completions.extend(funcs) + + elif suggestion['type'] == 'schema': + schema_names = self.dbmetadata.keys() + schema_names = self.find_matches(word_before_cursor, schema_names) + completions.extend(schema_names) + + elif suggestion['type'] == 'table': + + if suggestion['schema']: + try: + tables = self.dbmetadata[suggestion['schema']].keys() + except KeyError: + #schema doesn't exist + tables = [] + else: + schemas = self.search_path + meta = self.dbmetadata + tables = [tbl for schema in schemas + for tbl in meta[schema].keys()] + + tables = self.find_matches(word_before_cursor, tables) + completions.extend(tables) + elif suggestion['type'] == 'alias': + aliases = suggestion['aliases'] + aliases = self.find_matches(word_before_cursor, aliases) + completions.extend(aliases) + elif suggestion['type'] == 'database': + dbs = self.find_matches(word_before_cursor, self.databases) + completions.extend(dbs) + + elif suggestion['type'] == 'keyword': + keywords = self.keywords + self.special_commands + keywords = self.find_matches(word_before_cursor, keywords) + completions.extend(keywords) + + return completions + + def populate_scoped_cols(self, scoped_tbls): + """ Find all columns in a set of scoped_tables + :param scoped_tbls: list of (schema, table, alias) tuples + :return: list of column names + """ + + columns = [] + meta = self.dbmetadata + + for tbl in scoped_tbls: + if tbl[0]: + # A fully qualified schema.table reference + schema = self.escape_name(tbl[0]) + table = self.escape_name(tbl[1]) + try: + # Get columns from the corresponding schema.table + columns.extend(meta[schema][table]) + except KeyError: + # Either the schema or table doesn't exist + pass + else: + for schema in self.search_path: + table = self.escape_name(tbl[1]) + try: + columns.extend(meta[schema][table]) + break + except KeyError: + pass + + return columns + + + diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py index 5ba0dcd3..e16fbd41 100644 --- a/pgcli/pgexecute.py +++ b/pgcli/pgexecute.py @@ -4,7 +4,6 @@ import psycopg2 import psycopg2.extras import psycopg2.extensions as ext import sqlparse -from collections import defaultdict from .packages import pgspecial PY2 = sys.version_info[0] == 2 @@ -30,13 +29,43 @@ psycopg2.extensions.set_wait_callback(psycopg2.extras.wait_select) class PGExecute(object): - tables_query = '''SELECT c.relname as "Name" FROM pg_catalog.pg_class c - LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace WHERE - c.relkind IN ('r','') AND n.nspname <> 'pg_catalog' AND n.nspname <> - 'information_schema' AND n.nspname !~ '^pg_toast' AND - pg_catalog.pg_table_is_visible(c.oid) ORDER BY 1;''' + search_path_query = ''' + SELECT * FROM unnest(current_schemas(false))''' + + schemata_query = ''' + SELECT nspname + FROM pg_catalog.pg_namespace + WHERE nspname !~ '^pg_' + AND nspname <> 'information_schema' + ORDER BY 1 ''' + + tables_query = ''' + SELECT n.nspname schema_name, + c.relname table_name + FROM pg_catalog.pg_class c + LEFT JOIN pg_catalog.pg_namespace n + ON n.oid = c.relnamespace + WHERE c.relkind IN ('r','v', 'm') -- table, view, materialized view + AND n.nspname !~ '^pg_toast' + AND n.nspname NOT IN ('information_schema', 'pg_catalog') + ORDER BY 1,2;''' + + columns_query = ''' + SELECT nsp.nspname schema_name, + cls.relname table_name, + att.attname column_name + FROM pg_catalog.pg_attribute att + INNER JOIN pg_catalog.pg_class cls + ON att.attrelid = cls.oid + INNER JOIN pg_catalog.pg_namespace nsp + ON cls.relnamespace = nsp.oid + WHERE cls.relkind IN ('r', 'v', 'm') + AND nsp.nspname !~ '^pg_' + AND nsp.nspname <> 'information_schema' + AND NOT att.attisdropped + AND att.attnum > 0 + ORDER BY 1, 2, 3''' - columns_query = '''SELECT table_name, column_name FROM information_schema.columns''' databases_query = """SELECT d.datname as "Name", pg_catalog.pg_get_userbyid(d.datdba) as "Owner", @@ -133,22 +162,37 @@ class PGExecute(object): _logger.debug('No rows in result.') return (None, None, cur.statusmessage) + def search_path(self): + """Returns the current search path as a list of schema names""" + + with self.conn.cursor() as cur: + _logger.debug('Search path query. sql: %r', self.search_path_query) + cur.execute(self.search_path_query) + return [x[0] for x in cur.fetchall()] + + def schemata(self): + """Returns a list of schema names in the database""" + + with self.conn.cursor() as cur: + _logger.debug('Schemata Query. sql: %r', self.schemata_query) + cur.execute(self.schemata_query) + return [x[0] for x in cur.fetchall()] + def tables(self): - """ Returns tuple (sorted_tables, columns). Columns is a dictionary of - table name -> list of columns """ - columns = defaultdict(list) + """Returns a list of (schema_name, table_name) tuples """ + with self.conn.cursor() as cur: _logger.debug('Tables Query. sql: %r', self.tables_query) cur.execute(self.tables_query) - tables = [x[0] for x in cur.fetchall()] + return cur.fetchall() - table_set = set(tables) + def columns(self): + """Returns a list of (schema_name, table_name, column_name) tuples""" + + with self.conn.cursor() as cur: _logger.debug('Columns Query. sql: %r', self.columns_query) cur.execute(self.columns_query) - for table, column in cur.fetchall(): - if table in table_set: - columns[table].append(column) - return tables, columns + return cur.fetchall() def databases(self): with self.conn.cursor() as cur: @@ -28,7 +28,7 @@ setup( 'jedi == 0.8.1', # Temporary fix for installation woes. 'prompt_toolkit==0.26', 'psycopg2 >= 2.5.4', - 'sqlparse >= 0.1.14', + 'sqlparse >= 0.1.14' ], entry_points=''' [console_scripts] diff --git a/tests/test_parseutils.py b/tests/test_parseutils.py index e49aec68..b90f3855 100644 --- a/tests/test_parseutils.py +++ b/tests/test_parseutils.py @@ -1,3 +1,4 @@ +import pytest from pgcli.packages.parseutils import extract_tables @@ -7,48 +8,77 @@ def test_empty_string(): def test_simple_select_single_table(): tables = extract_tables('select * from abc') - assert tables == ['abc'] + assert tables == [(None, 'abc', None)] + +def test_simple_select_single_table_schema_qualified(): + tables = extract_tables('select * from abc.def') + assert tables == [('abc', 'def', None)] def test_simple_select_multiple_tables(): tables = extract_tables('select * from abc, def') - assert tables == ['abc', 'def'] + assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)] + +def test_simple_select_multiple_tables_schema_qualified(): + tables = extract_tables('select * from abc.def, ghi.jkl') + assert sorted(tables) == [('abc', 'def', None), ('ghi', 'jkl', None)] def test_simple_select_with_cols_single_table(): tables = extract_tables('select a,b from abc') - assert tables == ['abc'] + assert tables == [(None, 'abc', None)] + +def test_simple_select_with_cols_single_table_schema_qualified(): + tables = extract_tables('select a,b from abc.def') + assert tables == [('abc', 'def', None)] def test_simple_select_with_cols_multiple_tables(): tables = extract_tables('select a,b from abc, def') - assert tables == ['abc', 'def'] + assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)] + +def test_simple_select_with_cols_multiple_tables(): + tables = extract_tables('select a,b from abc.def, def.ghi') + assert sorted(tables) == [('abc', 'def', None), ('def', 'ghi', None)] def test_select_with_hanging_comma_single_table(): tables = extract_tables('select a, from abc') - assert tables == ['abc'] + assert tables == [(None, 'abc', None)] def test_select_with_hanging_comma_multiple_tables(): tables = extract_tables('select a, from abc, def') - assert tables == ['abc', 'def'] + assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)] + +def test_select_with_hanging_period_multiple_tables(): + tables = extract_tables('SELECT t1. FROM tabl1 t1, tabl2 t2') + assert sorted(tables) == [(None, 'tabl1', 't1'), (None, 'tabl2', 't2')] def test_simple_insert_single_table(): tables = extract_tables('insert into abc (id, name) values (1, "def")') - assert tables == ['abc'] + + # sqlparse mistakenly assigns an alias to the table + # assert tables == [(None, 'abc', None)] + assert tables == [(None, 'abc', 'abc')] + +@pytest.mark.xfail +def test_simple_insert_single_table_schema_qualified(): + tables = extract_tables('insert into abc.def (id, name) values (1, "def")') + assert tables == [('abc', 'def', None)] def test_simple_update_table(): tables = extract_tables('update abc set id = 1') - assert tables == ['abc'] + assert tables == [(None, 'abc', None)] + +def test_simple_update_table(): + tables = extract_tables('update abc.def set id = 1') + assert tables == [('abc', 'def', None)] def test_join_table(): - expected = {'a': 'abc', 'd': 'def'} tables = extract_tables('SELECT * FROM abc a JOIN def d ON a.id = d.num') - tables_aliases = extract_tables( - 'SELECT * FROM abc a JOIN def d ON a.id = d.num', True) - assert tables == sorted(expected.values()) - assert tables_aliases == expected + assert sorted(tables) == [(None, 'abc', 'a'), (None, 'def', 'd')] + +def test_join_table_schema_qualified(): + tables = extract_tables('SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num') + assert tables == [('abc', 'def', 'x'), ('ghi', 'jkl', 'y')] def test_join_as_table(): - expected = {'m': 'my_table'} - assert extract_tables( - 'SELECT * FROM my_table AS m WHERE m.a > 5') == \ - sorted(expected.values()) - assert extract_tables( - 'SELECT * FROM my_table AS m WHERE m.a > 5', True) == expected + tables = extract_tables('SELECT * FROM my_table AS m WHERE m.a > 5') + assert tables == [(None, 'my_table', 'm')] + diff --git a/tests/test_pgexecute.py b/tests/test_pgexecute.py index 94cb6ef0..37815766 100644 --- a/tests/test_pgexecute.py +++ b/tests/test_pgexecute.py @@ -16,14 +16,22 @@ def test_conn(executor): SELECT 1""") @dbtest -def test_table_and_columns_query(executor): +def test_schemata_table_and_columns_query(executor): run(executor, "create table a(x text, y text)") run(executor, "create table b(z text)") + run(executor, "create schema schema1") + run(executor, "create table schema1.c (w text)") + run(executor, "create schema schema2") - tables, columns = executor.tables() - assert tables == ['a', 'b'] - assert columns['a'] == ['x', 'y'] - assert columns['b'] == ['z'] + assert executor.schemata() == ['public', 'schema1', 'schema2'] + assert executor.tables() == [ + ('public', 'a'), ('public', 'b'), ('schema1', 'c')] + + assert executor.columns() == [ + ('public', 'a', 'x'), ('public', 'a', 'y'), + ('public', 'b', 'z'), ('schema1', 'c', 'w')] + + assert executor.search_path() == ['public'] @dbtest def test_database_list(executor): diff --git a/tests/test_pgspecial.py b/tests/test_pgspecial.py index e69de29b..f02efac9 100644 --- a/tests/test_pgspecial.py +++ b/tests/test_pgspecial.py @@ -0,0 +1,19 @@ +from pgcli.packages.sqlcompletion import suggest_type +from test_sqlcompletion import sorted_dicts + +def test_d_suggests_tables_and_schemas(): + suggestions = suggest_type('\d ', '\d ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'schema'}, {'type': 'table', 'schema': []}]) + + suggestions = suggest_type('\d xxx', '\d xxx') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'schema'}, {'type': 'table', 'schema': []}]) + +def test_d_dot_suggests_schema_qualified_tables(): + suggestions = suggest_type('\d myschema.', '\d myschema.') + assert suggestions == [{'type': 'table', 'schema': 'myschema'}] + + suggestions = suggest_type('\d myschema.xxx', '\d myschema.xxx') + assert suggestions == [{'type': 'table', 'schema': 'myschema'}] + diff --git a/tests/test_smart_completion_multiple_schemata.py b/tests/test_smart_completion_multiple_schemata.py new file mode 100644 index 00000000..d36f4c5a --- /dev/null +++ b/tests/test_smart_completion_multiple_schemata.py @@ -0,0 +1,227 @@ +import pytest +from prompt_toolkit.completion import Completion +from prompt_toolkit.document import Document + +metadata = { + 'public': { + 'users': ['id', 'email', 'first_name', 'last_name'], + 'orders': ['id', 'ordered_date', 'status'], + 'select': ['id', 'insert', 'ABC'] + }, + 'custom': { + 'users': ['id', 'phone_number'], + 'products': ['id', 'product_name', 'price'], + 'shipments': ['id', 'address', 'user_id'] + } + } + +@pytest.fixture +def completer(): + + import pgcli.pgcompleter as pgcompleter + comp = pgcompleter.PGCompleter(smart_completion=True) + + schemata, tables, columns = [], [], [] + + for schema, tbls in metadata.items(): + schemata.append(schema) + + for table, cols in tbls.items(): + tables.append((schema, table)) + columns.extend([(schema, table, col) for col in cols]) + + comp.extend_schemata(schemata) + comp.extend_tables(tables) + comp.extend_columns(columns) + comp.set_search_path(['public']) + + return comp + +@pytest.fixture +def complete_event(): + from mock import Mock + return Mock() + +def test_schema_or_visible_table_completion(completer, complete_event): + text = 'SELECT * FROM ' + position = len(text) + result = completer.get_completions( + Document(text=text, cursor_position=position), complete_event) + assert set(result) == set([Completion(text='public', start_position=0), + Completion(text='custom', start_position=0), + Completion(text='users', start_position=0), + Completion(text='"select"', start_position=0), + Completion(text='orders', start_position=0)]) + +def test_suggested_column_names_from_shadowed_visible_table(completer, complete_event): + """ + Suggest column and function names when selecting from table + :param completer: + :param complete_event: + :return: + """ + text = 'SELECT from users' + position = len('SELECT ') + result = set(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert set(result) == set([ + Completi |