summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAmjith Ramanujam <amjith.r@gmail.com>2017-03-08 18:50:59 -0800
committerGitHub <noreply@github.com>2017-03-08 18:50:59 -0800
commit83442f8ebf3b855ccc9ea691e6797dc63a5648bc (patch)
treeeff32a691b08169049dcc3e2488beaa7e9029e06
parent4904a982dd9970f843a0c368d62de55cd6f1c0a4 (diff)
parent1277752d6287225062ed6bdb32edb9fa8b82b2c8 (diff)
Merge pull request #655 from dbcli/koljonen/parse_function_body
Parse function bodies
-rw-r--r--changelog.rst7
-rw-r--r--pgcli/packages/sqlcompletion.py34
-rw-r--r--tests/test_smart_completion_multiple_schemata.py11
-rw-r--r--tests/test_sqlcompletion.py79
4 files changed, 129 insertions, 2 deletions
diff --git a/changelog.rst b/changelog.rst
index 156562f7..94cb634a 100644
--- a/changelog.rst
+++ b/changelog.rst
@@ -1,3 +1,10 @@
+Upcoming
+=====
+
+Features:
+---------
+* Better suggestions when editing functions (Thanks: `Joakim Koljonen`_)
+
1.5.0
=====
diff --git a/pgcli/packages/sqlcompletion.py b/pgcli/packages/sqlcompletion.py
index e4a1de72..a870d1e1 100644
--- a/pgcli/packages/sqlcompletion.py
+++ b/pgcli/packages/sqlcompletion.py
@@ -168,6 +168,26 @@ def _strip_named_query(txt):
txt = named_query_regex.sub('', txt)
return txt
+function_body_pattern = re.compile('(\\$.*?\\$)([\s\S]*?)\\1', re.M)
+
+
+def _find_function_body(text):
+ split = function_body_pattern.search(text)
+ return (split.start(2), split.end(2)) if split else (None, None)
+
+
+def _statement_from_function(full_text, text_before_cursor, statement):
+ current_pos = len(text_before_cursor)
+ body_start, body_end = _find_function_body(full_text)
+ if body_start is None:
+ return full_text, text_before_cursor, statement
+ if not body_start <= current_pos < body_end:
+ return full_text, text_before_cursor, statement
+ full_text = full_text[body_start:body_end]
+ text_before_cursor = text_before_cursor[body_start:]
+ parsed = sqlparse.parse(text_before_cursor)
+ return _split_multiple_statements(full_text, text_before_cursor, parsed)
+
def _split_multiple_statements(full_text, text_before_cursor, parsed):
if len(parsed) > 1:
@@ -191,8 +211,18 @@ def _split_multiple_statements(full_text, text_before_cursor, parsed):
statement = parsed[0]
else:
# The empty string
- statement = None
-
+ return full_text, text_before_cursor, None
+
+ token2 = None
+ if statement.get_type() in ('CREATE', 'CREATE OR REPLACE'):
+ token1 = statement.token_first()
+ if token1:
+ token1_idx = statement.token_index(token1)
+ token2 = statement.token_next(token1_idx)[1]
+ if token2 and token2.value.upper() == 'FUNCTION':
+ full_text, text_before_cursor, statement = _statement_from_function(
+ full_text, text_before_cursor, statement
+ )
return full_text, text_before_cursor, statement
diff --git a/tests/test_smart_completion_multiple_schemata.py b/tests/test_smart_completion_multiple_schemata.py
index 2d2d599a..86875047 100644
--- a/tests/test_smart_completion_multiple_schemata.py
+++ b/tests/test_smart_completion_multiple_schemata.py
@@ -387,6 +387,17 @@ def test_wildcard_column_expansion_with_alias_qualifier(completer, complete_even
assert expected == completions
@pytest.mark.parametrize('text', [
+ '''
+ SELECT count(1) FROM users;
+ CREATE FUNCTION foo(custom.products _products) returns custom.shipments
+ LANGUAGE SQL
+ AS $foo$
+ SELECT 1 FROM custom.shipments;
+ INSERT INTO public.orders(*) values(-1, now(), 'preliminary');
+ SELECT 2 FROM custom.users;
+ $foo$;
+ SELECT count(1) FROM custom.shipments;
+ ''',
'INSERT INTO public.orders(*',
'INSERT INTO public.Orders(*',
'INSERT INTO public.orders (*',
diff --git a/tests/test_sqlcompletion.py b/tests/test_sqlcompletion.py
index 80507319..a2928df4 100644
--- a/tests/test_sqlcompletion.py
+++ b/tests/test_sqlcompletion.py
@@ -582,6 +582,85 @@ def test_3_statements_2nd_current():
'select * from a; select ')
assert set(suggestions) == cols_etc('b')
+@pytest.mark.parametrize('text', [
+'''
+CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $$
+SELECT FROM foo;
+SELECT 2 FROM bar;
+$$ language sql;
+ ''',
+ '''create function func2(int, varchar)
+RETURNS text
+language sql AS
+$func$
+SELECT 2 FROM bar;
+SELECT FROM foo;
+$func$
+ ''',
+'''
+CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $func$
+SELECT 3 FROM foo;
+SELECT 2 FROM bar;
+$$ language sql;
+create function func2(int, varchar)
+RETURNS text
+language sql AS
+$func$
+SELECT 2 FROM bar;
+SELECT FROM foo;
+$func$
+ ''',
+'''
+SELECT * FROM baz;
+CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $func$
+SELECT FROM foo;
+SELECT 2 FROM bar;
+$$ language sql;
+create function func2(int, varchar)
+RETURNS text
+language sql AS
+$func$
+SELECT 3 FROM bar;
+SELECT FROM foo;
+$func$
+SELECT * FROM qux;
+ '''
+])
+def test_statements_in_function_body(text):
+ suggestions = suggest_type(text, text[:text.find(' ') + 1])
+ assert set(suggestions) == set([
+ Column(table_refs=((None, 'foo', None, False),), qualifiable=True),
+ Function(schema=None),
+ Keyword()
+ ])
+
+functions = [
+'''
+CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $$
+SELECT 1 FROM foo;
+SELECT 2 FROM bar;
+$$ language sql;
+ ''',
+ '''
+create function func2(int, varchar)
+RETURNS text
+language sql AS
+'
+SELECT 2 FROM bar;
+SELECT 1 FROM foo;
+';
+ '''
+]
+
+@pytest.mark.parametrize('text', functions)
+def test_statements_with_cursor_after_function_body(text):
+ suggestions = suggest_type(text, text[:text.find('; ') + 1])
+ assert set(suggestions) == set([Keyword()])
+
+@pytest.mark.parametrize('text', functions)
+def test_statements_with_cursor_before_function_body(text):
+ suggestions = suggest_type(text, '')
+ assert set(suggestions) == set([Keyword()])
def test_create_db_with_template():
suggestions = suggest_type('create database foo with template ',