diff options
author | Amjith Ramanujam <amjith.r@gmail.com> | 2017-03-08 18:50:59 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-03-08 18:50:59 -0800 |
commit | 83442f8ebf3b855ccc9ea691e6797dc63a5648bc (patch) | |
tree | eff32a691b08169049dcc3e2488beaa7e9029e06 | |
parent | 4904a982dd9970f843a0c368d62de55cd6f1c0a4 (diff) | |
parent | 1277752d6287225062ed6bdb32edb9fa8b82b2c8 (diff) |
Merge pull request #655 from dbcli/koljonen/parse_function_body
Parse function bodies
-rw-r--r-- | changelog.rst | 7 | ||||
-rw-r--r-- | pgcli/packages/sqlcompletion.py | 34 | ||||
-rw-r--r-- | tests/test_smart_completion_multiple_schemata.py | 11 | ||||
-rw-r--r-- | tests/test_sqlcompletion.py | 79 |
4 files changed, 129 insertions, 2 deletions
diff --git a/changelog.rst b/changelog.rst index 156562f7..94cb634a 100644 --- a/changelog.rst +++ b/changelog.rst @@ -1,3 +1,10 @@ +Upcoming +===== + +Features: +--------- +* Better suggestions when editing functions (Thanks: `Joakim Koljonen`_) + 1.5.0 ===== diff --git a/pgcli/packages/sqlcompletion.py b/pgcli/packages/sqlcompletion.py index e4a1de72..a870d1e1 100644 --- a/pgcli/packages/sqlcompletion.py +++ b/pgcli/packages/sqlcompletion.py @@ -168,6 +168,26 @@ def _strip_named_query(txt): txt = named_query_regex.sub('', txt) return txt +function_body_pattern = re.compile('(\\$.*?\\$)([\s\S]*?)\\1', re.M) + + +def _find_function_body(text): + split = function_body_pattern.search(text) + return (split.start(2), split.end(2)) if split else (None, None) + + +def _statement_from_function(full_text, text_before_cursor, statement): + current_pos = len(text_before_cursor) + body_start, body_end = _find_function_body(full_text) + if body_start is None: + return full_text, text_before_cursor, statement + if not body_start <= current_pos < body_end: + return full_text, text_before_cursor, statement + full_text = full_text[body_start:body_end] + text_before_cursor = text_before_cursor[body_start:] + parsed = sqlparse.parse(text_before_cursor) + return _split_multiple_statements(full_text, text_before_cursor, parsed) + def _split_multiple_statements(full_text, text_before_cursor, parsed): if len(parsed) > 1: @@ -191,8 +211,18 @@ def _split_multiple_statements(full_text, text_before_cursor, parsed): statement = parsed[0] else: # The empty string - statement = None - + return full_text, text_before_cursor, None + + token2 = None + if statement.get_type() in ('CREATE', 'CREATE OR REPLACE'): + token1 = statement.token_first() + if token1: + token1_idx = statement.token_index(token1) + token2 = statement.token_next(token1_idx)[1] + if token2 and token2.value.upper() == 'FUNCTION': + full_text, text_before_cursor, statement = _statement_from_function( + full_text, text_before_cursor, statement + ) return full_text, text_before_cursor, statement diff --git a/tests/test_smart_completion_multiple_schemata.py b/tests/test_smart_completion_multiple_schemata.py index 2d2d599a..86875047 100644 --- a/tests/test_smart_completion_multiple_schemata.py +++ b/tests/test_smart_completion_multiple_schemata.py @@ -387,6 +387,17 @@ def test_wildcard_column_expansion_with_alias_qualifier(completer, complete_even assert expected == completions @pytest.mark.parametrize('text', [ + ''' + SELECT count(1) FROM users; + CREATE FUNCTION foo(custom.products _products) returns custom.shipments + LANGUAGE SQL + AS $foo$ + SELECT 1 FROM custom.shipments; + INSERT INTO public.orders(*) values(-1, now(), 'preliminary'); + SELECT 2 FROM custom.users; + $foo$; + SELECT count(1) FROM custom.shipments; + ''', 'INSERT INTO public.orders(*', 'INSERT INTO public.Orders(*', 'INSERT INTO public.orders (*', diff --git a/tests/test_sqlcompletion.py b/tests/test_sqlcompletion.py index 80507319..a2928df4 100644 --- a/tests/test_sqlcompletion.py +++ b/tests/test_sqlcompletion.py @@ -582,6 +582,85 @@ def test_3_statements_2nd_current(): 'select * from a; select ') assert set(suggestions) == cols_etc('b') +@pytest.mark.parametrize('text', [ +''' +CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $$ +SELECT FROM foo; +SELECT 2 FROM bar; +$$ language sql; + ''', + '''create function func2(int, varchar) +RETURNS text +language sql AS +$func$ +SELECT 2 FROM bar; +SELECT FROM foo; +$func$ + ''', +''' +CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $func$ +SELECT 3 FROM foo; +SELECT 2 FROM bar; +$$ language sql; +create function func2(int, varchar) +RETURNS text +language sql AS +$func$ +SELECT 2 FROM bar; +SELECT FROM foo; +$func$ + ''', +''' +SELECT * FROM baz; +CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $func$ +SELECT FROM foo; +SELECT 2 FROM bar; +$$ language sql; +create function func2(int, varchar) +RETURNS text +language sql AS +$func$ +SELECT 3 FROM bar; +SELECT FROM foo; +$func$ +SELECT * FROM qux; + ''' +]) +def test_statements_in_function_body(text): + suggestions = suggest_type(text, text[:text.find(' ') + 1]) + assert set(suggestions) == set([ + Column(table_refs=((None, 'foo', None, False),), qualifiable=True), + Function(schema=None), + Keyword() + ]) + +functions = [ +''' +CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $$ +SELECT 1 FROM foo; +SELECT 2 FROM bar; +$$ language sql; + ''', + ''' +create function func2(int, varchar) +RETURNS text +language sql AS +' +SELECT 2 FROM bar; +SELECT 1 FROM foo; +'; + ''' +] + +@pytest.mark.parametrize('text', functions) +def test_statements_with_cursor_after_function_body(text): + suggestions = suggest_type(text, text[:text.find('; ') + 1]) + assert set(suggestions) == set([Keyword()]) + +@pytest.mark.parametrize('text', functions) +def test_statements_with_cursor_before_function_body(text): + suggestions = suggest_type(text, '') + assert set(suggestions) == set([Keyword()]) def test_create_db_with_template(): suggestions = suggest_type('create database foo with template ', |