diff options
author | Joakim Koljonen <koljonen@outlook.com> | 2017-03-20 21:28:40 +0100 |
---|---|---|
committer | Joakim Koljonen <koljonen@outlook.com> | 2017-06-25 01:35:22 +0200 |
commit | f355c30ef79a539943c641be0376f5638caf28b0 (patch) | |
tree | 9de56750f1fc71765cc86efc630f1ad4b575618b | |
parent | 15b34c1dc955a35584ff0f33c5fc1eaac5cc5648 (diff) |
Include arguments in function completions
E.g. instead of suggesting `my_func()`, suggest `my_func(arg1 :=, arg2 :=)`
or `my_func(arg1 text, arg2 bigint)`, depending on the context.
-rw-r--r-- | changelog.rst | 1 | ||||
-rw-r--r-- | pgcli/completion_refresher.py | 11 | ||||
-rw-r--r-- | pgcli/packages/parseutils/meta.py | 108 | ||||
-rw-r--r-- | pgcli/packages/sqlcompletion.py | 18 | ||||
-rw-r--r-- | pgcli/pgcompleter.py | 186 | ||||
-rw-r--r-- | pgcli/pgexecute.py | 12 | ||||
-rw-r--r-- | tests/metadata.py | 29 | ||||
-rw-r--r-- | tests/parseutils/test_function_metadata.py | 15 | ||||
-rw-r--r-- | tests/test_completion_refresher.py | 4 | ||||
-rw-r--r-- | tests/test_pgexecute.py | 45 | ||||
-rw-r--r-- | tests/test_smart_completion_multiple_schemata.py | 10 | ||||
-rw-r--r-- | tests/test_smart_completion_public_schema_only.py | 97 |
12 files changed, 428 insertions, 108 deletions
diff --git a/changelog.rst b/changelog.rst index a547136e..4cfb31d2 100644 --- a/changelog.rst +++ b/changelog.rst @@ -6,6 +6,7 @@ Upcoming * Use dbcli's Homebrew tap for installing pgcli on macOS (issue #718) (Thanks: `Thomas Roten`_). * Only set `LESS` environment variable if it's unset. (Thanks: `Irina Truong`_) * Quote schema in `SET SCHEMA` statement (issue #469) (Thanks: `Irina Truong`_) +* Include arguments in function suggestions (Thanks: `Joakim Koljonen`_) 1.6.0 ===== diff --git a/pgcli/completion_refresher.py b/pgcli/completion_refresher.py index a7427e79..f137100f 100644 --- a/pgcli/completion_refresher.py +++ b/pgcli/completion_refresher.py @@ -119,12 +119,6 @@ def refresh_views(completer, executor): completer.extend_relations(executor.views(), kind='views') completer.extend_columns(executor.view_columns(), kind='views') - -@refresher('functions') -def refresh_functions(completer, executor): - completer.extend_functions(executor.functions()) - - @refresher('types') def refresh_types(completer, executor): completer.extend_datatypes(executor.datatypes()) @@ -148,3 +142,8 @@ def refresh_casing(completer, executor): if os.path.isfile(casing_file): with open(casing_file, 'r') as f: completer.extend_casing([line.strip() for line in f]) + + +@refresher('functions') +def refresh_functions(completer, executor): + completer.extend_functions(executor.functions()) diff --git a/pgcli/packages/parseutils/meta.py b/pgcli/packages/parseutils/meta.py index 6e64bffd..055478c0 100644 --- a/pgcli/packages/parseutils/meta.py +++ b/pgcli/packages/parseutils/meta.py @@ -1,15 +1,59 @@ from collections import namedtuple -ColumnMetadata = namedtuple('ColumnMetadata', ['name', 'datatype', 'foreignkeys']) -ForeignKey = namedtuple('ForeignKey', ['parentschema', 'parenttable', - 'parentcolumn', 'childschema', 'childtable', 'childcolumn']) +_ColumnMetadata = namedtuple( + 'ColumnMetadata', + ['name', 'datatype', 'foreignkeys', 'default', 'has_default'] +) + + +def ColumnMetadata( + name, datatype, foreignkeys=None, default=None, has_default=False +): + return _ColumnMetadata( + name, datatype, foreignkeys or [], default, has_default + ) + + +ForeignKey = namedtuple( + 'ForeignKey', + 'parentschema,parenttable,parentcolumn,childschema,childtable,childcolumn' +) TableMetadata = namedtuple('TableMetadata', 'name columns') +def parse_defaults(defaults_string): + """Yields default values for a function, given the string provided by + pg_get_expr(pg_catalog.pg_proc.proargdefaults, 0)""" + if not defaults_string: + return + current = '' + in_quote = None + for char in defaults_string: + if current == '' and char == ' ': + # Skip space after comma separating default expressions + continue + if char == '"' or char == '\'': + if in_quote and char == in_quote: + # End quote + in_quote = None + elif not in_quote: + # Begin quote + in_quote = char + elif char == ',' and not in_quote: + # End of expression + yield current + current = '' + continue + current += char + yield current + + class FunctionMetadata(object): - def __init__(self, schema_name, func_name, arg_names, arg_types, - arg_modes, return_type, is_aggregate, is_window, is_set_returning): + def __init__( + self, schema_name, func_name, arg_names, arg_types, arg_modes, + return_type, is_aggregate, is_window, is_set_returning, arg_defaults + ): """Class for describing a postgresql function""" self.schema_name = schema_name @@ -30,6 +74,8 @@ class FunctionMetadata(object): else: self.arg_types = None + self.arg_defaults = tuple(parse_defaults(arg_defaults)) + self.return_type = return_type.strip() self.is_aggregate = is_aggregate self.is_window = is_window @@ -42,19 +88,51 @@ class FunctionMetadata(object): def __ne__(self, other): return not self.__eq__(other) + def _signature(self): + return ( + self.schema_name, self.func_name, self.arg_names, self.arg_types, + self.arg_modes, self.return_type, self.is_aggregate, + self.is_window, self.is_set_returning, self.arg_defaults + ) + def __hash__(self): - return hash((self.schema_name, self.func_name, self.arg_names, - self.arg_types, self.arg_modes, self.return_type, - self.is_aggregate, self.is_window, self.is_set_returning)) + return hash(self._signature()) def __repr__(self): - return (('%s(schema_name=%r, func_name=%r, arg_names=%r, ' - 'arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, ' - 'is_window=%r, is_set_returning=%r)') - % (self.__class__.__name__, self.schema_name, self.func_name, - self.arg_names, self.arg_types, self.arg_modes, - self.return_type, self.is_aggregate, self.is_window, - self.is_set_returning)) + return ( + ( + '%s(schema_name=%r, func_name=%r, arg_names=%r, ' + 'arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, ' + 'is_window=%r, is_set_returning=%r, arg_defaults=%r)' + ) % (self.__class__.__name__,) + self._signature() + ) + + def has_variadic(self): + return self.arg_modes and any(arg_mode == 'v' for arg_mode in self.arg_modes) + + def args(self): + """Returns a list of input-parameter ColumnMetadata namedtuples.""" + if not self.arg_names: + return [] + modes = self.arg_modes or ['i'] * len(self.arg_names) + args = [ + (name, typ) + for name, typ, mode in zip(self.arg_names, self.arg_types, modes) + if mode in ('i', 'b', 'v') # IN, INOUT, VARIADIC + ] + + def arg(name, typ, num): + num_args = len(args) + num_defaults = len(self.arg_defaults) + has_default = num + num_defaults >= num_args + default = ( + self.arg_defaults[num - num_args + num_defaults] if has_default + else None + ) + return ColumnMetadata(name, typ, [], default, has_default) + + return [arg(name, typ, num) for num, (name, typ) in enumerate(args)] + def fields(self): """Returns a list of output-field ColumnMetadata namedtuples""" diff --git a/pgcli/packages/sqlcompletion.py b/pgcli/packages/sqlcompletion.py index 0fabd70b..fbdde900 100644 --- a/pgcli/packages/sqlcompletion.py +++ b/pgcli/packages/sqlcompletion.py @@ -112,6 +112,9 @@ class SqlStatement(object): tables = tables[1:] return tables + def get_previous_token(self, token): + return self.parsed.token_prev(self.parsed.token_index(token))[1] + def get_identifier_schema(self): schema = (self.identifier and self.identifier.get_parent_name()) or None # If schema name is unquoted, lower-case it @@ -432,8 +435,19 @@ def suggest_based_on_last_token(token, stmt): return tuple(suggest) - elif token_v in ('table', 'view', 'function'): - # E.g. 'DROP FUNCTION <funcname>', 'ALTER TABLE <tablname>' + elif token_v == 'function': + schema = stmt.get_identifier_schema() + # stmt.get_previous_token will fail for e.g. `SELECT 1 FROM functions WHERE function:` + try: + if stmt.get_previous_token(token).value.lower() in('drop', 'alter'): + return (Function(schema=schema, usage='signature'),) + else: + return (Function(schema=schema),) + except ValueError: + return tuple() + + elif token_v in ('table', 'view'): + # E.g. 'ALTER TABLE <tablname>' rel_type = {'table': Table, 'view': View, 'function': Function}[token_v] schema = stmt.get_identifier_schema() if schema: diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py index 220cd774..aa118dd7 100644 --- a/pgcli/pgcompleter.py +++ b/pgcli/pgcompleter.py @@ -29,17 +29,31 @@ NamedQueries.instance = NamedQueries.from_config( load_config(config_location() + 'config')) Match = namedtuple('Match', ['completion', 'priority']) -_SchemaObject = namedtuple('SchemaObject', ['name', 'schema', 'function']) +_SchemaObject = namedtuple('SchemaObject', 'name schema meta') + + +def SchemaObject(name, schema=None, meta=None): + return _SchemaObject(name, schema, meta) -def SchemaObject(name, schema=None, function=False): - return _SchemaObject(name, schema, function) _Candidate = namedtuple( - 'Candidate', ['completion', 'prio', 'meta', 'synonyms', 'prio2'] + 'Candidate', 'completion prio meta synonyms prio2 display' ) -def Candidate(completion, prio=None, meta=None, synonyms=None, prio2=None): - return _Candidate(completion, prio, meta, synonyms or [completion], prio2) + + +def Candidate( + completion, prio=None, meta=None, synonyms=None, prio2=None, + display=None +): + return _Candidate( + completion, prio, meta, synonyms or [completion], prio2, + display or completion + ) + + +# Used to strip trailing '::some_type' from default-value expressions +arg_default_type_strip_regex = re.compile(r'::[\w\.]+(\[\])?$') normalize_ref = lambda ref: ref if ref[0] == '"' else '"' + ref.lower() + '"' @@ -67,6 +81,16 @@ class PGCompleter(Completer): self.pgspecial = pgspecial self.prioritizer = PrevalenceCounter() settings = settings or {} + self.signature_arg_style = settings.get( + 'signature_arg_style', '{arg_name} {arg_type}' + ) + self.call_arg_style = settings.get( + 'call_arg_style', '{arg_name: <{max_arg_len}} := {arg_default}' + ) + self.call_arg_display_style = settings.get( + 'call_arg_display_style', '{arg_name}' + ) + self.call_arg_oneliner_max = settings.get('call_arg_oneliner_max', 2) self.search_path_filter = settings.get('search_path_filter') self.generate_aliases = settings.get('generate_aliases') self.casing_file = settings.get('casing_file') @@ -186,8 +210,6 @@ class PGCompleter(Completer): def extend_functions(self, func_data): # func_data is a list of function metadata namedtuples - # with fields schema_name, func_name, arg_list, result, - # is_aggregate, is_window, is_set_returning # dbmetadata['schema_name']['functions']['function_name'] should return # the function metadata namedtuple for the corresponding function @@ -203,6 +225,22 @@ class PGCompleter(Completer): self.all_completions.add(func) + self._refresh_arg_list_cache() + + def _refresh_arg_list_cache(self): + # We keep a cache of {function_usage:{function_metadata: function_arg_list_string}} + # This is used when suggesting functions, to avoid the latency that would result + # if we'd recalculate the arg lists each time we suggest functions (in large DBs) + self._arg_list_cache = { + usage: { + meta: self._arg_list(meta, usage) + for sch, funcs in self.dbmetadata['functions'].items() + for func, metas in funcs.items() + for meta in metas + } + for usage in ('call', 'call_display', 'signature') + } + def extend_foreignkeys(self, fk_data): # fk_data is a list of ForeignKey namedtuples, with fields @@ -329,7 +367,7 @@ class PGCompleter(Completer): matches = [] for cand in collection: if isinstance(cand, _Candidate): - item, prio, display_meta, synonyms, prio2 = cand + item, prio, display_meta, synonyms, prio2, display = cand if display_meta is None: display_meta = meta syn_matches = (_match(x) for x in synonyms) @@ -337,7 +375,7 @@ class PGCompleter(Completer): syn_matches = [m for m in syn_matches if m] sort_key = max(syn_matches) if syn_matches else None else: - item, display_meta, prio, prio2 = cand, meta, 0, 0 + item, display_meta, prio, prio2, display = cand, meta, 0, 0, cand sort_key = _match(cand) if sort_key: @@ -359,15 +397,22 @@ class PGCompleter(Completer): + tuple(c for c in item)) item = self.case(item) + display = self.case(display) priority = ( sort_key, type_priority, prio, priority_func(item), prio2, lexical_priority ) - - matches.append(Match( - completion=Completion(item, -text_len, - display_meta=display_meta), - priority=priority)) + matches.append( + Match( + completion=Completion( + text=item, + start_position=-text_len, + display_meta=display_meta, + display=display + ), + priority=priority + ) + ) return matches def case(self, word): @@ -405,7 +450,6 @@ class PGCompleter(Completer): return [m.completion for m in matches] - def get_column_matches(self, suggestion, word_before_cursor): tables = suggestion.table_refs do_qualify = suggestion.qualifiable and {'always': True, 'never': False, @@ -569,15 +613,16 @@ class PGCompleter(Completer): def get_function_matches(self, suggestion, word_before_cursor, alias=False): if suggestion.usage == 'from': # Only suggest functions allowed in FROM clause - filt = lambda f: not f.is_aggregate and not f.is_window + def filt(f): return not f.is_aggregate and not f.is_window else: alias = False - filt = lambda f: True + def filt(f): return True + arg_mode = 'signature' if suggestion.usage == 'signature' else 'call' # Function overloading means we way have multiple functions of the same # name at this point, so keep unique names only funcs = set( - self._make_cand(f, alias, suggestion) + self._make_cand(f, alias, suggestion, arg_mode) for f in self.populate_functions(suggestion.schema, filt) ) @@ -613,22 +658,84 @@ class PGCompleter(Completer): t_sug = Table(s.schema, s.table_refs, s.local_tables) v_sug = View(s.schema, s.table_refs) f_sug = Function(s.schema, s.table_refs, usage='from') - return (self.get_table_matches(t_sug, word_before_cursor, alias) + return ( + self.get_table_matches(t_sug, word_before_cursor, alias) + self.get_view_matches(v_sug, word_before_cursor, alias) - + self.get_function_matches(f_sug, word_before_cursor, alias)) + + self.get_function_matches(f_sug, word_before_cursor, alias) + ) + + def _arg_list(self, func, usage): + """Returns a an arg list string, e.g. `(_foo:=23)` for a func. - # Note: tbl is a SchemaObject - def _make_cand(self, tbl, do_alias, suggestion): + :param func is a FunctionMetadata object + :param usage is 'call', 'call_display' or 'signature' + + """ + template = { + 'call': self.call_arg_style, + 'call_display': self.call_arg_display_style, + 'signature': self.signature_arg_style + }[usage] + args = func.args() + if not template: + return '()' + elif usage == 'call' and len(args) < 2: + return '()' + elif usage == 'call' and func.has_variadic(): + return '()' + multiline = usage == 'call' and len(args) > self.call_arg_oneliner_max + max_arg_len = max(len(a.name) for a in args) if multiline else 0 + args = ( + self._format_arg(template, arg, arg_num + 1, max_arg_len) + for arg_num, arg in enumerate(args) + ) + if multiline: + return '(' + ','.join('\n ' + a for a in args if a) + '\n)' + else: + return '(' + ', '.join(a for a in args if a) + ')' + + def _format_arg(self, template, arg, arg_num, max_arg_len): + if not template: + return None + if arg.has_default: + arg_default = 'NULL' if arg.default is None else arg.default + # Remove trailing ::(schema.)type + arg_default = arg_default_type_strip_regex.sub('', arg_default) + else: + arg_default = '' + return template.format( + max_arg_len=max_arg_len, + arg_name=self.case(arg.name), + arg_num=arg_num, + arg_type=arg.datatype, + arg_default=arg_default + ) + + def _make_cand(self, tbl, do_alias, suggestion, arg_mode=None): + """Returns a Candidate namedtuple. + + :param tbl is a SchemaObject + :param arg_mode determines what type of arg list to suffix for functions. + Possible values: call, signature + + """ cased_tbl = self.case(tbl.name) if do_alias: alias = self.alias(cased_tbl, suggestion.table_refs) synonyms = (cased_tbl, generate_alias(cased_tbl)) - maybe_parens = '()' if tbl.function else '' maybe_alias = (' ' + alias) if do_alias else '' maybe_schema = (self.case(tbl.schema) + '.') if tbl.schema else '' - item = maybe_schema + cased_tbl + maybe_parens + maybe_alias + suffix = self._arg_list_cache[arg_mode][tbl.meta] if arg_mode else '' + if arg_mode == 'call': + display_suffix = self._arg_list_cache['call_display'][tbl.meta] + elif arg_mode == 'signature': + display_suffix = self._arg_list_cache['signature'][tbl.meta] + else: + display_suffix = '' + item = maybe_schema + cased_tbl + suffix + maybe_alias + display = maybe_schema + cased_tbl + display_suffix + maybe_alias prio2 = 0 if tbl.schema else 1 - return Candidate(item, synonyms=synonyms, prio2=prio2) + return Candidate(item, synonyms=synonyms, prio2=prio2, display=display) def get_table_matches(self, suggestion, word_before_cursor, alias=False): tables = self.populate_schema_objects(suggestion.schema, 'tables') @@ -736,10 +843,12 @@ class PGCompleter(Completer): } def populate_scoped_cols(self, scoped_tbls, local_tbls=()): - """ Find all columns in a set of scoped_tables + """Find all columns in a set of scoped_tables. + :param scoped_tbls: list of TableReference namedtuples :param local_tbls: tuple(TableMetadata) :return: {TableReference:{colname:ColumnMetaData}} + """ ctes = dict((normalize_ref(t.name), t.columns) for t in local_tbls) columns = OrderedDict() @@ -780,8 +889,10 @@ class PGCompleter(Completer): return columns def _get_schemas(self, obj_typ, schema): - """ Returns a list of schemas from which to suggest objects - schema is the schema qualification input by the user (if any) + """Returns a list of schemas from which to suggest objects. + + :param schema is the schema qualification input by the user (if any) + """ metadata = self.dbmetadata[obj_typ] if schema: @@ -793,8 +904,10 @@ class PGCompleter(Completer): return None if parent or schema in self.search_path else schema def populate_schema_objects(self, schema, obj_type): - """Returns a list of SchemaObjects representing tables or views - schema is the schema qualification input by the user (if any) + """Returns a list of SchemaObjects representing tables or views. + + :param schema is the schema qualification input by the user (if any) + """ return [ @@ -807,11 +920,12 @@ class PGCompleter(Completer): ] def populate_functions(self, schema, filter_func): - """Returns a list of function SchemaObjects + """Returns a list of function SchemaObjects. + + :param filter_func is a function that accepts a FunctionMetadata + namedtuple and returns a boolean indicating whether that + function should be kept or discarded - filter_func is a function that accepts a FunctionMetadata namedtuple - and returns a boolean indicating whether that function should be - kept or discarded """ # Because of multiple dispatch, we can have multiple functions @@ -821,7 +935,7 @@ class PGCompleter(Completer): SchemaObject( name=func, schema=(self._maybe_schema(schema=sch, parent=schema)), - function=True + meta=meta ) for sch in self._get_schemas('functions', schema) for (func, metas) in self.dbmetadata['functions'][sch].items() diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py index 47b17910..8b52e455 100644 --- a/pgcli/pgexecute.py +++ b/pgcli/pgexecute.py @@ -517,7 +517,8 @@ class PGExecute(object): prorettype::regtype::text return_type, p.proisagg is_aggregate, p.proiswindow is_window, - p.proretset is_set_returning + p.proretset is_set_returning, + pg_get_expr(proargdefaults, 0) AS arg_defaults FROM pg_catalog.pg_proc p INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace @@ -534,7 +535,8 @@ class PGExecute(object): prorettype::regtype::text, p.proisagg is_aggregate, false is_window, - p.proretset is_set_returning + p.proretset is_set_returning, + NULL AS arg_defaults FROM pg_catalog.pg_proc p INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace @@ -551,7 +553,8 @@ class PGExecute(object): '' ret_type, p.proisagg is_aggregate, false is_window, - p.proretset is_set_returning + p.proretset is_set_returning, + NULL AS arg_defaults FROM pg_catalog.pg_proc p INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace @@ -646,6 +649,9 @@ class PGExecute(object): UNION -- Schema names SELECT nspname FROM pg_catalog.pg_namespace + UNION -- Parameter names + SELECT unnest(proargnames) + FROM pg_proc ) SELECT Word FROM OrderWords diff --git a/tests/metadata.py b/tests/metadata.py index 37c8ff96..be19b7d8 100644 --- a/tests/metadata.py +++ b/tests/metadata.py @@ -22,6 +22,15 @@ def completion(display_meta, text, pos=0): return Completion(text, start_position=pos, display_meta=display_meta) +def function(text, pos=0, display=None): + return Completion( + text, + display=display or text, + start_position=pos, + display_meta='function' + ) + + def get_result(completer, text, position=None): position = len(text) if position is None else position return completer.get_completions( @@ -40,7 +49,6 @@ def result_set(completer, text, position=None): schema = partial(completion, 'schema') table = partial(completion, 'table') view = partial(completion, 'view') -function = partial(completion, 'function') column = partial(completion, 'column') keyword = partial(completion, 'keyword') datatype = partial(completion, 'datatype') @@ -93,8 +101,21 @@ class MetaData(object): def functions(self, parent='public', pos=0): return [ - function(escape(x[0] + '()'), pos) - for x in self.metadata.get('functions', {}).get(parent, [])] + function( + escape(x[0]) + '(' + ', '.join( + arg_name + ' := ' + for (arg_name, arg_mode) in zip(x[1], x[3]) + if arg_mode in ('b', 'i') + ) + ')', + pos, + escape(x[0]) + '(' + ', '.join( + arg_name + for (arg_name, arg_mode) in zip(x[1], x[3]) + if arg_mode in ('b', 'i') + ) + ')' + ) + for x in self.metadata.get('functions', {}).get(parent, []) + ] def schemas(self, pos=0): schemas = set(sch for schs in self.metadata.values() for sch in schs) @@ -191,7 +212,7 @@ class MetaData(object): view_cols.extend([(sch, tbl, col, 'text') for col in cols]) functions = [ - FunctionMetadata(sch, *func_meta) + FunctionMetadata(sch, *func_meta, arg_defaults=None) for sch, funcs in metadata['functions'].items() for func_meta in funcs] diff --git a/tests/parseutils/test_function_metadata.py b/tests/parseutils/test_function_metadata.py index 097ce62d..1f9c6930 100644 --- a/tests/parseutils/test_function_metadata.py +++ b/tests/parseutils/test_function_metadata.py @@ -2,12 +2,15 @@ from pgcli.packages.parseutils.meta import FunctionMetadata def test_function_metadata_eq(): - f1 = FunctionMetadata('s', 'f', ['x'], ['integer'], [], 'int', False, - False, False) - f2 = FunctionMetadata('s', 'f', ['x'], ['integer'], [], 'int', False, - False, False) - f3 = FunctionMetadata('s', 'g', ['x'], ['integer'], [], 'int', False, - False, False) + f1 = FunctionMetadata( + 's', 'f', ['x'], ['integer'], [], 'int', False, False, False, None + ) + f2 = FunctionMetadata( + 's', 'f', ['x'], ['integer'], [], 'int', False, False, False, None + ) + f3 = FunctionMetadata( + 's', 'g', ['x'], ['integer'], [], 'int', False, False, False, None + ) assert f1 == f2 assert f1 != f3 assert not (f1 != f2) diff --git a/tests/test_completion_refresher.py b/tests/test_completion_refresher.py index 08f7811e..e6a39ba4 100644 --- a/tests/test_completion_refresher.py +++ b/tests/test_completion_refresher.py @@ -17,8 +17,8 @@ def test_ctor(refresher): """ assert len(refresher.refreshers) > 0 actual_handlers = list(refresher.refreshers.keys()) - expected_handlers = ['schemata', 'tables', 'views', 'functions', - 'types', 'databases', 'casing'] + expected_handlers = ['schemata', 'tables', 'views', + 'types', 'databases', 'casing', 'functions'] assert expected_handlers == actual_handlers diff --git a/tests/test_pgexecute.py b/tests/test_pgexecute.py index 315d033a..51f7cb3e 100644 --- a/tests/test_pgexecute.py +++ b/tests/test_pgexecute.py @@ -8,6 +8,16 @@ from textwrap import dedent from utils import run, dbtest, requires_json, requires_jsonb +def function_meta_data( + func_name, schema_name='public', arg_names=None, arg_types=None, + arg_modes=None, return_type=None, is_aggregate=False, is_window=False, + is_set_returning=False, arg_defaults=None +): + return FunctionMetadata( + schema_name, func_name, arg_names, arg_types, arg_modes, return_type, + is_aggregate, is_window, is_set_returning, arg_defaults + ) + @dbtest def test_conn(executor): run(executor, '''create table test(a text)''') @@ -94,15 +104,32 @@ def test_functions_query(executor): funcs = set(executor.functions()) assert funcs >= set([ - FunctionMetadata('public', 'func1', None, [], [], - 'integer', False, False, False), - FunctionMetadata('public', 'func3', ['x', 'y'], - ['integer', 'integer'], ['t', 't'], 'record', False, False, True), - FunctionMetadata('public', 'func4', ('x',), ('integer',), [], - 'integer', False, False, True), - FunctionMetadata('schema1', 'func2', None, [], [], - 'integer', False, False, False), - ]) + function_meta_data( + func_name='func1', + return_type='integer' + ), + function_meta_data( + func_name='func3', + arg_names=['x', 'y'], + arg_types=['integer', 'integer'], + arg_modes=['t', 't'], + return_type='record', + is_set_returning=True + ), + function_meta_data( + schema_name='public', + func_name='func4', + arg_names=('x',), + arg_types=('integer',), + return_type='integer', + is_set_returning=True + ), + function_meta_data( + schema_name='schema1', + func_name='func2', + return_type='integer' + ), + ]) @dbtest diff --git a/tests/test_smart_completion_multiple_schemata.py b/tests/test_smart_completion_multiple_schemata.py index 89e2f568..20fd6ca9 100644 --- a/tests/test_smart_completion_multiple_schemata.py +++ b/tests/test_smart_completion_multiple_schemata.py @@ -512,14 +512,20 @@ def test_join_alias_search_without_aliases2(completer): def test_function_alias_search_without_aliases(completer): text = 'SELECT blog.ees' result = get_result(completer, text) - assert result[0] == function('extract_entry_symbols()', -3) + first = result[0] + assert first.start_position == -3 + assert first.text == 'extract_entry_symbols()' + assert first.display == 'extract_entry_symbols(_entryid)' @parametrize('completer', completers()) def test_function_alias_search_with_aliases(completer): text = 'SELECT blog.ee' result = get_result(completer, text) - assert result[0] == function('enter_entry()', -2) + first = result[0] + assert first.start_position == -2 + assert first.text == 'enter_entry(_title := , _text := )' + assert first.display == 'enter_entry(_title, _text)' @parametrize('completer',completers(filtr=True, casing=True, qualify=no_qual)) diff --git a/tests/test_smart_completion_public_schema_only.py b/tests/test_smart_completion_public_schema_only.py index b87e0dc2..2c1751a7 100644 --- a/tests/test_smart_completion_public_schema_only.py +++ b/tests/test_smart_completion_public_schema_only.py @@ -12,14 +12,16 @@ metadata = { 'orders': ['id', 'ordered_date', 'status', 'email'], 'select': ['id', 'insert', 'ABC']}, 'views': { - 'user_emails': ['id', 'email']}, + 'user_emails': ['id', 'email'], + 'functions': ['function'], + }, 'functions': [ - ['custom_fun', [''], [''], [''], '', False, False, False], - ['_custom_fun', [''], [''], [''], '', False, False, False], - ['custom_func1', [''], [''], [''], '', False, False, False], - ['custom_func2', [''], [''], [''], '', False, False, False], + ['custom_fun', [], [], [], '', False, False, False], + ['_custom_fun', [], [], [], '', False, False, False], + ['custom_func1', [], [], [], '', False, False, False], + ['custom_func2', [], [], [], '', False, False, False], ['set_returning_func', ['x', 'y'], ['integer', 'integer'], - ['o', 'o'], '', False, False, True]], + ['b', 'b'], '', False, False, True]], 'datatypes': ['custom_type1', 'custom_type2'], 'foreignkeys': [ ('public', 'users', 'id', 'public', 'users', 'parentid'), @@ -33,28 +35,66 @@ testdata = MetaData(metadata) cased_users_col_names = ['ID', 'PARENTID', 'Email', 'First_Name', 'last_name'] cased_users2_col_names = ['UserID', 'UserName'] -cased_funcs = ['Custom_Fun', '_custom_fun', 'Custom_Func1', - 'custom_func2', 'set_returning_func'] +cased_funcs = [ + 'Custom_Fun', '_custom_fun', 'Custom_Func1', 'custom_func2', 'set_returning_func' +] cased_tbls = ['Users', 'Orders'] -cased_views = ['User_Emails'] -casing = (['SELECT', 'PUBLIC'] + cased_funcs + cased_tbls + cased_views - + cased_users_col_names + cased_users2_col_names) +cased_views = ['User_Emails', 'Functions'] +casing = ( + ['SELECT', 'PUBLIC'] + cased_funcs + cased_tbls + cased_views + + cased_users_col_names + cased_users2_col_names +) # Lists for use in assertions -cased_funcs = [function(f + '()') for f in cased_funcs] +cased_funcs = [ + function(f) for f in ('Custom_Fun()', '_custom_fun() |