diff options
author | Darik Gamble <darik.gamble@gmail.com> | 2015-09-22 11:31:41 -0400 |
---|---|---|
committer | Darik Gamble <darik.gamble.spam@gmail.com> | 2015-09-29 09:07:20 -0400 |
commit | 07030e20d29f17a88d4bc5644a0e421f935af101 (patch) | |
tree | 4557e67f87d38f872973b94e2fd22df17f5cff62 | |
parent | a0266de19270d0331cfbe552b3b0072e331872ca (diff) |
Move FunctionMetadata definition into its own package
-rw-r--r-- | pgcli/packages/function_metadata.py | 146 | ||||
-rw-r--r-- | pgcli/pgexecute.py | 7 | ||||
-rw-r--r-- | tests/test_function_metadata.py | 57 | ||||
-rw-r--r-- | tests/test_pgexecute.py | 3 | ||||
-rw-r--r-- | tests/test_smart_completion_multiple_schemata.py | 2 | ||||
-rw-r--r-- | tests/test_smart_completion_public_schema_only.py | 2 |
6 files changed, 208 insertions, 9 deletions
diff --git a/pgcli/packages/function_metadata.py b/pgcli/packages/function_metadata.py new file mode 100644 index 00000000..ea560603 --- /dev/null +++ b/pgcli/packages/function_metadata.py @@ -0,0 +1,146 @@ +import re +import sqlparse +from sqlparse.tokens import Whitespace, Comment, Keyword, Name, Punctuation + + +table_def_regex = re.compile(r'^TABLE\s*\((.+)\)$', re.IGNORECASE) + + +class FunctionMetadata(object): + + def __init__(self, schema_name, func_name, arg_list, return_type, is_aggregate, + is_window, is_set_returning): + """Class for describing a postgresql function""" + + self.schema_name = schema_name + self.func_name = func_name + self.arg_list = arg_list.strip() + self.return_type = return_type.strip() + self.is_aggregate = is_aggregate + self.is_window = is_window + self.is_set_returning = is_set_returning + + def __eq__(self, other): + return (isinstance(other, self.__class__) + and self.__dict__ == other.__dict__) + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return hash((self.schema_name, self.func_name, self.arg_list, + self.return_type, self.is_aggregate, self.is_window, + self.is_set_returning)) + + def __repr__(self): + return (('%s(schema_name=%r, func_name=%r, arg_list=%r, return_type=%r,' + ' is_aggregate=%r, is_window=%r, is_set_returning=%r)') + % (self.__class__.__name__, self.schema_name, self.func_name, + self.arg_list, self.return_type, self.is_aggregate, + self.is_window, self.is_set_returning)) + + def fieldnames(self): + """Returns a list of output field names""" + + if self.return_type.lower() == 'void': + return [] + + match = table_def_regex.match(self.return_type) + if match: + # Function returns a table -- get the column names + return list(field_names(match.group(1), mode_filter=None)) + + # Function may have named output arguments -- find them and return + # their names + return list(field_names(self.arg_list, mode_filter=('OUT', 'INOUT'))) + + +class TypedFieldMetadata(object): + """Describes typed field from a function signature or table definition + + Attributes are: + name The name of the argument/column + mode 'IN', 'OUT', 'INOUT', 'VARIADIC' + type A list of tokens denoting the type + default A list of tokens denoting the default value + unknown A list of tokens not assigned to type or default + """ + + __slots__ = ['name', 'mode', 'type', 'default', 'unknown'] + + def __init__(self): + self.name = None + self.mode = 'IN' + self.type = [] + self.default = [] + self.unknown = [] + + +def parse_typed_field_list(tokens): + """Parses a argument/column list, yielding TypedFieldMetadata objects + + Field/column lists are used in function signatures and table + definitions. This function parses a flattened list of sqlparse tokens + and yields one metadata argument per argument / column. + """ + + # postgres function argument list syntax: + # " ( [ [ argmode ] [ argname ] argtype + # [ { DEFAULT | = } default_expr ] [, ...] ] )" + + mode_names = set(('IN', 'OUT', 'INOUT', 'VARIADIC')) + parse_state = 'type' + parens = 0 + field = TypedFieldMetadata() + + for tok in tokens: + if tok.ttype in Whitespace or tok.ttype in Comment: + continue + elif tok.ttype in Punctuation: + if parens == 0 and tok.value == ',': + # End of the current field specification + if field.type: + yield field + # Initialize metadata holder for the next field + field, parse_state = TypedFieldMetadata(), 'type' + elif parens == 0 and tok.value == '=': + parse_state = 'default' + else: + getattr(field, parse_state).append(tok) + if tok.value == '(': + parens += 1 + elif tok.value == ')': + parens -= 1 + elif parens == 0: + if tok.ttype in Keyword: + if not field.name and tok.value.upper() in mode_names: + # No other keywords allowed before arg name + field.mode = tok.value.upper() + elif tok.value.upper() == 'DEFAULT': + parse_state = 'default' + else: + parse_state = 'unknown' + elif tok.ttype == Name and not field.name: + # note that `ttype in Name` would also match Name.Builtin + field.name = tok.value + else: + getattr(field, parse_state).append(tok) + else: + getattr(field, parse_state).append(tok) + + # Final argument won't be followed by a comma, so make sure it gets yielded + if field.type: + yield field + + +def field_names(sql, mode_filter=('IN', 'OUT', 'INOUT', 'VARIADIC')): + """Yields field names from a table declaration""" + # sql is something like "x int, y text, ..." + tokens = sqlparse.parse(sql)[0].flatten() + for f in parse_typed_field_list(tokens): + if f.name and (not mode_filter or f.mode in mode_filter): + yield f.name + + + + diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py index 24611bd7..2b7d1eea 100644 --- a/pgcli/pgexecute.py +++ b/pgcli/pgexecute.py @@ -4,8 +4,8 @@ import psycopg2 import psycopg2.extras import psycopg2.extensions as ext import sqlparse -from collections import namedtuple import pgspecial as special +from .packages.function_metadata import FunctionMetadata from .encodingutils import unicode2utf8, PY2, utf8tounicode import click @@ -26,11 +26,6 @@ ext.register_type(ext.new_type((17,), 'BYTEA_TEXT', psycopg2.STRING)) # See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/ ext.set_wait_callback(psycopg2.extras.wait_select) -FunctionMetadata = namedtuple('FunctionMetadata', - ['schema_name', 'func_name', 'arg_list', - 'return_type', 'is_aggregate', 'is_window', - 'is_set_returning']) - def register_json_typecasters(conn, loads_fn): """Set the function for converting JSON data for a connection. diff --git a/tests/test_function_metadata.py b/tests/test_function_metadata.py new file mode 100644 index 00000000..ff722cfb --- /dev/null +++ b/tests/test_function_metadata.py @@ -0,0 +1,57 @@ +import sqlparse +from pgcli.packages.function_metadata import ( + FunctionMetadata, parse_typed_field_list, field_names) + + +def test_function_metadata_eq(): + f1 = FunctionMetadata('s', 'f', 'x int', 'int', False, False, False) + f2 = FunctionMetadata('s', 'f', 'x int', 'int', False, False, False) + f3 = FunctionMetadata('s', 'g', 'x int', 'int', False, False, False) + assert f1 == f2 + assert f1 != f3 + assert not (f1 != f2) + assert not (f1 == f3) + assert hash(f1) == hash(f2) + assert hash(f1) != hash(f3) + +def test_parse_typed_field_list_simple(): + sql = 'a int, b int[][], c double precision, d text' + tokens = sqlparse.parse(sql)[0].flatten() + args = list(parse_typed_field_list(tokens)) + assert [arg.name for arg in args] == ['a', 'b', 'c', 'd'] + + +def test_parse_typed_field_list_more_complex(): + sql = ''' IN a int = 5, + IN b text default 'abc'::text, + IN c double precision = 9.99", + OUT d double precision[] ''' + tokens = sqlparse.parse(sql)[0].flatten() + args = list(parse_typed_field_list(tokens)) + assert [arg.name for arg in args] == ['a', 'b', 'c', 'd'] + assert [arg.mode for arg in args] == ['IN', 'IN', 'IN', 'OUT'] + + +def test_parse_typed_field_list_no_arg_names(): + #waiting on sqlparse/169 + sql = 'int, double precision, text' + tokens = sqlparse.parse(sql)[0].flatten() + args = list(parse_typed_field_list(tokens)) + assert(len(args) == 3) + + +def test_table_column_names(): + tbl_str = ''' + x INT, + y DOUBLE PRECISION, + z TEXT ''' + names = list(field_names(tbl_str, mode_filter=None)) + assert names == ['x', 'y', 'z'] + + +def test_argument_names(): + func_header = 'IN x INT DEFAULT 2, OUT y DOUBLE PRECISION' + names = field_names(func_header, mode_filter=['OUT', 'INOUT']) + assert list(names) == ['y'] + + diff --git a/tests/test_pgexecute.py b/tests/test_pgexecute.py index 7be1d0b4..5d14eae8 100644 --- a/tests/test_pgexecute.py +++ b/tests/test_pgexecute.py @@ -2,9 +2,10 @@ import pytest from pgspecial.main import PGSpecial +from pgcli.packages.function_metadata import FunctionMetadata from textwrap import dedent from utils import run, dbtest, requires_json, requires_jsonb -from pgcli.pgexecute import FunctionMetadata + @dbtest def test_conn(executor): diff --git a/tests/test_smart_completion_multiple_schemata.py b/tests/test_smart_completion_multiple_schemata.py index ea6c3381..6597f6d0 100644 --- a/tests/test_smart_completion_multiple_schemata.py +++ b/tests/test_smart_completion_multiple_schemata.py @@ -2,7 +2,7 @@ from __future__ import unicode_literals import pytest from prompt_toolkit.completion import Completion from prompt_toolkit.document import Document -from pgcli.pgexecute import FunctionMetadata +from pgcli.packages.function_metadata import FunctionMetadata metadata = { 'tables': { diff --git a/tests/test_smart_completion_public_schema_only.py b/tests/test_smart_completion_public_schema_only.py index 98e8f088..cbf7d990 100644 --- a/tests/test_smart_completion_public_schema_only.py +++ b/tests/test_smart_completion_public_schema_only.py @@ -2,7 +2,7 @@ from __future__ import unicode_literals import pytest from prompt_toolkit.completion import Completion from prompt_toolkit.document import Document -from pgcli.pgexecute import FunctionMetadata +from pgcli.packages.function_metadata import FunctionMetadata metadata = { 'tables': { |