summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorkoljonen <koljonen@outlook.com>2016-06-09 23:38:33 +0200
committerkoljonen <koljonen@outlook.com>2016-06-16 19:27:57 +0200
commit5b20e107b8611fc8bddbebf8459af50ada759806 (patch)
tree0a5e1c698041cb1890c0d2639564975a9a717d60
parent9e98896bb3557b9a3cd21b4d8369d5427b66b770 (diff)
Fix some join-condition issues
When self-joining a table with an FK to or from some other table, we got a false FK-join suggestion for that column. There was also a problem with quoted tables not being quoted in the join condition. And there were a couple of problems when trying to join a non-existent table or using a non-existent qualifier (`SELECT * FROM Foo JOIN Bar ON Meow.`). I also rewrote get_join_condition_matches a bit in the process, hopefully making it a bit simpler.
-rw-r--r--pgcli/packages/sqlcompletion.py12
-rw-r--r--pgcli/pgcompleter.py80
-rw-r--r--tests/test_smart_completion_multiple_schemata.py9
-rw-r--r--tests/test_smart_completion_public_schema_only.py46
4 files changed, 97 insertions, 50 deletions
diff --git a/pgcli/packages/sqlcompletion.py b/pgcli/packages/sqlcompletion.py
index 94c88b41..917086e5 100644
--- a/pgcli/packages/sqlcompletion.py
+++ b/pgcli/packages/sqlcompletion.py
@@ -315,7 +315,7 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text,
suggest.append(Function(schema=schema, filter='for_from_clause'))
if (token_v.endswith('join') and token.is_keyword
- and _allow_join_suggestion(parsed_statement)):
+ and _allow_join(parsed_statement)):
tables = extract_tables(text_before_cursor)
suggest.append(Join(tables=tables, schema=schema))
@@ -346,15 +346,15 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text,
View(schema=parent),
Function(schema=parent)]
last_token = parsed_statement
- if _allow_join_condition_suggestion(parsed_statement):
+ if filteredtables and _allow_join_condition(parsed_statement):
sugs.append(JoinCondition(tables=tables,
parent=filteredtables[-1]))
return tuple(sugs)
else:
# ON <suggestion>
# Use table alias if there is one, otherwise the table name
- aliases = tuple(t.alias or t.name for t in tables)
- if _allow_join_condition_suggestion(parsed_statement):
+ aliases = tuple(t.ref for t in tables)
+ if _allow_join_condition(parsed_statement):
return (Alias(aliases=aliases), JoinCondition(
tables=tables, parent=None))
else:
@@ -397,7 +397,7 @@ def identifies(id, ref):
ref.schema and (id == ref.schema + '.' + ref.name))
-def _allow_join_condition_suggestion(statement):
+def _allow_join_condition(statement):
"""
Tests if a join condition should be suggested
@@ -417,7 +417,7 @@ def _allow_join_condition_suggestion(statement):
return last_tok.value.lower() in ('on', 'and', 'or')
-def _allow_join_suggestion(statement):
+def _allow_join(statement):
"""
Tests if a join should be suggested
diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py
index 0ae6753e..2c705db9 100644
--- a/pgcli/pgcompleter.py
+++ b/pgcli/pgcompleter.py
@@ -439,56 +439,56 @@ class PGCompleter(Completer):
priority_collection=prios, type_priority=100)
def get_join_condition_matches(self, suggestion, word_before_cursor):
- lefttable = suggestion.parent or suggestion.tables[-1]
- scoped_cols = self.populate_scoped_cols(suggestion.tables)
+ col = namedtuple('col', 'schema tbl col')
+ tbls = self.populate_scoped_cols(suggestion.tables).items
+ cols = [(t, c) for t, cs in tbls() for c in cs]
+ try:
+ lref = (suggestion.parent or suggestion.tables[-1]).ref
+ ltbl, lcols = [(t, cs) for (t, cs) in tbls() if t.ref == lref][-1]
+ except IndexError: # The user typed an incorrect table qualifier
+ return []
+ conds, found_conds = [], set()
- def make_cond(tbl1, tbl2, col1, col2):
- prefix = '' if suggestion.parent else tbl1 + '.'
+ def add_cond(lcol, rcol, rref, meta, prio):
+ prefix = '' if suggestion.parent else ltbl.ref + '.'
case = self.case
- return prefix + case(col1) + ' = ' + tbl2 + '.' + case(col2)
+ cond = prefix + case(lcol) + ' = ' + rref + '.' + case(rcol)
+ if cond not in found_conds:
+ found_conds.add(cond)
+ conds.append((cond, meta, prio + ref_prio[rref]))
+
+ def list_dict(pairs): # Turns [(a, b), (a, c)] into {a: [b, c]}
+ d = defaultdict(list)
+ for pair in pairs:
+ d[pair[0]].append(pair[1])
+ return d
# Tables that are closer to the cursor get higher prio
- refprio = dict((tbl.ref, num) for num, tbl
+ ref_prio = dict((tbl.ref, num) for num, tbl
in enumerate(suggestion.tables))
- # Map (schema, tablename) to tables and ref to columns
- tbldict = defaultdict(list)
- for t in scoped_cols.keys():
- tbldict[(t.schema, t.name)].append(t)
- refcols = dict((t.ref, cs) for t, cs in scoped_cols.items())
+ # Map (schema, table, col) to tables
+ coldict = list_dict(((t.schema, t.name, c.name), t)
+ for t, c in cols if t.ref != lref)
# For each fk from the left table, generate a join condition if
# the other table is also in the scope
- conds = []
- for lcol in refcols.get(lefttable.ref, []):
- for fk in lcol.foreignkeys:
- for rcol in ((fk.parentschema, fk.parenttable,
- fk.parentcolumn), (fk.childschema, fk.childtable,
- fk.childcolumn)):
- for rtbl in tbldict[(rcol[0], rcol[1])]:
- if rtbl and rtbl.ref != lefttable.ref:
- cond = make_cond(lefttable.ref, rtbl.ref,
- lcol.name, rcol[2])
- prio = 2000 + refprio[rtbl.ref]
- conds.append((cond, 'fk join', prio))
+ fks = ((fk, lcol.name) for lcol in lcols for fk in lcol.foreignkeys)
+ for fk, lcol in fks:
+ left = col(ltbl.schema, ltbl.name, lcol)
+ child = col(fk.childschema, fk.childtable, fk.childcolumn)
+ par = col(fk.parentschema, fk.parenttable, fk.parentcolumn)
+ left, right = (child, par) if left == child else (par, child)
+ for rtbl in coldict[right]:
+ add_cond(left.col, right.col, rtbl.ref, 'fk join', 2000)
# For name matching, use a {(colname, coltype): TableReference} dict
- col_table = defaultdict(lambda: [])
- for tbl, col in ((t, c) for t, cs in scoped_cols.items() for c in cs):
- col_table[(col.name, col.datatype)].append(tbl)
+ coltyp = namedtuple('coltyp', 'name datatype')
+ col_table = list_dict((coltyp(c.name, c.datatype), t) for t, c in cols)
# Find all name-match join conditions
- found = set(cnd[0] for cnd in conds)
- for c in refcols.get(lefttable.ref, []):
- for rtbl in col_table[(c.name, c.datatype)]:
- if rtbl.ref != lefttable.ref:
- cond = make_cond(lefttable.ref, rtbl.ref, c.name, c.name)
- if cond not in found:
- prio = (1000 if c.datatype and c.datatype in(
- 'integer', 'bigint', 'smallint')
- else 0 + refprio[rtbl.ref])
- conds.append((cond, 'name join', prio))
-
- if not conds:
- return []
+ for c in (coltyp(c.name, c.datatype) for c in lcols):
+ for rtbl in (t for t in col_table[c] if t.ref != ltbl.ref):
+ add_cond(c.name, c.name, rtbl.ref, 'name join', 1000
+ if c.datatype in ('integer', 'bigint', 'smallint') else 0)
- conds, metas, prios = zip(*conds)
+ conds, metas, prios = zip(*conds) if conds else ([], [], [])
return self.find_matches(word_before_cursor, conds,
meta_collection=metas, type_priority=100, priority_collection=prios)
diff --git a/tests/test_smart_completion_multiple_schemata.py b/tests/test_smart_completion_multiple_schemata.py
index 9f74a7d4..b5ac0d3a 100644
--- a/tests/test_smart_completion_multiple_schemata.py
+++ b/tests/test_smart_completion_multiple_schemata.py
@@ -100,17 +100,18 @@ def test_schema_or_visible_table_completion(completer, complete_event):
Completion(text='orders', start_position=0, display_meta='table')])
-@pytest.mark.parametrize('text', [
- 'SELECT FROM users',
- 'SELECT FROM "users"',
+@pytest.mark.parametrize('table', [
+ 'users',
+ '"users"',
])
-def test_suggested_column_names_from_shadowed_visible_table(completer, complete_event, text):
+def test_suggested_column_names_from_shadowed_visible_table(completer, complete_event, table):
"""
Suggest column and function names when selecting from table
:param completer:
:param complete_event:
:return:
"""
+ text = 'SELECT FROM ' + table
position = len('SELECT ')
result = set(completer.get_completions(
Document(text=text, cursor_position=position),
diff --git a/tests/test_smart_completion_public_schema_only.py b/tests/test_smart_completion_public_schema_only.py
index 3226484f..01c54fe5 100644
--- a/tests/test_smart_completion_public_schema_only.py
+++ b/tests/test_smart_completion_public_schema_only.py
@@ -343,6 +343,52 @@ def test_suggested_join_conditions(completer, complete_event, text):
Completion(text='u2.userid = u.id', start_position=0, display_meta='fk join')])
@pytest.mark.parametrize('text', [
+ '''SELECT *
+ FROM users
+ CROSS JOIN "Users"
+ NATURAL JOIN users u
+ JOIN "Users" u2 ON
+ '''
+])
+def test_suggested_join_conditions_with_same_table_twice(completer, complete_event, text):
+ position = len(text)
+ result = completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event)
+ assert result == [
+ Completion(text='u2.userid = u.id', start_position=0, display_meta='fk join'),
+ Completion(text='u2.userid = users.id', start_position=0, display_meta='fk join'),
+ Completion(text='u2.userid = "Users".userid', start_position=0, display_meta='name join'),
+ Completion(text='u2.username = "Users".username', start_position=0, display_meta='name join'),
+ Completion(text='"Users"', start_position=0, display_meta='table alias'),
+ Completion(text='u', start_position=0, display_meta='table alias'),
+ Completion(text='u2', start_position=0, display_meta='table alias'),
+ Completion(text='users', start_position=0, display_meta='table alias')]
+
+@pytest.mark.parametrize('text', [
+ 'SELECT * FROM users JOIN users u2 on foo.'
+])
+def test_suggested_join_conditions_with_invalid_qualifier(completer, complete_event, text):
+ position = len(text)
+ result = set(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert set(result) == set()
+
+@pytest.mark.parametrize(('text', 'ref'), [
+ ('SELECT * FROM users JOIN NonTable on ', 'NonTable'),
+ ('SELECT * FROM users JOIN nontable nt on ', 'nt')
+])
+def test_suggested_join_conditions_with_invalid_table(completer, complete_event, text, ref):
+ position = len(text)
+ result = set(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert set(result) == set([
+ Completion(text='users', start_position=0, display_meta='table alias'),
+ Completion(text=ref, start_position=0, display_meta='table alias')])
+
+@pytest.mark.parametrize('text', [
'SELECT * FROM "Users" u JOIN u',
'SELECT * FROM "Users" u JOIN uid',
'SELECT * FROM "Users" u JOIN userid',