diff options
author | koljonen <koljonen@outlook.com> | 2016-05-25 13:01:57 +0200 |
---|---|---|
committer | koljonen <koljonen@outlook.com> | 2016-06-02 01:44:31 +0200 |
commit | 5d9dfcdcc3e6049942e35f9e3bfb831aace5f824 (patch) | |
tree | 47e9f7566048dd25d7d5bdae82d6d3a803c8c7b4 | |
parent | f912633d6ded6ed563f8b8ba88525839afa20031 (diff) |
Use pg_proc.proargmodes &c. instead of parsing arg_list
Getting the parameters from proargnames, proallargtypes and proargmodes instead of from parsing the arg_list string simplifies FunctionMetadata quite a bit.
I also made the ColumnMetadata for table/view columns use the same format for the type (i.e. regtype instead of typname). This means we now get join-condition suggestions for joining tables/views to functions, which didn't work before.
-rw-r--r-- | pgcli/packages/function_metadata.py | 137 | ||||
-rw-r--r-- | pgcli/pgcompleter.py | 3 | ||||
-rw-r--r-- | pgcli/pgexecute.py | 25 | ||||
-rw-r--r-- | tests/test_function_metadata.py | 59 | ||||
-rw-r--r-- | tests/test_pgexecute.py | 14 | ||||
-rw-r--r-- | tests/test_smart_completion_multiple_schemata.py | 16 | ||||
-rw-r--r-- | tests/test_smart_completion_public_schema_only.py | 10 |
7 files changed, 60 insertions, 204 deletions
diff --git a/pgcli/packages/function_metadata.py b/pgcli/packages/function_metadata.py index a434b67d..b183229b 100644 --- a/pgcli/packages/function_metadata.py +++ b/pgcli/packages/function_metadata.py @@ -1,22 +1,22 @@ -import re -import sqlparse from collections import namedtuple -from sqlparse.tokens import Whitespace, Comment, Keyword, Name, Punctuation -table_def_regex = re.compile(r'^TABLE\s*\((.+)\)$', re.IGNORECASE) ColumnMetadata = namedtuple('ColumnMetadata', ['name', 'datatype', 'foreignkeys']) ForeignKey = namedtuple('ForeignKey', ['parentschema', 'parenttable', 'parentcolumn', 'childschema', 'childtable', 'childcolumn']) +TypedFieldMetadata = namedtuple('TypedFieldMetadata', ['name', 'mode', 'type']) + class FunctionMetadata(object): - def __init__(self, schema_name, func_name, arg_list, return_type, is_aggregate, - is_window, is_set_returning): + def __init__(self, schema_name, func_name, arg_names, arg_types, + arg_modes, return_type, is_aggregate, is_window, is_set_returning): """Class for describing a postgresql function""" self.schema_name = schema_name self.func_name = func_name - self.arg_list = arg_list.strip() + self.arg_names = tuple(arg_names) if arg_names else None + self.arg_types = tuple(arg_types) + self.arg_modes = tuple(arg_modes) if arg_modes else None self.return_type = return_type.strip() self.is_aggregate = is_aggregate self.is_window = is_window @@ -30,122 +30,27 @@ class FunctionMetadata(object): return not self.__eq__(other) def __hash__(self): - return hash((self.schema_name, self.func_name, self.arg_list, - self.return_type, self.is_aggregate, self.is_window, - self.is_set_returning)) + return hash((self.schema_name, self.func_name, self.arg_names, + self.arg_types, self.arg_modes, self.return_type, + self.is_aggregate, self.is_window, self.is_set_returning)) def __repr__(self): - return (('%s(schema_name=%r, func_name=%r, arg_list=%r, return_type=%r,' - ' is_aggregate=%r, is_window=%r, is_set_returning=%r)') + return (('%s(schema_name=%r, func_name=%r, arg_names=%r, ' + 'arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, ' + 'is_window=%r, is_set_returning=%r)') % (self.__class__.__name__, self.schema_name, self.func_name, - self.arg_list, self.return_type, self.is_aggregate, - self.is_window, self.is_set_returning)) + self.arg_names, self.arg_types, self.arg_modes, + self.return_type, self.is_aggregate, self.is_window, + self.is_set_returning)) def fields(self): """Returns a list of output-field ColumnMetadata namedtuples""" if self.return_type.lower() == 'void': return [] + elif not self.arg_modes: + return [ColumnMetadata(self.func_name, self.return_type, [])] - match = table_def_regex.match(self.return_type) - if match: - # Function returns a table -- get the column names - return list(fields(match.group(1), mode_filter=None)) - - # Function may have named output arguments -- find them and return - # their names - return list(fields(self.arg_list, mode_filter=('OUT', 'INOUT'))) - - -class TypedFieldMetadata(object): - """Describes typed field from a function signature or table definition - - Attributes are: - name The name of the argument/column - mode 'IN', 'OUT', 'INOUT', 'VARIADIC' - type A list of tokens denoting the type - default A list of tokens denoting the default value - unknown A list of tokens not assigned to type or default - """ - - __slots__ = ['name', 'mode', 'type', 'default', 'unknown'] - - def __init__(self): - self.name = None - self.mode = 'IN' - self.type = [] - self.default = [] - self.unknown = [] - - def __getitem__(self, attr): - return getattr(self, attr) - - -def parse_typed_field_list(tokens): - """Parses a argument/column list, yielding TypedFieldMetadata objects - - Field/column lists are used in function signatures and table - definitions. This function parses a flattened list of sqlparse tokens - and yields one metadata argument per argument / column. - """ - - # postgres function argument list syntax: - # " ( [ [ argmode ] [ argname ] argtype - # [ { DEFAULT | = } default_expr ] [, ...] ] )" - - mode_names = set(('IN', 'OUT', 'INOUT', 'VARIADIC')) - parse_state = 'type' - parens = 0 - field = TypedFieldMetadata() - - for tok in tokens: - if tok.ttype in Whitespace or tok.ttype in Comment: - continue - elif tok.ttype in Punctuation: - if parens == 0 and tok.value == ',': - # End of the current field specification - if field.type: - yield field - # Initialize metadata holder for the next field - field, parse_state = TypedFieldMetadata(), 'type' - elif parens == 0 and tok.value == '=': - parse_state = 'default' - else: - field[parse_state].append(tok) - if tok.value == '(': - parens += 1 - elif tok.value == ')': - parens -= 1 - elif parens == 0: - if tok.ttype in Keyword: - if not field.name and tok.value.upper() in mode_names: - # No other keywords allowed before arg name - field.mode = tok.value.upper() - elif tok.value.upper() == 'DEFAULT': - parse_state = 'default' - else: - parse_state = 'unknown' - elif tok.ttype == Name and not field.name: - # note that `ttype in Name` would also match Name.Builtin - field.name = tok.value - else: - field[parse_state].append(tok) - else: - field[parse_state].append(tok) - - # Final argument won't be followed by a comma, so make sure it gets yielded - if field.type: - yield field - - -def fields(sql, mode_filter=('IN', 'OUT', 'INOUT', 'VARIADIC')): - """Yields ColumnMetaData namedtuples from a table declaration""" - - if not sql: - return - - # sql is something like "x int, y text, ..." - tokens = sqlparse.parse(sql)[0].flatten() - for f in parse_typed_field_list(tokens): - if f.name and (not mode_filter or f.mode in mode_filter): - yield ColumnMetadata(f.name, None, []) + return [ColumnMetadata(n, t, []) + for n, t, m in zip(self.arg_names, self.arg_types, self.arg_modes) + if m in ('o', 'b', 't')] diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py index 8de6fb7f..153b46ef 100644 --- a/pgcli/pgcompleter.py +++ b/pgcli/pgcompleter.py @@ -403,7 +403,8 @@ class PGCompleter(Completer): if rtbl.ref != lefttable.ref: cond = make_cond(lefttable.ref, rtbl.ref, c.name, c.name) if cond not in found: - prio = (1000 if c.datatype and c.datatype[:3] == 'int' + prio = (1000 if c.datatype and c.datatype in( + 'integer', 'bigint', 'smallint') else 0 + refprio[rtbl.ref]) conds.append((cond, 'name join', prio)) diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py index 15e5c8c9..7c8cd2a8 100644 --- a/pgcli/pgexecute.py +++ b/pgcli/pgexecute.py @@ -117,14 +117,12 @@ class PGExecute(object): SELECT nsp.nspname schema_name, cls.relname table_name, att.attname column_name, - typ.typname type_name + att.atttypid::regtype::text type_name FROM pg_catalog.pg_attribute att INNER JOIN pg_catalog.pg_class cls ON att.attrelid = cls.oid INNER JOIN pg_catalog.pg_namespace nsp ON cls.relnamespace = nsp.oid - INNER JOIN pg_catalog.pg_type typ - ON typ.oid = att.atttypid WHERE cls.relkind = ANY(%s) AND NOT att.attisdropped AND att.attnum > 0 @@ -424,8 +422,10 @@ class PGExecute(object): query = ''' SELECT n.nspname schema_name, p.proname func_name, - pg_catalog.pg_get_function_arguments(p.oid) arg_list, - pg_catalog.pg_get_function_result(p.oid) return_type, + p.proargnames, + COALESCE(proallargtypes::regtype[], proargtypes::regtype[])::text[], + p.proargmodes, + prorettype::regtype::text return_type, p.proisagg is_aggregate, p.proiswindow is_window, p.proretset is_set_returning @@ -443,28 +443,21 @@ class PGExecute(object): SELECT n.nspname schema_name, p.proname func_name, p.proargnames, - oidvectortypes(p.proargtypes) proargtypes, - t.typname return_type, + COALESCE(proallargtypes::regtype[], proargtypes::regtype[])::text[], + p.proargmodes, + prorettype::regtype::text, p.proisagg is_aggregate, false is_window, p.proretset is_set_returning FROM pg_catalog.pg_proc p INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace - INNER JOIN pg_catalog.pg_type t - ON p.prorettype = t.oid ORDER BY 1, 2 ''' _logger.debug('Functions Query. sql: %r', query) cur.execute(query) for row in cur: - names = row[2] if row[2] is not None else [] - args = itertools.izip_longest(names, row[3].split(', '), '') - _logger.debug(list(args)) - typed_args = [f[0] + ', ' + f[1] for f in args] - arg_list = ', '.join(typed_args) - _logger.debug(arg_list) - yield FunctionMetadata(row[0], row[1], arg_list, row[4], row[5], row[6], row[7]) + yield FunctionMetadata(*row) def datatypes(self): """Yields tuples of (schema_name, type_name)""" diff --git a/tests/test_function_metadata.py b/tests/test_function_metadata.py index 59052cc0..f7669e6f 100644 --- a/tests/test_function_metadata.py +++ b/tests/test_function_metadata.py @@ -1,61 +1,16 @@ -import sqlparse -from pgcli.packages.function_metadata import ( - FunctionMetadata, parse_typed_field_list, fields) -from pgcli.pgcompleter import ColumnMetadata +from pgcli.packages.function_metadata import FunctionMetadata def test_function_metadata_eq(): - f1 = FunctionMetadata('s', 'f', 'x int', 'int', False, False, False) - f2 = FunctionMetadata('s', 'f', 'x int', 'int', False, False, False) - f3 = FunctionMetadata('s', 'g', 'x int', 'int', False, False, False) + f1 = FunctionMetadata('s', 'f', ['x'], ['integer'], [], 'int', False, + False, False) + f2 = FunctionMetadata('s', 'f', ['x'], ['integer'], [], 'int', False, + False, False) + f3 = FunctionMetadata('s', 'g', ['x'], ['integer'], [], 'int', False, + False, False) assert f1 == f2 assert f1 != f3 assert not (f1 != f2) assert not (f1 == f3) assert hash(f1) == hash(f2) assert hash(f1) != hash(f3) - -def test_parse_typed_field_list_simple(): - sql = 'a int, b int[][], c double precision, d text' - tokens = sqlparse.parse(sql)[0].flatten() - args = list(parse_typed_field_list(tokens)) - assert [arg.name for arg in args] == ['a', 'b', 'c', 'd'] - - -def test_parse_typed_field_list_more_complex(): - sql = ''' IN a int = 5, - IN b text default 'abc'::text, - IN c double precision = 9.99", - OUT d double precision[] ''' - tokens = sqlparse.parse(sql)[0].flatten() - args = list(parse_typed_field_list(tokens)) - assert [arg.name for arg in args] == ['a', 'b', 'c', 'd'] - assert [arg.mode for arg in args] == ['IN', 'IN', 'IN', 'OUT'] - - -def test_parse_typed_field_list_no_arg_names(): - sql = 'int, double precision, text' - tokens = sqlparse.parse(sql)[0].flatten() - args = list(parse_typed_field_list(tokens)) - assert(len(args) == 3) - - -def test_table_column_names(): - tbl_str = ''' - x INT, - y DOUBLE PRECISION, - z TEXT ''' - fs = list(fields(tbl_str, mode_filter=None)) - assert fs == [ColumnMetadata(x, None, []) for x in('x' ,'y', 'z')] - - -def test_argument_names(): - func_header = 'IN x INT DEFAULT 2, OUT y DOUBLE PRECISION' - fs = fields(func_header, mode_filter=['OUT', 'INOUT']) - assert list(fs) == [ColumnMetadata('y', None, [])] - - -def test_empty_arg_list(): - # happens for e.g. parameter-less functions like now() - fs = fields('') - assert list(fs) == [] diff --git a/tests/test_pgexecute.py b/tests/test_pgexecute.py index c7ed1885..2614b859 100644 --- a/tests/test_pgexecute.py +++ b/tests/test_pgexecute.py @@ -61,7 +61,7 @@ def test_schemata_table_views_and_columns_query(executor): ('public', 'd')]) assert set(executor.view_columns()) >= set([ - ('public', 'd', 'e', 'int4')]) + ('public', 'd', 'e', 'integer')]) @dbtest def test_foreign_key_query(executor): @@ -90,13 +90,13 @@ def test_functions_query(executor): funcs = set(executor.functions()) assert funcs >= set([ - FunctionMetadata('public', 'func1', '', + FunctionMetadata('public', 'func1', None, [], [], 'integer', False, False, False), - FunctionMetadata('public', 'func3', '', - 'TABLE(x integer, y integer)', False, False, True), - FunctionMetadata('public', 'func4', 'x integer', - 'SETOF integer', False, False, True), - FunctionMetadata('schema1', 'func2', '', + FunctionMetadata('public', 'func3', ['x', 'y'], + ['integer', 'integer'], ['t', 't'], 'record', False, False, True), + FunctionMetadata('public', 'func4', ('x',), ('integer',), [], + 'integer', False, False, True), + FunctionMetadata('schema1', 'func2', None, [], [], 'integer', False, False, False), ]) diff --git a/tests/test_smart_completion_multiple_schemata.py b/tests/test_smart_completion_multiple_schemata.py index 60781ffe..c107a710 100644 --- a/tests/test_smart_completion_multiple_schemata.py +++ b/tests/test_smart_completion_multiple_schemata.py @@ -22,16 +22,16 @@ metadata = { }}, 'functions': { 'public': [ - ['func1', '', '', False, False, False], - ['func2', '', '', False, False, False]], + ['func1', [], [], [], '', False, False, + False], + ['func2', [], [], [], '', False, False, + False]], 'custom': [ - ['func3', '', '', False, False, False], - ['set_returning_func', - 'OUT x INT', 'SETOF INT', - False, False, True]], + ['func3', [], [], [], '', False, False, False], + ['set_returning_func', ['x'], ['integer'], ['o'], + 'integer', False, False, True]], 'Custom': [ - ['func4', '', '', False, False, False] - ] + ['func4', [], [], [], '', False, False, False]] }, 'datatypes': { 'public': ['typ1', 'typ2'], diff --git a/tests/test_smart_completion_public_schema_only.py b/tests/test_smart_completion_public_schema_only.py index 9f78c49f..4e2048c7 100644 --- a/tests/test_smart_completion_public_schema_only.py +++ b/tests/test_smart_completion_public_schema_only.py @@ -13,10 +13,12 @@ metadata = { 'views': { 'user_emails': ['id', 'email']}, 'functions': [ - ['custom_func1', '', '', False, False, False], - ['custom_func2', '', '', False, False, False], - ['set_returning_func', '', 'TABLE (x INT, y INT)', - False, False, True]], + ['custom_func1', [''], [''], [''], '', False, False, + False], + ['custom_func2', [''], [''], [''], '', False, False, + False], + ['set_returning_func', ['x', 'y'], ['integer', 'integer'], + ['o', 'o'], '', False, False, True]], 'datatypes': ['custom_type1', 'custom_type2'], } |