from __future__ import print_function, unicode_literals
import logging
import re
from itertools import count, repeat, chain
import operator
from collections import namedtuple, defaultdict, OrderedDict
from cli_helpers.tabular_output import TabularOutputFormatter
from pgspecial.namedqueries import NamedQueries
from prompt_toolkit.completion import Completer, Completion
from prompt_toolkit.contrib.completers import PathCompleter
from prompt_toolkit.document import Document
from .packages.sqlcompletion import (
FromClauseItem, suggest_type, Special, Database, Schema, Table,
TableFormat, Function, Column, View, Keyword, NamedQuery,
Datatype, Alias, Path, JoinCondition, Join)
from .packages.parseutils.meta import ColumnMetadata, ForeignKey
from .packages.parseutils.utils import last_word
from .packages.parseutils.tables import TableReference
from .packages.pgliterals.main import get_literals
from .packages.prioritization import PrevalenceCounter
from .config import load_config, config_location
_logger = logging.getLogger(__name__)
NamedQueries.instance = NamedQueries.from_config(
load_config(config_location() + 'config'))
Match = namedtuple('Match', ['completion', 'priority'])
_SchemaObject = namedtuple('SchemaObject', 'name schema meta')
def SchemaObject(name, schema=None, meta=None):
return _SchemaObject(name, schema, meta)
_Candidate = namedtuple(
'Candidate', 'completion prio meta synonyms prio2 display'
)
def Candidate(
completion, prio=None, meta=None, synonyms=None, prio2=None,
display=None
):
return _Candidate(
completion, prio, meta, synonyms or [completion], prio2,
display or completion
)
# Used to strip trailing '::some_type' from default-value expressions
arg_default_type_strip_regex = re.compile(r'::[\w\.]+(\[\])?$')
normalize_ref = lambda ref: ref if ref[0] == '"' else '"' + ref.lower() + '"'
def generate_alias(tbl):
""" Generate a table alias, consisting of all upper-case letters in
the table name, or, if there are no upper-case letters, the first letter +
all letters preceded by _
param tbl - unescaped name of the table to alias
"""
return ''.join([l for l in tbl if l.isupper()] or
[l for l, prev in zip(tbl, '_' + tbl) if prev == '_' and l != '_'])
class PGCompleter(Completer):
# keywords_tree: A dict mapping keywords to well known following keywords.
# e.g. 'CREATE': ['TABLE', 'USER', ...],
keywords_tree = get_literals('keywords', type_=dict)
keywords = tuple(set(chain(keywords_tree.keys(), *keywords_tree.values())))
functions = get_literals('functions')
datatypes = get_literals('datatypes')
reserved_words = set(get_literals('reserved'))
def __init__(self, smart_completion=True, pgspecial=None, settings=None):
super(PGCompleter, self).__init__()
self.smart_completion = smart_completion
self.pgspecial = pgspecial
self.prioritizer = PrevalenceCounter()
settings = settings or {}
self.signature_arg_style = settings.get(
'signature_arg_style', '{arg_name} {arg_type}'
)
self.call_arg_style = settings.get(
'call_arg_style', '{arg_name: <{max_arg_len}} := {arg_default}'
)
self.call_arg_display_style = settings.get(
'call_arg_display_style', '{arg_name}'
)
self.call_arg_oneliner_max = settings.get('call_arg_oneliner_max', 2)
self.search_path_filter = settings.get('search_path_filter')
self.generate_aliases = settings.get('generate_aliases')
self.casing_file = settings.get('casing_file')
self.insert_col_skip_patterns = [
re.compile(pattern) for pattern in settings.get(
'insert_col_skip_patterns',
[r'^now\(\)$', r'^nextval\(']
)
]
self.generate_casing_file = settings.get('generate_casing_file')
self.qualify_columns = settings.get(
'qualify_columns', 'if_more_than_one_table')
self.asterisk_column_order = settings.get(
'asterisk_column_order', 'table_order')
keyword_casing = settings.get('keyword_casing', 'upper').lower()
if keyword_casing not in ('upper', 'lower', 'auto'):
keyword_casing = 'upper'
self.keyword_casing = keyword_casing
self.name_pattern = re.compile(r"^[_a-z][_a-z0-9\$]*$")
self.databases = []
self.dbmetadata = {'tables': {}, 'views': {}, 'functions': {},
'datatypes': {}}
self.search_path = []
self.casing = {}
self.all_completions = set(self.keywords + self.functions)
def escape_name(self, name):
if name and ((not self.name_pattern.match(name))
or (name.upper() in self.reserved_words)
or (name.upper() in self.functions)):
name = '"%s"' % name
return name
def escape_schema(self, name):
return "'{}'".format(self.unescape_name(name))
def unescape_name(self, name):
""" Unquote a string."""
if name and name[0] == '"' and name[