diff options
author | Darik Gamble <darik.gamble.spam@gmail.com> | 2016-07-22 19:10:56 -0400 |
---|---|---|
committer | Darik Gamble <darik.gamble.spam@gmail.com> | 2016-07-27 15:33:59 -0400 |
commit | 2407de0e4c4c52e27b0546c374722bd4e88287ad (patch) | |
tree | fc273eba9e0bf03343e0d5a93c9b5d6cfd8b9ddc | |
parent | d113d4e38f02fc3120e81bdb65c159ee24dd392a (diff) |
Make suggestions based on local tables
-rw-r--r-- | pgcli/pgcompleter.py | 11 | ||||
-rw-r--r-- | tests/test_smart_completion_public_schema_only.py | 47 |
2 files changed, 56 insertions, 2 deletions
diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py index a521e8d7..dd171d84 100644 --- a/pgcli/pgcompleter.py +++ b/pgcli/pgcompleter.py @@ -366,7 +366,8 @@ class PGCompleter(Completer): def get_column_matches(self, suggestion, word_before_cursor): tables = suggestion.table_refs _logger.debug("Completion column scope: %r", tables) - scoped_cols = self.populate_scoped_cols(tables) + scoped_cols = self.populate_scoped_cols(tables, suggestion.local_tables) + colit = scoped_cols.items flat_cols = list(chain(*((c.name for c in cols) for t, cols in colit()))) @@ -567,6 +568,7 @@ class PGCompleter(Completer): def get_table_matches(self, suggestion, word_before_cursor, alias=False): tables = self.populate_schema_objects(suggestion.schema, 'tables') + tables.extend(tbl.name for tbl in suggestion.local_tables) # Unless we're sure the user really wants them, don't suggest the # pg_catalog tables that are implicitly on the search path @@ -654,9 +656,10 @@ class PGCompleter(Completer): Path: get_path_matches, } - def populate_scoped_cols(self, scoped_tbls): + def populate_scoped_cols(self, scoped_tbls, local_tbls=()): """ Find all columns in a set of scoped_tables :param scoped_tbls: list of TableReference namedtuples + :param local_tbls: tuple(TableMetadata) :return: {TableReference:{colname:ColumnMetaData}} """ @@ -689,6 +692,10 @@ class PGCompleter(Completer): cols = cols.values() addcols(schema, relname, tbl.alias, reltype, cols) break + + # Local tables should shadow database tables + for tbl in local_tbls: + columns[tbl.name] = tbl.columns return columns diff --git a/tests/test_smart_completion_public_schema_only.py b/tests/test_smart_completion_public_schema_only.py index df73d51b..cc051d06 100644 --- a/tests/test_smart_completion_public_schema_only.py +++ b/tests/test_smart_completion_public_schema_only.py @@ -3,6 +3,7 @@ import pytest from metadata import (MetaData, alias, name_join, fk_join, join, keyword, schema, table, view, function, column, wildcard_expansion) from prompt_toolkit.document import Document +from prompt_toolkit.completion import Completion metadata = { 'tables': { @@ -867,3 +868,49 @@ def test_insert(completer, complete_event, text): result = completer.get_completions(Document(text=text, cursor_position=pos), complete_event) assert set(result) == set(testdata.columns('users')) + + +def test_suggest_cte_names(completer, complete_event): + text = ''' + WITH cte1 AS (SELECT a, b, c FROM foo), + cte2 AS (SELECT d, e, f FROM bar) + SELECT * FROM + ''' + pos = len(text) + result = completer.get_completions( + Document(text=text, cursor_position=pos), + complete_event) + expected = set([ + Completion('cte1', 0, display_meta='table'), + Completion('cte2', 0, display_meta='table'), + ]) + assert expected <= set(result) + + +def test_suggest_columns_from_cte(completer, complete_event): + text = 'WITH cte AS (SELECT foo, bar FROM baz) SELECT FROM cte' + pos = len('WITH cte AS (SELECT foo, bar FROM baz) SELECT ') + result = completer.get_completions(Document(text=text, cursor_position=pos), + complete_event) + expected = ([Completion('foo', 0, display_meta='column'), + Completion('bar', 0, display_meta='column'), + ] + + testdata.functions() + + testdata.builtin_functions() + + testdata.keywords() + ) + + assert set(expected) == set(result) + + +@pytest.mark.parametrize('text', [ + 'WITH cte AS (SELECT foo FROM bar) SELECT * FROM cte WHERE cte.', + 'WITH cte AS (SELECT foo FROM bar) SELECT * FROM cte c WHERE c.', +]) +def test_cte_qualified_columns(completer, complete_event, text): + pos = len(text) + result = completer.get_completions( + Document(text=text, cursor_position=pos), + complete_event) + expected = [Completion('foo', 0, display_meta='column')] + assert set(expected) == set(result) |