summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDarik Gamble <darik.gamble.spam@gmail.com>2016-07-22 07:28:21 -0400
committerDarik Gamble <darik.gamble.spam@gmail.com>2016-07-27 15:33:58 -0400
commitc66fbbf5987e91082434e3cb1ad73734c3858d95 (patch)
tree1ea0f31dbf20cf2f5230ce25c5ef3ac78c242454
parenta52ee78fcb41b4ab1fc17d0dd05f1a8cb6da4b92 (diff)
Add parseutils module for processing CTEs (not hooked up yet)
-rw-r--r--pgcli/packages/parseutils/ctes.py146
-rw-r--r--tests/parseutils/test_ctes.py121
2 files changed, 267 insertions, 0 deletions
diff --git a/pgcli/packages/parseutils/ctes.py b/pgcli/packages/parseutils/ctes.py
new file mode 100644
index 00000000..12c207c4
--- /dev/null
+++ b/pgcli/packages/parseutils/ctes.py
@@ -0,0 +1,146 @@
+from sqlparse import parse
+from sqlparse.tokens import Keyword, CTE, DML
+from sqlparse.sql import Identifier, IdentifierList, Parenthesis
+from collections import namedtuple
+from .meta import TableMetadata, ColumnMetadata
+
+
+# TableExpression is a namedtuple representing a CTE, used internally
+# name: cte alias assigned in the query
+# columns: list of column names
+# start: index into the original string of the left parens starting the CTE
+# stop: index into the original string of the right parens ending the CTE
+TableExpression = namedtuple('TableExpression', 'name columns start stop')
+
+
+def isolate_query_ctes(full_text, text_before_cursor):
+ """Simplify a query by converting CTEs into table metadata objects
+ """
+
+ if not full_text:
+ return full_text, text_before_cursor, tuple()
+
+ ctes, remainder = extract_ctes(full_text)
+ if not ctes:
+ return full_text, text_before_cursor, ()
+
+ current_position = len(text_before_cursor)
+ meta = []
+
+ for cte in ctes:
+ if cte.start < current_position < cte.stop:
+ # Currently editing a cte - treat its body as the current full_text
+ text_before_cursor = full_text[cte.start:current_position]
+ full_text = full_text[cte.start:cte.stop]
+ return full_text, text_before_cursor, meta
+
+ # Append this cte to the list of available table metadata
+ cols = (ColumnMetadata(name, None, ()) for name in cte.columns)
+ meta.append(TableMetadata(cte.name, cols))
+
+ # Editing past the last cte (ie the main body of the query)
+ full_text = full_text[ctes[-1].stop:]
+ text_before_cursor = text_before_cursor[ctes[-1].stop:current_position]
+
+ return full_text, text_before_cursor, tuple(meta)
+
+
+def extract_ctes(sql):
+ """ Extract constant table expresseions from a query
+
+ Returns tuple (ctes, remainder_sql)
+
+ ctes is a list of TableExpression namedtuples
+ remainder_sql is the text from the original query after the CTEs have
+ been stripped.
+ """
+
+ p = parse(sql)[0]
+
+ # Make sure the first meaningful token is "WITH" which is necessary to
+ # define CTEs
+ idx, tok = p.token_next(-1, skip_ws=True, skip_cm=True)
+ if not (tok and tok.ttype == CTE):
+ return [], sql
+
+ # Get the next (meaningful) token, which should be the first CTE
+ idx, tok = p.token_next(idx)
+ start_pos = token_start_pos(p.tokens, idx)
+ ctes = []
+
+ if isinstance(tok, IdentifierList):
+ # Multiple ctes
+ for t in tok.get_identifiers():
+ cte_start_offset = token_start_pos(tok.tokens, tok.token_index(t))
+ cte = get_cte_from_token(t, start_pos + cte_start_offset)
+ if not cte:
+ continue
+ ctes.append(cte)
+ elif isinstance(tok, Identifier):
+ # A single CTE
+ cte = get_cte_from_token(tok, start_pos)
+ if cte:
+ ctes.append(cte)
+
+ idx = p.token_index(tok) + 1
+
+ # Collapse everything after the ctes into a remainder query
+ remainder = u''.join(str(tok) for tok in p.tokens[idx:])
+
+ return ctes, remainder
+
+
+def get_cte_from_token(tok, pos0):
+ cte_name = tok.get_real_name()
+ if not cte_name:
+ return None
+
+ # Find the start position of the opening parens enclosing the cte body
+ idx, parens = tok.token_next_by(Parenthesis)
+ if not parens:
+ return None
+
+ start_pos = pos0 + token_start_pos(tok.tokens, idx)
+ cte_len = len(str(parens)) # includes parens
+ stop_pos = start_pos + cte_len
+
+ column_names = extract_column_names(parens)
+
+ return TableExpression(cte_name, column_names, start_pos, stop_pos)
+
+
+def extract_column_names(parsed):
+ # Find the first DML token to check if it's a SELECT or INSERT/UPDATE/DELETE
+ idx, tok = parsed.token_next_by(t=DML)
+ tok_val = tok and tok.value.lower()
+
+ if tok_val in ('insert', 'update', 'delete'):
+ # Jump ahead to the RETURNING clause where the list of column names is
+ idx, tok = parsed.token_next_by(idx, (Keyword, 'returning'))
+ elif not tok_val == 'select':
+ # Must be invalid CTE
+ return ()
+
+ # The next token should be either a column name, or a list of column names
+ idx, tok = parsed.token_next(idx, skip_ws=True, skip_cm=True)
+ return tuple(t.get_name() for t in _identifiers(tok))
+
+
+def token_start_pos(tokens, idx):
+ return sum(len(str(t)) for t in tokens[:idx])
+
+
+def _identifiers(tok):
+ if isinstance(tok, IdentifierList):
+ for t in tok.get_identifiers():
+ # NB: IdentifierList.get_identifiers() can return non-identifiers!
+ if isinstance(t, Identifier):
+ yield t
+ elif isinstance(tok, Identifier):
+ yield tok
+
+
+
+
+
+
diff --git a/tests/parseutils/test_ctes.py b/tests/parseutils/test_ctes.py
new file mode 100644
index 00000000..9566cf65
--- /dev/null
+++ b/tests/parseutils/test_ctes.py
@@ -0,0 +1,121 @@
+import pytest
+from sqlparse import parse
+from pgcli.packages.parseutils.ctes import (
+ token_start_pos, extract_ctes,
+ extract_column_names as _extract_column_names)
+
+
+def extract_column_names(sql):
+ p = parse(sql)[0]
+ return _extract_column_names(p)
+
+
+def test_token_str_pos():
+ sql = 'SELECT * FROM xxx'
+ p = parse(sql)[0]
+ idx = p.token_index(p.tokens[-1])
+ assert token_start_pos(p.tokens, idx) == len('SELECT * FROM ')
+
+ sql = 'SELECT * FROM \nxxx'
+ p = parse(sql)[0]
+ idx = p.token_index(p.tokens[-1])
+ assert token_start_pos(p.tokens, idx) == len('SELECT * FROM \n')
+
+
+def test_single_column_name_extraction():
+ sql = 'SELECT abc FROM xxx'
+ assert extract_column_names(sql) == ('abc',)
+
+
+def test_aliased_single_column_name_extraction():
+ sql = 'SELECT abc def FROM xxx'
+ assert extract_column_names(sql) == ('def',)
+
+
+def test_aliased_expression_name_extraction():
+ sql = 'SELECT 99 abc FROM xxx'
+ assert extract_column_names(sql) == ('abc',)
+
+
+def test_multiple_column_name_extraction():
+ sql = 'SELECT abc, def FROM xxx'
+ assert extract_column_names(sql) == ('abc', 'def')
+
+
+def test_missing_column_name_handled_gracefully():
+ sql = 'SELECT abc, 99 FROM xxx'
+ assert extract_column_names(sql) == ('abc',)
+
+ sql = 'SELECT abc, 99, def FROM xxx'
+ assert extract_column_names(sql) == ('abc', 'def')
+
+
+def test_aliased_multiple_column_name_extraction():
+ sql = 'SELECT abc def, ghi jkl FROM xxx'
+ assert extract_column_names(sql) == ('def', 'jkl')
+
+
+def test_table_qualified_column_name_extraction():
+ sql = 'SELECT abc.def, ghi.jkl FROM xxx'
+ assert extract_column_names(sql) == ('def', 'jkl')
+
+
+@pytest.mark.parametrize('sql', [
+ 'INSERT INTO foo (x, y, z) VALUES (5, 6, 7) RETURNING x, y',
+ 'DELETE FROM foo WHERE x > y RETURNING x, y',
+ 'UPDATE foo SET x = 9 RETURNING x, y',
+])
+def test_extract_column_names_from_returning_clause(sql):
+ assert extract_column_names(sql) == ('x', 'y')
+
+
+def test_simple_cte_extraction():
+ sql = 'WITH a AS (SELECT abc FROM xxx) SELECT * FROM a'
+ start_pos = len('WITH a AS ')
+ stop_pos = len('WITH a AS (SELECT abc FROM xxx)')
+ ctes, remainder = extract_ctes(sql)
+
+ assert tuple(ctes) == (('a', ('abc',), start_pos, stop_pos),)
+ assert remainder.strip() == 'SELECT * FROM a'
+
+
+def test_cte_extraction_around_comments():
+ sql = '''--blah blah blah
+ WITH a AS (SELECT abc def FROM x)
+ SELECT * FROM a'''
+ start_pos = len('''--blah blah blah
+ WITH a AS ''')
+ stop_pos = len('''--blah blah blah
+ WITH a AS (SELECT abc def FROM x)''')
+
+ ctes, remainder = extract_ctes(sql)
+ assert tuple(ctes) == (('a', ('def',), start_pos, stop_pos),)
+ assert remainder.strip() == 'SELECT * FROM a'
+
+
+def test_multiple_cte_extraction():
+ sql = '''WITH
+ x AS (SELECT abc, def FROM x),
+ y AS (SELECT ghi, jkl FROM y)
+ SELECT * FROM a, b'''
+
+ start1 = len('''WITH
+ x AS ''')
+
+ stop1 = len('''WITH
+ x AS (SELECT abc, def FROM x)''')
+
+ start2 = len('''WITH
+ x AS (SELECT abc, def FROM x),
+ y AS ''')
+
+ stop2 = len('''WITH
+ x AS (SELECT abc, def FROM x),
+ y AS (SELECT ghi, jkl FROM y)''')
+
+ ctes, remainder = extract_ctes(sql)
+ assert tuple(ctes) == (
+ ('x', ('abc', 'def'), start1, stop1),
+ ('y', ('ghi', 'jkl'), start2, stop2))
+
+