diff options
author | Amjith Ramanujam <amjith.r@gmail.com> | 2016-08-08 22:10:29 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-08-08 22:10:29 -0700 |
commit | ab6e54b2cf5fd3badb1a31143146b36beaeb88d9 (patch) | |
tree | f392d3c936e8f9f600630cad545e7630ae6ab690 | |
parent | 28d01e44ae4b7741c461116d660b05bc27b5ab60 (diff) | |
parent | 2407de0e4c4c52e27b0546c374722bd4e88287ad (diff) |
Merge pull request #553 from dbcli/darikg/cte-suggestions
CTE-aware suggestions
-rw-r--r-- | pgcli/packages/parseutils/__init__.py | 0 | ||||
-rw-r--r-- | pgcli/packages/parseutils/ctes.py | 146 | ||||
-rw-r--r-- | pgcli/packages/parseutils/meta.py (renamed from pgcli/packages/function_metadata.py) | 1 | ||||
-rw-r--r-- | pgcli/packages/parseutils/tables.py (renamed from pgcli/packages/parseutils.py) | 166 | ||||
-rw-r--r-- | pgcli/packages/parseutils/utils.py | 161 | ||||
-rw-r--r-- | pgcli/packages/sqlcompletion.py | 65 | ||||
-rw-r--r-- | pgcli/pgbuffer.py | 2 | ||||
-rw-r--r-- | pgcli/pgcompleter.py | 41 | ||||
-rw-r--r-- | pgcli/pgexecute.py | 2 | ||||
-rw-r--r-- | tests/metadata.py | 2 | ||||
-rw-r--r-- | tests/parseutils/test_ctes.py | 121 | ||||
-rw-r--r-- | tests/parseutils/test_function_metadata.py (renamed from tests/test_function_metadata.py) | 2 | ||||
-rw-r--r-- | tests/parseutils/test_parseutils.py (renamed from tests/test_parseutils.py) | 4 | ||||
-rw-r--r-- | tests/test_pgexecute.py | 2 | ||||
-rw-r--r-- | tests/test_smart_completion_public_schema_only.py | 47 | ||||
-rw-r--r-- | tests/test_sqlcompletion.py | 66 |
16 files changed, 582 insertions, 246 deletions
diff --git a/pgcli/packages/parseutils/__init__.py b/pgcli/packages/parseutils/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/pgcli/packages/parseutils/__init__.py diff --git a/pgcli/packages/parseutils/ctes.py b/pgcli/packages/parseutils/ctes.py new file mode 100644 index 00000000..12c207c4 --- /dev/null +++ b/pgcli/packages/parseutils/ctes.py @@ -0,0 +1,146 @@ +from sqlparse import parse +from sqlparse.tokens import Keyword, CTE, DML +from sqlparse.sql import Identifier, IdentifierList, Parenthesis +from collections import namedtuple +from .meta import TableMetadata, ColumnMetadata + + +# TableExpression is a namedtuple representing a CTE, used internally +# name: cte alias assigned in the query +# columns: list of column names +# start: index into the original string of the left parens starting the CTE +# stop: index into the original string of the right parens ending the CTE +TableExpression = namedtuple('TableExpression', 'name columns start stop') + + +def isolate_query_ctes(full_text, text_before_cursor): + """Simplify a query by converting CTEs into table metadata objects + """ + + if not full_text: + return full_text, text_before_cursor, tuple() + + ctes, remainder = extract_ctes(full_text) + if not ctes: + return full_text, text_before_cursor, () + + current_position = len(text_before_cursor) + meta = [] + + for cte in ctes: + if cte.start < current_position < cte.stop: + # Currently editing a cte - treat its body as the current full_text + text_before_cursor = full_text[cte.start:current_position] + full_text = full_text[cte.start:cte.stop] + return full_text, text_before_cursor, meta + + # Append this cte to the list of available table metadata + cols = (ColumnMetadata(name, None, ()) for name in cte.columns) + meta.append(TableMetadata(cte.name, cols)) + + # Editing past the last cte (ie the main body of the query) + full_text = full_text[ctes[-1].stop:] + text_before_cursor = text_before_cursor[ctes[-1].stop:current_position] + + return full_text, text_before_cursor, tuple(meta) + + +def extract_ctes(sql): + """ Extract constant table expresseions from a query + + Returns tuple (ctes, remainder_sql) + + ctes is a list of TableExpression namedtuples + remainder_sql is the text from the original query after the CTEs have + been stripped. + """ + + p = parse(sql)[0] + + # Make sure the first meaningful token is "WITH" which is necessary to + # define CTEs + idx, tok = p.token_next(-1, skip_ws=True, skip_cm=True) + if not (tok and tok.ttype == CTE): + return [], sql + + # Get the next (meaningful) token, which should be the first CTE + idx, tok = p.token_next(idx) + start_pos = token_start_pos(p.tokens, idx) + ctes = [] + + if isinstance(tok, IdentifierList): + # Multiple ctes + for t in tok.get_identifiers(): + cte_start_offset = token_start_pos(tok.tokens, tok.token_index(t)) + cte = get_cte_from_token(t, start_pos + cte_start_offset) + if not cte: + continue + ctes.append(cte) + elif isinstance(tok, Identifier): + # A single CTE + cte = get_cte_from_token(tok, start_pos) + if cte: + ctes.append(cte) + + idx = p.token_index(tok) + 1 + + # Collapse everything after the ctes into a remainder query + remainder = u''.join(str(tok) for tok in p.tokens[idx:]) + + return ctes, remainder + + +def get_cte_from_token(tok, pos0): + cte_name = tok.get_real_name() + if not cte_name: + return None + + # Find the start position of the opening parens enclosing the cte body + idx, parens = tok.token_next_by(Parenthesis) + if not parens: + return None + + start_pos = pos0 + token_start_pos(tok.tokens, idx) + cte_len = len(str(parens)) # includes parens + stop_pos = start_pos + cte_len + + column_names = extract_column_names(parens) + + return TableExpression(cte_name, column_names, start_pos, stop_pos) + + +def extract_column_names(parsed): + # Find the first DML token to check if it's a SELECT or INSERT/UPDATE/DELETE + idx, tok = parsed.token_next_by(t=DML) + tok_val = tok and tok.value.lower() + + if tok_val in ('insert', 'update', 'delete'): + # Jump ahead to the RETURNING clause where the list of column names is + idx, tok = parsed.token_next_by(idx, (Keyword, 'returning')) + elif not tok_val == 'select': + # Must be invalid CTE + return () + + # The next token should be either a column name, or a list of column names + idx, tok = parsed.token_next(idx, skip_ws=True, skip_cm=True) + return tuple(t.get_name() for t in _identifiers(tok)) + + +def token_start_pos(tokens, idx): + return sum(len(str(t)) for t in tokens[:idx]) + + +def _identifiers(tok): + if isinstance(tok, IdentifierList): + for t in tok.get_identifiers(): + # NB: IdentifierList.get_identifiers() can return non-identifiers! + if isinstance(t, Identifier): + yield t + elif isinstance(tok, Identifier): + yield tok + + + + + + diff --git a/pgcli/packages/function_metadata.py b/pgcli/packages/parseutils/meta.py index 88d9fc49..48ba3a5e 100644 --- a/pgcli/packages/function_metadata.py +++ b/pgcli/packages/parseutils/meta.py @@ -3,6 +3,7 @@ from collections import namedtuple ColumnMetadata = namedtuple('ColumnMetadata', ['name', 'datatype', 'foreignkeys']) ForeignKey = namedtuple('ForeignKey', ['parentschema', 'parenttable', 'parentcolumn', 'childschema', 'childtable', 'childcolumn']) +TableMetadata = namedtuple('TableMetadata', 'name columns') class FunctionMetadata(object): diff --git a/pgcli/packages/parseutils.py b/pgcli/packages/parseutils/tables.py index 4a5b47c2..72cdcc6b 100644 --- a/pgcli/packages/parseutils.py +++ b/pgcli/packages/parseutils/tables.py @@ -1,68 +1,8 @@ from __future__ import print_function -import re import sqlparse from collections import namedtuple from sqlparse.sql import IdentifierList, Identifier, Function -from sqlparse.tokens import Keyword, DML, Punctuation, Token, Error - -cleanup_regex = { - # This matches only alphanumerics and underscores. - 'alphanum_underscore': re.compile(r'(\w+)$'), - # This matches everything except spaces, parens, colon, and comma - 'many_punctuations': re.compile(r'([^():,\s]+)$'), - # This matches everything except spaces, parens, colon, comma, and period - 'most_punctuations': re.compile(r'([^\.():,\s]+)$'), - # This matches everything except a space. - 'all_punctuations': re.compile('([^\s]+)$'), - } - -def last_word(text, include='alphanum_underscore'): - """ - Find the last word in a sentence. - - >>> last_word('abc') - 'abc' - >>> last_word(' abc') - 'abc' - >>> last_word('') - '' - >>> last_word(' ') - '' - >>> last_word('abc ') - '' - >>> last_word('abc def') - 'def' - >>> last_word('abc def ') - '' - >>> last_word('abc def;') - '' - >>> last_word('bac $def') - 'def' - >>> last_word('bac $def', include='most_punctuations') - '$def' - >>> last_word('bac \def', include='most_punctuations') - '\\\\def' - >>> last_word('bac \def;', include='most_punctuations') - '\\\\def;' - >>> last_word('bac::def', include='most_punctuations') - 'def' - >>> last_word('"foo*bar', include='most_punctuations') - '"foo*bar' - """ - - if not text: # Empty string - return '' - - if text[-1].isspace(): - return '' - else: - regex = cleanup_regex[include] - matches = regex.search(text) - if matches: - return matches.group(0) - else: - return '' - +from sqlparse.tokens import Keyword, DML, Punctuation TableReference = namedtuple('TableReference', ['schema', 'name', 'alias', 'is_function']) @@ -86,6 +26,7 @@ def is_subselect(parsed): def _identifier_is_function(identifier): return any(isinstance(t, Function) for t in identifier.tokens) + def extract_from_part(parsed, stop_at_punctuation=True): tbl_prefix_seen = False for item in parsed.tokens: @@ -204,106 +145,3 @@ def extract_tables(sql): allow_functions=not insert_stmt) # In the case 'sche.<cursor>', we get an empty TableReference; remove that return tuple(i for i in identifiers if i.name) - - - -def find_prev_keyword(sql): - """ Find the last sql keyword in an SQL statement - - Returns the value of the last keyword, and the text of the query with - everything after the last keyword stripped - """ - if not sql.strip(): - return None, '' - - parsed = sqlparse.parse(sql)[0] - flattened = list(parsed.flatten()) - - logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN') - - for t in reversed(flattened): - if t.value == '(' or (t.is_keyword and ( - t.value.upper() not in logical_operators)): - # Find the location of token t in the original parsed statement - # We can't use parsed.token_index(t) because t may be a child token - # inside a TokenList, in which case token_index thows an error - # Minimal example: - # p = sqlparse.parse('select * from foo where bar') - # t = list(p.flatten())[-3] # The "Where" token - # p.token_index(t) # Throws ValueError: not in list - idx = flattened.index(t) - - # Combine the string values of all tokens in the original list - # up to and including the target keyword token t, to produce a - # query string with everything after the keyword token removed - text = ''.join(tok.value for tok in flattened[:idx+1]) - return t, text - - return None, '' - - -# Postgresql dollar quote signs look like `$$` or `$tag$` -dollar_quote_regex = re.compile(r'^\$[^$]*\$$') - - -def is_open_quote(sql): - """Returns true if the query contains an unclosed quote""" - - # parsed can contain one or more semi-colon separated commands - parsed = sqlparse.parse(sql) - return any(_parsed_is_open_quote(p) for p in parsed) - - -def _parsed_is_open_quote(parsed): - tokens = list(parsed.flatten()) - - i = 0 - while i < len(tokens): - tok = tokens[i] - if tok.match(Token.Error, "'"): - # An unmatched single quote - return True - elif (tok.ttype in Token.Name.Builtin - and dollar_quote_regex.match(tok.value)): - # Find the matching closing dollar quote sign - for (j, tok2) in enumerate(tokens[i+1:], i+1): - if tok2.match(Token.Name.Builtin, tok.value): - # Found the matching closing quote - continue our scan for - # open quotes thereafter - i = j - break - else: - # No matching dollar sign quote - return True - - i += 1 - - return False - - -def parse_partial_identifier(word): - """Attempt to parse a (partially typed) word as an identifier - - word may include a schema qualification, like `schema_name.partial_name` - or `schema_name.` There may also be unclosed quotation marks, like - `"schema`, or `schema."partial_name` - - :param word: string representing a (partially complete) identifier - :return: sqlparse.sql.Identifier, or None - """ - - p = sqlparse.parse(word)[0] - n_tok = len(p.tokens) - if n_tok == 1 and isinstance(p.tokens[0], Identifier): - return p.tokens[0] - elif p.token_next_by(m=(Error, '"'))[1]: - # An unmatched double quote, e.g. '"foo', 'foo."', or 'foo."bar' - # Close the double quote, then reparse - return parse_partial_identifier(word + '"') - else: - return None - - -if __name__ == '__main__': - sql = 'select * from (select t. from tabl t' - print (extract_tables(sql)) diff --git a/pgcli/packages/parseutils/utils.py b/pgcli/packages/parseutils/utils.py new file mode 100644 index 00000000..e5baa0b0 --- /dev/null +++ b/pgcli/packages/parseutils/utils.py @@ -0,0 +1,161 @@ +from __future__ import print_function +import re +import sqlparse +from sqlparse.sql import Identifier +from sqlparse.tokens import Token, Error + +cleanup_regex = { + # This matches only alphanumerics and underscores. + 'alphanum_underscore': re.compile(r'(\w+)$'), + # This matches everything except spaces, parens, colon, and comma + 'many_punctuations': re.compile(r'([^():,\s]+)$'), + # This matches everything except spaces, parens, colon, comma, and period + 'most_punctuations': re.compile(r'([^\.():,\s]+)$'), + # This matches everything except a space. + 'all_punctuations': re.compile('([^\s]+)$'), + } + +def last_word(text, include='alphanum_underscore'): + """ + Find the last word in a sentence. + + >>> last_word('abc') + 'abc' + >>> last_word(' abc') + 'abc' + >>> last_word('') + '' + >>> last_word(' ') + '' + >>> last_word('abc ') + '' + >>> last_word('abc def') + 'def' + >>> last_word('abc def ') + '' + >>> last_word('abc def;') + '' + >>> last_word('bac $def') + 'def' + >>> last_word('bac $def', include='most_punctuations') + '$def' + >>> last_word('bac \def', include='most_punctuations') + '\\\\def' + >>> last_word('bac \def;', include='most_punctuations') + '\\\\def;' + >>> last_word('bac::def', include='most_punctuations') + 'def' + >>> last_word('"foo*bar', include='most_punctuations') + '"foo*bar' + """ + + if not text: # Empty string + return '' + + if text[-1].isspace(): + return '' + else: + regex = cleanup_regex[include] + matches = regex.search(text) + if matches: + return matches.group(0) + else: + return '' + + +def find_prev_keyword(sql): + """ Find the last sql keyword in an SQL statement + + Returns the value of the last keyword, and the text of the query with + everything after the last keyword stripped + """ + if not sql.strip(): + return None, '' + + parsed = sqlparse.parse(sql)[0] + flattened = list(parsed.flatten()) + + logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN') + + for t in reversed(flattened): + if t.value == '(' or (t.is_keyword and ( + t.value.upper() not in logical_operators)): + # Find the location of token t in the original parsed statement + # We can't use parsed.token_index(t) because t may be a child token + # inside a TokenList, in which case token_index thows an error + # Minimal example: + # p = sqlparse.parse('select * from foo where bar') + # t = list(p.flatten())[-3] # The "Where" token + # p.token_index(t) # Throws ValueError: not in list + idx = flattened.index(t) + + # Combine the string values of all tokens in the original list + # up to and including the target keyword token t, to produce a + # query string with everything after the keyword token removed + text = ''.join(tok.value for tok in flattened[:idx+1]) + return t, text + + return None, '' + + +# Postgresql dollar quote signs look like `$$` or `$tag$` +dollar_quote_regex = re.compile(r'^\$[^$]*\$$') + + +def is_open_quote(sql): + """Returns true if the query contains an unclosed quote""" + + # parsed can contain one or more semi-colon separated commands + parsed = sqlparse.parse(sql) + return any(_parsed_is_open_quote(p) for p in parsed) + + +def _parsed_is_open_quote(parsed): + tokens = list(parsed.flatten()) + + i = 0 + while i < len(tokens): + tok = tokens[i] + if tok.match(Token.Error, "'"): + # An unmatched single quote + return True + elif (tok.ttype in Token.Name.Builtin + and dollar_quote_regex.match(tok.value)): + # Find the matching closing dollar quote sign + for (j, tok2) in enumerate(tokens[i+1:], i+1): + if tok2.match(Token.Name.Builtin, tok.value): + # Found the matching closing quote - continue our scan for + # open quotes thereafter + i = j + break + else: + # No matching dollar sign quote + return True + + i += 1 + + return False + + +def parse_partial_identifier(word): + """Attempt to parse a (partially typed) word as an identifier + + word may include a schema qualification, like `schema_name.partial_name` + or `schema_name.` There may also be unclosed quotation marks, like + `"schema`, or `schema."partial_name` + + :param word: string representing a (partially complete) identifier + :return: sqlparse.sql.Identifier, or None + """ + + p = sqlparse.parse(word)[0] + n_tok = len(p.tokens) + if n_tok == 1 and isinstance(p.tokens[0], Identifier): + return p.tokens[0] + elif p.token_next_by(m=(Error, '"'))[1]: + # An unmatched double quote, e.g. '"foo', 'foo."', or 'foo."bar' + # Close the double quote, then reparse + return parse_partial_identifier(word + '"') + else: + return None + diff --git a/pgcli/packages/sqlcompletion.py b/pgcli/packages/sqlcompletion.py index e339e2b4..8d00fc28 100644 --- a/pgcli/packages/sqlcompletion.py +++ b/pgcli/packages/sqlcompletion.py @@ -4,8 +4,10 @@ import re import sqlparse from collections import namedtuple from sqlparse.sql import Comparison, Identifier, Where -from .parseutils import ( - last_word, extract_tables, find_prev_keyword, parse_partial_identifier) +from .parseutils.utils import ( + last_word, find_prev_keyword, parse_partial_identifier) +from .parseutils.tables import extract_tables +from .parseutils.ctes import isolate_query_ctes from pgspecial.main import parse_special_command PY2 = sys.version_info[0] == 2 @@ -21,25 +23,26 @@ Special = namedtuple('Special', []) Database = namedtuple('Database', []) Schema = namedtuple('Schema', []) # FromClauseItem is a table/view/function used in the FROM clause -# `tables` contains the list of tables/... already in the statement, +# `table_refs` contains the list of tables/... already in the statement, # used to ensure that the alias we suggest is unique -FromClauseItem = namedtuple('FromClauseItem', 'schema tables') -Table = namedtuple('Table', ['schema', 'tables']) -View = namedtuple('View', ['schema', 'tables']) +FromClauseItem = namedtuple('FromClauseItem', 'schema table_refs local_tables') +Table = namedtuple('Table', ['schema', 'table_refs', 'local_tables']) +View = namedtuple('View', ['schema', 'table_refs']) # JoinConditions are suggested after ON, e.g. 'foo.barid = bar.barid' -JoinCondition = namedtuple('JoinCondition', ['tables', 'parent']) +JoinCondition = namedtuple('JoinCondition', ['table_refs', 'parent']) # Joins are suggested after JOIN, e.g. 'foo ON foo.barid = bar.barid' -Join = namedtuple('Join', ['tables', 'schema']) +Join = namedtuple('Join', ['table_refs', 'schema']) -Function = namedtuple('Function', ['schema', 'tables', 'filter']) +Function = namedtuple('Function', ['schema', 'table_refs', 'filter']) # For convenience, don't require the `filter` argument in Function constructor Function.__new__.__defaults__ = (None, tuple(), None) -Table.__new__.__defaults__ = (None, tuple()) +Table.__new__.__defaults__ = (None, tuple(), tuple()) View.__new__.__defaults__ = (None, tuple()) -FromClauseItem.__new__.__defaults__ = (None, tuple()) +FromClauseItem.__new__.__defaults__ = (None, tuple(), tuple()) -Column = namedtuple('Column', ['tables', 'require_last_table']) -Column.__new__.__defaults__ = (None, None) +Column = namedtuple( + 'Column', ['table_refs', 'require_last_table', 'local_tables']) +Column.__new__.__defaults__ = (None, None, tuple()) Keyword = namedtuple('Keyword', []) NamedQuery = namedtuple('NamedQuery', []) @@ -56,6 +59,10 @@ class SqlStatement(object): text_before_cursor, include='many_punctuations') full_text = _strip_named_query(full_text) text_before_cursor = _strip_named_query(text_before_cursor) + + full_text, text_before_cursor, self.local_tables = \ + isolate_query_ctes(full_text, text_before_cursor) + self.text_before_cursor_including_last_word = text_before_cursor # If we've partially typed a word then word_before_cursor won't be an @@ -313,7 +320,9 @@ def suggest_based_on_last_token(token, stmt): tables = stmt.get_tables('before') # suggest columns that are present in more than one table - return (Column(tables=tables, require_last_table=True),) + return (Column(table_refs=tables, + require_last_table=True, + local_tables=stmt.local_tables),) elif p.token_first().value.lower() == 'select': # If the lparen is preceeded by a space chances are we're about to @@ -323,23 +332,25 @@ def suggest_based_on_last_token(token, stmt): return (Keyword(),) prev_prev_tok = p.token_prev(p.token_index(prev_tok))[1] if prev_prev_tok and prev_prev_tok.normalized == 'INTO': - return (Column(tables=stmt.get_tables('insert')),) + return (Column(table_refs=stmt.get_tables('insert')),) # We're probably in a function argument list - return (Column(tables=extract_tables(stmt.full_text)),) + return (Column(table_refs=extract_tables(stmt.full_text), + local_tables=stmt.local_tables),) elif token_v in ('set', 'by', 'distinct'): - return (Column(tables=stmt.get_tables()),) + return (Column(table_refs=stmt.get_tables(), + local_tables=stmt.local_tables),) elif token_v in ('select', 'where', 'having'): # Check for a table alias or schema qualification parent = (stmt.identifier and stmt.identifier.get_parent_name()) or [] tables = stmt.get_tables() if parent: tables = tuple(t for t in tables if identifies(parent, t)) - return (Column(tables=tables), + return (Column(table_refs=tables, local_tables=stmt.local_tables), Table(schema=parent), View(schema=parent), Function(schema=parent),) else: - return (Column(tables=tables), + return (Column(table_refs=tables, local_tables=stmt.local_tables), Function(schema=None), Keyword(),) @@ -358,9 +369,10 @@ def suggest_based_on_last_token(token, stmt): # Suggest schemas suggest.insert(0, Schema()) - # Suggest set-returning functions in the FROM clause if token_v == 'from' or is_join: - suggest.append(FromClauseItem(schema=schema, tables=tables)) + suggest.append(FromClauseItem(schema=schema, + table_refs=tables, + local_tables=stmt.local_tables)) elif token_v == 'truncate': suggest.append(Table(schema)) else: @@ -368,7 +380,7 @@ def suggest_based_on_last_token(token, stmt): if is_join and _allow_join(stmt.parsed): tables = stmt.get_tables('before') - suggest.append(Join(tables=tables, schema=schema)) + suggest.append(Join(table_refs=tables, schema=schema)) return tuple(suggest) @@ -383,7 +395,7 @@ def suggest_based_on_last_token(token, stmt): elif token_v == 'column': # E.g. 'ALTER TABLE foo ALTER COLUMN bar - return (Column(tables=stmt.get_tables()),) + return (Column(table_refs=stmt.get_tables()),) elif token_v == 'on': tables = stmt.get_tables('before') @@ -392,12 +404,13 @@ def suggest_based_on_last_token(token, stmt): # "ON parent.<suggestion>" # parent can be either a schema name or table alias filteredtables = tuple(t for t in tables if identifies(parent, t)) - sugs = [Column(tables=filteredtables), + sugs = [Column(table_refs=filteredtables, + local_tables=stmt.local_tables), Table(schema=parent), View(schema=parent), Function(schema=parent)] if filteredtables and _allow_join_condition(stmt.parsed): - sugs.append(JoinCondition(tables=tables, + sugs.append(JoinCondition(table_refs=tables, parent=filteredtables[-1])) return tuple(sugs) else: @@ -406,7 +419,7 @@ def suggest_based_on_last_token(token, stmt): aliases = tuple(t.ref for t in tables) if _allow_join_condition(stmt.parsed): return (Alias(aliases=aliases), JoinCondition( - tables=tables, parent=None)) + table_refs=tables, parent=None)) else: return (Alias(aliases=aliases),) diff --git a/pgcli/pgbuffer.py b/pgcli/pgbuffer.py index df797cf8..b9e9b930 100644 --- a/pgcli/pgbuffer.py +++ b/pgcli/pgbuffer.py @@ -1,6 +1,6 @@ from prompt_toolkit.buffer import Buffer from prompt_toolkit.filters import Condition -from .packages.parseutils import is_open_quote +from .packages.parseutils.utils import is_open_quote class PGBuffer(Buffer): diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py index 6a32e0c0..3c876aba 100644 --- a/pgcli/pgcompleter.py +++ b/pgcli/pgcompleter.py @@ -11,8 +11,9 @@ from prompt_toolkit.document import Document from .packages.sqlcompletion import (FromClauseItem, suggest_type, Special, Database, Schema, Table, Function, Column, View, Keyword, NamedQuery, Datatype, Alias, Path, JoinCondition, Join) -from .packages.function_metadata import ColumnMetadata, ForeignKey -from .packages.parseutils import last_word, TableReference +from .packages.parseutils.meta import ColumnMetadata, ForeignKey +from .packages.parseutils.utils import last_word +from .packages.parseutils.tables import TableReference from .packages.pgliterals.main import get_literals from .packages.prioritization import PrevalenceCounter from .config import load_config, config_location @@ -367,9 +368,10 @@ class PGCompleter(Completer): def get_column_matches(self, suggestion, word_before_cursor): - tables = suggestion.tables + tables = suggestion.table_refs _logger.debug("Completion column scope: %r", tables) - scoped_cols = self.populate_scoped_cols(tables) + scoped_cols = self.populate_scoped_cols(tables, suggestion.local_tables) + colit = scoped_cols.items flat_cols = list(chain(*((c.name for c in cols) for t, cols in colit()))) @@ -424,7 +426,7 @@ class PGCompleter(Completer): return next(a for a in aliases if normalize_ref(a) not in tbls) def get_join_matches(self, suggestion, word_before_cursor): - tbls = suggestion.tables + tbls = suggestion.table_refs cols = self.populate_scoped_cols(tbls) # Set up some data structures for efficient access qualified = dict((normalize_ref(t.ref), t.schema) for t in tbls) @@ -446,7 +448,7 @@ class PGCompleter(Completer): continue c = self.case if self.generate_aliases or normalize_ref(left.tbl) in refs: - lref = self.alias(left.tbl, suggestion.tables) + lref = self.alias(left.tbl, suggestion.table_refs) join = '{0} {4} ON {4}.{1} = {2}.{3}'.format( c(left.tbl), c(left.col), rtbl.ref, c(right.col), lref) else: @@ -467,10 +469,10 @@ class PGCompleter(Completer): def get_join_condition_matches(self, suggestion, word_before_cursor): col = namedtuple('col', 'schema tbl col') - tbls = self.populate_scoped_cols(suggestion.tables).items + tbls = self.populate_scoped_cols(suggestion.table_refs).items cols = [(t, c) for t, cs in tbls() for c in cs] try: - lref = (suggestion.parent or suggestion.tables[-1]).ref + lref = (suggestion.parent or suggestion.table_refs[-1]).ref ltbl, lcols = [(t, cs) for (t, cs) in tbls() if t.ref == lref][-1] except IndexError: # The user typed an incorrect table qualifier return [] @@ -493,7 +495,7 @@ class PGCompleter(Completer): # Tables that are closer to the cursor get higher prio ref_prio = dict((tbl.ref, num) for num, tbl - in enumerate(suggestion.tables)) + in enumerate(suggestion.table_refs)) # Map (schema, table, col) to tables coldict = list_dict(((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref) @@ -529,7 +531,7 @@ class PGCompleter(Completer): funcs = self.populate_functions(suggestion.schema, filt) if alias: funcs = [self.case(f) + '() ' + self.alias(f, - suggestion.tables) for f in funcs] + suggestion.table_refs) for f in funcs] else: funcs = [self.case(f) + '()' for f in funcs] else: @@ -564,15 +566,17 @@ class PGCompleter(Completer): def get_from_clause_item_matches(self, suggestion, word_before_cursor): alias = self.generate_aliases - t_sug = Table(*suggestion) - v_sug = View(*suggestion) - f_sug = Function(*suggestion, filter='for_from_clause') + s = suggestion + t_sug = Table(s.schema, s.table_refs, s.local_tables) + v_sug = View(s.schema, s.table_refs) + f_sug = Function(s.schema, s.table_refs, filter='for_from_clause') return (self.get_table_matches(t_sug, word_before_cursor, alias) + self.get_view_matches(v_sug, word_before_cursor, alias) + self.get_function_matches(f_sug, word_before_cursor, alias)) def get_table_matches(self, suggestion, word_before_cursor, alias=False): tables = self.populate_schema_objects(suggestion.schema, 'tables') + tables.extend(tbl.name for tbl in suggestion.local_tables) # Unless we're sure the user really wants them, don't suggest the # pg_catalog tables that are implicitly on the search path @@ -580,7 +584,7 @@ class PGCompleter(Completer): not word_before_cursor.startswith('pg_')): tables = [t for t in tables if not t.startswith('pg_')] if alias: - tables = [self.case(t) + ' ' + self.alias(t, suggestion.tables) + tables = [self.case(t) + ' ' + self.alias(t, suggestion.table_refs) for t in tables] return self.find_matches(word_before_cursor, tables, meta='table') @@ -593,7 +597,7 @@ class PGCompleter(Completer): not word_before_cursor.startswith('pg_')): views = [v for v in views if not v.startswith('pg_')] if alias: - views = [self.case(v) + ' ' + self.alias(v, suggestion.tables) + views = [self.case(v) + ' ' + self.alias(v, suggestion.table_refs) for v in views] return self.find_matches(word_before_cursor, views, meta='view') @@ -661,9 +665,10 @@ class PGCompleter(Completer): Path: get_path_matches, } - def populate_scoped_cols(self, scoped_tbls): + def populate_scoped_cols(self, scoped_tbls, local_tbls=()): """ Find all columns in a set of scoped_tables :param scoped_tbls: list of TableReference namedtuples + :param local_tbls: tuple(TableMetadata) :return: {TableReference:{colname:ColumnMetaData}} """ @@ -696,6 +701,10 @@ class PGCompleter(Completer): cols = cols.values() addcols(schema, relname, tbl.alias, reltype, cols) break + + # Local tables sho |