diff options
author | Darik Gamble <darik.gamble.spam@gmail.com> | 2016-07-22 07:28:21 -0400 |
---|---|---|
committer | Darik Gamble <darik.gamble.spam@gmail.com> | 2016-07-27 15:33:58 -0400 |
commit | c66fbbf5987e91082434e3cb1ad73734c3858d95 (patch) | |
tree | 1ea0f31dbf20cf2f5230ce25c5ef3ac78c242454 | |
parent | a52ee78fcb41b4ab1fc17d0dd05f1a8cb6da4b92 (diff) |
Add parseutils module for processing CTEs (not hooked up yet)
-rw-r--r-- | pgcli/packages/parseutils/ctes.py | 146 | ||||
-rw-r--r-- | tests/parseutils/test_ctes.py | 121 |
2 files changed, 267 insertions, 0 deletions
diff --git a/pgcli/packages/parseutils/ctes.py b/pgcli/packages/parseutils/ctes.py new file mode 100644 index 00000000..12c207c4 --- /dev/null +++ b/pgcli/packages/parseutils/ctes.py @@ -0,0 +1,146 @@ +from sqlparse import parse +from sqlparse.tokens import Keyword, CTE, DML +from sqlparse.sql import Identifier, IdentifierList, Parenthesis +from collections import namedtuple +from .meta import TableMetadata, ColumnMetadata + + +# TableExpression is a namedtuple representing a CTE, used internally +# name: cte alias assigned in the query +# columns: list of column names +# start: index into the original string of the left parens starting the CTE +# stop: index into the original string of the right parens ending the CTE +TableExpression = namedtuple('TableExpression', 'name columns start stop') + + +def isolate_query_ctes(full_text, text_before_cursor): + """Simplify a query by converting CTEs into table metadata objects + """ + + if not full_text: + return full_text, text_before_cursor, tuple() + + ctes, remainder = extract_ctes(full_text) + if not ctes: + return full_text, text_before_cursor, () + + current_position = len(text_before_cursor) + meta = [] + + for cte in ctes: + if cte.start < current_position < cte.stop: + # Currently editing a cte - treat its body as the current full_text + text_before_cursor = full_text[cte.start:current_position] + full_text = full_text[cte.start:cte.stop] + return full_text, text_before_cursor, meta + + # Append this cte to the list of available table metadata + cols = (ColumnMetadata(name, None, ()) for name in cte.columns) + meta.append(TableMetadata(cte.name, cols)) + + # Editing past the last cte (ie the main body of the query) + full_text = full_text[ctes[-1].stop:] + text_before_cursor = text_before_cursor[ctes[-1].stop:current_position] + + return full_text, text_before_cursor, tuple(meta) + + +def extract_ctes(sql): + """ Extract constant table expresseions from a query + + Returns tuple (ctes, remainder_sql) + + ctes is a list of TableExpression namedtuples + remainder_sql is the text from the original query after the CTEs have + been stripped. + """ + + p = parse(sql)[0] + + # Make sure the first meaningful token is "WITH" which is necessary to + # define CTEs + idx, tok = p.token_next(-1, skip_ws=True, skip_cm=True) + if not (tok and tok.ttype == CTE): + return [], sql + + # Get the next (meaningful) token, which should be the first CTE + idx, tok = p.token_next(idx) + start_pos = token_start_pos(p.tokens, idx) + ctes = [] + + if isinstance(tok, IdentifierList): + # Multiple ctes + for t in tok.get_identifiers(): + cte_start_offset = token_start_pos(tok.tokens, tok.token_index(t)) + cte = get_cte_from_token(t, start_pos + cte_start_offset) + if not cte: + continue + ctes.append(cte) + elif isinstance(tok, Identifier): + # A single CTE + cte = get_cte_from_token(tok, start_pos) + if cte: + ctes.append(cte) + + idx = p.token_index(tok) + 1 + + # Collapse everything after the ctes into a remainder query + remainder = u''.join(str(tok) for tok in p.tokens[idx:]) + + return ctes, remainder + + +def get_cte_from_token(tok, pos0): + cte_name = tok.get_real_name() + if not cte_name: + return None + + # Find the start position of the opening parens enclosing the cte body + idx, parens = tok.token_next_by(Parenthesis) + if not parens: + return None + + start_pos = pos0 + token_start_pos(tok.tokens, idx) + cte_len = len(str(parens)) # includes parens + stop_pos = start_pos + cte_len + + column_names = extract_column_names(parens) + + return TableExpression(cte_name, column_names, start_pos, stop_pos) + + +def extract_column_names(parsed): + # Find the first DML token to check if it's a SELECT or INSERT/UPDATE/DELETE + idx, tok = parsed.token_next_by(t=DML) + tok_val = tok and tok.value.lower() + + if tok_val in ('insert', 'update', 'delete'): + # Jump ahead to the RETURNING clause where the list of column names is + idx, tok = parsed.token_next_by(idx, (Keyword, 'returning')) + elif not tok_val == 'select': + # Must be invalid CTE + return () + + # The next token should be either a column name, or a list of column names + idx, tok = parsed.token_next(idx, skip_ws=True, skip_cm=True) + return tuple(t.get_name() for t in _identifiers(tok)) + + +def token_start_pos(tokens, idx): + return sum(len(str(t)) for t in tokens[:idx]) + + +def _identifiers(tok): + if isinstance(tok, IdentifierList): + for t in tok.get_identifiers(): + # NB: IdentifierList.get_identifiers() can return non-identifiers! + if isinstance(t, Identifier): + yield t + elif isinstance(tok, Identifier): + yield tok + + + + + + diff --git a/tests/parseutils/test_ctes.py b/tests/parseutils/test_ctes.py new file mode 100644 index 00000000..9566cf65 --- /dev/null +++ b/tests/parseutils/test_ctes.py @@ -0,0 +1,121 @@ +import pytest +from sqlparse import parse +from pgcli.packages.parseutils.ctes import ( + token_start_pos, extract_ctes, + extract_column_names as _extract_column_names) + + +def extract_column_names(sql): + p = parse(sql)[0] + return _extract_column_names(p) + + +def test_token_str_pos(): + sql = 'SELECT * FROM xxx' + p = parse(sql)[0] + idx = p.token_index(p.tokens[-1]) + assert token_start_pos(p.tokens, idx) == len('SELECT * FROM ') + + sql = 'SELECT * FROM \nxxx' + p = parse(sql)[0] + idx = p.token_index(p.tokens[-1]) + assert token_start_pos(p.tokens, idx) == len('SELECT * FROM \n') + + +def test_single_column_name_extraction(): + sql = 'SELECT abc FROM xxx' + assert extract_column_names(sql) == ('abc',) + + +def test_aliased_single_column_name_extraction(): + sql = 'SELECT abc def FROM xxx' + assert extract_column_names(sql) == ('def',) + + +def test_aliased_expression_name_extraction(): + sql = 'SELECT 99 abc FROM xxx' + assert extract_column_names(sql) == ('abc',) + + +def test_multiple_column_name_extraction(): + sql = 'SELECT abc, def FROM xxx' + assert extract_column_names(sql) == ('abc', 'def') + + +def test_missing_column_name_handled_gracefully(): + sql = 'SELECT abc, 99 FROM xxx' + assert extract_column_names(sql) == ('abc',) + + sql = 'SELECT abc, 99, def FROM xxx' + assert extract_column_names(sql) == ('abc', 'def') + + +def test_aliased_multiple_column_name_extraction(): + sql = 'SELECT abc def, ghi jkl FROM xxx' + assert extract_column_names(sql) == ('def', 'jkl') + + +def test_table_qualified_column_name_extraction(): + sql = 'SELECT abc.def, ghi.jkl FROM xxx' + assert extract_column_names(sql) == ('def', 'jkl') + + +@pytest.mark.parametrize('sql', [ + 'INSERT INTO foo (x, y, z) VALUES (5, 6, 7) RETURNING x, y', + 'DELETE FROM foo WHERE x > y RETURNING x, y', + 'UPDATE foo SET x = 9 RETURNING x, y', +]) +def test_extract_column_names_from_returning_clause(sql): + assert extract_column_names(sql) == ('x', 'y') + + +def test_simple_cte_extraction(): + sql = 'WITH a AS (SELECT abc FROM xxx) SELECT * FROM a' + start_pos = len('WITH a AS ') + stop_pos = len('WITH a AS (SELECT abc FROM xxx)') + ctes, remainder = extract_ctes(sql) + + assert tuple(ctes) == (('a', ('abc',), start_pos, stop_pos),) + assert remainder.strip() == 'SELECT * FROM a' + + +def test_cte_extraction_around_comments(): + sql = '''--blah blah blah + WITH a AS (SELECT abc def FROM x) + SELECT * FROM a''' + start_pos = len('''--blah blah blah + WITH a AS ''') + stop_pos = len('''--blah blah blah + WITH a AS (SELECT abc def FROM x)''') + + ctes, remainder = extract_ctes(sql) + assert tuple(ctes) == (('a', ('def',), start_pos, stop_pos),) + assert remainder.strip() == 'SELECT * FROM a' + + +def test_multiple_cte_extraction(): + sql = '''WITH + x AS (SELECT abc, def FROM x), + y AS (SELECT ghi, jkl FROM y) + SELECT * FROM a, b''' + + start1 = len('''WITH + x AS ''') + + stop1 = len('''WITH + x AS (SELECT abc, def FROM x)''') + + start2 = len('''WITH + x AS (SELECT abc, def FROM x), + y AS ''') + + stop2 = len('''WITH + x AS (SELECT abc, def FROM x), + y AS (SELECT ghi, jkl FROM y)''') + + ctes, remainder = extract_ctes(sql) + assert tuple(ctes) == ( + ('x', ('abc', 'def'), start1, stop1), + ('y', ('ghi', 'jkl'), start2, stop2)) + + |