summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--pgcli/packages/parseutils.py12
-rw-r--r--pgcli/pgcompleter.py57
-rw-r--r--tests/test_parseutils.py11
-rw-r--r--tests/test_smart_completion_multiple_schemata.py15
-rw-r--r--tests/test_smart_completion_public_schema_only.py50
-rw-r--r--tests/test_sqlcompletion.py13
6 files changed, 117 insertions, 41 deletions
diff --git a/pgcli/packages/parseutils.py b/pgcli/packages/parseutils.py
index d491cc2b..cbdd1e3e 100644
--- a/pgcli/packages/parseutils.py
+++ b/pgcli/packages/parseutils.py
@@ -66,7 +66,9 @@ def last_word(text, include='alphanum_underscore'):
TableReference = namedtuple('TableReference', ['schema', 'name', 'alias',
'is_function'])
-TableReference.ref = property(lambda self: self.alias or self.name)
+TableReference.ref = property(lambda self: self.alias or (
+ self.name if self.name.islower() or self.name[0] == '"'
+ else '"' + self.name + '"'))
# This code is borrowed from sqlparse example script.
@@ -140,7 +142,10 @@ def extract_table_identifiers(token_stream, allow_functions=True):
schema_name = schema_name.lower()
quote_count = item.value.count('"')
name_quoted = quote_count > 2 or (quote_count and not schema_quoted)
- if name and not name_quoted and name != name.lower():
+ alias_quoted = alias and item.value[-1] == '"'
+ if alias_quoted or name_quoted and not alias and name.islower():
+ alias = '"' + (alias or name) + '"'
+ if name and not name_quoted and not name.islower():
if not alias:
alias = name
name = name.lower()
@@ -198,7 +203,8 @@ def extract_tables(sql):
identifiers = extract_table_identifiers(stream,
allow_functions=not insert_stmt)
# In the case 'sche.<cursor>', we get an empty TableReference; remove that
- return tuple(i for i in identifiers if i.ref)
+ return tuple(i for i in identifiers if i.name)
+
def find_prev_keyword(sql):
diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py
index 7b2a25a1..11c29882 100644
--- a/pgcli/pgcompleter.py
+++ b/pgcli/pgcompleter.py
@@ -30,6 +30,7 @@ NamedQueries.instance = NamedQueries.from_config(
Match = namedtuple('Match', ['completion', 'priority'])
+normalize_ref = lambda ref: ref if ref[0] == '"' else '"' + ref.lower() + '"'
class PGCompleter(Completer):
keywords = get_literals('keywords')
@@ -384,38 +385,50 @@ class PGCompleter(Completer):
return self.find_matches(word_before_cursor, flat_cols, meta='column')
+ def generate_alias(self, tbl, tbls):
+ if tbl[0] == '"':
+ aliases = ('"' + tbl[1:-1] + str(i) + '"' for i in itertools.count(2))
+ else:
+ aliases = (self.case(tbl) + str(i) for i in itertools.count(2))
+ return (a for a in aliases if normalize_ref(a) not in tbls).next()
+
def get_join_matches(self, suggestion, word_before_cursor):
- scoped_cols = self.populate_scoped_cols(suggestion.tables)
+ tbls = suggestion.tables
+ cols = self.populate_scoped_cols(tbls)
# Set up some data structures for efficient access
- qualified = dict((t.ref, t.schema) for t in suggestion.tables)
- tbls = set((t.schema, t.name) for t in scoped_cols.keys())
- tblprio = dict((t.ref, n) for n, t in enumerate(suggestion.tables))
+ qualified = dict((normalize_ref(t.ref), t.schema) for t in tbls)
+ ref_prio = dict((normalize_ref(t.ref), n) for n, t in enumerate(tbls))
+ refs = set(normalize_ref(t.ref) for t in tbls)
+ other_tbls = set((t.schema, t.name) for t in cols.keys()[:-1])
joins, prios = [], []
# Iterate over FKs in existing tables to find potential joins
- fks = ((fk, rtbl, rcol) for rtbl, rcols in scoped_cols.items()
+ fks = ((fk, rtbl, rcol) for rtbl, rcols in cols.items()
for rcol in rcols for fk in rcol.foreignkeys)
+ col = namedtuple('col', 'schema tbl col')
for fk, rtbl, rcol in fks:
- if (fk.childschema, fk.childtable, fk.childcolumn) == (
- rtbl.schema, rtbl.name, rcol.name):
- lsch = fk.parentschema
- ltbl = fk.parenttable
- lcol = fk.parentcolumn
- else:
- lsch = fk.childschema
- ltbl = fk.childtable
- lcol = fk.childcolumn
- if suggestion.schema and lsch != suggestion.schema:
+ right = col(rtbl.schema, rtbl.name, rcol.name)
+ child = col(fk.childschema, fk.childtable, fk.childcolumn)
+ parent = col(fk.parentschema, fk.parenttable, fk.parentcolumn)
+ left = child if parent == right else parent
+ if suggestion.schema and left.schema != suggestion.schema:
continue
- rsch, rtbl, rcol = rtbl.schema, rtbl.ref, rcol.name
- join = '{0} ON {0}.{1} = {2}.{3}'.format(
- self.case(ltbl), self.case(lcol), rtbl, self.case(rcol))
+ c = self.case
+ if normalize_ref(left.tbl) in refs:
+ lref = self.generate_alias(left.tbl, refs)
+ join = '{0} {4} ON {4}.{1} = {2}.{3}'.format(
+ c(left.tbl), c(left.col), rtbl.ref, c(right.col), lref)
+ else:
+ join = '{0} ON {0}.{1} = {2}.{3}'.format(
+ c(left.tbl), c(left.col), rtbl.ref, c(right.col))
# Schema-qualify if (1) new table in same schema as old, and old
# is schema-qualified, or (2) new in other schema, except public
- if not suggestion.schema and (qualified[rtbl] and lsch == rsch
- or lsch not in(rsch, 'public')):
- join = lsch + '.' + join
+ if not suggestion.schema and (qualified[normalize_ref(rtbl.ref)]
+ and left.schema == right.schema
+ or left.schema not in(right.schema, 'public')):
+ join = left.schema + '.' + join
joins.append(join)
- prios.append(tblprio[rtbl] * 2 + 0 if (lsch, ltbl) in tbls else 1)
+ prios.append(ref_prio[normalize_ref(rtbl.ref)] * 2 + (
+ 0 if (left.schema, left.tbl) in other_tbls else 1))
return self.find_matches(word_before_cursor, joins, meta='join',
priority_collection=prios, type_priority=100)
diff --git a/tests/test_parseutils.py b/tests/test_parseutils.py
index a2a2fd15..f82f271f 100644
--- a/tests/test_parseutils.py
+++ b/tests/test_parseutils.py
@@ -12,11 +12,18 @@ def test_simple_select_single_table():
@pytest.mark.parametrize('sql', [
- 'select * from abc.def',
- 'select * from "abc".def',
'select * from "abc"."def"',
'select * from abc."def"',
])
+def test_simple_select_single_table_schema_qualified_quoted_table(sql):
+ tables = extract_tables(sql)
+ assert tables == (('abc', 'def', '"def"', False),)
+
+
+@pytest.mark.parametrize('sql', [
+ 'select * from abc.def',
+ 'select * from "abc".def',
+])
def test_simple_select_single_table_schema_qualified(sql):
tables = extract_tables(sql)
assert tables == (('abc', 'def', None, False),)
diff --git a/tests/test_smart_completion_multiple_schemata.py b/tests/test_smart_completion_multiple_schemata.py
index 381657de..9f74a7d4 100644
--- a/tests/test_smart_completion_multiple_schemata.py
+++ b/tests/test_smart_completion_multiple_schemata.py
@@ -1,5 +1,6 @@
from __future__ import unicode_literals
import pytest
+import itertools
from prompt_toolkit.completion import Completion
from prompt_toolkit.document import Document
from pgcli.packages.function_metadata import FunctionMetadata, ForeignKey
@@ -158,19 +159,21 @@ def test_suggested_join_conditions(completer, complete_event, text):
Completion(text='shipments.id = users.id', start_position=0, display_meta='name join'),
Completion(text='shipments.user_id = users.id', start_position=0, display_meta='fk join')])
-@pytest.mark.parametrize('text', [
- 'SELECT * FROM public.users RIGHT OUTER JOIN ',
+@pytest.mark.parametrize(('query', 'tbl'), itertools.product((
+ 'SELECT * FROM public.{0} RIGHT OUTER JOIN ',
'''SELECT *
- FROM users
+ FROM {0}
JOIN '''
-])
-def test_suggested_joins(completer, complete_event, text):
+), ('users', '"users"', 'Users')))
+def test_suggested_joins(completer, complete_event, query, tbl):
+ text = query.format(tbl)
position = len(text)
result = set(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
+ join = 'custom.shipments ON shipments.user_id = {0}.id'.format(tbl)
assert set(result) == set([
- Completion(text='custom.shipments ON shipments.user_id = users.id', start_position=0, display_meta='join'),
+ Completion(text=join, start_position=0, display_meta='join'),
Completion(text='public', start_position=0, display_meta='schema'),
Completion(text='custom', start_position=0, display_meta='schema'),
Completion(text='"Custom"', start_position=0, display_meta='schema'),
diff --git a/tests/test_smart_completion_public_schema_only.py b/tests/test_smart_completion_public_schema_only.py
index 04b605c9..a097492c 100644
--- a/tests/test_smart_completion_public_schema_only.py
+++ b/tests/test_smart_completion_public_schema_only.py
@@ -6,7 +6,7 @@ from pgcli.packages.function_metadata import FunctionMetadata, ForeignKey
metadata = {
'tables': {
- 'users': ['id', 'email', 'first_name', 'last_name'],
+ 'users': ['id', 'parentid', 'email', 'first_name', 'last_name'],
'Users': ['userid', 'username'],
'orders': ['id', 'ordered_date', 'status', 'email'],
'select': ['id', 'insert', 'ABC']},
@@ -21,6 +21,7 @@ metadata = {
['o', 'o'], '', False, False, True]],
'datatypes': ['custom_type1', 'custom_type2'],
'foreignkeys': [
+ ('public', 'users', 'id', 'public', 'users', 'parentid'),
('public', 'users', 'id', 'public', 'Users', 'userid')
],
}
@@ -170,6 +171,7 @@ def test_suggested_column_names_from_visible_table(completer, complete_event):
complete_event))
assert set(result) == set([
Completion(text='id', start_position=0, display_meta='column'),
+ Completion(text='parentid', start_position=0, display_meta='column'),
Completion(text='email', start_position=0, display_meta='column'),
Completion(text='first_name', start_position=0, display_meta='column'),
Completion(text='last_name', start_position=0, display_meta='column'),
@@ -196,6 +198,7 @@ def test_suggested_column_names_in_function(completer, complete_event):
complete_event)
assert set(result) == set([
Completion(text='id', start_position=0, display_meta='column'),
+ Completion(text='parentid', start_position=0, display_meta='column'),
Completion(text='email', start_position=0, display_meta='column'),
Completion(text='first_name', start_position=0, display_meta='column'),
Completion(text='last_name', start_position=0, display_meta='column')])
@@ -214,6 +217,7 @@ def test_suggested_column_names_with_table_dot(completer, complete_event):
complete_event))
assert set(result) == set([
Completion(text='id', start_position=0, display_meta='column'),
+ Completion(text='parentid', start_position=0, display_meta='column'),
Completion(text='email', start_position=0, display_meta='column'),
Completion(text='first_name', start_position=0, display_meta='column'),
Completion(text='last_name', start_position=0, display_meta='column')])
@@ -232,6 +236,7 @@ def test_suggested_column_names_with_alias(completer, complete_event):
complete_event))
assert set(result) == set([
Completion(text='id', start_position=0, display_meta='column'),
+ Completion(text='parentid', start_position=0, display_meta='column'),
Completion(text='email', start_position=0, display_meta='column'),
Completion(text='first_name', start_position=0, display_meta='column'),
Completion(text='last_name', start_position=0, display_meta='column')])
@@ -251,6 +256,7 @@ def test_suggested_multiple_column_names(completer, complete_event):
complete_event))
assert set(result) == set([
Completion(text='id', start_position=0, display_meta='column'),
+ Completion(text='parentid', start_position=0, display_meta='column'),
Completion(text='email', start_position=0, display_meta='column'),
Completion(text='first_name', start_position=0, display_meta='column'),
Completion(text='last_name', start_position=0, display_meta='column'),
@@ -276,6 +282,7 @@ def test_suggested_multiple_column_names_with_alias(completer, complete_event):
complete_event))
assert set(result) == set([
Completion(text='id', start_position=0, display_meta='column'),
+ Completion(text='parentid', start_position=0, display_meta='column'),
Completion(text='email', start_position=0, display_meta='column'),
Completion(text='first_name', start_position=0, display_meta='column'),
Completion(text='last_name', start_position=0, display_meta='column')])
@@ -295,6 +302,7 @@ def test_suggested_multiple_column_names_with_dot(completer, complete_event):
complete_event))
assert set(result) == set([
Completion(text='id', start_position=0, display_meta='column'),
+ Completion(text='parentid', start_position=0, display_meta='column'),
Completion(text='email', start_position=0, display_meta='column'),
Completion(text='first_name', start_position=0, display_meta='column'),
Completion(text='last_name', start_position=0, display_meta='column')])
@@ -321,7 +329,8 @@ def test_suggest_columns_after_three_way_join(completer, complete_event):
'SELECT * FROM users u FULL OUTER JOIN "Users" u2 ON ',
'''SELECT *
FROM users u
- FULL OUTER JOIN "Users" u2 ON '''
+ FULL OUTER JOIN "Users" u2 ON
+ '''
])
def test_suggested_join_conditions(completer, complete_event, text):
position = len(text)
@@ -346,6 +355,32 @@ def test_suggested_joins(completer, complete_event, text):
complete_event))
assert set(result) == set([
Completion(text='"Users" ON "Users".userid = users.id', start_position=0, display_meta='join'),
+ Completion(text='users users2 ON users2.id = users.parentid', start_position=0, display_meta='join'),
+ Completion(text='users users2 ON users2.parentid = users.id', start_position=0, display_meta='join'),
+ Completion(text='public', start_position=0, display_meta='schema'),
+ Completion(text='"Users"', start_position=0, display_meta='table'),
+ Completion(text='"select"', start_position=0, display_meta='table'),
+ Completion(text='orders', start_position=0, display_meta='table'),
+ Completion(text='users', start_position=0, display_meta='table'),
+ Completion(text='user_emails', start_position=0, display_meta='view'),
+ Completion(text='custom_func2', start_position=0, display_meta='function'),
+ Completion(text='set_returning_func', start_position=0, display_meta='function'),
+ Completion(text='custom_func1', start_position=0, display_meta='function')])
+
+@pytest.mark.parametrize('text', [
+ 'SELECT * FROM public."Users" JOIN ',
+ 'SELECT * FROM public."Users" RIGHT OUTER JOIN ',
+ '''SELECT *
+ FROM public."Users"
+ LEFT JOIN '''
+])
+def test_suggested_joins_quoted_schema_qualified_table(completer, complete_event, text):
+ position = len(text)
+ result = set(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert set(result) == set([
+ Completion(text='public.users ON users.id = "Users".userid', start_position=0, display_meta='join'),
Completion(text='public', start_position=0, display_meta='schema'),
Completion(text='"Users"', start_position=0, display_meta='table'),
Completion(text='"select"', start_position=0, display_meta='table'),
@@ -642,7 +677,7 @@ def test_wildcard_column_expansion(completer, complete_event):
completions = completer.get_completions(
Document(text=sql, cursor_position=pos), complete_event)
- col_list = 'id, email, first_name, last_name'
+ col_list = 'id, parentid, email, first_name, last_name'
expected = [Completion(text=col_list, start_position=-1,
display='*', display_meta='columns')]
@@ -656,7 +691,7 @@ def test_wildcard_column_expansion_with_alias_qualifier(completer, complete_even
completions = completer.get_completions(
Document(text=sql, cursor_position=pos), complete_event)
- col_list = 'id, u.email, u.first_name, u.last_name'
+ col_list = 'id, u.parentid, u.email, u.first_name, u.last_name'
expected = [Completion(text=col_list, start_position=-1,
display='*', display_meta='columns')]
@@ -664,9 +699,9 @@ def test_wildcard_column_expansion_with_alias_qualifier(completer, complete_even
@pytest.mark.parametrize('text,expected', [
('SELECT users.* FROM users',
- 'id, users.email, users.first_name, users.last_name'),
+ 'id, users.parentid, users.email, users.first_name, users.last_name'),
('SELECT Users.* FROM Users',
- 'id, Users.email, Users.first_name, Users.last_name'),
+ 'id, Users.parentid, Users.email, Users.first_name, Users.last_name'),
])
def test_wildcard_column_expansion_with_table_qualifier(completer, complete_event, text, expected):
pos = len('SELECT users.*')
@@ -687,7 +722,7 @@ def test_wildcard_column_expansion_with_two_tables(completer, complete_event):
Document(text=sql, cursor_position=pos), complete_event)
cols = ('"select".id, "select"."insert", "select"."ABC", '
- 'u.id, u.email, u.first_name, u.last_name')
+ 'u.id, u.parentid, u.email, u.first_name, u.last_name')
expected = [Completion(text=cols, start_position=-1,
display='*', display_meta='columns')]
assert completions == expected
@@ -718,6 +753,7 @@ def test_suggest_columns_from_unquoted_table(completer, complete_event, text):
complete_event)
assert set(result) == set([
Completion(text='id', start_position=0, display_meta='column'),
+ Completion(text='parentid', start_position=0, display_meta='column'),
Completion(text='email', start_position=0, display_meta='column'),
Completion(text='first_name', start_position=0, display_meta='column'),
Completion(text='last_name', start_position=0, display_meta='column')])
diff --git a/tests/test_sqlcompletion.py b/tests/test_sqlcompletion.py
index de95936b..a7f5330d 100644
--- a/tests/test_sqlcompletion.py
+++ b/tests/test_sqlcompletion.py
@@ -23,8 +23,19 @@ def test_select_suggests_cols_with_qualified_table_scope():
@pytest.mark.parametrize('expression', [
- 'SELECT * FROM tabl WHERE ',
'SELECT * FROM "tabl" WHERE ',
+])
+def test_where_suggests_columns_functions_quoted_table(expression):
+ suggestions = suggest_type(expression, expression)
+ assert set(suggestions) == set([
+ Column(tables=((None, 'tabl', '"tabl"', False),)),
+ Function(schema=None),
+ Keyword(),
+ ])
+
+
+@pytest.mark.parametrize('expression', [
+ 'SELECT * FROM tabl WHERE ',
'SELECT * FROM tabl WHERE (',
'SELECT * FROM tabl WHERE foo = ',
'SELECT * FROM tabl WHERE bar OR ',