summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--pgcli/packages/function_metadata.py146
-rw-r--r--pgcli/pgexecute.py7
-rw-r--r--tests/test_function_metadata.py57
-rw-r--r--tests/test_pgexecute.py3
-rw-r--r--tests/test_smart_completion_multiple_schemata.py2
-rw-r--r--tests/test_smart_completion_public_schema_only.py2
6 files changed, 208 insertions, 9 deletions
diff --git a/pgcli/packages/function_metadata.py b/pgcli/packages/function_metadata.py
new file mode 100644
index 00000000..ea560603
--- /dev/null
+++ b/pgcli/packages/function_metadata.py
@@ -0,0 +1,146 @@
+import re
+import sqlparse
+from sqlparse.tokens import Whitespace, Comment, Keyword, Name, Punctuation
+
+
+table_def_regex = re.compile(r'^TABLE\s*\((.+)\)$', re.IGNORECASE)
+
+
+class FunctionMetadata(object):
+
+ def __init__(self, schema_name, func_name, arg_list, return_type, is_aggregate,
+ is_window, is_set_returning):
+ """Class for describing a postgresql function"""
+
+ self.schema_name = schema_name
+ self.func_name = func_name
+ self.arg_list = arg_list.strip()
+ self.return_type = return_type.strip()
+ self.is_aggregate = is_aggregate
+ self.is_window = is_window
+ self.is_set_returning = is_set_returning
+
+ def __eq__(self, other):
+ return (isinstance(other, self.__class__)
+ and self.__dict__ == other.__dict__)
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __hash__(self):
+ return hash((self.schema_name, self.func_name, self.arg_list,
+ self.return_type, self.is_aggregate, self.is_window,
+ self.is_set_returning))
+
+ def __repr__(self):
+ return (('%s(schema_name=%r, func_name=%r, arg_list=%r, return_type=%r,'
+ ' is_aggregate=%r, is_window=%r, is_set_returning=%r)')
+ % (self.__class__.__name__, self.schema_name, self.func_name,
+ self.arg_list, self.return_type, self.is_aggregate,
+ self.is_window, self.is_set_returning))
+
+ def fieldnames(self):
+ """Returns a list of output field names"""
+
+ if self.return_type.lower() == 'void':
+ return []
+
+ match = table_def_regex.match(self.return_type)
+ if match:
+ # Function returns a table -- get the column names
+ return list(field_names(match.group(1), mode_filter=None))
+
+ # Function may have named output arguments -- find them and return
+ # their names
+ return list(field_names(self.arg_list, mode_filter=('OUT', 'INOUT')))
+
+
+class TypedFieldMetadata(object):
+ """Describes typed field from a function signature or table definition
+
+ Attributes are:
+ name The name of the argument/column
+ mode 'IN', 'OUT', 'INOUT', 'VARIADIC'
+ type A list of tokens denoting the type
+ default A list of tokens denoting the default value
+ unknown A list of tokens not assigned to type or default
+ """
+
+ __slots__ = ['name', 'mode', 'type', 'default', 'unknown']
+
+ def __init__(self):
+ self.name = None
+ self.mode = 'IN'
+ self.type = []
+ self.default = []
+ self.unknown = []
+
+
+def parse_typed_field_list(tokens):
+ """Parses a argument/column list, yielding TypedFieldMetadata objects
+
+ Field/column lists are used in function signatures and table
+ definitions. This function parses a flattened list of sqlparse tokens
+ and yields one metadata argument per argument / column.
+ """
+
+ # postgres function argument list syntax:
+ # " ( [ [ argmode ] [ argname ] argtype
+ # [ { DEFAULT | = } default_expr ] [, ...] ] )"
+
+ mode_names = set(('IN', 'OUT', 'INOUT', 'VARIADIC'))
+ parse_state = 'type'
+ parens = 0
+ field = TypedFieldMetadata()
+
+ for tok in tokens:
+ if tok.ttype in Whitespace or tok.ttype in Comment:
+ continue
+ elif tok.ttype in Punctuation:
+ if parens == 0 and tok.value == ',':
+ # End of the current field specification
+ if field.type:
+ yield field
+ # Initialize metadata holder for the next field
+ field, parse_state = TypedFieldMetadata(), 'type'
+ elif parens == 0 and tok.value == '=':
+ parse_state = 'default'
+ else:
+ getattr(field, parse_state).append(tok)
+ if tok.value == '(':
+ parens += 1
+ elif tok.value == ')':
+ parens -= 1
+ elif parens == 0:
+ if tok.ttype in Keyword:
+ if not field.name and tok.value.upper() in mode_names:
+ # No other keywords allowed before arg name
+ field.mode = tok.value.upper()
+ elif tok.value.upper() == 'DEFAULT':
+ parse_state = 'default'
+ else:
+ parse_state = 'unknown'
+ elif tok.ttype == Name and not field.name:
+ # note that `ttype in Name` would also match Name.Builtin
+ field.name = tok.value
+ else:
+ getattr(field, parse_state).append(tok)
+ else:
+ getattr(field, parse_state).append(tok)
+
+ # Final argument won't be followed by a comma, so make sure it gets yielded
+ if field.type:
+ yield field
+
+
+def field_names(sql, mode_filter=('IN', 'OUT', 'INOUT', 'VARIADIC')):
+ """Yields field names from a table declaration"""
+ # sql is something like "x int, y text, ..."
+ tokens = sqlparse.parse(sql)[0].flatten()
+ for f in parse_typed_field_list(tokens):
+ if f.name and (not mode_filter or f.mode in mode_filter):
+ yield f.name
+
+
+
+
diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py
index 24611bd7..2b7d1eea 100644
--- a/pgcli/pgexecute.py
+++ b/pgcli/pgexecute.py
@@ -4,8 +4,8 @@ import psycopg2
import psycopg2.extras
import psycopg2.extensions as ext
import sqlparse
-from collections import namedtuple
import pgspecial as special
+from .packages.function_metadata import FunctionMetadata
from .encodingutils import unicode2utf8, PY2, utf8tounicode
import click
@@ -26,11 +26,6 @@ ext.register_type(ext.new_type((17,), 'BYTEA_TEXT', psycopg2.STRING))
# See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/
ext.set_wait_callback(psycopg2.extras.wait_select)
-FunctionMetadata = namedtuple('FunctionMetadata',
- ['schema_name', 'func_name', 'arg_list',
- 'return_type', 'is_aggregate', 'is_window',
- 'is_set_returning'])
-
def register_json_typecasters(conn, loads_fn):
"""Set the function for converting JSON data for a connection.
diff --git a/tests/test_function_metadata.py b/tests/test_function_metadata.py
new file mode 100644
index 00000000..ff722cfb
--- /dev/null
+++ b/tests/test_function_metadata.py
@@ -0,0 +1,57 @@
+import sqlparse
+from pgcli.packages.function_metadata import (
+ FunctionMetadata, parse_typed_field_list, field_names)
+
+
+def test_function_metadata_eq():
+ f1 = FunctionMetadata('s', 'f', 'x int', 'int', False, False, False)
+ f2 = FunctionMetadata('s', 'f', 'x int', 'int', False, False, False)
+ f3 = FunctionMetadata('s', 'g', 'x int', 'int', False, False, False)
+ assert f1 == f2
+ assert f1 != f3
+ assert not (f1 != f2)
+ assert not (f1 == f3)
+ assert hash(f1) == hash(f2)
+ assert hash(f1) != hash(f3)
+
+def test_parse_typed_field_list_simple():
+ sql = 'a int, b int[][], c double precision, d text'
+ tokens = sqlparse.parse(sql)[0].flatten()
+ args = list(parse_typed_field_list(tokens))
+ assert [arg.name for arg in args] == ['a', 'b', 'c', 'd']
+
+
+def test_parse_typed_field_list_more_complex():
+ sql = ''' IN a int = 5,
+ IN b text default 'abc'::text,
+ IN c double precision = 9.99",
+ OUT d double precision[] '''
+ tokens = sqlparse.parse(sql)[0].flatten()
+ args = list(parse_typed_field_list(tokens))
+ assert [arg.name for arg in args] == ['a', 'b', 'c', 'd']
+ assert [arg.mode for arg in args] == ['IN', 'IN', 'IN', 'OUT']
+
+
+def test_parse_typed_field_list_no_arg_names():
+ #waiting on sqlparse/169
+ sql = 'int, double precision, text'
+ tokens = sqlparse.parse(sql)[0].flatten()
+ args = list(parse_typed_field_list(tokens))
+ assert(len(args) == 3)
+
+
+def test_table_column_names():
+ tbl_str = '''
+ x INT,
+ y DOUBLE PRECISION,
+ z TEXT '''
+ names = list(field_names(tbl_str, mode_filter=None))
+ assert names == ['x', 'y', 'z']
+
+
+def test_argument_names():
+ func_header = 'IN x INT DEFAULT 2, OUT y DOUBLE PRECISION'
+ names = field_names(func_header, mode_filter=['OUT', 'INOUT'])
+ assert list(names) == ['y']
+
+
diff --git a/tests/test_pgexecute.py b/tests/test_pgexecute.py
index 7be1d0b4..5d14eae8 100644
--- a/tests/test_pgexecute.py
+++ b/tests/test_pgexecute.py
@@ -2,9 +2,10 @@
import pytest
from pgspecial.main import PGSpecial
+from pgcli.packages.function_metadata import FunctionMetadata
from textwrap import dedent
from utils import run, dbtest, requires_json, requires_jsonb
-from pgcli.pgexecute import FunctionMetadata
+
@dbtest
def test_conn(executor):
diff --git a/tests/test_smart_completion_multiple_schemata.py b/tests/test_smart_completion_multiple_schemata.py
index ea6c3381..6597f6d0 100644
--- a/tests/test_smart_completion_multiple_schemata.py
+++ b/tests/test_smart_completion_multiple_schemata.py
@@ -2,7 +2,7 @@ from __future__ import unicode_literals
import pytest
from prompt_toolkit.completion import Completion
from prompt_toolkit.document import Document
-from pgcli.pgexecute import FunctionMetadata
+from pgcli.packages.function_metadata import FunctionMetadata
metadata = {
'tables': {
diff --git a/tests/test_smart_completion_public_schema_only.py b/tests/test_smart_completion_public_schema_only.py
index 98e8f088..cbf7d990 100644
--- a/tests/test_smart_completion_public_schema_only.py
+++ b/tests/test_smart_completion_public_schema_only.py
@@ -2,7 +2,7 @@ from __future__ import unicode_literals
import pytest
from prompt_toolkit.completion import Completion
from prompt_toolkit.document import Document
-from pgcli.pgexecute import FunctionMetadata
+from pgcli.packages.function_metadata import FunctionMetadata
metadata = {
'tables': {