summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorkoljonen <koljonen@outlook.com>2016-06-30 00:41:54 +0200
committerJoakim Koljonen <koljonen@Joakims-MacBook-Pro-2.local>2016-07-06 20:03:24 +0200
commitf09bb42d67d879ef292a8dc9654f41308fd1a6d8 (patch)
tree7258baf716759464dbe423a46d4d6fffe1daa854
parent1605bf1cdb7c4f7bd10f3f215451195d3286fedf (diff)
Better scoping for tables in insert statements
This commit makes it so that given `INSERT INTO foo(<cursor1>) SELECT <cursor2> FROM bar;`, we suggest `bar` columns for `<cursor2>` and `foo` columns for `<cursor1>`. Previous behaviour is sugggesting columns from both tables in both cases.
-rw-r--r--pgcli/packages/sqlcompletion.py35
-rw-r--r--tests/test_smart_completion_multiple_schemata.py1
-rw-r--r--tests/test_smart_completion_public_schema_only.py51
-rw-r--r--tests/test_sqlcompletion.py33
4 files changed, 94 insertions, 26 deletions
diff --git a/pgcli/packages/sqlcompletion.py b/pgcli/packages/sqlcompletion.py
index 4dab4897..54a0ec5e 100644
--- a/pgcli/packages/sqlcompletion.py
+++ b/pgcli/packages/sqlcompletion.py
@@ -83,6 +83,24 @@ class SqlStatement(object):
self.last_token = parsed and parsed.token_prev(len(parsed.tokens)) or ''
+ def is_insert(self):
+ return self.parsed.token_first().value.lower() == 'insert'
+
+ def get_tables(self, scope='full'):
+ """ Gets the tables available in the statement.
+ param `scope:` possible values: 'full', 'insert', 'before'
+ If 'insert', only the first table is returned.
+ If 'before', only tables before the cursor are returned.
+ If not 'insert' and the stmt is an insert, the first table is skipped.
+ """
+ tables = extract_tables(
+ self.full_text if scope == 'full' else self.text_before_cursor)
+ if scope == 'insert':
+ tables = tables[:1]
+ elif self.is_insert():
+ tables = tables[1:]
+ return tables
+
def get_identifier_schema(self):
schema = (self.identifier and self.identifier.get_parent_name()) or None
# If schema name is unquoted, lower-case it
@@ -292,7 +310,7 @@ def suggest_based_on_last_token(token, stmt):
if (prev_tok and prev_tok.value
and prev_tok.value.lower().split(' ')[-1] == 'using'):
# tbl1 INNER JOIN tbl2 USING (col1, col2)
- tables = extract_tables(stmt.text_before_cursor)
+ tables = stmt.get_tables('before')
# suggest columns that are present in more than one table
return (Column(tables=tables, require_last_table=True),)
@@ -303,23 +321,25 @@ def suggest_based_on_last_token(token, stmt):
if last_word(stmt.text_before_cursor,
'all_punctuations').startswith('('):
return (Keyword(),)
+ prev_prev_tok = p.token_prev(p.token_index(prev_tok))
+ if prev_prev_tok and prev_prev_tok.normalized == 'INTO':
+ return (Column(tables=stmt.get_tables('insert')),)
# We're probably in a function argument list
return (Column(tables=extract_tables(stmt.full_text)),)
elif token_v in ('set', 'by', 'distinct'):
- return (Column(tables=extract_tables(stmt.full_text)),)
+ return (Column(tables=stmt.get_tables()),)
elif token_v in ('select', 'where', 'having'):
# Check for a table alias or schema qualification
parent = (stmt.identifier and stmt.identifier.get_parent_name()) or []
-
+ tables = stmt.get_tables()
if parent:
- tables = extract_tables(stmt.full_text)
tables = tuple(t for t in tables if identifies(parent, t))
return (Column(tables=tables),
Table(schema=parent),
View(schema=parent),
Function(schema=parent),)
else:
- return (Column(tables=extract_tables(stmt.full_text)),
+ return (Column(tables=tables),
Function(schema=None),
Keyword(),)
@@ -347,6 +367,7 @@ def suggest_based_on_last_token(token, stmt):
suggest.extend((Table(schema), View(schema)))
if is_join and _allow_join(stmt.parsed):
+ tables = stmt.get_tables('before')
suggest.append(Join(tables=tables, schema=schema))
return tuple(suggest)
@@ -362,10 +383,10 @@ def suggest_based_on_last_token(token, stmt):
elif token_v == 'column':
# E.g. 'ALTER TABLE foo ALTER COLUMN bar
- return (Column(tables=extract_tables(stmt.text_before_cursor)),)
+ return (Column(tables=stmt.get_tables()),)
elif token_v == 'on':
- tables = extract_tables(stmt.text_before_cursor) # [(schema, table, alias), ...]
+ tables = stmt.get_tables('before')
parent = (stmt.identifier and stmt.identifier.get_parent_name()) or None
if parent:
# "ON parent.<suggestion>"
diff --git a/tests/test_smart_completion_multiple_schemata.py b/tests/test_smart_completion_multiple_schemata.py
index efa00f3c..cd3b65ba 100644
--- a/tests/test_smart_completion_multiple_schemata.py
+++ b/tests/test_smart_completion_multiple_schemata.py
@@ -4,7 +4,6 @@ import itertools
from metadata import (MetaData, alias, name_join, fk_join, join,
schema, table, function, wildcard_expansion)
from prompt_toolkit.document import Document
-from pgcli.packages.function_metadata import FunctionMetadata, ForeignKey
metadata = {
'tables': {
diff --git a/tests/test_smart_completion_public_schema_only.py b/tests/test_smart_completion_public_schema_only.py
index 61b91972..df73d51b 100644
--- a/tests/test_smart_completion_public_schema_only.py
+++ b/tests/test_smart_completion_public_schema_only.py
@@ -3,7 +3,6 @@ import pytest
from metadata import (MetaData, alias, name_join, fk_join, join, keyword,
schema, table, view, function, column, wildcard_expansion)
from prompt_toolkit.document import Document
-from pgcli.packages.function_metadata import FunctionMetadata, ForeignKey
metadata = {
'tables': {
@@ -294,6 +293,9 @@ def test_suggest_columns_after_three_way_join(completer, complete_event):
set(result))
join_condition_texts = [
+ 'INSERT INTO orders SELECT * FROM users U JOIN "Users" U2 ON ',
+ '''INSERT INTO public.orders(orderid)
+ SELECT * FROM users U JOIN "Users" U2 ON ''',
'SELECT * FROM users U JOIN "Users" U2 ON ',
'SELECT * FROM users U INNER join "Users" U2 ON ',
'SELECT * FROM USERS U right JOIN "Users" U2 ON ',
@@ -388,6 +390,14 @@ def test_suggested_joins_fuzzy(completer, complete_event, text):
join_texts = [
'SELECT * FROM Users JOIN ',
+ '''INSERT INTO "Users"
+ SELECT *
+ FROM Users
+ INNER JOIN ''',
+ '''INSERT INTO public."Users"(username)
+ SELECT *
+ FROM Users
+ INNER JOIN ''',
'''SELECT *
FROM Users
INNER JOIN '''
@@ -689,10 +699,15 @@ def test_columns_before_keywords(completer, complete_event):
assert completions.index(col) < completions.index(kw)
-
-def test_wildcard_column_expansion(completer, complete_event):
- sql = 'SELECT * FROM users'
- pos = len('SELECT *')
+@pytest.mark.parametrize('sql', [
+ 'SELECT * FROM users',
+ 'INSERT INTO users SELECT * FROM users u',
+ '''INSERT INTO users(id, parentid, email, first_name, last_name)
+ SELECT *
+ FROM users u''',
+ ])
+def test_wildcard_column_expansion(completer, complete_event, sql):
+ pos = sql.find('*') + 1
completions = completer.get_completions(
Document(text=sql, cursor_position=pos), complete_event)
@@ -702,10 +717,15 @@ def test_wildcard_column_expansion(completer, complete_event):
assert expected == completions
-
-def test_wildcard_column_expansion_with_alias_qualifier(completer, complete_event):
- sql = 'SELECT u.* FROM users u'
- pos = len('SELECT u.*')
+@pytest.mark.parametrize('sql', [
+ 'SELECT u.* FROM users u',
+ 'INSERT INTO public.users SELECT u.* FROM users u',
+ '''INSERT INTO users(id, parentid, email, first_name, last_name)
+ SELECT u.*
+ FROM users u''',
+ ])
+def test_wildcard_column_expansion_with_alias(completer, complete_event, sql):
+ pos = sql.find('*') + 1
completions = completer.get_completions(
Document(text=sql, cursor_position=pos), complete_event)
@@ -834,3 +854,16 @@ def test_table_casing(cased_completer, complete_event, text):
result = cased_completer.get_completions(
Document(text=text), complete_event)
assert set(result) == set([schema('PUBLIC')] + cased_rels)
+
+
+@pytest.mark.parametrize('text', [
+ 'INSERT INTO users ()',
+ 'INSERT INTO users()',
+ 'INSERT INTO users () SELECT * FROM orders;',
+ 'INSERT INTO users() SELECT * FROM users u cross join orders o',
+])
+def test_insert(completer, complete_event, text):
+ pos = text.find('(') + 1
+ result = completer.get_completions(Document(text=text, cursor_position=pos),
+ complete_event)
+ assert set(result) == set(testdata.columns('users'))
diff --git a/tests/test_sqlcompletion.py b/tests/test_sqlcompletion.py
index de4cbe52..fee7be4d 100644
--- a/tests/test_sqlcompletion.py
+++ b/tests/test_sqlcompletion.py
@@ -29,6 +29,8 @@ def test_where_suggests_columns_functions_quoted_table(expression):
@pytest.mark.parametrize('expression', [
+ 'INSERT INTO OtherTabl(ID, Name) SELECT * FROM tabl WHERE ',
+ 'INSERT INTO OtherTabl SELECT * FROM tabl WHERE ',
'SELECT * FROM tabl WHERE ',
'SELECT * FROM tabl WHERE (',
'SELECT * FROM tabl WHERE foo = ',
@@ -189,9 +191,12 @@ def test_truncate_suggests_qualified_tables():
assert set(suggestions) == set([
Table(schema='sch')])
-
-def test_distinct_suggests_cols():
- suggestions = suggest_type('SELECT DISTINCT ', 'SELECT DISTINCT ')
+@pytest.mark.parametrize('text', [
+ 'SELECT DISTINCT ',
+ 'INSERT INTO foo SELECT DISTINCT '
+])
+def test_distinct_suggests_cols(text):
+ suggestions = suggest_type(text, text)
assert suggestions ==(Column(tables=()),)
@@ -221,9 +226,12 @@ def test_into_suggests_tables_and_schemas():
Schema(),
])
-
-def test_insert_into_lparen_suggests_cols():
- suggestions = suggest_type('INSERT INTO abc (', 'INSERT INTO abc (')
+@pytest.mark.parametrize('text', [
+ 'INSERT INTO abc (',
+ 'INSERT INTO abc () SELECT * FROM hij;',
+])
+def test_insert_into_lparen_suggests_cols(text):
+ suggestions = suggest_type(text, 'INSERT INTO abc (')
assert suggestions ==(Column(tables=((None, 'abc', None, False),)),)
@@ -483,9 +491,16 @@ def test_on_suggests_tables_and_join_conditions_right_side(sql):
Alias(aliases=('abc', 'bcd',)),))
-@pytest.mark.parametrize('col_list', ('', 'col1, ',))
-def test_join_using_suggests_common_columns(col_list):
- text = 'select * from abc inner join def using (' + col_list
+@pytest.mark.parametrize('text', (
+ 'select * from abc inner join def using (',
+ 'select * from abc inner join def using (col1, ',
+ 'insert into hij select * from abc inner join def using (',
+ '''insert into hij(x, y, z)
+ select * from abc inner join def using (col1, ''',
+ '''insert into hij (a,b,c)
+ select * from abc inner join def using (col1, ''',
+))
+def test_join_using_suggests_common_columns(text):
tables = ((None, 'abc', None, False), (None, 'def', None, False))
assert set(suggest_type(text, text)) == set([
Column(tables=tables, require_last_table=True),])