diff options
author | koljonen <koljonen@outlook.com> | 2016-06-09 23:38:33 +0200 |
---|---|---|
committer | koljonen <koljonen@outlook.com> | 2016-06-16 19:27:57 +0200 |
commit | 5b20e107b8611fc8bddbebf8459af50ada759806 (patch) | |
tree | 0a5e1c698041cb1890c0d2639564975a9a717d60 | |
parent | 9e98896bb3557b9a3cd21b4d8369d5427b66b770 (diff) |
Fix some join-condition issues
When self-joining a table with an FK to or from some other table, we got a false FK-join suggestion for that column.
There was also a problem with quoted tables not being quoted in the join condition.
And there were a couple of problems when trying to join a non-existent table or using a non-existent qualifier (`SELECT * FROM Foo JOIN Bar ON Meow.`).
I also rewrote get_join_condition_matches a bit in the process, hopefully making it a bit simpler.
-rw-r--r-- | pgcli/packages/sqlcompletion.py | 12 | ||||
-rw-r--r-- | pgcli/pgcompleter.py | 80 | ||||
-rw-r--r-- | tests/test_smart_completion_multiple_schemata.py | 9 | ||||
-rw-r--r-- | tests/test_smart_completion_public_schema_only.py | 46 |
4 files changed, 97 insertions, 50 deletions
diff --git a/pgcli/packages/sqlcompletion.py b/pgcli/packages/sqlcompletion.py index 94c88b41..917086e5 100644 --- a/pgcli/packages/sqlcompletion.py +++ b/pgcli/packages/sqlcompletion.py @@ -315,7 +315,7 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, suggest.append(Function(schema=schema, filter='for_from_clause')) if (token_v.endswith('join') and token.is_keyword - and _allow_join_suggestion(parsed_statement)): + and _allow_join(parsed_statement)): tables = extract_tables(text_before_cursor) suggest.append(Join(tables=tables, schema=schema)) @@ -346,15 +346,15 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, View(schema=parent), Function(schema=parent)] last_token = parsed_statement - if _allow_join_condition_suggestion(parsed_statement): + if filteredtables and _allow_join_condition(parsed_statement): sugs.append(JoinCondition(tables=tables, parent=filteredtables[-1])) return tuple(sugs) else: # ON <suggestion> # Use table alias if there is one, otherwise the table name - aliases = tuple(t.alias or t.name for t in tables) - if _allow_join_condition_suggestion(parsed_statement): + aliases = tuple(t.ref for t in tables) + if _allow_join_condition(parsed_statement): return (Alias(aliases=aliases), JoinCondition( tables=tables, parent=None)) else: @@ -397,7 +397,7 @@ def identifies(id, ref): ref.schema and (id == ref.schema + '.' + ref.name)) -def _allow_join_condition_suggestion(statement): +def _allow_join_condition(statement): """ Tests if a join condition should be suggested @@ -417,7 +417,7 @@ def _allow_join_condition_suggestion(statement): return last_tok.value.lower() in ('on', 'and', 'or') -def _allow_join_suggestion(statement): +def _allow_join(statement): """ Tests if a join should be suggested diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py index 0ae6753e..2c705db9 100644 --- a/pgcli/pgcompleter.py +++ b/pgcli/pgcompleter.py @@ -439,56 +439,56 @@ class PGCompleter(Completer): priority_collection=prios, type_priority=100) def get_join_condition_matches(self, suggestion, word_before_cursor): - lefttable = suggestion.parent or suggestion.tables[-1] - scoped_cols = self.populate_scoped_cols(suggestion.tables) + col = namedtuple('col', 'schema tbl col') + tbls = self.populate_scoped_cols(suggestion.tables).items + cols = [(t, c) for t, cs in tbls() for c in cs] + try: + lref = (suggestion.parent or suggestion.tables[-1]).ref + ltbl, lcols = [(t, cs) for (t, cs) in tbls() if t.ref == lref][-1] + except IndexError: # The user typed an incorrect table qualifier + return [] + conds, found_conds = [], set() - def make_cond(tbl1, tbl2, col1, col2): - prefix = '' if suggestion.parent else tbl1 + '.' + def add_cond(lcol, rcol, rref, meta, prio): + prefix = '' if suggestion.parent else ltbl.ref + '.' case = self.case - return prefix + case(col1) + ' = ' + tbl2 + '.' + case(col2) + cond = prefix + case(lcol) + ' = ' + rref + '.' + case(rcol) + if cond not in found_conds: + found_conds.add(cond) + conds.append((cond, meta, prio + ref_prio[rref])) + + def list_dict(pairs): # Turns [(a, b), (a, c)] into {a: [b, c]} + d = defaultdict(list) + for pair in pairs: + d[pair[0]].append(pair[1]) + return d # Tables that are closer to the cursor get higher prio - refprio = dict((tbl.ref, num) for num, tbl + ref_prio = dict((tbl.ref, num) for num, tbl in enumerate(suggestion.tables)) - # Map (schema, tablename) to tables and ref to columns - tbldict = defaultdict(list) - for t in scoped_cols.keys(): - tbldict[(t.schema, t.name)].append(t) - refcols = dict((t.ref, cs) for t, cs in scoped_cols.items()) + # Map (schema, table, col) to tables + coldict = list_dict(((t.schema, t.name, c.name), t) + for t, c in cols if t.ref != lref) # For each fk from the left table, generate a join condition if # the other table is also in the scope - conds = [] - for lcol in refcols.get(lefttable.ref, []): - for fk in lcol.foreignkeys: - for rcol in ((fk.parentschema, fk.parenttable, - fk.parentcolumn), (fk.childschema, fk.childtable, - fk.childcolumn)): - for rtbl in tbldict[(rcol[0], rcol[1])]: - if rtbl and rtbl.ref != lefttable.ref: - cond = make_cond(lefttable.ref, rtbl.ref, - lcol.name, rcol[2]) - prio = 2000 + refprio[rtbl.ref] - conds.append((cond, 'fk join', prio)) + fks = ((fk, lcol.name) for lcol in lcols for fk in lcol.foreignkeys) + for fk, lcol in fks: + left = col(ltbl.schema, ltbl.name, lcol) + child = col(fk.childschema, fk.childtable, fk.childcolumn) + par = col(fk.parentschema, fk.parenttable, fk.parentcolumn) + left, right = (child, par) if left == child else (par, child) + for rtbl in coldict[right]: + add_cond(left.col, right.col, rtbl.ref, 'fk join', 2000) # For name matching, use a {(colname, coltype): TableReference} dict - col_table = defaultdict(lambda: []) - for tbl, col in ((t, c) for t, cs in scoped_cols.items() for c in cs): - col_table[(col.name, col.datatype)].append(tbl) + coltyp = namedtuple('coltyp', 'name datatype') + col_table = list_dict((coltyp(c.name, c.datatype), t) for t, c in cols) # Find all name-match join conditions - found = set(cnd[0] for cnd in conds) - for c in refcols.get(lefttable.ref, []): - for rtbl in col_table[(c.name, c.datatype)]: - if rtbl.ref != lefttable.ref: - cond = make_cond(lefttable.ref, rtbl.ref, c.name, c.name) - if cond not in found: - prio = (1000 if c.datatype and c.datatype in( - 'integer', 'bigint', 'smallint') - else 0 + refprio[rtbl.ref]) - conds.append((cond, 'name join', prio)) - - if not conds: - return [] + for c in (coltyp(c.name, c.datatype) for c in lcols): + for rtbl in (t for t in col_table[c] if t.ref != ltbl.ref): + add_cond(c.name, c.name, rtbl.ref, 'name join', 1000 + if c.datatype in ('integer', 'bigint', 'smallint') else 0) - conds, metas, prios = zip(*conds) + conds, metas, prios = zip(*conds) if conds else ([], [], []) return self.find_matches(word_before_cursor, conds, meta_collection=metas, type_priority=100, priority_collection=prios) diff --git a/tests/test_smart_completion_multiple_schemata.py b/tests/test_smart_completion_multiple_schemata.py index 9f74a7d4..b5ac0d3a 100644 --- a/tests/test_smart_completion_multiple_schemata.py +++ b/tests/test_smart_completion_multiple_schemata.py @@ -100,17 +100,18 @@ def test_schema_or_visible_table_completion(completer, complete_event): Completion(text='orders', start_position=0, display_meta='table')]) -@pytest.mark.parametrize('text', [ - 'SELECT FROM users', - 'SELECT FROM "users"', +@pytest.mark.parametrize('table', [ + 'users', + '"users"', ]) -def test_suggested_column_names_from_shadowed_visible_table(completer, complete_event, text): +def test_suggested_column_names_from_shadowed_visible_table(completer, complete_event, table): """ Suggest column and function names when selecting from table :param completer: :param complete_event: :return: """ + text = 'SELECT FROM ' + table position = len('SELECT ') result = set(completer.get_completions( Document(text=text, cursor_position=position), diff --git a/tests/test_smart_completion_public_schema_only.py b/tests/test_smart_completion_public_schema_only.py index 3226484f..01c54fe5 100644 --- a/tests/test_smart_completion_public_schema_only.py +++ b/tests/test_smart_completion_public_schema_only.py @@ -343,6 +343,52 @@ def test_suggested_join_conditions(completer, complete_event, text): Completion(text='u2.userid = u.id', start_position=0, display_meta='fk join')]) @pytest.mark.parametrize('text', [ + '''SELECT * + FROM users + CROSS JOIN "Users" + NATURAL JOIN users u + JOIN "Users" u2 ON + ''' +]) +def test_suggested_join_conditions_with_same_table_twice(completer, complete_event, text): + position = len(text) + result = completer.get_completions( + Document(text=text, cursor_position=position), + complete_event) + assert result == [ + Completion(text='u2.userid = u.id', start_position=0, display_meta='fk join'), + Completion(text='u2.userid = users.id', start_position=0, display_meta='fk join'), + Completion(text='u2.userid = "Users".userid', start_position=0, display_meta='name join'), + Completion(text='u2.username = "Users".username', start_position=0, display_meta='name join'), + Completion(text='"Users"', start_position=0, display_meta='table alias'), + Completion(text='u', start_position=0, display_meta='table alias'), + Completion(text='u2', start_position=0, display_meta='table alias'), + Completion(text='users', start_position=0, display_meta='table alias')] + +@pytest.mark.parametrize('text', [ + 'SELECT * FROM users JOIN users u2 on foo.' +]) +def test_suggested_join_conditions_with_invalid_qualifier(completer, complete_event, text): + position = len(text) + result = set(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert set(result) == set() + +@pytest.mark.parametrize(('text', 'ref'), [ + ('SELECT * FROM users JOIN NonTable on ', 'NonTable'), + ('SELECT * FROM users JOIN nontable nt on ', 'nt') +]) +def test_suggested_join_conditions_with_invalid_table(completer, complete_event, text, ref): + position = len(text) + result = set(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert set(result) == set([ + Completion(text='users', start_position=0, display_meta='table alias'), + Completion(text=ref, start_position=0, display_meta='table alias')]) + +@pytest.mark.parametrize('text', [ 'SELECT * FROM "Users" u JOIN u', 'SELECT * FROM "Users" u JOIN uid', 'SELECT * FROM "Users" u JOIN userid', |