summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorkoljonen <koljonen@outlook.com>2016-05-25 13:01:57 +0200
committerkoljonen <koljonen@outlook.com>2016-06-02 01:44:31 +0200
commit5d9dfcdcc3e6049942e35f9e3bfb831aace5f824 (patch)
tree47e9f7566048dd25d7d5bdae82d6d3a803c8c7b4
parentf912633d6ded6ed563f8b8ba88525839afa20031 (diff)
Use pg_proc.proargmodes &c. instead of parsing arg_list
Getting the parameters from proargnames, proallargtypes and proargmodes instead of from parsing the arg_list string simplifies FunctionMetadata quite a bit. I also made the ColumnMetadata for table/view columns use the same format for the type (i.e. regtype instead of typname). This means we now get join-condition suggestions for joining tables/views to functions, which didn't work before.
-rw-r--r--pgcli/packages/function_metadata.py137
-rw-r--r--pgcli/pgcompleter.py3
-rw-r--r--pgcli/pgexecute.py25
-rw-r--r--tests/test_function_metadata.py59
-rw-r--r--tests/test_pgexecute.py14
-rw-r--r--tests/test_smart_completion_multiple_schemata.py16
-rw-r--r--tests/test_smart_completion_public_schema_only.py10
7 files changed, 60 insertions, 204 deletions
diff --git a/pgcli/packages/function_metadata.py b/pgcli/packages/function_metadata.py
index a434b67d..b183229b 100644
--- a/pgcli/packages/function_metadata.py
+++ b/pgcli/packages/function_metadata.py
@@ -1,22 +1,22 @@
-import re
-import sqlparse
from collections import namedtuple
-from sqlparse.tokens import Whitespace, Comment, Keyword, Name, Punctuation
-table_def_regex = re.compile(r'^TABLE\s*\((.+)\)$', re.IGNORECASE)
ColumnMetadata = namedtuple('ColumnMetadata', ['name', 'datatype', 'foreignkeys'])
ForeignKey = namedtuple('ForeignKey', ['parentschema', 'parenttable',
'parentcolumn', 'childschema', 'childtable', 'childcolumn'])
+TypedFieldMetadata = namedtuple('TypedFieldMetadata', ['name', 'mode', 'type'])
+
class FunctionMetadata(object):
- def __init__(self, schema_name, func_name, arg_list, return_type, is_aggregate,
- is_window, is_set_returning):
+ def __init__(self, schema_name, func_name, arg_names, arg_types,
+ arg_modes, return_type, is_aggregate, is_window, is_set_returning):
"""Class for describing a postgresql function"""
self.schema_name = schema_name
self.func_name = func_name
- self.arg_list = arg_list.strip()
+ self.arg_names = tuple(arg_names) if arg_names else None
+ self.arg_types = tuple(arg_types)
+ self.arg_modes = tuple(arg_modes) if arg_modes else None
self.return_type = return_type.strip()
self.is_aggregate = is_aggregate
self.is_window = is_window
@@ -30,122 +30,27 @@ class FunctionMetadata(object):
return not self.__eq__(other)
def __hash__(self):
- return hash((self.schema_name, self.func_name, self.arg_list,
- self.return_type, self.is_aggregate, self.is_window,
- self.is_set_returning))
+ return hash((self.schema_name, self.func_name, self.arg_names,
+ self.arg_types, self.arg_modes, self.return_type,
+ self.is_aggregate, self.is_window, self.is_set_returning))
def __repr__(self):
- return (('%s(schema_name=%r, func_name=%r, arg_list=%r, return_type=%r,'
- ' is_aggregate=%r, is_window=%r, is_set_returning=%r)')
+ return (('%s(schema_name=%r, func_name=%r, arg_names=%r, '
+ 'arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, '
+ 'is_window=%r, is_set_returning=%r)')
% (self.__class__.__name__, self.schema_name, self.func_name,
- self.arg_list, self.return_type, self.is_aggregate,
- self.is_window, self.is_set_returning))
+ self.arg_names, self.arg_types, self.arg_modes,
+ self.return_type, self.is_aggregate, self.is_window,
+ self.is_set_returning))
def fields(self):
"""Returns a list of output-field ColumnMetadata namedtuples"""
if self.return_type.lower() == 'void':
return []
+ elif not self.arg_modes:
+ return [ColumnMetadata(self.func_name, self.return_type, [])]
- match = table_def_regex.match(self.return_type)
- if match:
- # Function returns a table -- get the column names
- return list(fields(match.group(1), mode_filter=None))
-
- # Function may have named output arguments -- find them and return
- # their names
- return list(fields(self.arg_list, mode_filter=('OUT', 'INOUT')))
-
-
-class TypedFieldMetadata(object):
- """Describes typed field from a function signature or table definition
-
- Attributes are:
- name The name of the argument/column
- mode 'IN', 'OUT', 'INOUT', 'VARIADIC'
- type A list of tokens denoting the type
- default A list of tokens denoting the default value
- unknown A list of tokens not assigned to type or default
- """
-
- __slots__ = ['name', 'mode', 'type', 'default', 'unknown']
-
- def __init__(self):
- self.name = None
- self.mode = 'IN'
- self.type = []
- self.default = []
- self.unknown = []
-
- def __getitem__(self, attr):
- return getattr(self, attr)
-
-
-def parse_typed_field_list(tokens):
- """Parses a argument/column list, yielding TypedFieldMetadata objects
-
- Field/column lists are used in function signatures and table
- definitions. This function parses a flattened list of sqlparse tokens
- and yields one metadata argument per argument / column.
- """
-
- # postgres function argument list syntax:
- # " ( [ [ argmode ] [ argname ] argtype
- # [ { DEFAULT | = } default_expr ] [, ...] ] )"
-
- mode_names = set(('IN', 'OUT', 'INOUT', 'VARIADIC'))
- parse_state = 'type'
- parens = 0
- field = TypedFieldMetadata()
-
- for tok in tokens:
- if tok.ttype in Whitespace or tok.ttype in Comment:
- continue
- elif tok.ttype in Punctuation:
- if parens == 0 and tok.value == ',':
- # End of the current field specification
- if field.type:
- yield field
- # Initialize metadata holder for the next field
- field, parse_state = TypedFieldMetadata(), 'type'
- elif parens == 0 and tok.value == '=':
- parse_state = 'default'
- else:
- field[parse_state].append(tok)
- if tok.value == '(':
- parens += 1
- elif tok.value == ')':
- parens -= 1
- elif parens == 0:
- if tok.ttype in Keyword:
- if not field.name and tok.value.upper() in mode_names:
- # No other keywords allowed before arg name
- field.mode = tok.value.upper()
- elif tok.value.upper() == 'DEFAULT':
- parse_state = 'default'
- else:
- parse_state = 'unknown'
- elif tok.ttype == Name and not field.name:
- # note that `ttype in Name` would also match Name.Builtin
- field.name = tok.value
- else:
- field[parse_state].append(tok)
- else:
- field[parse_state].append(tok)
-
- # Final argument won't be followed by a comma, so make sure it gets yielded
- if field.type:
- yield field
-
-
-def fields(sql, mode_filter=('IN', 'OUT', 'INOUT', 'VARIADIC')):
- """Yields ColumnMetaData namedtuples from a table declaration"""
-
- if not sql:
- return
-
- # sql is something like "x int, y text, ..."
- tokens = sqlparse.parse(sql)[0].flatten()
- for f in parse_typed_field_list(tokens):
- if f.name and (not mode_filter or f.mode in mode_filter):
- yield ColumnMetadata(f.name, None, [])
+ return [ColumnMetadata(n, t, [])
+ for n, t, m in zip(self.arg_names, self.arg_types, self.arg_modes)
+ if m in ('o', 'b', 't')]
diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py
index 8de6fb7f..153b46ef 100644
--- a/pgcli/pgcompleter.py
+++ b/pgcli/pgcompleter.py
@@ -403,7 +403,8 @@ class PGCompleter(Completer):
if rtbl.ref != lefttable.ref:
cond = make_cond(lefttable.ref, rtbl.ref, c.name, c.name)
if cond not in found:
- prio = (1000 if c.datatype and c.datatype[:3] == 'int'
+ prio = (1000 if c.datatype and c.datatype in(
+ 'integer', 'bigint', 'smallint')
else 0 + refprio[rtbl.ref])
conds.append((cond, 'name join', prio))
diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py
index 15e5c8c9..7c8cd2a8 100644
--- a/pgcli/pgexecute.py
+++ b/pgcli/pgexecute.py
@@ -117,14 +117,12 @@ class PGExecute(object):
SELECT nsp.nspname schema_name,
cls.relname table_name,
att.attname column_name,
- typ.typname type_name
+ att.atttypid::regtype::text type_name
FROM pg_catalog.pg_attribute att
INNER JOIN pg_catalog.pg_class cls
ON att.attrelid = cls.oid
INNER JOIN pg_catalog.pg_namespace nsp
ON cls.relnamespace = nsp.oid
- INNER JOIN pg_catalog.pg_type typ
- ON typ.oid = att.atttypid
WHERE cls.relkind = ANY(%s)
AND NOT att.attisdropped
AND att.attnum > 0
@@ -424,8 +422,10 @@ class PGExecute(object):
query = '''
SELECT n.nspname schema_name,
p.proname func_name,
- pg_catalog.pg_get_function_arguments(p.oid) arg_list,
- pg_catalog.pg_get_function_result(p.oid) return_type,
+ p.proargnames,
+ COALESCE(proallargtypes::regtype[], proargtypes::regtype[])::text[],
+ p.proargmodes,
+ prorettype::regtype::text return_type,
p.proisagg is_aggregate,
p.proiswindow is_window,
p.proretset is_set_returning
@@ -443,28 +443,21 @@ class PGExecute(object):
SELECT n.nspname schema_name,
p.proname func_name,
p.proargnames,
- oidvectortypes(p.proargtypes) proargtypes,
- t.typname return_type,
+ COALESCE(proallargtypes::regtype[], proargtypes::regtype[])::text[],
+ p.proargmodes,
+ prorettype::regtype::text,
p.proisagg is_aggregate,
false is_window,
p.proretset is_set_returning
FROM pg_catalog.pg_proc p
INNER JOIN pg_catalog.pg_namespace n
ON n.oid = p.pronamespace
- INNER JOIN pg_catalog.pg_type t
- ON p.prorettype = t.oid
ORDER BY 1, 2
'''
_logger.debug('Functions Query. sql: %r', query)
cur.execute(query)
for row in cur:
- names = row[2] if row[2] is not None else []
- args = itertools.izip_longest(names, row[3].split(', '), '')
- _logger.debug(list(args))
- typed_args = [f[0] + ', ' + f[1] for f in args]
- arg_list = ', '.join(typed_args)
- _logger.debug(arg_list)
- yield FunctionMetadata(row[0], row[1], arg_list, row[4], row[5], row[6], row[7])
+ yield FunctionMetadata(*row)
def datatypes(self):
"""Yields tuples of (schema_name, type_name)"""
diff --git a/tests/test_function_metadata.py b/tests/test_function_metadata.py
index 59052cc0..f7669e6f 100644
--- a/tests/test_function_metadata.py
+++ b/tests/test_function_metadata.py
@@ -1,61 +1,16 @@
-import sqlparse
-from pgcli.packages.function_metadata import (
- FunctionMetadata, parse_typed_field_list, fields)
-from pgcli.pgcompleter import ColumnMetadata
+from pgcli.packages.function_metadata import FunctionMetadata
def test_function_metadata_eq():
- f1 = FunctionMetadata('s', 'f', 'x int', 'int', False, False, False)
- f2 = FunctionMetadata('s', 'f', 'x int', 'int', False, False, False)
- f3 = FunctionMetadata('s', 'g', 'x int', 'int', False, False, False)
+ f1 = FunctionMetadata('s', 'f', ['x'], ['integer'], [], 'int', False,
+ False, False)
+ f2 = FunctionMetadata('s', 'f', ['x'], ['integer'], [], 'int', False,
+ False, False)
+ f3 = FunctionMetadata('s', 'g', ['x'], ['integer'], [], 'int', False,
+ False, False)
assert f1 == f2
assert f1 != f3
assert not (f1 != f2)
assert not (f1 == f3)
assert hash(f1) == hash(f2)
assert hash(f1) != hash(f3)
-
-def test_parse_typed_field_list_simple():
- sql = 'a int, b int[][], c double precision, d text'
- tokens = sqlparse.parse(sql)[0].flatten()
- args = list(parse_typed_field_list(tokens))
- assert [arg.name for arg in args] == ['a', 'b', 'c', 'd']
-
-
-def test_parse_typed_field_list_more_complex():
- sql = ''' IN a int = 5,
- IN b text default 'abc'::text,
- IN c double precision = 9.99",
- OUT d double precision[] '''
- tokens = sqlparse.parse(sql)[0].flatten()
- args = list(parse_typed_field_list(tokens))
- assert [arg.name for arg in args] == ['a', 'b', 'c', 'd']
- assert [arg.mode for arg in args] == ['IN', 'IN', 'IN', 'OUT']
-
-
-def test_parse_typed_field_list_no_arg_names():
- sql = 'int, double precision, text'
- tokens = sqlparse.parse(sql)[0].flatten()
- args = list(parse_typed_field_list(tokens))
- assert(len(args) == 3)
-
-
-def test_table_column_names():
- tbl_str = '''
- x INT,
- y DOUBLE PRECISION,
- z TEXT '''
- fs = list(fields(tbl_str, mode_filter=None))
- assert fs == [ColumnMetadata(x, None, []) for x in('x' ,'y', 'z')]
-
-
-def test_argument_names():
- func_header = 'IN x INT DEFAULT 2, OUT y DOUBLE PRECISION'
- fs = fields(func_header, mode_filter=['OUT', 'INOUT'])
- assert list(fs) == [ColumnMetadata('y', None, [])]
-
-
-def test_empty_arg_list():
- # happens for e.g. parameter-less functions like now()
- fs = fields('')
- assert list(fs) == []
diff --git a/tests/test_pgexecute.py b/tests/test_pgexecute.py
index c7ed1885..2614b859 100644
--- a/tests/test_pgexecute.py
+++ b/tests/test_pgexecute.py
@@ -61,7 +61,7 @@ def test_schemata_table_views_and_columns_query(executor):
('public', 'd')])
assert set(executor.view_columns()) >= set([
- ('public', 'd', 'e', 'int4')])
+ ('public', 'd', 'e', 'integer')])
@dbtest
def test_foreign_key_query(executor):
@@ -90,13 +90,13 @@ def test_functions_query(executor):
funcs = set(executor.functions())
assert funcs >= set([
- FunctionMetadata('public', 'func1', '',
+ FunctionMetadata('public', 'func1', None, [], [],
'integer', False, False, False),
- FunctionMetadata('public', 'func3', '',
- 'TABLE(x integer, y integer)', False, False, True),
- FunctionMetadata('public', 'func4', 'x integer',
- 'SETOF integer', False, False, True),
- FunctionMetadata('schema1', 'func2', '',
+ FunctionMetadata('public', 'func3', ['x', 'y'],
+ ['integer', 'integer'], ['t', 't'], 'record', False, False, True),
+ FunctionMetadata('public', 'func4', ('x',), ('integer',), [],
+ 'integer', False, False, True),
+ FunctionMetadata('schema1', 'func2', None, [], [],
'integer', False, False, False),
])
diff --git a/tests/test_smart_completion_multiple_schemata.py b/tests/test_smart_completion_multiple_schemata.py
index 60781ffe..c107a710 100644
--- a/tests/test_smart_completion_multiple_schemata.py
+++ b/tests/test_smart_completion_multiple_schemata.py
@@ -22,16 +22,16 @@ metadata = {
}},
'functions': {
'public': [
- ['func1', '', '', False, False, False],
- ['func2', '', '', False, False, False]],
+ ['func1', [], [], [], '', False, False,
+ False],
+ ['func2', [], [], [], '', False, False,
+ False]],
'custom': [
- ['func3', '', '', False, False, False],
- ['set_returning_func',
- 'OUT x INT', 'SETOF INT',
- False, False, True]],
+ ['func3', [], [], [], '', False, False, False],
+ ['set_returning_func', ['x'], ['integer'], ['o'],
+ 'integer', False, False, True]],
'Custom': [
- ['func4', '', '', False, False, False]
- ]
+ ['func4', [], [], [], '', False, False, False]]
},
'datatypes': {
'public': ['typ1', 'typ2'],
diff --git a/tests/test_smart_completion_public_schema_only.py b/tests/test_smart_completion_public_schema_only.py
index 9f78c49f..4e2048c7 100644
--- a/tests/test_smart_completion_public_schema_only.py
+++ b/tests/test_smart_completion_public_schema_only.py
@@ -13,10 +13,12 @@ metadata = {
'views': {
'user_emails': ['id', 'email']},
'functions': [
- ['custom_func1', '', '', False, False, False],
- ['custom_func2', '', '', False, False, False],
- ['set_returning_func', '', 'TABLE (x INT, y INT)',
- False, False, True]],
+ ['custom_func1', [''], [''], [''], '', False, False,
+ False],
+ ['custom_func2', [''], [''], [''], '', False, False,
+ False],
+ ['set_returning_func', ['x', 'y'], ['integer', 'integer'],
+ ['o', 'o'], '', False, False, True]],
'datatypes': ['custom_type1', 'custom_type2'],
}