summaryrefslogtreecommitdiffstats
path: root/tests/parseutils/test_ctes.py
blob: 4d2d050f6cac63b37484dfea7fe38e30c56fd234 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import pytest
from sqlparse import parse
from pgcli.packages.parseutils.ctes import (
    token_start_pos, extract_ctes,
    extract_column_names as _extract_column_names)


def extract_column_names(sql):
    p = parse(sql)[0]
    return _extract_column_names(p)


def test_token_str_pos():
    sql = 'SELECT * FROM xxx'
    p = parse(sql)[0]
    idx = p.token_index(p.tokens[-1])
    assert token_start_pos(p.tokens, idx) == len('SELECT * FROM ')

    sql = 'SELECT * FROM \nxxx'
    p = parse(sql)[0]
    idx = p.token_index(p.tokens[-1])
    assert token_start_pos(p.tokens, idx) == len('SELECT * FROM \n')


def test_single_column_name_extraction():
    sql = 'SELECT abc FROM xxx'
    assert extract_column_names(sql) == ('abc',)


def test_aliased_single_column_name_extraction():
    sql = 'SELECT abc def FROM xxx'
    assert extract_column_names(sql) == ('def',)


def test_aliased_expression_name_extraction():
    sql = 'SELECT 99 abc FROM xxx'
    assert extract_column_names(sql) == ('abc',)


def test_multiple_column_name_extraction():
    sql = 'SELECT abc, def FROM xxx'
    assert extract_column_names(sql) == ('abc', 'def')


def test_missing_column_name_handled_gracefully():
    sql = 'SELECT abc, 99 FROM xxx'
    assert extract_column_names(sql) == ('abc',)

    sql = 'SELECT abc, 99, def FROM xxx'
    assert extract_column_names(sql) == ('abc', 'def')


def test_aliased_multiple_column_name_extraction():
    sql = 'SELECT abc def, ghi jkl FROM xxx'
    assert extract_column_names(sql) == ('def', 'jkl')


def test_table_qualified_column_name_extraction():
    sql = 'SELECT abc.def, ghi.jkl FROM xxx'
    assert extract_column_names(sql) == ('def', 'jkl')


@pytest.mark.parametrize('sql', [
    'INSERT INTO foo (x, y, z) VALUES (5, 6, 7) RETURNING x, y',
    'DELETE FROM foo WHERE x > y RETURNING x, y',
    'UPDATE foo SET x = 9 RETURNING x, y',
])
def test_extract_column_names_from_returning_clause(sql):
    assert extract_column_names(sql) == ('x', 'y')


def test_simple_cte_extraction():
    sql = 'WITH a AS (SELECT abc FROM xxx) SELECT * FROM a'
    start_pos = len('WITH a AS ')
    stop_pos = len('WITH a AS (SELECT abc FROM xxx)')
    ctes, remainder = extract_ctes(sql)

    assert tuple(ctes) == (('a', ('abc',), start_pos, stop_pos),)
    assert remainder.strip() == 'SELECT * FROM a'


def test_cte_extraction_around_comments():
    sql = '''--blah blah blah
            WITH a AS (SELECT abc def FROM x)
            SELECT * FROM a'''
    start_pos = len('''--blah blah blah
            WITH a AS ''')
    stop_pos = len('''--blah blah blah
            WITH a AS (SELECT abc def FROM x)''')

    ctes, remainder = extract_ctes(sql)
    assert tuple(ctes) == (('a', ('def',), start_pos, stop_pos),)
    assert remainder.strip() == 'SELECT * FROM a'


def test_multiple_cte_extraction():
    sql = '''WITH
            x AS (SELECT abc, def FROM x),
            y AS (SELECT ghi, jkl FROM y)
            SELECT * FROM a, b'''

    start1 = len('''WITH
            x AS ''')

    stop1 = len('''WITH
            x AS (SELECT abc, def FROM x)''')

    start2 = len('''WITH
            x AS (SELECT abc, def FROM x),
            y AS ''')

    stop2 = len('''WITH
            x AS (SELECT abc, def FROM x),
            y AS (SELECT ghi, jkl FROM y)''')

    ctes, remainder = extract_ctes(sql)
    assert tuple(ctes) == (
        ('x', ('abc', 'def'), start1, stop1),
        ('y', ('ghi', 'jkl'), start2, stop2))