summaryrefslogtreecommitdiffstats
path: root/pgcli/packages/parseutils/ctes.py
blob: 75e4e40f79e7e3b1160094bf1e9a5d9a23e6d0ad (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from sqlparse import parse
from sqlparse.tokens import Keyword, CTE, DML
from sqlparse.sql import Identifier, IdentifierList, Parenthesis
from collections import namedtuple
from .meta import TableMetadata, ColumnMetadata


# TableExpression is a namedtuple representing a CTE, used internally
# name: cte alias assigned in the query
# columns: list of column names
# start: index into the original string of the left parens starting the CTE
# stop: index into the original string of the right parens ending the CTE
TableExpression = namedtuple("TableExpression", "name columns start stop")


def isolate_query_ctes(full_text, text_before_cursor):
    """Simplify a query by converting CTEs into table metadata objects
    """

    if not full_text:
        return full_text, text_before_cursor, tuple()

    ctes, remainder = extract_ctes(full_text)
    if not ctes:
        return full_text, text_before_cursor, ()

    current_position = len(text_before_cursor)
    meta = []

    for cte in ctes:
        if cte.start < current_position < cte.stop:
            # Currently editing a cte - treat its body as the current full_text
            text_before_cursor = full_text[cte.start : current_position]
            full_text = full_text[cte.start : cte.stop]
            return full_text, text_before_cursor, meta

        # Append this cte to the list of available table metadata
        cols = (ColumnMetadata(name, None, ()) for name in cte.columns)
        meta.append(TableMetadata(cte.name, cols))

    # Editing past the last cte (ie the main body of the query)
    full_text = full_text[ctes[-1].stop :]
    text_before_cursor = text_before_cursor[ctes[-1].stop : current_position]

    return full_text, text_before_cursor, tuple(meta)


def extract_ctes(sql):
    """ Extract constant table expresseions from a query

        Returns tuple (ctes, remainder_sql)

        ctes is a list of TableExpression namedtuples
        remainder_sql is the text from the original query after the CTEs have
        been stripped.
    """

    p = parse(sql)[0]

    # Make sure the first meaningful token is "WITH" which is necessary to
    # define CTEs
    idx, tok = p.token_next(-1, skip_ws=True, skip_cm=True)
    if not (tok and tok.ttype == CTE):
        return [], sql

    # Get the next (meaningful) token, which should be the first CTE
    idx, tok = p.token_next(idx)
    if not tok:
        return ([], "")
    start_pos = token_start_pos(p.tokens, idx)
    ctes = []

    if isinstance(tok, IdentifierList):
        # Multiple ctes
        for t in tok.get_identifiers():
            cte_start_offset = token_start_pos(tok.tokens, tok.token_index(t))
            cte = get_cte_from_token(t, start_pos + cte_start_offset)
            if not cte:
                continue
            ctes.append(cte)
    elif isinstance(tok, Identifier):
        # A single CTE
        cte = get_cte_from_token(tok, start_pos)
        if cte:
            ctes.append(cte)

    idx = p.token_index(tok) + 1

    # Collapse everything after the ctes into a remainder query
    remainder = "".join(str(tok) for tok in p.tokens[idx:])

    return ctes, remainder


def get_cte_from_token(tok, pos0):
    cte_name = tok.get_real_name()
    if not cte_name:
        return None

    # Find the start position of the opening parens enclosing the cte body
    idx, parens = tok.token_next_by(Parenthesis)
    if not parens:
        return None

    start_pos = pos0 + token_start_pos(tok.tokens, idx)
    cte_len = len(str(parens))  # includes parens
    stop_pos = start_pos + cte_len

    column_names = extract_column_names(parens)

    return TableExpression(cte_name, column_names, start_pos, stop_pos)


def extract_column_names(parsed):
    # Find the first DML token to check if it's a SELECT or INSERT/UPDATE/DELETE
    idx, tok = parsed.token_next_by(t=DML)
    tok_val = tok and tok.value.lower()

    if tok_val in ("insert", "update", "delete"):
        # Jump ahead to the RETURNING clause where the list of column names is
        idx, tok = parsed.token_next_by(idx, (Keyword, "returning"))
    elif not tok_val == "select":
        # Must be invalid CTE
        return ()

    # The next token should be either a column name, or a list of column names
    idx, tok = parsed.token_next(idx, skip_ws=True, skip_cm=True)
    return tuple(t.get_name() for t in _identifiers(tok))


def token_start_pos(tokens, idx):
    return sum(len(str(t)) for t in tokens[:idx])


def _identifiers(tok):
    if isinstance(tok, IdentifierList):
        for t in tok.get_identifiers():
            # NB: IdentifierList.get_identifiers() can return non-identifiers!
            if isinstance(t, Identifier):
                yield t
    elif isinstance(tok, Identifier):
        yield tok