summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--pgcli/packages/sqlcompletion.py27
1 files changed, 19 insertions, 8 deletions
diff --git a/pgcli/packages/sqlcompletion.py b/pgcli/packages/sqlcompletion.py
index 5cce2340..8d00fc28 100644
--- a/pgcli/packages/sqlcompletion.py
+++ b/pgcli/packages/sqlcompletion.py
@@ -7,6 +7,7 @@ from sqlparse.sql import Comparison, Identifier, Where
from .parseutils.utils import (
last_word, find_prev_keyword, parse_partial_identifier)
from .parseutils.tables import extract_tables
+from .parseutils.ctes import isolate_query_ctes
from pgspecial.main import parse_special_command
PY2 = sys.version_info[0] == 2
@@ -58,6 +59,10 @@ class SqlStatement(object):
text_before_cursor, include='many_punctuations')
full_text = _strip_named_query(full_text)
text_before_cursor = _strip_named_query(text_before_cursor)
+
+ full_text, text_before_cursor, self.local_tables = \
+ isolate_query_ctes(full_text, text_before_cursor)
+
self.text_before_cursor_including_last_word = text_before_cursor
# If we've partially typed a word then word_before_cursor won't be an
@@ -315,7 +320,9 @@ def suggest_based_on_last_token(token, stmt):
tables = stmt.get_tables('before')
# suggest columns that are present in more than one table
- return (Column(table_refs=tables, require_last_table=True),)
+ return (Column(table_refs=tables,
+ require_last_table=True,
+ local_tables=stmt.local_tables),)
elif p.token_first().value.lower() == 'select':
# If the lparen is preceeded by a space chances are we're about to
@@ -327,21 +334,23 @@ def suggest_based_on_last_token(token, stmt):
if prev_prev_tok and prev_prev_tok.normalized == 'INTO':
return (Column(table_refs=stmt.get_tables('insert')),)
# We're probably in a function argument list
- return (Column(table_refs=extract_tables(stmt.full_text)),)
+ return (Column(table_refs=extract_tables(stmt.full_text),
+ local_tables=stmt.local_tables),)
elif token_v in ('set', 'by', 'distinct'):
- return (Column(table_refs=stmt.get_tables()),)
+ return (Column(table_refs=stmt.get_tables(),
+ local_tables=stmt.local_tables),)
elif token_v in ('select', 'where', 'having'):
# Check for a table alias or schema qualification
parent = (stmt.identifier and stmt.identifier.get_parent_name()) or []
tables = stmt.get_tables()
if parent:
tables = tuple(t for t in tables if identifies(parent, t))
- return (Column(table_refs=tables),
+ return (Column(table_refs=tables, local_tables=stmt.local_tables),
Table(schema=parent),
View(schema=parent),
Function(schema=parent),)
else:
- return (Column(table_refs=tables),
+ return (Column(table_refs=tables, local_tables=stmt.local_tables),
Function(schema=None),
Keyword(),)
@@ -360,9 +369,10 @@ def suggest_based_on_last_token(token, stmt):
# Suggest schemas
suggest.insert(0, Schema())
- # Suggest set-returning functions in the FROM clause
if token_v == 'from' or is_join:
- suggest.append(FromClauseItem(schema=schema, table_refs=tables))
+ suggest.append(FromClauseItem(schema=schema,
+ table_refs=tables,
+ local_tables=stmt.local_tables))
elif token_v == 'truncate':
suggest.append(Table(schema))
else:
@@ -394,7 +404,8 @@ def suggest_based_on_last_token(token, stmt):
# "ON parent.<suggestion>"
# parent can be either a schema name or table alias
filteredtables = tuple(t for t in tables if identifies(parent, t))
- sugs = [Column(table_refs=filteredtables),
+ sugs = [Column(table_refs=filteredtables,
+ local_tables=stmt.local_tables),
Table(schema=parent),
View(schema=parent),
Function(schema=parent)]