from __future__ import print_function, unicode_literals import logging import re from itertools import count, repeat, chain import operator from collections import namedtuple, defaultdict, OrderedDict from cli_helpers.tabular_output import TabularOutputFormatter from pgspecial.namedqueries import NamedQueries from prompt_toolkit.completion import Completer, Completion from prompt_toolkit.contrib.completers import PathCompleter from prompt_toolkit.document import Document from .packages.sqlcompletion import ( FromClauseItem, suggest_type, Special, Database, Schema, Table, TableFormat, Function, Column, View, Keyword, NamedQuery, Datatype, Alias, Path, JoinCondition, Join) from .packages.parseutils.meta import ColumnMetadata, ForeignKey from .packages.parseutils.utils import last_word from .packages.parseutils.tables import TableReference from .packages.pgliterals.main import get_literals from .packages.prioritization import PrevalenceCounter from .config import load_config, config_location _logger = logging.getLogger(__name__) NamedQueries.instance = NamedQueries.from_config( load_config(config_location() + 'config')) Match = namedtuple('Match', ['completion', 'priority']) _SchemaObject = namedtuple('SchemaObject', 'name schema meta') def SchemaObject(name, schema=None, meta=None): return _SchemaObject(name, schema, meta) _Candidate = namedtuple( 'Candidate', 'completion prio meta synonyms prio2 display' ) def Candidate( completion, prio=None, meta=None, synonyms=None, prio2=None, display=None ): return _Candidate( completion, prio, meta, synonyms or [completion], prio2, display or completion ) # Used to strip trailing '::some_type' from default-value expressions arg_default_type_strip_regex = re.compile(r'::[\w\.]+(\[\])?$') normalize_ref = lambda ref: ref if ref[0] == '"' else '"' + ref.lower() + '"' def generate_alias(tbl): """ Generate a table alias, consisting of all upper-case letters in the table name, or, if there are no upper-case letters, the first letter + all letters preceded by _ param tbl - unescaped name of the table to alias """ return ''.join([l for l in tbl if l.isupper()] or [l for l, prev in zip(tbl, '_' + tbl) if prev == '_' and l != '_']) class PGCompleter(Completer): # keywords_tree: A dict mapping keywords to well known following keywords. # e.g. 'CREATE': ['TABLE', 'USER', ...], keywords_tree = get_literals('keywords', type_=dict) keywords = tuple(set(chain(keywords_tree.keys(), *keywords_tree.values()))) functions = get_literals('functions') datatypes = get_literals('datatypes') reserved_words = set(get_literals('reserved')) def __init__(self, smart_completion=True, pgspecial=None, settings=None): super(PGCompleter, self).__init__() self.smart_completion = smart_completion self.pgspecial = pgspecial self.prioritizer = PrevalenceCounter() settings = settings or {} self.signature_arg_style = settings.get( 'signature_arg_style', '{arg_name} {arg_type}' ) self.call_arg_style = settings.get( 'call_arg_style', '{arg_name: <{max_arg_len}} := {arg_default}' ) self.call_arg_display_style = settings.get( 'call_arg_display_style', '{arg_name}' ) self.call_arg_oneliner_max = settings.get('call_arg_oneliner_max', 2) self.search_path_filter = settings.get('search_path_filter') self.generate_aliases = settings.get('generate_aliases') self.casing_file = settings.get('casing_file') self.insert_col_skip_patterns = [ re.compile(pattern) for pattern in settings.get( 'insert_col_skip_patterns', [r'^now\(\)$', r'^nextval\('] ) ] self.generate_casing_file = settings.get('generate_casing_file') self.qualify_columns = settings.get( 'qualify_columns', 'if_more_than_one_table') self.asterisk_column_order = settings.get( 'asterisk_column_order', 'table_order') keyword_casing = settings.get('keyword_casing', 'upper').lower() if keyword_casing not in ('upper', 'lower', 'auto'): keyword_casing = 'upper' self.keyword_casing = keyword_casing self.name_pattern = re.compile(r"^[_a-z][_a-z0-9\$]*$") self.databases = [] self.dbmetadata = {'tables': {}, 'views': {}, 'functions': {}, 'datatypes': {}} self.search_path = [] self.casing = {} self.all_completions = set(self.keywords + self.functions) def escape_name(self, name): if name and ((not self.name_pattern.match(name)) or (name.upper() in self.reserved_words) or (name.upper() in self.functions)): name = '"%s"' % name return name def escape_schema(self, name): return "'{}'".format(self.unescape_name(name)) def unescape_name(self, name): """ Unquote a string.""" if name and name[0] == '"' and name[-1] == '"': name = name[1:-1] return name def escaped_names(self, names): return [self.escape_name(name) for name in names] def extend_database_names(self, databases): self.databases.extend(databases) def extend_keywords(self, additional_keywords): self.keywords.extend(additional_keywords) self.all_completions.update(additional_keywords) def extend_schemata(self, schemata): # schemata is a list of schema names schemata = self.escaped_names(schemata) metadata = self.dbmetadata['tables'] for schema in schemata: metadata[schema] = {} # dbmetadata.values() are the 'tables' and 'functions' dicts for metadata in self.dbmetadata.values(): for schema in schemata: metadata[schema] = {} self.all_completions.update(schemata) def extend_casing(self, words): """ extend casing data :return: """ # casing should be a dict {lowercasename:PreferredCasingName} self.casing = dict((word.lower(), word) for word in words) def extend_relations(self, data, kind): """extend metadata for tables or views. :param data: list of (schema_name, rel_name) tuples :param kind: either 'tables' or 'views' :return: """ data = [self.escaped_names(d) for d in data] # dbmetadata['tables']['schema_name']['table_name'] should be an # OrderedDict {column_name:ColumnMetaData}. metadata = self.dbmetadata[kind] for schema, relname in data: try: metadata[schema][relname] = OrderedDict() except KeyError: _logger.error('%r %r listed in unrecognized schema %r', kind, relname, schema) self.all_completions.add(relname) def extend_columns(self, column_data, kind): """extend column metadata. :param column_data: list of (schema_name, rel_name, column_name, column_type, has_default, default) tuples :param kind: either 'tables' or 'views' :return: """ metadata = self.dbmetadata[kind] for schema, relname, colname, datatype, has_default, default in column_data: (schema, relname, colname) = self.escaped_names( [schema, relname, colname]) column = ColumnMetadata( name=colname, datatype=datatype, has_default=has_default, default=default ) metadata[schema][relname][colname] = column self.all_completions.add(colname) def extend_functions(self, func_data): # func_data is a list of function metadata namedtuples # dbmetadata['schema_name']['functions']['function_name'] should return # the function metadata namedtuple for the corresponding function metadata = self.dbmetadata['functions'] for f in func_data: schema, func = self.escaped_names([f.schema_name, f.func_name]) if func in metadata[schema]: metadata[schema][func].append(f) else: metadata[schema][func] = [f] self.all_completions.add(func) self._refresh_arg_list_cache() def _refresh_arg_list_cache(self): # We keep a cache of {function_usage:{function_metadata: function_arg_list_string}} # This is used when suggesting functions, to avoid the latency that would result # if we'd recalculate the arg lists each time we suggest functions (in large DBs) self._arg_list_cache = { usage: { meta: self._arg_list(meta, usage) for sch, funcs in self.dbmetadata['functions'].items() for func, metas in funcs.items() for meta in metas } for usage in ('call', 'call_display', 'signature') } def extend_foreignkeys(self, fk_data): # fk_data is a list of ForeignKey namedtuples, with fields # parentschema, childschema, parenttable, childtable, # parentcolumns, childcolumns # These are added as a list of ForeignKey namedtuples to the # ColumnMetadata namedtuple for both the child and parent meta = self.dbmetadata['tables'] for fk in fk_data: e = self.escaped_names parentschema, childschema = e([fk.parentschema, fk.childschema]) parenttable, childtable = e([fk.parenttable, fk.childtable]) childcol, parcol = e([fk.childcolumn, fk.parentcolumn]) childcolmeta = meta[childschema][childtable][childcol] parcolmeta = meta[parentschema][parenttable][parcol] fk = ForeignKey(parentschema, parenttable, parcol, childschema, childtable, childcol) childcolmeta.foreignkeys.append((fk)) parcolmeta.foreignkeys.append((fk)) def extend_datatypes(self, type_data): # dbmetadata['datatypes'][schema_name][type_name] should store type # metadata, such as composite type field names. Currently, we're not # storing any metadata beyond typename, so just store None meta = self.dbmetadata['datatypes'] for t in type_data: schema, type_name = self.escaped_names(t) meta[schema][type_name] = None self.all_completions.add(type_name) def extend_query_history(self, text, is_init=False): if is_init: # During completer initialization, only load keyword preferences, # not names self.prioritizer.update_keywords(text) else: self.prioritizer.update(text) def set_search_path(self, search_path): self.search_path = self.escaped_names(search_path) def reset_completions(self): self.databases = [] self.special_commands = [] self.search_path = [] self.dbmetadata = {'tables': {}, 'views': {}, 'functions': {}, 'datatypes': {}} self.all_completions = set(self.keywords + self.functions) def find_matches(self, text, collection, mode='fuzzy', meta=None): """Find completion matches for the given text. Given the user's input text and a collection of available completions, find completions matching the last word of the text. `collection` can be either a list of strings or a list of Candidate namedtuples. `mode` can be either 'fuzzy', or 'strict' 'fuzzy': fuzzy matching, ties broken by name prevalance `keyword`: start only matching, ties broken by keyword prevalance yields prompt_toolkit Completion instances for any matches found in the collection of available completions. """ if not collection: return [] prio_order = [ 'keyword', 'function', 'view', 'table', 'datatype', 'database', 'schema', 'column', 'table alias', 'join', 'name join', 'fk join', 'table format' ] type_priority = prio_order.index(meta) if meta in prio_order else -1 text = last_word(text, include='most_punctuations').lower() text_len = len(text) if text and text[0] == '"': # text starts with double quote; user is manually escaping a name # Match on everything that follows the double-quote. Note that # text_len is calculated before removing the quote, so the # Completion.position value is correct text = text[1:] if mode == 'fuzzy': fuzzy = True priority_func = self.prioritizer.name_count else: fuzzy = False priority_func = self.prioritizer.keyword_count # Construct a `_match` function for either fuzzy or non-fuzzy matching # The match function returns a 2-tuple used for sorting the matches, # or None if the item doesn't match # Note: higher priority values mean more important, so use negative # signs to flip the direction of the tuple if fuzzy: regex = '.*?'.join(map(re.escape, text)) pat = re.compile('(%s)' % regex) def _match(item): if item.lower()[:len(text) + 1] in (text, text + ' '): # Exact match of first word in suggestion # This is to get exact alias matches to the top # E.g. for input `e`, 'Entries E' should be on top # (before e.g. `EndUsers EU`) return float('Infinity'), -1 r = pat.search(self.unescape_name(item.lower())) if r: return -len(r.group()), -r.start() else: match_end_limit = len(text) def _match(item): match_point = item.lower().find(text, 0, match_end_limit) if match_point >= 0: # Use negative infinity to force keywords to sort after all # fuzzy matches return -float('Infinity'), -match_point matches = [] for cand in collection: if isinstance(cand, _Candidate): item, prio, display_meta, synonyms, prio2, display = cand if display_meta is None: display_meta = meta syn_matches = (_match(x) for x in synonyms) # Nones need to be removed to avoid max() crashing in Python 3 syn_matches = [m for m in syn_matches if m] sort_key = max(syn_matches) if syn_matches else None else: item, display_meta, prio, prio2, display = cand, meta, 0, 0, cand sort_key = _match(cand) if sort_key: if display_meta and len(display_meta) > 50: # Truncate meta-text to 50 characters, if necessary display_meta = display_meta[:47] + u'...' # Lexical order of items in the collection, used for # tiebreaking items with the same match group length and start # position. Since we use *higher* priority to mean "more # important," we use -ord(c) to prioritize "aa" > "ab" and end # with 1 to prioritize shorter strings (ie "user" > "users"). # We first do a case-insensitive sort and then a # case-sensitive one as a tie breaker. # We also use the unescape_name to make sure quoted names have # the same priority as unquoted names. lexical_priority = (tuple(0 if c in(' _') else -ord(c) for c in self.unescape_name(item.lower())) + (1,) + tuple(c for c in item)) item = self.case(item) display = self.case(display) priority = ( sort_key, type_priority, prio, priority_func(item), prio2, lexical_priority ) matches.append( Match( completion=Completion( text=item, start_position=-text_len, display_meta=display_meta, display=display ), priority=priority ) ) return matches def case(self, word): return self.casing.get(word, word) def get_completions(self, document, complete_event, smart_completion=None): word_before_cursor = document.get_word_before_cursor(WORD=True) if smart_completion is None: smart_completion = self.smart_completion # If smart_completion is off then match any word that starts with # 'word_before_cursor'. if not smart_completion: matches = self.find_matches(word_before_cursor, self.all_completions, mode='strict') completions = [m.completion for m in matches] return sorted(completions, key=operator.attrgetter('text')) matches = [] suggestions = suggest_type(document.text, document.text_before_cursor) for suggestion in suggestions: suggestion_type = type(suggestion) _logger.debug('Suggestion type: %r', suggestion_type) # Map suggestion type to method # e.g. 'table' -> self.get_table_matches matcher = self.suggestion_matchers[suggestion_type] matches.extend(matcher(self, suggestion, word_before_cursor)) # Sort matches so highest priorities are first matches = sorted(matches, key=operator.attrgetter('priority'), reverse=True) return [m.completion for m in matches] def get_column_matches(self, suggestion, word_before_cursor): tables = suggestion.table_refs do_qualify = suggestion.qualifiable and {'always': True, 'never': False, 'if_more_than_one_table': len(tables) > 1}[self.qualify_columns] qualify = lambda col, tbl: ( (tbl + '.' + self.case(col)) if do_qualify else self.case(col)) _logger.debug("Completion column scope: %r", tables) scoped_cols = self.populate_scoped_cols(tables, suggestion.local_tables) def make_cand(name, ref): synonyms = (name, generate_alias(self.case(name))) return Candidate(qualify(name, ref), 0, 'column', synonyms) def flat_cols(): return [make_cand(c.name, t.ref) for t, cols in scoped_cols.items() for c in cols] if suggestion.require_last_table: # require_last_table is used for 'tb11 JOIN tbl2 USING (...' which should # suggest only columns that appear in the last table and one more ltbl = tables[-1].ref other_tbl_cols = set( c.name for t, cs in scoped_cols.items() if t.ref != ltbl for c in cs) scoped_cols = { t: [col for col in cols if col.name in other_tbl_cols] for t, cols in scoped_cols.items() if t.ref == ltbl } lastword = last_word(word_before_cursor, include='most_punctuations') if lastword == '*': if suggestion.context == 'insert': def filter(col): if not col.has_default: return True return not any( p.match(col.default) for p in self.insert_col_skip_patterns ) scoped_cols = { t: [col for col in cols if filter(col)] for t, cols in scoped_cols.items() } if self.asterisk_column_order == 'alphabetic': for cols in scoped_cols.values(): cols.sort(key=operator.attrgetter('name')) if (lastword != word_before_cursor and len(tables) == 1 and word_before_cursor[-len(lastword) - 1] == '.'): # User typed x.*; replicate "x." for all columns except the # first, which gets the original (as we only replace the "*"") sep = ', ' + word_before_cursor[:-1] collist = sep.join(self.case(c.completion) for c in flat_cols()) else: collist = ', '.join(qualify(c.name, t.ref) for t, cs in scoped_cols.items() for c in cs) return [Match( completion=Completion( collist, -1, display_meta='columns', display='*' ), priority=(1, 1, 1) )] return self.find_matches(word_before_cursor, flat_cols(), meta='column') def alias(self, tbl, tbls): """ Generate a unique table alias tbl - name of the table to alias, quoted if it needs to be tbls - TableReference iterable of tables already in query """ tbl = self.case(tbl) tbls = set(normalize_ref(t.ref) for t in tbls) if self.generate_aliases: tbl = generate_alias(self.unescape_name(tbl)) if normalize_ref(tbl) not in tbls: return tbl elif tbl[0] == '"': aliases = ('"' + tbl[1:-1] + str(i) + '"' for i in count(2)) else: aliases = (tbl + str(i) for i in count(2)) return next(a for a in aliases if normalize_ref(a) not in tbls) def get_join_matches(self, suggestion, word_before_cursor): tbls = suggestion.table_refs cols = self.populate_scoped_cols(tbls) # Set up some data structures for efficient access qualified = dict((normalize_ref(t.ref), t.schema) for t in tbls) ref_prio = dict((normalize_ref(t.ref), n) for n, t in enumerate(tbls)) refs = set(normalize_ref(t.ref) for t in tbls) other_tbls = set((t.schema, t.name) for t in list(cols)[:-1]) joins = [] # Iterate over FKs in existing tables to find potential joins fks = ((fk, rtbl, rcol) for rtbl, rcols in cols.items() for rcol in rcols for fk in rcol.foreignkeys) col = namedtuple('col', 'schema tbl col') for fk, rtbl, rcol in fks: right = col(rtbl.schema, rtbl.name, rcol.name) child = col(fk.childschema, fk.childtable, fk.childcolumn) parent = col(fk.parentschema, fk.parenttable, fk.parentcolumn) left = child if parent == right else parent if suggestion.schema and left.schema != suggestion.schema: continue c = self.case if self.generate_aliases or normalize_ref(left.tbl) in refs: lref = self.alias(left.tbl, suggestion.table_refs) join = '{0} {4} ON {4}.{1} = {2}.{3}'.format( c(left.tbl), c(left.col), rtbl.ref, c(right.col), lref) else: join = '{0} ON {0}.{1} = {2}.{3}'.format( c(left.tbl), c(left.col), rtbl.ref, c(right.col)) alias = generate_alias(self.case(left.tbl)) synonyms = [join, '{0} ON {0}.{1} = {2}.{3}'.format( alias, c(left.col), rtbl.ref, c(right.col))] # Schema-qualify if (1) new table in same schema as old, and old # is schema-qualified, or (2) new in other schema, except public if not suggestion.schema and (qualified[normalize_ref(rtbl.ref)] and left.schema == right.schema or left.schema not in(right.schema, 'public')): join = left.schema + '.' + join prio = ref_prio[normalize_ref(rtbl.ref)] * 2 + ( 0 if (left.schema, left.tbl) in other_tbls else 1) joins.append(Candidate(join, prio, 'join', synonyms=synonyms)) return self.find_matches(word_before_cursor, joins, meta='join') def get_join_condition_matches(self, suggestion, word_before_cursor): col = namedtuple('col', 'schema tbl col') tbls = self.populate_scoped_cols(suggestion.table_refs).items cols = [(t, c) for t, cs in tbls() for c in cs] try: lref = (suggestion.parent or suggestion.table_refs[-1]).ref ltbl, lcols = [(t, cs) for (t, cs) in tbls() if t.ref == lref][-1] except IndexError: # The user typed an incorrect table qualifier return [] conds, found_conds = [], set() def add_cond(lcol, rcol, rref, prio, meta): prefix = '' if suggestion.parent else ltbl.ref + '.' case = self.case cond = prefix + case(lcol) + ' = ' + rref + '.' + case(rcol) if cond not in found_conds: found_conds.add(cond) conds.append(Candidate(cond, prio + ref_prio[rref], meta)) def list_dict(pairs): # Turns [(a, b), (a, c)] into {a: [b, c]} d = defaultdict(list) for pair in pairs: d[pair[0]].append(pair[1]) return d # Tables that are closer to the cursor get higher prio ref_prio = dict((tbl.ref, num) for num, tbl in enumerate(suggestion.table_refs)) # Map (schema, table, col) to tables coldict = list_dict(((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref) # For each fk from the left table, generate a join condition if # the other table is also in the scope fks = ((fk, lcol.name) for lcol in lcols for fk in lcol.foreignkeys) for fk, lcol in fks: left = col(ltbl.schema, ltbl.name, lcol) child = col(fk.childschema, fk.childtable, fk.childcolumn) par = col(fk.parentschema, fk.parenttable, fk.parentcolumn) left, right = (child, par) if left == child else (par, child) for rtbl in coldict[right]: add_cond(left.col, right.col, rtbl.ref, 2000, 'fk join') # For name matching, use a {(colname, coltype): TableReference} dict coltyp = namedtuple('coltyp', 'name datatype') col_table = list_dict((coltyp(c.name, c.datatype), t) for t, c in cols) # Find all name-match join conditions for c in (coltyp(c.name, c.datatype) for c in lcols): for rtbl in (t for t in col_table[c] if t.ref != ltbl.ref): prio = 1000 if c.datatype in ( 'integer', 'bigint', 'smallint') else 0 add_cond(c.name, c.name, rtbl.ref, prio, 'name join') return self.find_matches(word_before_cursor, conds, meta='join') def get_function_matches(self, suggestion, word_before_cursor, alias=False): if suggestion.usage == 'from': # Only suggest functions allowed in FROM clause def filt(f): return not f.is_aggregate and not f.is_window else: alias = False def filt(f): return True arg_mode = { 'signature': 'signature', 'special': None, }.get(suggestion.usage, 'call') # Function overloading means we way have multiple functions of the same # name at this point, so keep unique names only funcs = set( self._make_cand(f, alias, suggestion, arg_mode) for f in self.populate_functions(suggestion.schema, filt) ) matches = self.find_matches(word_before_cursor, funcs, meta='function') if not suggestion.schema and not suggestion.usage: # also suggest hardcoded functions using startswith matching predefined_funcs = self.find_matches( word_before_cursor, self.functions, mode='strict', meta='function') matches.extend(predefined_funcs) return matches def get_schema_matches(self, suggestion, word_before_cursor): schema_names = self.dbmetadata['tables'].keys() # Unless we're sure the user really wants them, hide schema names # starting with pg_, which are mostly temporary schemas if not word_before_cursor.startswith('pg_'): schema_names = [s for s in schema_names if not s.startswith('pg_')] if suggestion.quoted: schema_names = [self.escape_schema(s) for s in schema_names] return self.find_matches(word_before_cursor, schema_names, meta='schema') def get_from_clause_item_matches(self, suggestion, word_before_cursor): alias = self.generate_aliases s = suggestion t_sug = Table(s.schema, s.table_refs, s.local_tables) v_sug = View(s.schema, s.table_refs) f_sug = Function(s.schema, s.table_refs, usage='from') return ( self.get_table_matches(t_sug, word_before_cursor, alias) + self.get_view_matches(v_sug, word_before_cursor, alias) + self.get_function_matches(f_sug, word_before_cursor, alias) ) def _arg_list(self, func, usage): """Returns a an arg list string, e.g. `(_foo:=23)` for a func. :param func is a FunctionMetadata object :param usage is 'call', 'call_display' or 'signature' """ template = { 'call': self.call_arg_style, 'call_display': self.call_arg_display_style, 'signature': self.signature_arg_style }[usage] args = func.args() if not template: return '()' elif usage == 'call' and len(args) < 2: return '()' elif usage == 'call' and func.has_variadic(): return '()' multiline = usage == 'call' and len(args) > self.call_arg_oneliner_max max_arg_len = max(len(a.name) for a in args) if multiline else 0 args = ( self._format_arg(template, arg, arg_num + 1, max_arg_len) for arg_num, arg in enumerate(args) ) if multiline: return '(' + ','.join('\n ' + a for a in args if a) + '\n)' else: return '(' + ', '.join(a for a in args if a) + ')' def _format_arg(self, template, arg, arg_num, max_arg_len): if not template: return None if arg.has_default: arg_default = 'NULL' if arg.default is None else arg.default # Remove trailing ::(schema.)type arg_default = arg_default_type_strip_regex.sub('', arg_default) else: arg_default = '' return template.format( max_arg_len=max_arg_len, arg_name=self.case(arg.name), arg_num=arg_num, arg_type=arg.datatype, arg_default=arg_default ) def _make_cand(self, tbl, do_alias, suggestion, arg_mode=None): """Returns a Candidate namedtuple. :param tbl is a SchemaObject :param arg_mode determines what type of arg list to suffix for functions. Possible values: call, signature """ cased_tbl = self.case(tbl.name) if do_alias: alias = self.alias(cased_tbl, suggestion.table_refs) synonyms = (cased_tbl, generate_alias(cased_tbl)) maybe_alias = (' ' + alias) if do_alias else '' maybe_schema = (self.case(tbl.schema) + '.') if tbl.schema else '' suffix = self._arg_list_cache[arg_mode][tbl.meta] if arg_mode else '' if arg_mode == 'call': display_suffix = self._arg_list_cache['call_display'][tbl.meta] elif arg_mode == 'signature': display_suffix = self._arg_list_cache['signature'][tbl.meta] else: display_suffix = '' item = maybe_schema + cased_tbl + suffix + maybe_alias display = maybe_schema + cased_tbl + display_suffix + maybe_alias prio2 = 0 if tbl.schema else 1 return Candidate(item, synonyms=synonyms, prio2=prio2, display=display) def get_table_matches(self, suggestion, word_before_cursor, alias=False): tables = self.populate_schema_objects(suggestion.schema, 'tables') tables.extend(SchemaObject(tbl.name) for tbl in suggestion.local_tables) # Unless we're sure the user really wants them, don't suggest the # pg_catalog tables that are implicitly on the search path if not suggestion.schema and ( not word_before_cursor.startswith('pg_')): tables = [t for t in tables if not t.name.startswith('pg_')] tables = [self._make_cand(t, alias, suggestion) for t in tables] return self.find_matches(word_before_cursor, tables, meta='table') def get_table_formats(self, _, word_before_cursor): formats = TabularOutputFormatter().supported_formats return self.find_matches(word_before_cursor, formats, meta='table format') def get_view_matches(self, suggestion, word_before_cursor, alias=False): views = self.populate_schema_objects(suggestion.schema, 'views') if not suggestion.schema and ( not word_before_cursor.startswith('pg_')): views = [v for v in views if not v.name.startswith('pg_')] views = [self._make_cand(v, alias, suggestion) for v in views] return self.find_matches(word_before_cursor, views, meta='view') def get_alias_matches(self, suggestion, word_before_cursor): aliases = suggestion.aliases return self.find_matches(word_before_cursor, aliases, meta='table alias') def get_database_matches(self, _, word_before_cursor): return self.find_matches(word_before_cursor, self.databases, meta='database') def get_keyword_matches(self, suggestion, word_before_cursor): keywords = self.keywords_tree.keys() # Get well known following keywords for the last token. If any, narrow # candidates to this list. next_keywords = self.keywords_tree.get(suggestion.last_token, []) if next_keywords: keywords = next_keywords casing = self.keyword_casing if casing == 'auto': if word_before_cursor and word_before_cursor[-1].islower(): casing = 'lower' else: casing = 'upper' if casing == 'upper': keywords = [k.upper() for k in keywords] else: keywords = [k.lower() for k in keywords] return self.find_matches(word_before_cursor, keywords, mode='strict', meta='keyword') def get_path_matches(self, _, word_before_cursor): completer = PathCompleter(expanduser=True) document = Document(text=word_before_cursor, cursor_position=len(word_before_cursor)) for c in completer.get_completions(document, None): yield Match(completion=c, priority=(0,)) def get_special_matches(self, _, word_before_cursor): if not self.pgspecial: return [] commands = self.pgspecial.commands cmds = commands.keys() cmds = [Candidate(cmd, 0, commands[cmd].description) for cmd in cmds] return self.find_matches(word_before_cursor, cmds, mode='strict') def get_datatype_matches(self, suggestion, word_before_cursor): # suggest custom datatypes types = self.populate_schema_objects(suggestion.schema, 'datatypes') types = [self._make_cand(t, False, suggestion) for t in types] matches = self.find_matches(word_before_cursor, types, meta='datatype') if not suggestion.schema: # Also suggest hardcoded types matches.extend(self.find_matches(word_before_cursor, self.datatypes, mode='strict', meta='datatype')) return matches def get_namedquery_matches(self, _, word_before_cursor): return self.find_matches( word_before_cursor, NamedQueries.instance.list(), meta='named query') suggestion_matchers = { FromClauseItem: get_from_clause_item_matches, JoinCondition: get_join_condition_matches, Join: get_join_matches, Column: get_column_matches, Function: get_function_matches, Schema: get_schema_matches, Table: get_table_matches, TableFormat: get_table_formats, View: get_view_matches, Alias: get_alias_matches, Database: get_database_matches, Keyword: get_keyword_matches, Special: get_special_matches, Datatype: get_datatype_matches, NamedQuery: get_namedquery_matches, Path: get_path_matches, } def populate_scoped_cols(self, scoped_tbls, local_tbls=()): """Find all columns in a set of scoped_tables. :param scoped_tbls: list of TableReference namedtuples :param local_tbls: tuple(TableMetadata) :return: {TableReference:{colname:ColumnMetaData}} """ ctes = dict((normalize_ref(t.name), t.columns) for t in local_tbls) columns = OrderedDict() meta = self.dbmetadata def addcols(schema, rel, alias, reltype, cols): tbl = TableReference(schema, rel, alias, reltype == 'functions') if tbl not in columns: columns[tbl] = [] columns[tbl].extend(cols) for tbl in scoped_tbls: # Local tables should shadow database tables if tbl.schema is None and normalize_ref(tbl.name) in ctes: cols = ctes[normalize_ref(tbl.name)] addcols(None, tbl.name, 'CTE', tbl.alias, cols) continue schemas = [tbl.schema] if tbl.schema else self.search_path for schema in schemas: relname = self.escape_name(tbl.name) schema = self.escape_name(schema) if tbl.is_function: # Return column names from a set-returning function # Get an array of FunctionMetadata objects functions = meta['functions'].get(schema, {}).get(relname) for func in (functions or []): # func is a FunctionMetadata object cols = func.fields() addcols(schema, relname, tbl.alias, 'functions', cols) else: for reltype in ('tables', 'views'): cols = meta[reltype].get(schema, {}).get(relname) if cols: cols = cols.values() addcols(schema, relname, tbl.alias, reltype, cols) break return columns def _get_schemas(self, obj_typ, schema): """Returns a list of schemas from which to suggest objects. :param schema is the schema qualification input by the user (if any) """ metadata = self.dbmetadata[obj_typ] if schema: schema = self.escape_name(schema) return [schema] if schema in metadata else [] return self.search_path if self.search_path_filter else metadata.keys() def _maybe_schema(self, schema, parent): return None if parent or schema in self.search_path else schema def populate_schema_objects(self, schema, obj_type): """Returns a list of SchemaObjects representing tables or views. :param schema is the schema qualification input by the user (if any) """ return [ SchemaObject( name=obj, schema=(self._maybe_schema(schema=sch, parent=schema)) ) for sch in self._get_schemas(obj_type, schema) for obj in self.dbmetadata[obj_type][sch].keys() ] def populate_functions(self, schema, filter_func): """Returns a list of function SchemaObjects. :param filter_func is a function that accepts a FunctionMetadata namedtuple and returns a boolean indicating whether that function should be kept or discarded """ # Because of multiple dispatch, we can have multiple functions # with the same name, which is why `for meta in metas` is necessary # in the comprehensions below return [ SchemaObject( name=func, schema=(self._maybe_schema(schema=sch, parent=schema)), meta=meta ) for sch in self._get_schemas('functions', schema) for (func, metas) in self.dbmetadata['functions'][sch].items() for meta in metas if filter_func(meta) ]