summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAmjith Ramanujam <amjith.r@gmail.com>2016-08-08 22:10:29 -0700
committerGitHub <noreply@github.com>2016-08-08 22:10:29 -0700
commitab6e54b2cf5fd3badb1a31143146b36beaeb88d9 (patch)
treef392d3c936e8f9f600630cad545e7630ae6ab690
parent28d01e44ae4b7741c461116d660b05bc27b5ab60 (diff)
parent2407de0e4c4c52e27b0546c374722bd4e88287ad (diff)
Merge pull request #553 from dbcli/darikg/cte-suggestions
CTE-aware suggestions
-rw-r--r--pgcli/packages/parseutils/__init__.py0
-rw-r--r--pgcli/packages/parseutils/ctes.py146
-rw-r--r--pgcli/packages/parseutils/meta.py (renamed from pgcli/packages/function_metadata.py)1
-rw-r--r--pgcli/packages/parseutils/tables.py (renamed from pgcli/packages/parseutils.py)166
-rw-r--r--pgcli/packages/parseutils/utils.py161
-rw-r--r--pgcli/packages/sqlcompletion.py65
-rw-r--r--pgcli/pgbuffer.py2
-rw-r--r--pgcli/pgcompleter.py41
-rw-r--r--pgcli/pgexecute.py2
-rw-r--r--tests/metadata.py2
-rw-r--r--tests/parseutils/test_ctes.py121
-rw-r--r--tests/parseutils/test_function_metadata.py (renamed from tests/test_function_metadata.py)2
-rw-r--r--tests/parseutils/test_parseutils.py (renamed from tests/test_parseutils.py)4
-rw-r--r--tests/test_pgexecute.py2
-rw-r--r--tests/test_smart_completion_public_schema_only.py47
-rw-r--r--tests/test_sqlcompletion.py66
16 files changed, 582 insertions, 246 deletions
diff --git a/pgcli/packages/parseutils/__init__.py b/pgcli/packages/parseutils/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/pgcli/packages/parseutils/__init__.py
diff --git a/pgcli/packages/parseutils/ctes.py b/pgcli/packages/parseutils/ctes.py
new file mode 100644
index 00000000..12c207c4
--- /dev/null
+++ b/pgcli/packages/parseutils/ctes.py
@@ -0,0 +1,146 @@
+from sqlparse import parse
+from sqlparse.tokens import Keyword, CTE, DML
+from sqlparse.sql import Identifier, IdentifierList, Parenthesis
+from collections import namedtuple
+from .meta import TableMetadata, ColumnMetadata
+
+
+# TableExpression is a namedtuple representing a CTE, used internally
+# name: cte alias assigned in the query
+# columns: list of column names
+# start: index into the original string of the left parens starting the CTE
+# stop: index into the original string of the right parens ending the CTE
+TableExpression = namedtuple('TableExpression', 'name columns start stop')
+
+
+def isolate_query_ctes(full_text, text_before_cursor):
+ """Simplify a query by converting CTEs into table metadata objects
+ """
+
+ if not full_text:
+ return full_text, text_before_cursor, tuple()
+
+ ctes, remainder = extract_ctes(full_text)
+ if not ctes:
+ return full_text, text_before_cursor, ()
+
+ current_position = len(text_before_cursor)
+ meta = []
+
+ for cte in ctes:
+ if cte.start < current_position < cte.stop:
+ # Currently editing a cte - treat its body as the current full_text
+ text_before_cursor = full_text[cte.start:current_position]
+ full_text = full_text[cte.start:cte.stop]
+ return full_text, text_before_cursor, meta
+
+ # Append this cte to the list of available table metadata
+ cols = (ColumnMetadata(name, None, ()) for name in cte.columns)
+ meta.append(TableMetadata(cte.name, cols))
+
+ # Editing past the last cte (ie the main body of the query)
+ full_text = full_text[ctes[-1].stop:]
+ text_before_cursor = text_before_cursor[ctes[-1].stop:current_position]
+
+ return full_text, text_before_cursor, tuple(meta)
+
+
+def extract_ctes(sql):
+ """ Extract constant table expresseions from a query
+
+ Returns tuple (ctes, remainder_sql)
+
+ ctes is a list of TableExpression namedtuples
+ remainder_sql is the text from the original query after the CTEs have
+ been stripped.
+ """
+
+ p = parse(sql)[0]
+
+ # Make sure the first meaningful token is "WITH" which is necessary to
+ # define CTEs
+ idx, tok = p.token_next(-1, skip_ws=True, skip_cm=True)
+ if not (tok and tok.ttype == CTE):
+ return [], sql
+
+ # Get the next (meaningful) token, which should be the first CTE
+ idx, tok = p.token_next(idx)
+ start_pos = token_start_pos(p.tokens, idx)
+ ctes = []
+
+ if isinstance(tok, IdentifierList):
+ # Multiple ctes
+ for t in tok.get_identifiers():
+ cte_start_offset = token_start_pos(tok.tokens, tok.token_index(t))
+ cte = get_cte_from_token(t, start_pos + cte_start_offset)
+ if not cte:
+ continue
+ ctes.append(cte)
+ elif isinstance(tok, Identifier):
+ # A single CTE
+ cte = get_cte_from_token(tok, start_pos)
+ if cte:
+ ctes.append(cte)
+
+ idx = p.token_index(tok) + 1
+
+ # Collapse everything after the ctes into a remainder query
+ remainder = u''.join(str(tok) for tok in p.tokens[idx:])
+
+ return ctes, remainder
+
+
+def get_cte_from_token(tok, pos0):
+ cte_name = tok.get_real_name()
+ if not cte_name:
+ return None
+
+ # Find the start position of the opening parens enclosing the cte body
+ idx, parens = tok.token_next_by(Parenthesis)
+ if not parens:
+ return None
+
+ start_pos = pos0 + token_start_pos(tok.tokens, idx)
+ cte_len = len(str(parens)) # includes parens
+ stop_pos = start_pos + cte_len
+
+ column_names = extract_column_names(parens)
+
+ return TableExpression(cte_name, column_names, start_pos, stop_pos)
+
+
+def extract_column_names(parsed):
+ # Find the first DML token to check if it's a SELECT or INSERT/UPDATE/DELETE
+ idx, tok = parsed.token_next_by(t=DML)
+ tok_val = tok and tok.value.lower()
+
+ if tok_val in ('insert', 'update', 'delete'):
+ # Jump ahead to the RETURNING clause where the list of column names is
+ idx, tok = parsed.token_next_by(idx, (Keyword, 'returning'))
+ elif not tok_val == 'select':
+ # Must be invalid CTE
+ return ()
+
+ # The next token should be either a column name, or a list of column names
+ idx, tok = parsed.token_next(idx, skip_ws=True, skip_cm=True)
+ return tuple(t.get_name() for t in _identifiers(tok))
+
+
+def token_start_pos(tokens, idx):
+ return sum(len(str(t)) for t in tokens[:idx])
+
+
+def _identifiers(tok):
+ if isinstance(tok, IdentifierList):
+ for t in tok.get_identifiers():
+ # NB: IdentifierList.get_identifiers() can return non-identifiers!
+ if isinstance(t, Identifier):
+ yield t
+ elif isinstance(tok, Identifier):
+ yield tok
+
+
+
+
+
+
diff --git a/pgcli/packages/function_metadata.py b/pgcli/packages/parseutils/meta.py
index 88d9fc49..48ba3a5e 100644
--- a/pgcli/packages/function_metadata.py
+++ b/pgcli/packages/parseutils/meta.py
@@ -3,6 +3,7 @@ from collections import namedtuple
ColumnMetadata = namedtuple('ColumnMetadata', ['name', 'datatype', 'foreignkeys'])
ForeignKey = namedtuple('ForeignKey', ['parentschema', 'parenttable',
'parentcolumn', 'childschema', 'childtable', 'childcolumn'])
+TableMetadata = namedtuple('TableMetadata', 'name columns')
class FunctionMetadata(object):
diff --git a/pgcli/packages/parseutils.py b/pgcli/packages/parseutils/tables.py
index 4a5b47c2..72cdcc6b 100644
--- a/pgcli/packages/parseutils.py
+++ b/pgcli/packages/parseutils/tables.py
@@ -1,68 +1,8 @@
from __future__ import print_function
-import re
import sqlparse
from collections import namedtuple
from sqlparse.sql import IdentifierList, Identifier, Function
-from sqlparse.tokens import Keyword, DML, Punctuation, Token, Error
-
-cleanup_regex = {
- # This matches only alphanumerics and underscores.
- 'alphanum_underscore': re.compile(r'(\w+)$'),
- # This matches everything except spaces, parens, colon, and comma
- 'many_punctuations': re.compile(r'([^():,\s]+)$'),
- # This matches everything except spaces, parens, colon, comma, and period
- 'most_punctuations': re.compile(r'([^\.():,\s]+)$'),
- # This matches everything except a space.
- 'all_punctuations': re.compile('([^\s]+)$'),
- }
-
-def last_word(text, include='alphanum_underscore'):
- """
- Find the last word in a sentence.
-
- >>> last_word('abc')
- 'abc'
- >>> last_word(' abc')
- 'abc'
- >>> last_word('')
- ''
- >>> last_word(' ')
- ''
- >>> last_word('abc ')
- ''
- >>> last_word('abc def')
- 'def'
- >>> last_word('abc def ')
- ''
- >>> last_word('abc def;')
- ''
- >>> last_word('bac $def')
- 'def'
- >>> last_word('bac $def', include='most_punctuations')
- '$def'
- >>> last_word('bac \def', include='most_punctuations')
- '\\\\def'
- >>> last_word('bac \def;', include='most_punctuations')
- '\\\\def;'
- >>> last_word('bac::def', include='most_punctuations')
- 'def'
- >>> last_word('"foo*bar', include='most_punctuations')
- '"foo*bar'
- """
-
- if not text: # Empty string
- return ''
-
- if text[-1].isspace():
- return ''
- else:
- regex = cleanup_regex[include]
- matches = regex.search(text)
- if matches:
- return matches.group(0)
- else:
- return ''
-
+from sqlparse.tokens import Keyword, DML, Punctuation
TableReference = namedtuple('TableReference', ['schema', 'name', 'alias',
'is_function'])
@@ -86,6 +26,7 @@ def is_subselect(parsed):
def _identifier_is_function(identifier):
return any(isinstance(t, Function) for t in identifier.tokens)
+
def extract_from_part(parsed, stop_at_punctuation=True):
tbl_prefix_seen = False
for item in parsed.tokens:
@@ -204,106 +145,3 @@ def extract_tables(sql):
allow_functions=not insert_stmt)
# In the case 'sche.<cursor>', we get an empty TableReference; remove that
return tuple(i for i in identifiers if i.name)
-
-
-
-def find_prev_keyword(sql):
- """ Find the last sql keyword in an SQL statement
-
- Returns the value of the last keyword, and the text of the query with
- everything after the last keyword stripped
- """
- if not sql.strip():
- return None, ''
-
- parsed = sqlparse.parse(sql)[0]
- flattened = list(parsed.flatten())
-
- logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN')
-
- for t in reversed(flattened):
- if t.value == '(' or (t.is_keyword and (
- t.value.upper() not in logical_operators)):
- # Find the location of token t in the original parsed statement
- # We can't use parsed.token_index(t) because t may be a child token
- # inside a TokenList, in which case token_index thows an error
- # Minimal example:
- # p = sqlparse.parse('select * from foo where bar')
- # t = list(p.flatten())[-3] # The "Where" token
- # p.token_index(t) # Throws ValueError: not in list
- idx = flattened.index(t)
-
- # Combine the string values of all tokens in the original list
- # up to and including the target keyword token t, to produce a
- # query string with everything after the keyword token removed
- text = ''.join(tok.value for tok in flattened[:idx+1])
- return t, text
-
- return None, ''
-
-
-# Postgresql dollar quote signs look like `$$` or `$tag$`
-dollar_quote_regex = re.compile(r'^\$[^$]*\$$')
-
-
-def is_open_quote(sql):
- """Returns true if the query contains an unclosed quote"""
-
- # parsed can contain one or more semi-colon separated commands
- parsed = sqlparse.parse(sql)
- return any(_parsed_is_open_quote(p) for p in parsed)
-
-
-def _parsed_is_open_quote(parsed):
- tokens = list(parsed.flatten())
-
- i = 0
- while i < len(tokens):
- tok = tokens[i]
- if tok.match(Token.Error, "'"):
- # An unmatched single quote
- return True
- elif (tok.ttype in Token.Name.Builtin
- and dollar_quote_regex.match(tok.value)):
- # Find the matching closing dollar quote sign
- for (j, tok2) in enumerate(tokens[i+1:], i+1):
- if tok2.match(Token.Name.Builtin, tok.value):
- # Found the matching closing quote - continue our scan for
- # open quotes thereafter
- i = j
- break
- else:
- # No matching dollar sign quote
- return True
-
- i += 1
-
- return False
-
-
-def parse_partial_identifier(word):
- """Attempt to parse a (partially typed) word as an identifier
-
- word may include a schema qualification, like `schema_name.partial_name`
- or `schema_name.` There may also be unclosed quotation marks, like
- `"schema`, or `schema."partial_name`
-
- :param word: string representing a (partially complete) identifier
- :return: sqlparse.sql.Identifier, or None
- """
-
- p = sqlparse.parse(word)[0]
- n_tok = len(p.tokens)
- if n_tok == 1 and isinstance(p.tokens[0], Identifier):
- return p.tokens[0]
- elif p.token_next_by(m=(Error, '"'))[1]:
- # An unmatched double quote, e.g. '"foo', 'foo."', or 'foo."bar'
- # Close the double quote, then reparse
- return parse_partial_identifier(word + '"')
- else:
- return None
-
-
-if __name__ == '__main__':
- sql = 'select * from (select t. from tabl t'
- print (extract_tables(sql))
diff --git a/pgcli/packages/parseutils/utils.py b/pgcli/packages/parseutils/utils.py
new file mode 100644
index 00000000..e5baa0b0
--- /dev/null
+++ b/pgcli/packages/parseutils/utils.py
@@ -0,0 +1,161 @@
+from __future__ import print_function
+import re
+import sqlparse
+from sqlparse.sql import Identifier
+from sqlparse.tokens import Token, Error
+
+cleanup_regex = {
+ # This matches only alphanumerics and underscores.
+ 'alphanum_underscore': re.compile(r'(\w+)$'),
+ # This matches everything except spaces, parens, colon, and comma
+ 'many_punctuations': re.compile(r'([^():,\s]+)$'),
+ # This matches everything except spaces, parens, colon, comma, and period
+ 'most_punctuations': re.compile(r'([^\.():,\s]+)$'),
+ # This matches everything except a space.
+ 'all_punctuations': re.compile('([^\s]+)$'),
+ }
+
+def last_word(text, include='alphanum_underscore'):
+ """
+ Find the last word in a sentence.
+
+ >>> last_word('abc')
+ 'abc'
+ >>> last_word(' abc')
+ 'abc'
+ >>> last_word('')
+ ''
+ >>> last_word(' ')
+ ''
+ >>> last_word('abc ')
+ ''
+ >>> last_word('abc def')
+ 'def'
+ >>> last_word('abc def ')
+ ''
+ >>> last_word('abc def;')
+ ''
+ >>> last_word('bac $def')
+ 'def'
+ >>> last_word('bac $def', include='most_punctuations')
+ '$def'
+ >>> last_word('bac \def', include='most_punctuations')
+ '\\\\def'
+ >>> last_word('bac \def;', include='most_punctuations')
+ '\\\\def;'
+ >>> last_word('bac::def', include='most_punctuations')
+ 'def'
+ >>> last_word('"foo*bar', include='most_punctuations')
+ '"foo*bar'
+ """
+
+ if not text: # Empty string
+ return ''
+
+ if text[-1].isspace():
+ return ''
+ else:
+ regex = cleanup_regex[include]
+ matches = regex.search(text)
+ if matches:
+ return matches.group(0)
+ else:
+ return ''
+
+
+def find_prev_keyword(sql):
+ """ Find the last sql keyword in an SQL statement
+
+ Returns the value of the last keyword, and the text of the query with
+ everything after the last keyword stripped
+ """
+ if not sql.strip():
+ return None, ''
+
+ parsed = sqlparse.parse(sql)[0]
+ flattened = list(parsed.flatten())
+
+ logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN')
+
+ for t in reversed(flattened):
+ if t.value == '(' or (t.is_keyword and (
+ t.value.upper() not in logical_operators)):
+ # Find the location of token t in the original parsed statement
+ # We can't use parsed.token_index(t) because t may be a child token
+ # inside a TokenList, in which case token_index thows an error
+ # Minimal example:
+ # p = sqlparse.parse('select * from foo where bar')
+ # t = list(p.flatten())[-3] # The "Where" token
+ # p.token_index(t) # Throws ValueError: not in list
+ idx = flattened.index(t)
+
+ # Combine the string values of all tokens in the original list
+ # up to and including the target keyword token t, to produce a
+ # query string with everything after the keyword token removed
+ text = ''.join(tok.value for tok in flattened[:idx+1])
+ return t, text
+
+ return None, ''
+
+
+# Postgresql dollar quote signs look like `$$` or `$tag$`
+dollar_quote_regex = re.compile(r'^\$[^$]*\$$')
+
+
+def is_open_quote(sql):
+ """Returns true if the query contains an unclosed quote"""
+
+ # parsed can contain one or more semi-colon separated commands
+ parsed = sqlparse.parse(sql)
+ return any(_parsed_is_open_quote(p) for p in parsed)
+
+
+def _parsed_is_open_quote(parsed):
+ tokens = list(parsed.flatten())
+
+ i = 0
+ while i < len(tokens):
+ tok = tokens[i]
+ if tok.match(Token.Error, "'"):
+ # An unmatched single quote
+ return True
+ elif (tok.ttype in Token.Name.Builtin
+ and dollar_quote_regex.match(tok.value)):
+ # Find the matching closing dollar quote sign
+ for (j, tok2) in enumerate(tokens[i+1:], i+1):
+ if tok2.match(Token.Name.Builtin, tok.value):
+ # Found the matching closing quote - continue our scan for
+ # open quotes thereafter
+ i = j
+ break
+ else:
+ # No matching dollar sign quote
+ return True
+
+ i += 1
+
+ return False
+
+
+def parse_partial_identifier(word):
+ """Attempt to parse a (partially typed) word as an identifier
+
+ word may include a schema qualification, like `schema_name.partial_name`
+ or `schema_name.` There may also be unclosed quotation marks, like
+ `"schema`, or `schema."partial_name`
+
+ :param word: string representing a (partially complete) identifier
+ :return: sqlparse.sql.Identifier, or None
+ """
+
+ p = sqlparse.parse(word)[0]
+ n_tok = len(p.tokens)
+ if n_tok == 1 and isinstance(p.tokens[0], Identifier):
+ return p.tokens[0]
+ elif p.token_next_by(m=(Error, '"'))[1]:
+ # An unmatched double quote, e.g. '"foo', 'foo."', or 'foo."bar'
+ # Close the double quote, then reparse
+ return parse_partial_identifier(word + '"')
+ else:
+ return None
+
diff --git a/pgcli/packages/sqlcompletion.py b/pgcli/packages/sqlcompletion.py
index e339e2b4..8d00fc28 100644
--- a/pgcli/packages/sqlcompletion.py
+++ b/pgcli/packages/sqlcompletion.py
@@ -4,8 +4,10 @@ import re
import sqlparse
from collections import namedtuple
from sqlparse.sql import Comparison, Identifier, Where
-from .parseutils import (
- last_word, extract_tables, find_prev_keyword, parse_partial_identifier)
+from .parseutils.utils import (
+ last_word, find_prev_keyword, parse_partial_identifier)
+from .parseutils.tables import extract_tables
+from .parseutils.ctes import isolate_query_ctes
from pgspecial.main import parse_special_command
PY2 = sys.version_info[0] == 2
@@ -21,25 +23,26 @@ Special = namedtuple('Special', [])
Database = namedtuple('Database', [])
Schema = namedtuple('Schema', [])
# FromClauseItem is a table/view/function used in the FROM clause
-# `tables` contains the list of tables/... already in the statement,
+# `table_refs` contains the list of tables/... already in the statement,
# used to ensure that the alias we suggest is unique
-FromClauseItem = namedtuple('FromClauseItem', 'schema tables')
-Table = namedtuple('Table', ['schema', 'tables'])
-View = namedtuple('View', ['schema', 'tables'])
+FromClauseItem = namedtuple('FromClauseItem', 'schema table_refs local_tables')
+Table = namedtuple('Table', ['schema', 'table_refs', 'local_tables'])
+View = namedtuple('View', ['schema', 'table_refs'])
# JoinConditions are suggested after ON, e.g. 'foo.barid = bar.barid'
-JoinCondition = namedtuple('JoinCondition', ['tables', 'parent'])
+JoinCondition = namedtuple('JoinCondition', ['table_refs', 'parent'])
# Joins are suggested after JOIN, e.g. 'foo ON foo.barid = bar.barid'
-Join = namedtuple('Join', ['tables', 'schema'])
+Join = namedtuple('Join', ['table_refs', 'schema'])
-Function = namedtuple('Function', ['schema', 'tables', 'filter'])
+Function = namedtuple('Function', ['schema', 'table_refs', 'filter'])
# For convenience, don't require the `filter` argument in Function constructor
Function.__new__.__defaults__ = (None, tuple(), None)
-Table.__new__.__defaults__ = (None, tuple())
+Table.__new__.__defaults__ = (None, tuple(), tuple())
View.__new__.__defaults__ = (None, tuple())
-FromClauseItem.__new__.__defaults__ = (None, tuple())
+FromClauseItem.__new__.__defaults__ = (None, tuple(), tuple())
-Column = namedtuple('Column', ['tables', 'require_last_table'])
-Column.__new__.__defaults__ = (None, None)
+Column = namedtuple(
+ 'Column', ['table_refs', 'require_last_table', 'local_tables'])
+Column.__new__.__defaults__ = (None, None, tuple())
Keyword = namedtuple('Keyword', [])
NamedQuery = namedtuple('NamedQuery', [])
@@ -56,6 +59,10 @@ class SqlStatement(object):
text_before_cursor, include='many_punctuations')
full_text = _strip_named_query(full_text)
text_before_cursor = _strip_named_query(text_before_cursor)
+
+ full_text, text_before_cursor, self.local_tables = \
+ isolate_query_ctes(full_text, text_before_cursor)
+
self.text_before_cursor_including_last_word = text_before_cursor
# If we've partially typed a word then word_before_cursor won't be an
@@ -313,7 +320,9 @@ def suggest_based_on_last_token(token, stmt):
tables = stmt.get_tables('before')
# suggest columns that are present in more than one table
- return (Column(tables=tables, require_last_table=True),)
+ return (Column(table_refs=tables,
+ require_last_table=True,
+ local_tables=stmt.local_tables),)
elif p.token_first().value.lower() == 'select':
# If the lparen is preceeded by a space chances are we're about to
@@ -323,23 +332,25 @@ def suggest_based_on_last_token(token, stmt):
return (Keyword(),)
prev_prev_tok = p.token_prev(p.token_index(prev_tok))[1]
if prev_prev_tok and prev_prev_tok.normalized == 'INTO':
- return (Column(tables=stmt.get_tables('insert')),)
+ return (Column(table_refs=stmt.get_tables('insert')),)
# We're probably in a function argument list
- return (Column(tables=extract_tables(stmt.full_text)),)
+ return (Column(table_refs=extract_tables(stmt.full_text),
+ local_tables=stmt.local_tables),)
elif token_v in ('set', 'by', 'distinct'):
- return (Column(tables=stmt.get_tables()),)
+ return (Column(table_refs=stmt.get_tables(),
+ local_tables=stmt.local_tables),)
elif token_v in ('select', 'where', 'having'):
# Check for a table alias or schema qualification
parent = (stmt.identifier and stmt.identifier.get_parent_name()) or []
tables = stmt.get_tables()
if parent:
tables = tuple(t for t in tables if identifies(parent, t))
- return (Column(tables=tables),
+ return (Column(table_refs=tables, local_tables=stmt.local_tables),
Table(schema=parent),
View(schema=parent),
Function(schema=parent),)
else:
- return (Column(tables=tables),
+ return (Column(table_refs=tables, local_tables=stmt.local_tables),
Function(schema=None),
Keyword(),)
@@ -358,9 +369,10 @@ def suggest_based_on_last_token(token, stmt):
# Suggest schemas
suggest.insert(0, Schema())
- # Suggest set-returning functions in the FROM clause
if token_v == 'from' or is_join:
- suggest.append(FromClauseItem(schema=schema, tables=tables))
+ suggest.append(FromClauseItem(schema=schema,
+ table_refs=tables,
+ local_tables=stmt.local_tables))
elif token_v == 'truncate':
suggest.append(Table(schema))
else:
@@ -368,7 +380,7 @@ def suggest_based_on_last_token(token, stmt):
if is_join and _allow_join(stmt.parsed):
tables = stmt.get_tables('before')
- suggest.append(Join(tables=tables, schema=schema))
+ suggest.append(Join(table_refs=tables, schema=schema))
return tuple(suggest)
@@ -383,7 +395,7 @@ def suggest_based_on_last_token(token, stmt):
elif token_v == 'column':
# E.g. 'ALTER TABLE foo ALTER COLUMN bar
- return (Column(tables=stmt.get_tables()),)
+ return (Column(table_refs=stmt.get_tables()),)
elif token_v == 'on':
tables = stmt.get_tables('before')
@@ -392,12 +404,13 @@ def suggest_based_on_last_token(token, stmt):
# "ON parent.<suggestion>"
# parent can be either a schema name or table alias
filteredtables = tuple(t for t in tables if identifies(parent, t))
- sugs = [Column(tables=filteredtables),
+ sugs = [Column(table_refs=filteredtables,
+ local_tables=stmt.local_tables),
Table(schema=parent),
View(schema=parent),
Function(schema=parent)]
if filteredtables and _allow_join_condition(stmt.parsed):
- sugs.append(JoinCondition(tables=tables,
+ sugs.append(JoinCondition(table_refs=tables,
parent=filteredtables[-1]))
return tuple(sugs)
else:
@@ -406,7 +419,7 @@ def suggest_based_on_last_token(token, stmt):
aliases = tuple(t.ref for t in tables)
if _allow_join_condition(stmt.parsed):
return (Alias(aliases=aliases), JoinCondition(
- tables=tables, parent=None))
+ table_refs=tables, parent=None))
else:
return (Alias(aliases=aliases),)
diff --git a/pgcli/pgbuffer.py b/pgcli/pgbuffer.py
index df797cf8..b9e9b930 100644
--- a/pgcli/pgbuffer.py
+++ b/pgcli/pgbuffer.py
@@ -1,6 +1,6 @@
from prompt_toolkit.buffer import Buffer
from prompt_toolkit.filters import Condition
-from .packages.parseutils import is_open_quote
+from .packages.parseutils.utils import is_open_quote
class PGBuffer(Buffer):
diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py
index 6a32e0c0..3c876aba 100644
--- a/pgcli/pgcompleter.py
+++ b/pgcli/pgcompleter.py
@@ -11,8 +11,9 @@ from prompt_toolkit.document import Document
from .packages.sqlcompletion import (FromClauseItem,
suggest_type, Special, Database, Schema, Table, Function, Column, View,
Keyword, NamedQuery, Datatype, Alias, Path, JoinCondition, Join)
-from .packages.function_metadata import ColumnMetadata, ForeignKey
-from .packages.parseutils import last_word, TableReference
+from .packages.parseutils.meta import ColumnMetadata, ForeignKey
+from .packages.parseutils.utils import last_word
+from .packages.parseutils.tables import TableReference
from .packages.pgliterals.main import get_literals
from .packages.prioritization import PrevalenceCounter
from .config import load_config, config_location
@@ -367,9 +368,10 @@ class PGCompleter(Completer):
def get_column_matches(self, suggestion, word_before_cursor):
- tables = suggestion.tables
+ tables = suggestion.table_refs
_logger.debug("Completion column scope: %r", tables)
- scoped_cols = self.populate_scoped_cols(tables)
+ scoped_cols = self.populate_scoped_cols(tables, suggestion.local_tables)
+
colit = scoped_cols.items
flat_cols = list(chain(*((c.name for c in cols)
for t, cols in colit())))
@@ -424,7 +426,7 @@ class PGCompleter(Completer):
return next(a for a in aliases if normalize_ref(a) not in tbls)
def get_join_matches(self, suggestion, word_before_cursor):
- tbls = suggestion.tables
+ tbls = suggestion.table_refs
cols = self.populate_scoped_cols(tbls)
# Set up some data structures for efficient access
qualified = dict((normalize_ref(t.ref), t.schema) for t in tbls)
@@ -446,7 +448,7 @@ class PGCompleter(Completer):
continue
c = self.case
if self.generate_aliases or normalize_ref(left.tbl) in refs:
- lref = self.alias(left.tbl, suggestion.tables)
+ lref = self.alias(left.tbl, suggestion.table_refs)
join = '{0} {4} ON {4}.{1} = {2}.{3}'.format(
c(left.tbl), c(left.col), rtbl.ref, c(right.col), lref)
else:
@@ -467,10 +469,10 @@ class PGCompleter(Completer):
def get_join_condition_matches(self, suggestion, word_before_cursor):
col = namedtuple('col', 'schema tbl col')
- tbls = self.populate_scoped_cols(suggestion.tables).items
+ tbls = self.populate_scoped_cols(suggestion.table_refs).items
cols = [(t, c) for t, cs in tbls() for c in cs]
try:
- lref = (suggestion.parent or suggestion.tables[-1]).ref
+ lref = (suggestion.parent or suggestion.table_refs[-1]).ref
ltbl, lcols = [(t, cs) for (t, cs) in tbls() if t.ref == lref][-1]
except IndexError: # The user typed an incorrect table qualifier
return []
@@ -493,7 +495,7 @@ class PGCompleter(Completer):
# Tables that are closer to the cursor get higher prio
ref_prio = dict((tbl.ref, num) for num, tbl
- in enumerate(suggestion.tables))
+ in enumerate(suggestion.table_refs))
# Map (schema, table, col) to tables
coldict = list_dict(((t.schema, t.name, c.name), t)
for t, c in cols if t.ref != lref)
@@ -529,7 +531,7 @@ class PGCompleter(Completer):
funcs = self.populate_functions(suggestion.schema, filt)
if alias:
funcs = [self.case(f) + '() ' + self.alias(f,
- suggestion.tables) for f in funcs]
+ suggestion.table_refs) for f in funcs]
else:
funcs = [self.case(f) + '()' for f in funcs]
else:
@@ -564,15 +566,17 @@ class PGCompleter(Completer):
def get_from_clause_item_matches(self, suggestion, word_before_cursor):
alias = self.generate_aliases
- t_sug = Table(*suggestion)
- v_sug = View(*suggestion)
- f_sug = Function(*suggestion, filter='for_from_clause')
+ s = suggestion
+ t_sug = Table(s.schema, s.table_refs, s.local_tables)
+ v_sug = View(s.schema, s.table_refs)
+ f_sug = Function(s.schema, s.table_refs, filter='for_from_clause')
return (self.get_table_matches(t_sug, word_before_cursor, alias)
+ self.get_view_matches(v_sug, word_before_cursor, alias)
+ self.get_function_matches(f_sug, word_before_cursor, alias))
def get_table_matches(self, suggestion, word_before_cursor, alias=False):
tables = self.populate_schema_objects(suggestion.schema, 'tables')
+ tables.extend(tbl.name for tbl in suggestion.local_tables)
# Unless we're sure the user really wants them, don't suggest the
# pg_catalog tables that are implicitly on the search path
@@ -580,7 +584,7 @@ class PGCompleter(Completer):
not word_before_cursor.startswith('pg_')):
tables = [t for t in tables if not t.startswith('pg_')]
if alias:
- tables = [self.case(t) + ' ' + self.alias(t, suggestion.tables)
+ tables = [self.case(t) + ' ' + self.alias(t, suggestion.table_refs)
for t in tables]
return self.find_matches(word_before_cursor, tables,
meta='table')
@@ -593,7 +597,7 @@ class PGCompleter(Completer):
not word_before_cursor.startswith('pg_')):
views = [v for v in views if not v.startswith('pg_')]
if alias:
- views = [self.case(v) + ' ' + self.alias(v, suggestion.tables)
+ views = [self.case(v) + ' ' + self.alias(v, suggestion.table_refs)
for v in views]
return self.find_matches(word_before_cursor, views, meta='view')
@@ -661,9 +665,10 @@ class PGCompleter(Completer):
Path: get_path_matches,
}
- def populate_scoped_cols(self, scoped_tbls):
+ def populate_scoped_cols(self, scoped_tbls, local_tbls=()):
""" Find all columns in a set of scoped_tables
:param scoped_tbls: list of TableReference namedtuples
+ :param local_tbls: tuple(TableMetadata)
:return: {TableReference:{colname:ColumnMetaData}}
"""
@@ -696,6 +701,10 @@ class PGCompleter(Completer):
cols = cols.values()
addcols(schema, relname, tbl.alias, reltype, cols)
break
+
+ # Local tables sho