diff options
author | koljonen <koljonen@outlook.com> | 2016-06-30 00:41:54 +0200 |
---|---|---|
committer | Joakim Koljonen <koljonen@Joakims-MacBook-Pro-2.local> | 2016-07-06 20:03:24 +0200 |
commit | f09bb42d67d879ef292a8dc9654f41308fd1a6d8 (patch) | |
tree | 7258baf716759464dbe423a46d4d6fffe1daa854 | |
parent | 1605bf1cdb7c4f7bd10f3f215451195d3286fedf (diff) |
Better scoping for tables in insert statements
This commit makes it so that given `INSERT INTO foo(<cursor1>) SELECT <cursor2> FROM bar;`, we suggest `bar` columns for `<cursor2>` and `foo` columns for `<cursor1>`. Previous behaviour is sugggesting columns from both tables in both cases.
-rw-r--r-- | pgcli/packages/sqlcompletion.py | 35 | ||||
-rw-r--r-- | tests/test_smart_completion_multiple_schemata.py | 1 | ||||
-rw-r--r-- | tests/test_smart_completion_public_schema_only.py | 51 | ||||
-rw-r--r-- | tests/test_sqlcompletion.py | 33 |
4 files changed, 94 insertions, 26 deletions
diff --git a/pgcli/packages/sqlcompletion.py b/pgcli/packages/sqlcompletion.py index 4dab4897..54a0ec5e 100644 --- a/pgcli/packages/sqlcompletion.py +++ b/pgcli/packages/sqlcompletion.py @@ -83,6 +83,24 @@ class SqlStatement(object): self.last_token = parsed and parsed.token_prev(len(parsed.tokens)) or '' + def is_insert(self): + return self.parsed.token_first().value.lower() == 'insert' + + def get_tables(self, scope='full'): + """ Gets the tables available in the statement. + param `scope:` possible values: 'full', 'insert', 'before' + If 'insert', only the first table is returned. + If 'before', only tables before the cursor are returned. + If not 'insert' and the stmt is an insert, the first table is skipped. + """ + tables = extract_tables( + self.full_text if scope == 'full' else self.text_before_cursor) + if scope == 'insert': + tables = tables[:1] + elif self.is_insert(): + tables = tables[1:] + return tables + def get_identifier_schema(self): schema = (self.identifier and self.identifier.get_parent_name()) or None # If schema name is unquoted, lower-case it @@ -292,7 +310,7 @@ def suggest_based_on_last_token(token, stmt): if (prev_tok and prev_tok.value and prev_tok.value.lower().split(' ')[-1] == 'using'): # tbl1 INNER JOIN tbl2 USING (col1, col2) - tables = extract_tables(stmt.text_before_cursor) + tables = stmt.get_tables('before') # suggest columns that are present in more than one table return (Column(tables=tables, require_last_table=True),) @@ -303,23 +321,25 @@ def suggest_based_on_last_token(token, stmt): if last_word(stmt.text_before_cursor, 'all_punctuations').startswith('('): return (Keyword(),) + prev_prev_tok = p.token_prev(p.token_index(prev_tok)) + if prev_prev_tok and prev_prev_tok.normalized == 'INTO': + return (Column(tables=stmt.get_tables('insert')),) # We're probably in a function argument list return (Column(tables=extract_tables(stmt.full_text)),) elif token_v in ('set', 'by', 'distinct'): - return (Column(tables=extract_tables(stmt.full_text)),) + return (Column(tables=stmt.get_tables()),) elif token_v in ('select', 'where', 'having'): # Check for a table alias or schema qualification parent = (stmt.identifier and stmt.identifier.get_parent_name()) or [] - + tables = stmt.get_tables() if parent: - tables = extract_tables(stmt.full_text) tables = tuple(t for t in tables if identifies(parent, t)) return (Column(tables=tables), Table(schema=parent), View(schema=parent), Function(schema=parent),) else: - return (Column(tables=extract_tables(stmt.full_text)), + return (Column(tables=tables), Function(schema=None), Keyword(),) @@ -347,6 +367,7 @@ def suggest_based_on_last_token(token, stmt): suggest.extend((Table(schema), View(schema))) if is_join and _allow_join(stmt.parsed): + tables = stmt.get_tables('before') suggest.append(Join(tables=tables, schema=schema)) return tuple(suggest) @@ -362,10 +383,10 @@ def suggest_based_on_last_token(token, stmt): elif token_v == 'column': # E.g. 'ALTER TABLE foo ALTER COLUMN bar - return (Column(tables=extract_tables(stmt.text_before_cursor)),) + return (Column(tables=stmt.get_tables()),) elif token_v == 'on': - tables = extract_tables(stmt.text_before_cursor) # [(schema, table, alias), ...] + tables = stmt.get_tables('before') parent = (stmt.identifier and stmt.identifier.get_parent_name()) or None if parent: # "ON parent.<suggestion>" diff --git a/tests/test_smart_completion_multiple_schemata.py b/tests/test_smart_completion_multiple_schemata.py index efa00f3c..cd3b65ba 100644 --- a/tests/test_smart_completion_multiple_schemata.py +++ b/tests/test_smart_completion_multiple_schemata.py @@ -4,7 +4,6 @@ import itertools from metadata import (MetaData, alias, name_join, fk_join, join, schema, table, function, wildcard_expansion) from prompt_toolkit.document import Document -from pgcli.packages.function_metadata import FunctionMetadata, ForeignKey metadata = { 'tables': { diff --git a/tests/test_smart_completion_public_schema_only.py b/tests/test_smart_completion_public_schema_only.py index 61b91972..df73d51b 100644 --- a/tests/test_smart_completion_public_schema_only.py +++ b/tests/test_smart_completion_public_schema_only.py @@ -3,7 +3,6 @@ import pytest from metadata import (MetaData, alias, name_join, fk_join, join, keyword, schema, table, view, function, column, wildcard_expansion) from prompt_toolkit.document import Document -from pgcli.packages.function_metadata import FunctionMetadata, ForeignKey metadata = { 'tables': { @@ -294,6 +293,9 @@ def test_suggest_columns_after_three_way_join(completer, complete_event): set(result)) join_condition_texts = [ + 'INSERT INTO orders SELECT * FROM users U JOIN "Users" U2 ON ', + '''INSERT INTO public.orders(orderid) + SELECT * FROM users U JOIN "Users" U2 ON ''', 'SELECT * FROM users U JOIN "Users" U2 ON ', 'SELECT * FROM users U INNER join "Users" U2 ON ', 'SELECT * FROM USERS U right JOIN "Users" U2 ON ', @@ -388,6 +390,14 @@ def test_suggested_joins_fuzzy(completer, complete_event, text): join_texts = [ 'SELECT * FROM Users JOIN ', + '''INSERT INTO "Users" + SELECT * + FROM Users + INNER JOIN ''', + '''INSERT INTO public."Users"(username) + SELECT * + FROM Users + INNER JOIN ''', '''SELECT * FROM Users INNER JOIN ''' @@ -689,10 +699,15 @@ def test_columns_before_keywords(completer, complete_event): assert completions.index(col) < completions.index(kw) - -def test_wildcard_column_expansion(completer, complete_event): - sql = 'SELECT * FROM users' - pos = len('SELECT *') +@pytest.mark.parametrize('sql', [ + 'SELECT * FROM users', + 'INSERT INTO users SELECT * FROM users u', + '''INSERT INTO users(id, parentid, email, first_name, last_name) + SELECT * + FROM users u''', + ]) +def test_wildcard_column_expansion(completer, complete_event, sql): + pos = sql.find('*') + 1 completions = completer.get_completions( Document(text=sql, cursor_position=pos), complete_event) @@ -702,10 +717,15 @@ def test_wildcard_column_expansion(completer, complete_event): assert expected == completions - -def test_wildcard_column_expansion_with_alias_qualifier(completer, complete_event): - sql = 'SELECT u.* FROM users u' - pos = len('SELECT u.*') +@pytest.mark.parametrize('sql', [ + 'SELECT u.* FROM users u', + 'INSERT INTO public.users SELECT u.* FROM users u', + '''INSERT INTO users(id, parentid, email, first_name, last_name) + SELECT u.* + FROM users u''', + ]) +def test_wildcard_column_expansion_with_alias(completer, complete_event, sql): + pos = sql.find('*') + 1 completions = completer.get_completions( Document(text=sql, cursor_position=pos), complete_event) @@ -834,3 +854,16 @@ def test_table_casing(cased_completer, complete_event, text): result = cased_completer.get_completions( Document(text=text), complete_event) assert set(result) == set([schema('PUBLIC')] + cased_rels) + + +@pytest.mark.parametrize('text', [ + 'INSERT INTO users ()', + 'INSERT INTO users()', + 'INSERT INTO users () SELECT * FROM orders;', + 'INSERT INTO users() SELECT * FROM users u cross join orders o', +]) +def test_insert(completer, complete_event, text): + pos = text.find('(') + 1 + result = completer.get_completions(Document(text=text, cursor_position=pos), + complete_event) + assert set(result) == set(testdata.columns('users')) diff --git a/tests/test_sqlcompletion.py b/tests/test_sqlcompletion.py index de4cbe52..fee7be4d 100644 --- a/tests/test_sqlcompletion.py +++ b/tests/test_sqlcompletion.py @@ -29,6 +29,8 @@ def test_where_suggests_columns_functions_quoted_table(expression): @pytest.mark.parametrize('expression', [ + 'INSERT INTO OtherTabl(ID, Name) SELECT * FROM tabl WHERE ', + 'INSERT INTO OtherTabl SELECT * FROM tabl WHERE ', 'SELECT * FROM tabl WHERE ', 'SELECT * FROM tabl WHERE (', 'SELECT * FROM tabl WHERE foo = ', @@ -189,9 +191,12 @@ def test_truncate_suggests_qualified_tables(): assert set(suggestions) == set([ Table(schema='sch')]) - -def test_distinct_suggests_cols(): - suggestions = suggest_type('SELECT DISTINCT ', 'SELECT DISTINCT ') +@pytest.mark.parametrize('text', [ + 'SELECT DISTINCT ', + 'INSERT INTO foo SELECT DISTINCT ' +]) +def test_distinct_suggests_cols(text): + suggestions = suggest_type(text, text) assert suggestions ==(Column(tables=()),) @@ -221,9 +226,12 @@ def test_into_suggests_tables_and_schemas(): Schema(), ]) - -def test_insert_into_lparen_suggests_cols(): - suggestions = suggest_type('INSERT INTO abc (', 'INSERT INTO abc (') +@pytest.mark.parametrize('text', [ + 'INSERT INTO abc (', + 'INSERT INTO abc () SELECT * FROM hij;', +]) +def test_insert_into_lparen_suggests_cols(text): + suggestions = suggest_type(text, 'INSERT INTO abc (') assert suggestions ==(Column(tables=((None, 'abc', None, False),)),) @@ -483,9 +491,16 @@ def test_on_suggests_tables_and_join_conditions_right_side(sql): Alias(aliases=('abc', 'bcd',)),)) -@pytest.mark.parametrize('col_list', ('', 'col1, ',)) -def test_join_using_suggests_common_columns(col_list): - text = 'select * from abc inner join def using (' + col_list +@pytest.mark.parametrize('text', ( + 'select * from abc inner join def using (', + 'select * from abc inner join def using (col1, ', + 'insert into hij select * from abc inner join def using (', + '''insert into hij(x, y, z) + select * from abc inner join def using (col1, ''', + '''insert into hij (a,b,c) + select * from abc inner join def using (col1, ''', +)) +def test_join_using_suggests_common_columns(text): tables = ((None, 'abc', None, False), (None, 'def', None, False)) assert set(suggest_type(text, text)) == set([ Column(tables=tables, require_last_table=True),]) |