summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAmjith Ramanujam <amjith.r@gmail.com>2016-06-24 06:43:06 -0700
committerGitHub <noreply@github.com>2016-06-24 06:43:06 -0700
commit17790dc2c426c4cad9b886d4d62c6721373d166c (patch)
treec1a4f83052ed6da3b4f3bd5059f6992274f42e05
parent5253e57fb8cbd6f51cd4026c24f79e09f3ef1354 (diff)
parent5f6876165818ea68ab323a62a35ee543c66760e5 (diff)
Merge pull request #532 from dbcli/darikg/doc-object
Some sqlcompletion refactoring
-rw-r--r--pgcli/packages/sqlcompletion.py194
1 files changed, 107 insertions, 87 deletions
diff --git a/pgcli/packages/sqlcompletion.py b/pgcli/packages/sqlcompletion.py
index 917086e5..a313a957 100644
--- a/pgcli/packages/sqlcompletion.py
+++ b/pgcli/packages/sqlcompletion.py
@@ -42,6 +42,54 @@ Alias = namedtuple('Alias', ['aliases'])
Path = namedtuple('Path', [])
+class SqlStatement(object):
+ def __init__(self, full_text, text_before_cursor):
+ self.identifier = None
+ self.word_before_cursor = word_before_cursor = last_word(
+ text_before_cursor, include='many_punctuations')
+ full_text = _strip_named_query(full_text)
+ text_before_cursor = _strip_named_query(text_before_cursor)
+ self.text_before_cursor_including_last_word = text_before_cursor
+
+ # If we've partially typed a word then word_before_cursor won't be an
+ # empty string. In that case we want to remove the partially typed
+ # string before sending it to the sqlparser. Otherwise the last token
+ # will always be the partially typed string which renders the smart
+ # completion useless because it will always return the list of
+ # keywords as completion.
+ if self.word_before_cursor:
+ if word_before_cursor[-1] == '(' or word_before_cursor[0] == '\\':
+ parsed = sqlparse.parse(text_before_cursor)
+ else:
+ text_before_cursor = text_before_cursor[:-len(word_before_cursor)]
+ parsed = sqlparse.parse(text_before_cursor)
+ self.identifier = parse_partial_identifier(word_before_cursor)
+ else:
+ parsed = sqlparse.parse(text_before_cursor)
+
+ full_text, text_before_cursor, parsed = \
+ _split_multiple_statements(full_text, text_before_cursor, parsed)
+
+ self.full_text = full_text
+ self.text_before_cursor = text_before_cursor
+ self.parsed = parsed
+
+ self.last_token = parsed and parsed.token_prev(len(parsed.tokens)) or ''
+
+ def get_identifier_schema(self):
+ schema = (self.identifier and self.identifier.get_parent_name()) or None
+ # If schema name is unquoted, lower-case it
+ if schema and self.identifier.value[0] != '"':
+ schema = schema.lower()
+
+ return schema
+
+ def reduce_to_prev_keyword(self):
+ prev_keyword, self.text_before_cursor = \
+ find_prev_keyword(self.text_before_cursor)
+ return prev_keyword
+
+
def suggest_type(full_text, text_before_cursor):
"""Takes the full_text that is typed so far and also the text before the
cursor to suggest completion type and scope.
@@ -53,46 +101,40 @@ def suggest_type(full_text, text_before_cursor):
if full_text.startswith('\\i '):
return (Path(),)
- word_before_cursor = last_word(text_before_cursor,
- include='many_punctuations')
-
- identifier = None
-
- def strip_named_query(txt):
- """
- This will strip "save named query" command in the beginning of the line:
- '\ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc'
- ' \ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc'
- """
- pattern = re.compile(r'^\s*\\ns\s+[A-z0-9\-_]+\s+')
- if pattern.match(txt):
- txt = pattern.sub('', txt)
- return txt
-
- full_text = strip_named_query(full_text)
- text_before_cursor = strip_named_query(text_before_cursor)
- text_before_cursor_including_last_word = text_before_cursor
-
- # If we've partially typed a word then word_before_cursor won't be an empty
- # string. In that case we want to remove the partially typed string before
- # sending it to the sqlparser. Otherwise the last token will always be the
- # partially typed string which renders the smart completion useless because
- # it will always return the list of keywords as completion.
- if word_before_cursor:
- if word_before_cursor[-1] == '(' or word_before_cursor[0] == '\\':
- parsed = sqlparse.parse(text_before_cursor)
- else:
- text_before_cursor = text_before_cursor[:-len(word_before_cursor)]
- parsed = sqlparse.parse(text_before_cursor)
+ stmt = SqlStatement(full_text, text_before_cursor)
- identifier = parse_partial_identifier(word_before_cursor)
- else:
- parsed = sqlparse.parse(text_before_cursor)
+ # Check for special commands and handle those separately
+ if stmt.parsed:
+ # Be careful here because trivial whitespace is parsed as a
+ # statement, but the statement won't have a first token
+ tok1 = stmt.parsed.token_first()
+ if tok1 and tok1.value == '\\':
+ text = stmt.text_before_cursor + stmt.word_before_cursor
+ return suggest_special(text)
+
+ return suggest_based_on_last_token(stmt.last_token, stmt)
+
+
+named_query_regex = re.compile(r'^\s*\\ns\s+[A-z0-9\-_]+\s+')
+
+
+def _strip_named_query(txt):
+ """
+ This will strip "save named query" command in the beginning of the line:
+ '\ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc'
+ ' \ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc'
+ """
+
+ if named_query_regex.match(txt):
+ txt = named_query_regex.sub('', txt)
+ return txt
+
+def _split_multiple_statements(full_text, text_before_cursor, parsed):
if len(parsed) > 1:
# Multiple statements being edited -- isolate the current one by
- # cumulatively summing statement lengths to find the one that bounds the
- # current position
+ # cumulatively summing statement lengths to find the one that bounds
+ # the current position
current_pos = len(text_before_cursor)
stmt_start, stmt_end = 0, 0
@@ -112,19 +154,7 @@ def suggest_type(full_text, text_before_cursor):
# The empty string
statement = None
- # Check for special commands and handle those separately
- if statement:
- # Be careful here because trivial whitespace is parsed as a statement,
- # but the statement won't have a first token
- tok1 = statement.token_first()
- if tok1 and tok1.value == '\\':
- return suggest_special(text_before_cursor_including_last_word)
-
- last_token = statement and statement.token_prev(len(statement.tokens)) or ''
-
- return suggest_based_on_last_token(
- last_token, text_before_cursor, full_text, identifier,
- parsed_statement=statement)
+ return full_text, text_before_cursor, statement
def suggest_special(text):
@@ -179,8 +209,7 @@ def suggest_special(text):
return (Keyword(), Special())
-def suggest_based_on_last_token(token, text_before_cursor, full_text,
- identifier, parsed_statement=None):
+def suggest_based_on_last_token(token, stmt):
if isinstance(token, string_types):
token_v = token.lower()
@@ -197,9 +226,8 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text,
# list. This means that token.value may be something like
# 'where foo > 5 and '. We need to look "inside" token.tokens to handle
# suggestions in complicated where clauses correctly
- prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
- return suggest_based_on_last_token(prev_keyword, text_before_cursor,
- full_text, identifier, parsed_statement)
+ prev_keyword = stmt.reduce_to_prev_keyword()
+ return suggest_based_on_last_token(prev_keyword, stmt)
elif isinstance(token, Identifier):
# If the previous token is an identifier, we can suggest datatypes if
# we're in a parenthesized column/field list, e.g.:
@@ -209,11 +237,10 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text,
# user is about to specify an alias, e.g.:
# SELECT Identifier <CURSOR>
# SELECT foo FROM Identifier <CURSOR>
- prev_keyword, _ = find_prev_keyword(text_before_cursor)
+ prev_keyword, _ = find_prev_keyword(stmt.text_before_cursor)
if prev_keyword and prev_keyword.value == '(':
# Suggest datatypes
- return suggest_based_on_last_token('type', text_before_cursor,
- full_text, identifier, parsed_statement)
+ return suggest_based_on_last_token('type', stmt)
else:
return (Keyword(),)
else:
@@ -222,7 +249,7 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text,
if not token:
return (Keyword(), Special())
elif token_v.endswith('('):
- p = sqlparse.parse(text_before_cursor)[0]
+ p = sqlparse.parse(stmt.text_before_cursor)[0]
if p.tokens and isinstance(p.tokens[-1], Where):
# Four possibilities:
@@ -236,9 +263,7 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text,
# Suggest columns/functions AND keywords. (If we wanted to be
# really fancy, we could suggest only array-typed columns)
- column_suggestions = suggest_based_on_last_token(
- 'where', text_before_cursor, full_text, identifier,
- parsed_statement)
+ column_suggestions = suggest_based_on_last_token('where', stmt)
# Check for a subquery expression (cases 3 & 4)
where = p.tokens[-1]
@@ -260,7 +285,7 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text,
if (prev_tok and prev_tok.value
and prev_tok.value.lower().split(' ')[-1] == 'using'):
# tbl1 INNER JOIN tbl2 USING (col1, col2)
- tables = extract_tables(text_before_cursor)
+ tables = extract_tables(stmt.text_before_cursor)
# suggest columns that are present in more than one table
return (Column(tables=tables, require_last_table=True),)
@@ -268,36 +293,34 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text,
elif p.token_first().value.lower() == 'select':
# If the lparen is preceeded by a space chances are we're about to
# do a sub-select.
- if last_word(text_before_cursor,
+ if last_word(stmt.text_before_cursor,
'all_punctuations').startswith('('):
return (Keyword(),)
# We're probably in a function argument list
- return (Column(tables=extract_tables(full_text)),)
+ return (Column(tables=extract_tables(stmt.full_text)),)
elif token_v in ('set', 'by', 'distinct'):
- return (Column(tables=extract_tables(full_text)),)
+ return (Column(tables=extract_tables(stmt.full_text)),)
elif token_v in ('select', 'where', 'having'):
# Check for a table alias or schema qualification
- parent = (identifier and identifier.get_parent_name()) or []
+ parent = (stmt.identifier and stmt.identifier.get_parent_name()) or []
if parent:
- tables = extract_tables(full_text)
+ tables = extract_tables(stmt.full_text)
tables = tuple(t for t in tables if identifies(parent, t))
return (Column(tables=tables),
Table(schema=parent),
View(schema=parent),
Function(schema=parent),)
else:
- return (Column(tables=extract_tables(full_text)),
+ return (Column(tables=extract_tables(stmt.full_text)),
Function(schema=None),
Keyword(),)
elif (token_v.endswith('join') and token.is_keyword) or (token_v in
('copy', 'from', 'update', 'into', 'describe', 'truncate')):
- schema = (identifier and identifier.get_parent_name()) or None
- # If schema name is unquoted, lower-case it
- if schema and identifier.value[0] != '"':
- schema = schema.lower()
+ schema = stmt.get_identifier_schema()
+
# Suggest tables from either the currently-selected schema or the
# public schema if no schema has been specified
suggest = [Table(schema=schema)]
@@ -315,8 +338,8 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text,
suggest.append(Function(schema=schema, filter='for_from_clause'))
if (token_v.endswith('join') and token.is_keyword
- and _allow_join(parsed_statement)):
- tables = extract_tables(text_before_cursor)
+ and _allow_join(stmt.parsed)):
+ tables = extract_tables(stmt.text_before_cursor)
suggest.append(Join(tables=tables, schema=schema))
return tuple(suggest)
@@ -324,7 +347,7 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text,
elif token_v in ('table', 'view', 'function'):
# E.g. 'DROP FUNCTION <funcname>', 'ALTER TABLE <tablname>'
rel_type = {'table': Table, 'view': View, 'function': Function}[token_v]
- schema = (identifier and identifier.get_parent_name()) or None
+ schema = stmt.get_identifier_schema()
if schema:
return (rel_type(schema=schema),)
else:
@@ -332,11 +355,11 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text,
elif token_v == 'column':
# E.g. 'ALTER TABLE foo ALTER COLUMN bar
- return (Column(tables=extract_tables(text_before_cursor)),)
+ return (Column(tables=extract_tables(stmt.text_before_cursor)),)
elif token_v == 'on':
- tables = extract_tables(text_before_cursor) # [(schema, table, alias), ...]
- parent = (identifier and identifier.get_parent_name()) or None
+ tables = extract_tables(stmt.text_before_cursor) # [(schema, table, alias), ...]
+ parent = (stmt.identifier and stmt.identifier.get_parent_name()) or None
if parent:
# "ON parent.<suggestion>"
# parent can be either a schema name or table alias
@@ -345,16 +368,15 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text,
Table(schema=parent),
View(schema=parent),
Function(schema=parent)]
- last_token = parsed_statement
- if filteredtables and _allow_join_condition(parsed_statement):
+ if filteredtables and _allow_join_condition(stmt.parsed):
sugs.append(JoinCondition(tables=tables,
- parent=filteredtables[-1]))
+ parent=filteredtables[-1]))
return tuple(sugs)
else:
# ON <suggestion>
# Use table alias if there is one, otherwise the table name
aliases = tuple(t.ref for t in tables)
- if _allow_join_condition(parsed_statement):
+ if _allow_join_condition(stmt.parsed):
return (Alias(aliases=aliases), JoinCondition(
tables=tables, parent=None))
else:
@@ -368,11 +390,9 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text,
# DROP SCHEMA schema_name
return (Schema(),)
elif token_v.endswith(',') or token_v in ('=', 'and', 'or'):
- prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
+ prev_keyword = stmt.reduce_to_prev_keyword()
if prev_keyword:
- return suggest_based_on_last_token(
- prev_keyword, text_before_cursor, full_text, identifier,
- parsed_statement)
+ return suggest_based_on_last_token(prev_keyword, stmt)
else:
return ()
elif token_v in ('type', '::'):
@@ -380,7 +400,7 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text,
# SELECT foo::bar
# Note that tables are a form of composite type in postgresql, so
# they're suggested here as well
- schema = (identifier and identifier.get_parent_name()) or None
+ schema = stmt.get_identifier_schema()
suggestions = [Datatype(schema=schema),
Table(schema=schema)]
if not schema: