diff options
author | Amjith Ramanujam <amjith.r@gmail.com> | 2016-06-24 06:43:06 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-06-24 06:43:06 -0700 |
commit | 17790dc2c426c4cad9b886d4d62c6721373d166c (patch) | |
tree | c1a4f83052ed6da3b4f3bd5059f6992274f42e05 | |
parent | 5253e57fb8cbd6f51cd4026c24f79e09f3ef1354 (diff) | |
parent | 5f6876165818ea68ab323a62a35ee543c66760e5 (diff) |
Merge pull request #532 from dbcli/darikg/doc-object
Some sqlcompletion refactoring
-rw-r--r-- | pgcli/packages/sqlcompletion.py | 194 |
1 files changed, 107 insertions, 87 deletions
diff --git a/pgcli/packages/sqlcompletion.py b/pgcli/packages/sqlcompletion.py index 917086e5..a313a957 100644 --- a/pgcli/packages/sqlcompletion.py +++ b/pgcli/packages/sqlcompletion.py @@ -42,6 +42,54 @@ Alias = namedtuple('Alias', ['aliases']) Path = namedtuple('Path', []) +class SqlStatement(object): + def __init__(self, full_text, text_before_cursor): + self.identifier = None + self.word_before_cursor = word_before_cursor = last_word( + text_before_cursor, include='many_punctuations') + full_text = _strip_named_query(full_text) + text_before_cursor = _strip_named_query(text_before_cursor) + self.text_before_cursor_including_last_word = text_before_cursor + + # If we've partially typed a word then word_before_cursor won't be an + # empty string. In that case we want to remove the partially typed + # string before sending it to the sqlparser. Otherwise the last token + # will always be the partially typed string which renders the smart + # completion useless because it will always return the list of + # keywords as completion. + if self.word_before_cursor: + if word_before_cursor[-1] == '(' or word_before_cursor[0] == '\\': + parsed = sqlparse.parse(text_before_cursor) + else: + text_before_cursor = text_before_cursor[:-len(word_before_cursor)] + parsed = sqlparse.parse(text_before_cursor) + self.identifier = parse_partial_identifier(word_before_cursor) + else: + parsed = sqlparse.parse(text_before_cursor) + + full_text, text_before_cursor, parsed = \ + _split_multiple_statements(full_text, text_before_cursor, parsed) + + self.full_text = full_text + self.text_before_cursor = text_before_cursor + self.parsed = parsed + + self.last_token = parsed and parsed.token_prev(len(parsed.tokens)) or '' + + def get_identifier_schema(self): + schema = (self.identifier and self.identifier.get_parent_name()) or None + # If schema name is unquoted, lower-case it + if schema and self.identifier.value[0] != '"': + schema = schema.lower() + + return schema + + def reduce_to_prev_keyword(self): + prev_keyword, self.text_before_cursor = \ + find_prev_keyword(self.text_before_cursor) + return prev_keyword + + def suggest_type(full_text, text_before_cursor): """Takes the full_text that is typed so far and also the text before the cursor to suggest completion type and scope. @@ -53,46 +101,40 @@ def suggest_type(full_text, text_before_cursor): if full_text.startswith('\\i '): return (Path(),) - word_before_cursor = last_word(text_before_cursor, - include='many_punctuations') - - identifier = None - - def strip_named_query(txt): - """ - This will strip "save named query" command in the beginning of the line: - '\ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc' - ' \ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc' - """ - pattern = re.compile(r'^\s*\\ns\s+[A-z0-9\-_]+\s+') - if pattern.match(txt): - txt = pattern.sub('', txt) - return txt - - full_text = strip_named_query(full_text) - text_before_cursor = strip_named_query(text_before_cursor) - text_before_cursor_including_last_word = text_before_cursor - - # If we've partially typed a word then word_before_cursor won't be an empty - # string. In that case we want to remove the partially typed string before - # sending it to the sqlparser. Otherwise the last token will always be the - # partially typed string which renders the smart completion useless because - # it will always return the list of keywords as completion. - if word_before_cursor: - if word_before_cursor[-1] == '(' or word_before_cursor[0] == '\\': - parsed = sqlparse.parse(text_before_cursor) - else: - text_before_cursor = text_before_cursor[:-len(word_before_cursor)] - parsed = sqlparse.parse(text_before_cursor) + stmt = SqlStatement(full_text, text_before_cursor) - identifier = parse_partial_identifier(word_before_cursor) - else: - parsed = sqlparse.parse(text_before_cursor) + # Check for special commands and handle those separately + if stmt.parsed: + # Be careful here because trivial whitespace is parsed as a + # statement, but the statement won't have a first token + tok1 = stmt.parsed.token_first() + if tok1 and tok1.value == '\\': + text = stmt.text_before_cursor + stmt.word_before_cursor + return suggest_special(text) + + return suggest_based_on_last_token(stmt.last_token, stmt) + + +named_query_regex = re.compile(r'^\s*\\ns\s+[A-z0-9\-_]+\s+') + + +def _strip_named_query(txt): + """ + This will strip "save named query" command in the beginning of the line: + '\ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc' + ' \ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc' + """ + + if named_query_regex.match(txt): + txt = named_query_regex.sub('', txt) + return txt + +def _split_multiple_statements(full_text, text_before_cursor, parsed): if len(parsed) > 1: # Multiple statements being edited -- isolate the current one by - # cumulatively summing statement lengths to find the one that bounds the - # current position + # cumulatively summing statement lengths to find the one that bounds + # the current position current_pos = len(text_before_cursor) stmt_start, stmt_end = 0, 0 @@ -112,19 +154,7 @@ def suggest_type(full_text, text_before_cursor): # The empty string statement = None - # Check for special commands and handle those separately - if statement: - # Be careful here because trivial whitespace is parsed as a statement, - # but the statement won't have a first token - tok1 = statement.token_first() - if tok1 and tok1.value == '\\': - return suggest_special(text_before_cursor_including_last_word) - - last_token = statement and statement.token_prev(len(statement.tokens)) or '' - - return suggest_based_on_last_token( - last_token, text_before_cursor, full_text, identifier, - parsed_statement=statement) + return full_text, text_before_cursor, statement def suggest_special(text): @@ -179,8 +209,7 @@ def suggest_special(text): return (Keyword(), Special()) -def suggest_based_on_last_token(token, text_before_cursor, full_text, - identifier, parsed_statement=None): +def suggest_based_on_last_token(token, stmt): if isinstance(token, string_types): token_v = token.lower() @@ -197,9 +226,8 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, # list. This means that token.value may be something like # 'where foo > 5 and '. We need to look "inside" token.tokens to handle # suggestions in complicated where clauses correctly - prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor) - return suggest_based_on_last_token(prev_keyword, text_before_cursor, - full_text, identifier, parsed_statement) + prev_keyword = stmt.reduce_to_prev_keyword() + return suggest_based_on_last_token(prev_keyword, stmt) elif isinstance(token, Identifier): # If the previous token is an identifier, we can suggest datatypes if # we're in a parenthesized column/field list, e.g.: @@ -209,11 +237,10 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, # user is about to specify an alias, e.g.: # SELECT Identifier <CURSOR> # SELECT foo FROM Identifier <CURSOR> - prev_keyword, _ = find_prev_keyword(text_before_cursor) + prev_keyword, _ = find_prev_keyword(stmt.text_before_cursor) if prev_keyword and prev_keyword.value == '(': # Suggest datatypes - return suggest_based_on_last_token('type', text_before_cursor, - full_text, identifier, parsed_statement) + return suggest_based_on_last_token('type', stmt) else: return (Keyword(),) else: @@ -222,7 +249,7 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, if not token: return (Keyword(), Special()) elif token_v.endswith('('): - p = sqlparse.parse(text_before_cursor)[0] + p = sqlparse.parse(stmt.text_before_cursor)[0] if p.tokens and isinstance(p.tokens[-1], Where): # Four possibilities: @@ -236,9 +263,7 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, # Suggest columns/functions AND keywords. (If we wanted to be # really fancy, we could suggest only array-typed columns) - column_suggestions = suggest_based_on_last_token( - 'where', text_before_cursor, full_text, identifier, - parsed_statement) + column_suggestions = suggest_based_on_last_token('where', stmt) # Check for a subquery expression (cases 3 & 4) where = p.tokens[-1] @@ -260,7 +285,7 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, if (prev_tok and prev_tok.value and prev_tok.value.lower().split(' ')[-1] == 'using'): # tbl1 INNER JOIN tbl2 USING (col1, col2) - tables = extract_tables(text_before_cursor) + tables = extract_tables(stmt.text_before_cursor) # suggest columns that are present in more than one table return (Column(tables=tables, require_last_table=True),) @@ -268,36 +293,34 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, elif p.token_first().value.lower() == 'select': # If the lparen is preceeded by a space chances are we're about to # do a sub-select. - if last_word(text_before_cursor, + if last_word(stmt.text_before_cursor, 'all_punctuations').startswith('('): return (Keyword(),) # We're probably in a function argument list - return (Column(tables=extract_tables(full_text)),) + return (Column(tables=extract_tables(stmt.full_text)),) elif token_v in ('set', 'by', 'distinct'): - return (Column(tables=extract_tables(full_text)),) + return (Column(tables=extract_tables(stmt.full_text)),) elif token_v in ('select', 'where', 'having'): # Check for a table alias or schema qualification - parent = (identifier and identifier.get_parent_name()) or [] + parent = (stmt.identifier and stmt.identifier.get_parent_name()) or [] if parent: - tables = extract_tables(full_text) + tables = extract_tables(stmt.full_text) tables = tuple(t for t in tables if identifies(parent, t)) return (Column(tables=tables), Table(schema=parent), View(schema=parent), Function(schema=parent),) else: - return (Column(tables=extract_tables(full_text)), + return (Column(tables=extract_tables(stmt.full_text)), Function(schema=None), Keyword(),) elif (token_v.endswith('join') and token.is_keyword) or (token_v in ('copy', 'from', 'update', 'into', 'describe', 'truncate')): - schema = (identifier and identifier.get_parent_name()) or None - # If schema name is unquoted, lower-case it - if schema and identifier.value[0] != '"': - schema = schema.lower() + schema = stmt.get_identifier_schema() + # Suggest tables from either the currently-selected schema or the # public schema if no schema has been specified suggest = [Table(schema=schema)] @@ -315,8 +338,8 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, suggest.append(Function(schema=schema, filter='for_from_clause')) if (token_v.endswith('join') and token.is_keyword - and _allow_join(parsed_statement)): - tables = extract_tables(text_before_cursor) + and _allow_join(stmt.parsed)): + tables = extract_tables(stmt.text_before_cursor) suggest.append(Join(tables=tables, schema=schema)) return tuple(suggest) @@ -324,7 +347,7 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, elif token_v in ('table', 'view', 'function'): # E.g. 'DROP FUNCTION <funcname>', 'ALTER TABLE <tablname>' rel_type = {'table': Table, 'view': View, 'function': Function}[token_v] - schema = (identifier and identifier.get_parent_name()) or None + schema = stmt.get_identifier_schema() if schema: return (rel_type(schema=schema),) else: @@ -332,11 +355,11 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, elif token_v == 'column': # E.g. 'ALTER TABLE foo ALTER COLUMN bar - return (Column(tables=extract_tables(text_before_cursor)),) + return (Column(tables=extract_tables(stmt.text_before_cursor)),) elif token_v == 'on': - tables = extract_tables(text_before_cursor) # [(schema, table, alias), ...] - parent = (identifier and identifier.get_parent_name()) or None + tables = extract_tables(stmt.text_before_cursor) # [(schema, table, alias), ...] + parent = (stmt.identifier and stmt.identifier.get_parent_name()) or None if parent: # "ON parent.<suggestion>" # parent can be either a schema name or table alias @@ -345,16 +368,15 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, Table(schema=parent), View(schema=parent), Function(schema=parent)] - last_token = parsed_statement - if filteredtables and _allow_join_condition(parsed_statement): + if filteredtables and _allow_join_condition(stmt.parsed): sugs.append(JoinCondition(tables=tables, - parent=filteredtables[-1])) + parent=filteredtables[-1])) return tuple(sugs) else: # ON <suggestion> # Use table alias if there is one, otherwise the table name aliases = tuple(t.ref for t in tables) - if _allow_join_condition(parsed_statement): + if _allow_join_condition(stmt.parsed): return (Alias(aliases=aliases), JoinCondition( tables=tables, parent=None)) else: @@ -368,11 +390,9 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, # DROP SCHEMA schema_name return (Schema(),) elif token_v.endswith(',') or token_v in ('=', 'and', 'or'): - prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor) + prev_keyword = stmt.reduce_to_prev_keyword() if prev_keyword: - return suggest_based_on_last_token( - prev_keyword, text_before_cursor, full_text, identifier, - parsed_statement) + return suggest_based_on_last_token(prev_keyword, stmt) else: return () elif token_v in ('type', '::'): @@ -380,7 +400,7 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, # SELECT foo::bar # Note that tables are a form of composite type in postgresql, so # they're suggested here as well - schema = (identifier and identifier.get_parent_name()) or None + schema = stmt.get_identifier_schema() suggestions = [Datatype(schema=schema), Table(schema=schema)] if not schema: |