diff options
-rw-r--r-- | pgcli/packages/parseutils.py | 12 | ||||
-rw-r--r-- | pgcli/pgcompleter.py | 57 | ||||
-rw-r--r-- | tests/test_parseutils.py | 11 | ||||
-rw-r--r-- | tests/test_smart_completion_multiple_schemata.py | 15 | ||||
-rw-r--r-- | tests/test_smart_completion_public_schema_only.py | 50 | ||||
-rw-r--r-- | tests/test_sqlcompletion.py | 13 |
6 files changed, 117 insertions, 41 deletions
diff --git a/pgcli/packages/parseutils.py b/pgcli/packages/parseutils.py index d491cc2b..cbdd1e3e 100644 --- a/pgcli/packages/parseutils.py +++ b/pgcli/packages/parseutils.py @@ -66,7 +66,9 @@ def last_word(text, include='alphanum_underscore'): TableReference = namedtuple('TableReference', ['schema', 'name', 'alias', 'is_function']) -TableReference.ref = property(lambda self: self.alias or self.name) +TableReference.ref = property(lambda self: self.alias or ( + self.name if self.name.islower() or self.name[0] == '"' + else '"' + self.name + '"')) # This code is borrowed from sqlparse example script. @@ -140,7 +142,10 @@ def extract_table_identifiers(token_stream, allow_functions=True): schema_name = schema_name.lower() quote_count = item.value.count('"') name_quoted = quote_count > 2 or (quote_count and not schema_quoted) - if name and not name_quoted and name != name.lower(): + alias_quoted = alias and item.value[-1] == '"' + if alias_quoted or name_quoted and not alias and name.islower(): + alias = '"' + (alias or name) + '"' + if name and not name_quoted and not name.islower(): if not alias: alias = name name = name.lower() @@ -198,7 +203,8 @@ def extract_tables(sql): identifiers = extract_table_identifiers(stream, allow_functions=not insert_stmt) # In the case 'sche.<cursor>', we get an empty TableReference; remove that - return tuple(i for i in identifiers if i.ref) + return tuple(i for i in identifiers if i.name) + def find_prev_keyword(sql): diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py index 7b2a25a1..11c29882 100644 --- a/pgcli/pgcompleter.py +++ b/pgcli/pgcompleter.py @@ -30,6 +30,7 @@ NamedQueries.instance = NamedQueries.from_config( Match = namedtuple('Match', ['completion', 'priority']) +normalize_ref = lambda ref: ref if ref[0] == '"' else '"' + ref.lower() + '"' class PGCompleter(Completer): keywords = get_literals('keywords') @@ -384,38 +385,50 @@ class PGCompleter(Completer): return self.find_matches(word_before_cursor, flat_cols, meta='column') + def generate_alias(self, tbl, tbls): + if tbl[0] == '"': + aliases = ('"' + tbl[1:-1] + str(i) + '"' for i in itertools.count(2)) + else: + aliases = (self.case(tbl) + str(i) for i in itertools.count(2)) + return (a for a in aliases if normalize_ref(a) not in tbls).next() + def get_join_matches(self, suggestion, word_before_cursor): - scoped_cols = self.populate_scoped_cols(suggestion.tables) + tbls = suggestion.tables + cols = self.populate_scoped_cols(tbls) # Set up some data structures for efficient access - qualified = dict((t.ref, t.schema) for t in suggestion.tables) - tbls = set((t.schema, t.name) for t in scoped_cols.keys()) - tblprio = dict((t.ref, n) for n, t in enumerate(suggestion.tables)) + qualified = dict((normalize_ref(t.ref), t.schema) for t in tbls) + ref_prio = dict((normalize_ref(t.ref), n) for n, t in enumerate(tbls)) + refs = set(normalize_ref(t.ref) for t in tbls) + other_tbls = set((t.schema, t.name) for t in cols.keys()[:-1]) joins, prios = [], [] # Iterate over FKs in existing tables to find potential joins - fks = ((fk, rtbl, rcol) for rtbl, rcols in scoped_cols.items() + fks = ((fk, rtbl, rcol) for rtbl, rcols in cols.items() for rcol in rcols for fk in rcol.foreignkeys) + col = namedtuple('col', 'schema tbl col') for fk, rtbl, rcol in fks: - if (fk.childschema, fk.childtable, fk.childcolumn) == ( - rtbl.schema, rtbl.name, rcol.name): - lsch = fk.parentschema - ltbl = fk.parenttable - lcol = fk.parentcolumn - else: - lsch = fk.childschema - ltbl = fk.childtable - lcol = fk.childcolumn - if suggestion.schema and lsch != suggestion.schema: + right = col(rtbl.schema, rtbl.name, rcol.name) + child = col(fk.childschema, fk.childtable, fk.childcolumn) + parent = col(fk.parentschema, fk.parenttable, fk.parentcolumn) + left = child if parent == right else parent + if suggestion.schema and left.schema != suggestion.schema: continue - rsch, rtbl, rcol = rtbl.schema, rtbl.ref, rcol.name - join = '{0} ON {0}.{1} = {2}.{3}'.format( - self.case(ltbl), self.case(lcol), rtbl, self.case(rcol)) + c = self.case + if normalize_ref(left.tbl) in refs: + lref = self.generate_alias(left.tbl, refs) + join = '{0} {4} ON {4}.{1} = {2}.{3}'.format( + c(left.tbl), c(left.col), rtbl.ref, c(right.col), lref) + else: + join = '{0} ON {0}.{1} = {2}.{3}'.format( + c(left.tbl), c(left.col), rtbl.ref, c(right.col)) # Schema-qualify if (1) new table in same schema as old, and old # is schema-qualified, or (2) new in other schema, except public - if not suggestion.schema and (qualified[rtbl] and lsch == rsch - or lsch not in(rsch, 'public')): - join = lsch + '.' + join + if not suggestion.schema and (qualified[normalize_ref(rtbl.ref)] + and left.schema == right.schema + or left.schema not in(right.schema, 'public')): + join = left.schema + '.' + join joins.append(join) - prios.append(tblprio[rtbl] * 2 + 0 if (lsch, ltbl) in tbls else 1) + prios.append(ref_prio[normalize_ref(rtbl.ref)] * 2 + ( + 0 if (left.schema, left.tbl) in other_tbls else 1)) return self.find_matches(word_before_cursor, joins, meta='join', priority_collection=prios, type_priority=100) diff --git a/tests/test_parseutils.py b/tests/test_parseutils.py index a2a2fd15..f82f271f 100644 --- a/tests/test_parseutils.py +++ b/tests/test_parseutils.py @@ -12,11 +12,18 @@ def test_simple_select_single_table(): @pytest.mark.parametrize('sql', [ - 'select * from abc.def', - 'select * from "abc".def', 'select * from "abc"."def"', 'select * from abc."def"', ]) +def test_simple_select_single_table_schema_qualified_quoted_table(sql): + tables = extract_tables(sql) + assert tables == (('abc', 'def', '"def"', False),) + + +@pytest.mark.parametrize('sql', [ + 'select * from abc.def', + 'select * from "abc".def', +]) def test_simple_select_single_table_schema_qualified(sql): tables = extract_tables(sql) assert tables == (('abc', 'def', None, False),) diff --git a/tests/test_smart_completion_multiple_schemata.py b/tests/test_smart_completion_multiple_schemata.py index 381657de..9f74a7d4 100644 --- a/tests/test_smart_completion_multiple_schemata.py +++ b/tests/test_smart_completion_multiple_schemata.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals import pytest +import itertools from prompt_toolkit.completion import Completion from prompt_toolkit.document import Document from pgcli.packages.function_metadata import FunctionMetadata, ForeignKey @@ -158,19 +159,21 @@ def test_suggested_join_conditions(completer, complete_event, text): Completion(text='shipments.id = users.id', start_position=0, display_meta='name join'), Completion(text='shipments.user_id = users.id', start_position=0, display_meta='fk join')]) -@pytest.mark.parametrize('text', [ - 'SELECT * FROM public.users RIGHT OUTER JOIN ', +@pytest.mark.parametrize(('query', 'tbl'), itertools.product(( + 'SELECT * FROM public.{0} RIGHT OUTER JOIN ', '''SELECT * - FROM users + FROM {0} JOIN ''' -]) -def test_suggested_joins(completer, complete_event, text): +), ('users', '"users"', 'Users'))) +def test_suggested_joins(completer, complete_event, query, tbl): + text = query.format(tbl) position = len(text) result = set(completer.get_completions( Document(text=text, cursor_position=position), complete_event)) + join = 'custom.shipments ON shipments.user_id = {0}.id'.format(tbl) assert set(result) == set([ - Completion(text='custom.shipments ON shipments.user_id = users.id', start_position=0, display_meta='join'), + Completion(text=join, start_position=0, display_meta='join'), Completion(text='public', start_position=0, display_meta='schema'), Completion(text='custom', start_position=0, display_meta='schema'), Completion(text='"Custom"', start_position=0, display_meta='schema'), diff --git a/tests/test_smart_completion_public_schema_only.py b/tests/test_smart_completion_public_schema_only.py index 04b605c9..a097492c 100644 --- a/tests/test_smart_completion_public_schema_only.py +++ b/tests/test_smart_completion_public_schema_only.py @@ -6,7 +6,7 @@ from pgcli.packages.function_metadata import FunctionMetadata, ForeignKey metadata = { 'tables': { - 'users': ['id', 'email', 'first_name', 'last_name'], + 'users': ['id', 'parentid', 'email', 'first_name', 'last_name'], 'Users': ['userid', 'username'], 'orders': ['id', 'ordered_date', 'status', 'email'], 'select': ['id', 'insert', 'ABC']}, @@ -21,6 +21,7 @@ metadata = { ['o', 'o'], '', False, False, True]], 'datatypes': ['custom_type1', 'custom_type2'], 'foreignkeys': [ + ('public', 'users', 'id', 'public', 'users', 'parentid'), ('public', 'users', 'id', 'public', 'Users', 'userid') ], } @@ -170,6 +171,7 @@ def test_suggested_column_names_from_visible_table(completer, complete_event): complete_event)) assert set(result) == set([ Completion(text='id', start_position=0, display_meta='column'), + Completion(text='parentid', start_position=0, display_meta='column'), Completion(text='email', start_position=0, display_meta='column'), Completion(text='first_name', start_position=0, display_meta='column'), Completion(text='last_name', start_position=0, display_meta='column'), @@ -196,6 +198,7 @@ def test_suggested_column_names_in_function(completer, complete_event): complete_event) assert set(result) == set([ Completion(text='id', start_position=0, display_meta='column'), + Completion(text='parentid', start_position=0, display_meta='column'), Completion(text='email', start_position=0, display_meta='column'), Completion(text='first_name', start_position=0, display_meta='column'), Completion(text='last_name', start_position=0, display_meta='column')]) @@ -214,6 +217,7 @@ def test_suggested_column_names_with_table_dot(completer, complete_event): complete_event)) assert set(result) == set([ Completion(text='id', start_position=0, display_meta='column'), + Completion(text='parentid', start_position=0, display_meta='column'), Completion(text='email', start_position=0, display_meta='column'), Completion(text='first_name', start_position=0, display_meta='column'), Completion(text='last_name', start_position=0, display_meta='column')]) @@ -232,6 +236,7 @@ def test_suggested_column_names_with_alias(completer, complete_event): complete_event)) assert set(result) == set([ Completion(text='id', start_position=0, display_meta='column'), + Completion(text='parentid', start_position=0, display_meta='column'), Completion(text='email', start_position=0, display_meta='column'), Completion(text='first_name', start_position=0, display_meta='column'), Completion(text='last_name', start_position=0, display_meta='column')]) @@ -251,6 +256,7 @@ def test_suggested_multiple_column_names(completer, complete_event): complete_event)) assert set(result) == set([ Completion(text='id', start_position=0, display_meta='column'), + Completion(text='parentid', start_position=0, display_meta='column'), Completion(text='email', start_position=0, display_meta='column'), Completion(text='first_name', start_position=0, display_meta='column'), Completion(text='last_name', start_position=0, display_meta='column'), @@ -276,6 +282,7 @@ def test_suggested_multiple_column_names_with_alias(completer, complete_event): complete_event)) assert set(result) == set([ Completion(text='id', start_position=0, display_meta='column'), + Completion(text='parentid', start_position=0, display_meta='column'), Completion(text='email', start_position=0, display_meta='column'), Completion(text='first_name', start_position=0, display_meta='column'), Completion(text='last_name', start_position=0, display_meta='column')]) @@ -295,6 +302,7 @@ def test_suggested_multiple_column_names_with_dot(completer, complete_event): complete_event)) assert set(result) == set([ Completion(text='id', start_position=0, display_meta='column'), + Completion(text='parentid', start_position=0, display_meta='column'), Completion(text='email', start_position=0, display_meta='column'), Completion(text='first_name', start_position=0, display_meta='column'), Completion(text='last_name', start_position=0, display_meta='column')]) @@ -321,7 +329,8 @@ def test_suggest_columns_after_three_way_join(completer, complete_event): 'SELECT * FROM users u FULL OUTER JOIN "Users" u2 ON ', '''SELECT * FROM users u - FULL OUTER JOIN "Users" u2 ON ''' + FULL OUTER JOIN "Users" u2 ON + ''' ]) def test_suggested_join_conditions(completer, complete_event, text): position = len(text) @@ -346,6 +355,32 @@ def test_suggested_joins(completer, complete_event, text): complete_event)) assert set(result) == set([ Completion(text='"Users" ON "Users".userid = users.id', start_position=0, display_meta='join'), + Completion(text='users users2 ON users2.id = users.parentid', start_position=0, display_meta='join'), + Completion(text='users users2 ON users2.parentid = users.id', start_position=0, display_meta='join'), + Completion(text='public', start_position=0, display_meta='schema'), + Completion(text='"Users"', start_position=0, display_meta='table'), + Completion(text='"select"', start_position=0, display_meta='table'), + Completion(text='orders', start_position=0, display_meta='table'), + Completion(text='users', start_position=0, display_meta='table'), + Completion(text='user_emails', start_position=0, display_meta='view'), + Completion(text='custom_func2', start_position=0, display_meta='function'), + Completion(text='set_returning_func', start_position=0, display_meta='function'), + Completion(text='custom_func1', start_position=0, display_meta='function')]) + +@pytest.mark.parametrize('text', [ + 'SELECT * FROM public."Users" JOIN ', + 'SELECT * FROM public."Users" RIGHT OUTER JOIN ', + '''SELECT * + FROM public."Users" + LEFT JOIN ''' +]) +def test_suggested_joins_quoted_schema_qualified_table(completer, complete_event, text): + position = len(text) + result = set(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert set(result) == set([ + Completion(text='public.users ON users.id = "Users".userid', start_position=0, display_meta='join'), Completion(text='public', start_position=0, display_meta='schema'), Completion(text='"Users"', start_position=0, display_meta='table'), Completion(text='"select"', start_position=0, display_meta='table'), @@ -642,7 +677,7 @@ def test_wildcard_column_expansion(completer, complete_event): completions = completer.get_completions( Document(text=sql, cursor_position=pos), complete_event) - col_list = 'id, email, first_name, last_name' + col_list = 'id, parentid, email, first_name, last_name' expected = [Completion(text=col_list, start_position=-1, display='*', display_meta='columns')] @@ -656,7 +691,7 @@ def test_wildcard_column_expansion_with_alias_qualifier(completer, complete_even completions = completer.get_completions( Document(text=sql, cursor_position=pos), complete_event) - col_list = 'id, u.email, u.first_name, u.last_name' + col_list = 'id, u.parentid, u.email, u.first_name, u.last_name' expected = [Completion(text=col_list, start_position=-1, display='*', display_meta='columns')] @@ -664,9 +699,9 @@ def test_wildcard_column_expansion_with_alias_qualifier(completer, complete_even @pytest.mark.parametrize('text,expected', [ ('SELECT users.* FROM users', - 'id, users.email, users.first_name, users.last_name'), + 'id, users.parentid, users.email, users.first_name, users.last_name'), ('SELECT Users.* FROM Users', - 'id, Users.email, Users.first_name, Users.last_name'), + 'id, Users.parentid, Users.email, Users.first_name, Users.last_name'), ]) def test_wildcard_column_expansion_with_table_qualifier(completer, complete_event, text, expected): pos = len('SELECT users.*') @@ -687,7 +722,7 @@ def test_wildcard_column_expansion_with_two_tables(completer, complete_event): Document(text=sql, cursor_position=pos), complete_event) cols = ('"select".id, "select"."insert", "select"."ABC", ' - 'u.id, u.email, u.first_name, u.last_name') + 'u.id, u.parentid, u.email, u.first_name, u.last_name') expected = [Completion(text=cols, start_position=-1, display='*', display_meta='columns')] assert completions == expected @@ -718,6 +753,7 @@ def test_suggest_columns_from_unquoted_table(completer, complete_event, text): complete_event) assert set(result) == set([ Completion(text='id', start_position=0, display_meta='column'), + Completion(text='parentid', start_position=0, display_meta='column'), Completion(text='email', start_position=0, display_meta='column'), Completion(text='first_name', start_position=0, display_meta='column'), Completion(text='last_name', start_position=0, display_meta='column')]) diff --git a/tests/test_sqlcompletion.py b/tests/test_sqlcompletion.py index de95936b..a7f5330d 100644 --- a/tests/test_sqlcompletion.py +++ b/tests/test_sqlcompletion.py @@ -23,8 +23,19 @@ def test_select_suggests_cols_with_qualified_table_scope(): @pytest.mark.parametrize('expression', [ - 'SELECT * FROM tabl WHERE ', 'SELECT * FROM "tabl" WHERE ', +]) +def test_where_suggests_columns_functions_quoted_table(expression): + suggestions = suggest_type(expression, expression) + assert set(suggestions) == set([ + Column(tables=((None, 'tabl', '"tabl"', False),)), + Function(schema=None), + Keyword(), + ]) + + +@pytest.mark.parametrize('expression', [ + 'SELECT * FROM tabl WHERE ', 'SELECT * FROM tabl WHERE (', 'SELECT * FROM tabl WHERE foo = ', 'SELECT * FROM tabl WHERE bar OR ', |