summaryrefslogtreecommitdiffstats
path: root/pgcli
diff options
context:
space:
mode:
Diffstat (limited to 'pgcli')
-rw-r--r--pgcli/__init__.py2
-rw-r--r--pgcli/auth.py60
-rw-r--r--pgcli/completion_refresher.py8
-rw-r--r--pgcli/explain_output_formatter.py19
-rw-r--r--pgcli/key_bindings.py10
-rw-r--r--pgcli/magic.py2
-rw-r--r--pgcli/main.py452
-rw-r--r--pgcli/packages/formatter/__init__.py1
-rw-r--r--pgcli/packages/formatter/sqlformatter.py74
-rw-r--r--pgcli/packages/parseutils/__init__.py48
-rw-r--r--pgcli/packages/parseutils/tables.py20
-rw-r--r--pgcli/packages/prompt_utils.py14
-rw-r--r--pgcli/packages/sqlcompletion.py5
-rw-r--r--pgcli/pgbuffer.py11
-rw-r--r--pgcli/pgclirc59
-rw-r--r--pgcli/pgcompleter.py46
-rw-r--r--pgcli/pgexecute.py414
-rw-r--r--pgcli/pgstyle.py2
-rw-r--r--pgcli/pgtoolbar.py17
-rw-r--r--pgcli/pyev.py439
20 files changed, 1300 insertions, 403 deletions
diff --git a/pgcli/__init__.py b/pgcli/__init__.py
index f5f41e56..76ad18b8 100644
--- a/pgcli/__init__.py
+++ b/pgcli/__init__.py
@@ -1 +1 @@
-__version__ = "3.1.0"
+__version__ = "4.0.1"
diff --git a/pgcli/auth.py b/pgcli/auth.py
new file mode 100644
index 00000000..2f1e5526
--- /dev/null
+++ b/pgcli/auth.py
@@ -0,0 +1,60 @@
+import click
+from textwrap import dedent
+
+
+keyring = None # keyring will be loaded later
+
+
+keyring_error_message = dedent(
+ """\
+ {}
+ {}
+ To remove this message do one of the following:
+ - prepare keyring as described at: https://keyring.readthedocs.io/en/stable/
+ - uninstall keyring: pip uninstall keyring
+ - disable keyring in our configuration: add keyring = False to [main]"""
+)
+
+
+def keyring_initialize(keyring_enabled, *, logger):
+ """Initialize keyring only if explicitly enabled"""
+ global keyring
+
+ if keyring_enabled:
+ # Try best to load keyring (issue #1041).
+ import importlib
+
+ try:
+ keyring = importlib.import_module("keyring")
+ except (
+ ModuleNotFoundError
+ ) as e: # ImportError for Python 2, ModuleNotFoundError for Python 3
+ logger.warning("import keyring failed: %r.", e)
+
+
+def keyring_get_password(key):
+ """Attempt to get password from keyring"""
+ # Find password from store
+ passwd = ""
+ try:
+ passwd = keyring.get_password("pgcli", key) or ""
+ except Exception as e:
+ click.secho(
+ keyring_error_message.format(
+ "Load your password from keyring returned:", str(e)
+ ),
+ err=True,
+ fg="red",
+ )
+ return passwd
+
+
+def keyring_set_password(key, passwd):
+ try:
+ keyring.set_password("pgcli", key, passwd)
+ except Exception as e:
+ click.secho(
+ keyring_error_message.format("Set password in keyring returned:", str(e)),
+ err=True,
+ fg="red",
+ )
diff --git a/pgcli/completion_refresher.py b/pgcli/completion_refresher.py
index 3e847b09..c887cb63 100644
--- a/pgcli/completion_refresher.py
+++ b/pgcli/completion_refresher.py
@@ -3,11 +3,9 @@ import os
from collections import OrderedDict
from .pgcompleter import PGCompleter
-from .pgexecute import PGExecute
class CompletionRefresher:
-
refreshers = OrderedDict()
def __init__(self):
@@ -27,6 +25,10 @@ class CompletionRefresher:
has completed the refresh. The newly created completion
object will be passed in as an argument to each callback.
"""
+ if executor.is_virtual_database():
+ # do nothing
+ return [(None, None, None, "Auto-completion refresh can't be started.")]
+
if self.is_refreshing():
self._restart_refresh.set()
return [(None, None, None, "Auto-completion refresh restarted.")]
@@ -36,7 +38,7 @@ class CompletionRefresher:
args=(executor, special, callbacks, history, settings),
name="completion_refresh",
)
- self._completer_thread.setDaemon(True)
+ self._completer_thread.daemon = True
self._completer_thread.start()
return [
(None, None, None, "Auto-completion refresh started in the background.")
diff --git a/pgcli/explain_output_formatter.py b/pgcli/explain_output_formatter.py
new file mode 100644
index 00000000..ce45b4f8
--- /dev/null
+++ b/pgcli/explain_output_formatter.py
@@ -0,0 +1,19 @@
+from pgcli.pyev import Visualizer
+import json
+
+
+"""Explain response output adapter"""
+
+
+class ExplainOutputFormatter:
+ def __init__(self, max_width):
+ self.max_width = max_width
+
+ def format_output(self, cur, headers, **output_kwargs):
+ # explain query results should always contain 1 row each
+ [(data,)] = list(cur)
+ explain_list = json.loads(data)
+ visualizer = Visualizer(self.max_width)
+ for explain in explain_list:
+ visualizer.load(explain)
+ yield visualizer.get_list()
diff --git a/pgcli/key_bindings.py b/pgcli/key_bindings.py
index 23174b6b..9c016f7f 100644
--- a/pgcli/key_bindings.py
+++ b/pgcli/key_bindings.py
@@ -9,7 +9,7 @@ from prompt_toolkit.filters import (
vi_mode,
)
-from .pgbuffer import buffer_should_be_handled
+from .pgbuffer import buffer_should_be_handled, safe_multi_line_mode
_logger = logging.getLogger(__name__)
@@ -39,6 +39,12 @@ def pgcli_bindings(pgcli):
pgcli.vi_mode = not pgcli.vi_mode
event.app.editing_mode = EditingMode.VI if pgcli.vi_mode else EditingMode.EMACS
+ @kb.add("f5")
+ def _(event):
+ """Toggle between Vi and Emacs mode."""
+ _logger.debug("Detected F5 key.")
+ pgcli.explain_mode = not pgcli.explain_mode
+
@kb.add("tab")
def _(event):
"""Force autocompletion at cursor on non-empty lines."""
@@ -108,7 +114,7 @@ def pgcli_bindings(pgcli):
_logger.debug("Detected enter key.")
event.current_buffer.validate_and_handle()
- @kb.add("escape", "enter", filter=~vi_mode)
+ @kb.add("escape", "enter", filter=~vi_mode & ~safe_multi_line_mode(pgcli))
def _(event):
"""Introduces a line break regardless of multi-line mode or not."""
_logger.debug("Detected alt-enter key.")
diff --git a/pgcli/magic.py b/pgcli/magic.py
index 6e58f28b..09902a29 100644
--- a/pgcli/magic.py
+++ b/pgcli/magic.py
@@ -43,7 +43,7 @@ def pgcli_line_magic(line):
u = conn.session.engine.url
_logger.debug("New pgcli: %r", str(u))
- pgcli.connect(u.database, u.host, u.username, u.port, u.password)
+ pgcli.connect_uri(str(u._replace(drivername="postgres")))
conn._pgcli = pgcli
# For convenience, print the connection alias
diff --git a/pgcli/main.py b/pgcli/main.py
index 2202c1a7..f95c8000 100644
--- a/pgcli/main.py
+++ b/pgcli/main.py
@@ -1,13 +1,8 @@
-import platform
-import warnings
-from os.path import expanduser
-
from configobj import ConfigObj, ParseError
from pgspecial.namedqueries import NamedQueries
from .config import skip_initial_comment
-warnings.filterwarnings("ignore", category=UserWarning, module="psycopg2")
-
+import atexit
import os
import re
import sys
@@ -21,12 +16,12 @@ import datetime as dt
import itertools
import platform
from time import time, sleep
-
-keyring = None # keyring will be loaded later
+from typing import Optional
from cli_helpers.tabular_output import TabularOutputFormatter
from cli_helpers.tabular_output.preprocessors import align_decimals, format_numbers
from cli_helpers.utils import strip_ansi
+from .explain_output_formatter import ExplainOutputFormatter
import click
try:
@@ -52,6 +47,7 @@ from pygments.lexers.sql import PostgresLexer
from pgspecial.main import PGSpecial, NO_QUERY, PAGER_OFF, PAGER_LONG_OUTPUT
import pgspecial as special
+from . import auth
from .pgcompleter import PGCompleter
from .pgtoolbar import create_toolbar_tokens_func
from .pgstyle import style_factory, style_factory_output
@@ -66,26 +62,34 @@ from .config import (
get_config_filename,
)
from .key_bindings import pgcli_bindings
-from .packages.prompt_utils import confirm_destructive_query
+from .packages.formatter.sqlformatter import register_new_formatter
+from .packages.prompt_utils import confirm, confirm_destructive_query
+from .packages.parseutils import is_destructive
+from .packages.parseutils import parse_destructive_warning
from .__init__ import __version__
click.disable_unicode_literals_warning = True
-try:
- from urlparse import urlparse, unquote, parse_qs
-except ImportError:
- from urllib.parse import urlparse, unquote, parse_qs
+from urllib.parse import urlparse
from getpass import getuser
-from psycopg2 import OperationalError, InterfaceError
-import psycopg2
+
+from psycopg import OperationalError, InterfaceError
+from psycopg.conninfo import make_conninfo, conninfo_to_dict
from collections import namedtuple
-from textwrap import dedent
+try:
+ import sshtunnel
+
+ SSH_TUNNEL_SUPPORT = True
+except ImportError:
+ SSH_TUNNEL_SUPPORT = False
+
# Ref: https://stackoverflow.com/questions/30425105/filter-special-chars-such-as-color-codes-from-shell-output
COLOR_CODE_REGEX = re.compile(r"\x1b(\[.*?[@-~]|\].*?(\x07|\x1b\\))")
+DEFAULT_MAX_FIELD_WIDTH = 500
# Query tuples are used for maintaining history
MetaQuery = namedtuple(
@@ -106,7 +110,7 @@ MetaQuery.__new__.__defaults__ = ("", False, 0, 0, False, False, False, False)
OutputSettings = namedtuple(
"OutputSettings",
- "table_format dcmlfmt floatfmt missingval expanded max_width case_function style_output",
+ "table_format dcmlfmt floatfmt missingval expanded max_width case_function style_output max_field_width",
)
OutputSettings.__new__.__defaults__ = (
None,
@@ -117,6 +121,7 @@ OutputSettings.__new__.__defaults__ = (
None,
lambda x: x,
None,
+ DEFAULT_MAX_FIELD_WIDTH,
)
@@ -166,8 +171,8 @@ class PGCli:
prompt_dsn=None,
auto_vertical_output=False,
warn=None,
+ ssh_tunnel_url: Optional[str] = None,
):
-
self.force_passwd_prompt = force_passwd_prompt
self.never_passwd_prompt = never_passwd_prompt
self.pgexecute = pgexecute
@@ -190,10 +195,14 @@ class PGCli:
self.output_file = None
self.pgspecial = PGSpecial()
+ self.explain_mode = False
self.multi_line = c["main"].as_bool("multi_line")
self.multiline_mode = c["main"].get("multi_line_mode", "psql")
self.vi_mode = c["main"].as_bool("vi")
self.auto_expand = auto_vertical_output or c["main"].as_bool("auto_expand")
+ self.auto_retry_closed_connection = c["main"].as_bool(
+ "auto_retry_closed_connection"
+ )
self.expanded_output = c["main"].as_bool("expand")
self.pgspecial.timing_enabled = c["main"].as_bool("timing")
if row_limit is not None:
@@ -201,17 +210,32 @@ class PGCli:
else:
self.row_limit = c["main"].as_int("row_limit")
+ # if not specified, set to DEFAULT_MAX_FIELD_WIDTH
+ # if specified but empty, set to None to disable truncation
+ # ellipsis will take at least 3 symbols, so this can't be less than 3 if specified and > 0
+ max_field_width = c["main"].get("max_field_width", DEFAULT_MAX_FIELD_WIDTH)
+ if max_field_width and max_field_width.lower() != "none":
+ max_field_width = max(3, abs(int(max_field_width)))
+ else:
+ max_field_width = None
+ self.max_field_width = max_field_width
+
self.min_num_menu_lines = c["main"].as_int("min_num_menu_lines")
self.multiline_continuation_char = c["main"]["multiline_continuation_char"]
self.table_format = c["main"]["table_format"]
self.syntax_style = c["main"]["syntax_style"]
self.cli_style = c["colors"]
self.wider_completion_menu = c["main"].as_bool("wider_completion_menu")
- self.destructive_warning = warn or c["main"]["destructive_warning"]
- # also handle boolean format of destructive warning
- self.destructive_warning = {"true": "all", "false": "off"}.get(
- self.destructive_warning.lower(), self.destructive_warning
+ self.destructive_warning = parse_destructive_warning(
+ warn or c["main"].as_list("destructive_warning")
+ )
+ self.destructive_warning_restarts_connection = c["main"].as_bool(
+ "destructive_warning_restarts_connection"
)
+ self.destructive_statements_require_transaction = c["main"].as_bool(
+ "destructive_statements_require_transaction"
+ )
+
self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty")
self.null_string = c["main"].get("null_string", "<null>")
self.prompt_format = (
@@ -223,7 +247,7 @@ class PGCli:
self.on_error = c["main"]["on_error"].upper()
self.decimal_format = c["data_formats"]["decimal"]
self.float_format = c["data_formats"]["float"]
- self.initialize_keyring()
+ auth.keyring_initialize(c["main"].as_bool("keyring"), logger=self.logger)
self.show_bottom_toolbar = c["main"].as_bool("show_bottom_toolbar")
self.pgspecial.pset_pager(
@@ -241,6 +265,9 @@ class PGCli:
# Initialize completer
smart_completion = c["main"].as_bool("smart_completion")
keyword_casing = c["main"]["keyword_casing"]
+ single_connection = single_connection or c["main"].as_bool(
+ "always_use_single_connection"
+ )
self.settings = {
"casing_file": get_casing_file(c),
"generate_casing_file": c["main"].as_bool("generate_casing_file"),
@@ -252,6 +279,7 @@ class PGCli:
"single_connection": single_connection,
"less_chatty": less_chatty,
"keyword_casing": keyword_casing,
+ "alias_map_file": c["main"]["alias_map_file"] or None,
}
completer = PGCompleter(
@@ -263,11 +291,18 @@ class PGCli:
self.prompt_app = None
+ self.ssh_tunnel_config = c.get("ssh tunnels")
+ self.ssh_tunnel_url = ssh_tunnel_url
+ self.ssh_tunnel = None
+
+ # formatter setup
+ self.formatter = TabularOutputFormatter(format_name=c["main"]["table_format"])
+ register_new_formatter(self.formatter)
+
def quit(self):
raise PgCliQuitError
def register_special_commands(self):
-
self.pgspecial.register(
self.change_db,
"\\c",
@@ -329,6 +364,23 @@ class PGCli:
"Change the table format used to output results",
)
+ self.pgspecial.register(
+ self.echo,
+ "\\echo",
+ "\\echo [string]",
+ "Echo a string to stdout",
+ )
+
+ self.pgspecial.register(
+ self.echo,
+ "\\qecho",
+ "\\qecho [string]",
+ "Echo a string to the query output channel.",
+ )
+
+ def echo(self, pattern, **_):
+ return [(None, None, None, pattern)]
+
def change_table_format(self, pattern, **_):
try:
if pattern not in TabularOutputFormatter().supported_formats:
@@ -398,16 +450,27 @@ class PGCli:
except OSError as e:
return [(None, None, None, str(e), "", False, True)]
- if (
- self.destructive_warning != "off"
- and confirm_destructive_query(query, self.destructive_warning) is False
- ):
- message = "Wise choice. Command execution stopped."
- return [(None, None, None, message)]
+ if self.destructive_warning:
+ if (
+ self.destructive_statements_require_transaction
+ and not self.pgexecute.valid_transaction()
+ and is_destructive(query, self.destructive_warning)
+ ):
+ message = "Destructive statements must be run within a transaction. Command execution stopped."
+ return [(None, None, None, message)]
+ destroy = confirm_destructive_query(
+ query, self.destructive_warning, self.dsn_alias
+ )
+ if destroy is False:
+ message = "Wise choice. Command execution stopped."
+ return [(None, None, None, message)]
on_error_resume = self.on_error == "RESUME"
return self.pgexecute.run(
- query, self.pgspecial, on_error_resume=on_error_resume
+ query,
+ self.pgspecial,
+ on_error_resume=on_error_resume,
+ explain_mode=self.explain_mode,
)
def write_to_file(self, pattern, **_):
@@ -428,7 +491,6 @@ class PGCli:
return [(None, None, None, message, "", True, True)]
def initialize_logging(self):
-
log_file = self.config["main"]["log_file"]
if log_file == "default":
log_file = config_location() + "log"
@@ -471,19 +533,6 @@ class PGCli:
pgspecial_logger.addHandler(handler)
pgspecial_logger.setLevel(log_level)
- def initialize_keyring(self):
- global keyring
-
- keyring_enabled = self.config["main"].as_bool("keyring")
- if keyring_enabled:
- # Try best to load keyring (issue #1041).
- import importlib
-
- try:
- keyring = importlib.import_module("keyring")
- except Exception as e: # ImportError for Python 2, ModuleNotFoundError for Python 3
- self.logger.warning("import keyring failed: %r.", e)
-
def connect_dsn(self, dsn, **kwargs):
self.connect(dsn=dsn, **kwargs)
@@ -503,7 +552,7 @@ class PGCli:
)
def connect_uri(self, uri):
- kwargs = psycopg2.extensions.parse_dsn(uri)
+ kwargs = conninfo_to_dict(uri)
remap = {"dbname": "database", "password": "passwd"}
kwargs = {remap.get(k, k): v for k, v in kwargs.items()}
self.connect(**kwargs)
@@ -526,30 +575,6 @@ class PGCli:
if not self.force_passwd_prompt and not passwd:
passwd = os.environ.get("PGPASSWORD", "")
- # Find password from store
- key = f"{user}@{host}"
- keyring_error_message = dedent(
- """\
- {}
- {}
- To remove this message do one of the following:
- - prepare keyring as described at: https://keyring.readthedocs.io/en/stable/
- - uninstall keyring: pip uninstall keyring
- - disable keyring in our configuration: add keyring = False to [main]"""
- )
- if not passwd and keyring:
-
- try:
- passwd = keyring.get_password("pgcli", key)
- except (RuntimeError, keyring.errors.InitError) as e:
- click.secho(
- keyring_error_message.format(
- "Load your password from keyring returned:", str(e)
- ),
- err=True,
- fg="red",
- )
-
# Prompt for a password immediately if requested via the -W flag. This
# avoids wasting time trying to connect to the database and catching a
# no-password exception.
@@ -560,6 +585,11 @@ class PGCli:
"Password for %s" % user, hide_input=True, show_default=False, type=str
)
+ key = f"{user}@{host}"
+
+ if not passwd and auth.keyring:
+ passwd = auth.keyring_get_password(key)
+
def should_ask_for_password(exc):
# Prompt for a password after 1st attempt to connect
# fails. Don't prompt if the -w flag is supplied
@@ -572,6 +602,56 @@ class PGCli:
return True
return False
+ if dsn:
+ parsed_dsn = conninfo_to_dict(dsn)
+ if "host" in parsed_dsn:
+ host = parsed_dsn["host"]
+ if "port" in parsed_dsn:
+ port = parsed_dsn["port"]
+
+ if self.ssh_tunnel_config and not self.ssh_tunnel_url:
+ for db_host_regex, tunnel_url in self.ssh_tunnel_config.items():
+ if re.search(db_host_regex, host):
+ self.ssh_tunnel_url = tunnel_url
+ break
+
+ if self.ssh_tunnel_url:
+ # We add the protocol as urlparse doesn't find it by itself
+ if "://" not in self.ssh_tunnel_url:
+ self.ssh_tunnel_url = f"ssh://{self.ssh_tunnel_url}"
+
+ tunnel_info = urlparse(self.ssh_tunnel_url)
+ params = {
+ "local_bind_address": ("127.0.0.1",),
+ "remote_bind_address": (host, int(port or 5432)),
+ "ssh_address_or_host": (tunnel_info.hostname, tunnel_info.port or 22),
+ "logger": self.logger,
+ }
+ if tunnel_info.username:
+ params["ssh_username"] = tunnel_info.username
+ if tunnel_info.password:
+ params["ssh_password"] = tunnel_info.password
+
+ # Hack: sshtunnel adds a console handler to the logger, so we revert handlers.
+ # We can remove this when https://github.com/pahaz/sshtunnel/pull/250 is merged.
+ logger_handlers = self.logger.handlers.copy()
+ try:
+ self.ssh_tunnel = sshtunnel.SSHTunnelForwarder(**params)
+ self.ssh_tunnel.start()
+ except Exception as e:
+ self.logger.handlers = logger_handlers
+ self.logger.error("traceback: %r", traceback.format_exc())
+ click.secho(str(e), err=True, fg="red")
+ exit(1)
+ self.logger.handlers = logger_handlers
+
+ atexit.register(self.ssh_tunnel.stop)
+ host = "127.0.0.1"
+ port = self.ssh_tunnel.local_bind_ports[0]
+
+ if dsn:
+ dsn = make_conninfo(dsn, host=host, port=port)
+
# Attempt to connect to the database.
# Note that passwd may be empty on the first attempt. If connection
# fails because of a missing or incorrect password, but we're allowed to
@@ -592,17 +672,8 @@ class PGCli:
)
else:
raise e
- if passwd and keyring:
- try:
- keyring.set_password("pgcli", key, passwd)
- except (RuntimeError, keyring.errors.KeyringError) as e:
- click.secho(
- keyring_error_message.format(
- "Set password in keyring returned:", str(e)
- ),
- err=True,
- fg="red",
- )
+ if passwd and auth.keyring:
+ auth.keyring_set_password(key, passwd)
except Exception as e: # Connecting to a database could fail.
self.logger.debug("Database connection failed: %r.", e)
@@ -650,34 +721,52 @@ class PGCli:
editor_command = special.editor_command(text)
return text
- def execute_command(self, text):
+ def execute_command(self, text, handle_closed_connection=True):
logger = self.logger
query = MetaQuery(query=text, successful=False)
try:
- if self.destructive_warning != "off":
- destroy = confirm = confirm_destructive_query(
- text, self.destructive_warning
+ if self.destructive_warning:
+ if (
+ self.destructive_statements_require_transaction
+ and not self.pgexecute.valid_transaction()
+ and is_destructive(text, self.destructive_warning)
+ ):
+ click.secho(
+ "Destructive statements must be run within a transaction."
+ )
+ raise KeyboardInterrupt
+ destroy = confirm_destructive_query(
+ text, self.destructive_warning, self.dsn_alias
)
if destroy is False:
click.secho("Wise choice!")
raise KeyboardInterrupt
elif destroy:
click.secho("Your call!")
+
output, query = self._evaluate_command(text)
except KeyboardInterrupt:
- # Restart connection to the database
- self.pgexecute.connect()
- logger.debug("cancelled query, sql: %r", text)
- click.secho("cancelled query", err=True, fg="red")
+ if self.destructive_warning_restarts_connection:
+ # Restart connection to the database
+ self.pgexecute.connect()
+ logger.debug("cancelled query and restarted connection, sql: %r", text)
+ click.secho(
+ "cancelled query and restarted connection", err=True, fg="red"
+ )
+ else:
+ logger.debug("cancelled query, sql: %r", text)
+ click.secho("cancelled query", err=True, fg="red")
except NotImplementedError:
click.secho("Not Yet Implemented.", fg="yellow")
except OperationalError as e:
logger.error("sql: %r, error: %r", text, e)
logger.error("traceback: %r", traceback.format_exc())
- self._handle_server_closed_connection(text)
- except (PgCliQuitError, EOFError) as e:
+ click.secho(str(e), err=True, fg="red")
+ if handle_closed_connection:
+ self._handle_server_closed_connection(text)
+ except (PgCliQuitError, EOFError):
raise
except Exception as e:
logger.error("sql: %r, error: %r", text, e)
@@ -685,7 +774,9 @@ class PGCli:
click.secho(str(e), err=True, fg="red")
else:
try:
- if self.output_file and not text.startswith(("\\o ", "\\? ")):
+ if self.output_file and not text.startswith(
+ ("\\o ", "\\? ", "\\echo ")
+ ):
try:
with open(self.output_file, "a", encoding="utf-8") as f:
click.echo(text, file=f)
@@ -729,6 +820,34 @@ class PGCli:
logger.debug("Search path: %r", self.completer.search_path)
return query
+ def _check_ongoing_transaction_and_allow_quitting(self):
+ """Return whether we can really quit, possibly by asking the
+ user to confirm so if there is an ongoing transaction.
+ """
+ if not self.pgexecute.valid_transaction():
+ return True
+ while 1:
+ try:
+ choice = click.prompt(
+ "A transaction is ongoing. Choose `c` to COMMIT, `r` to ROLLBACK, `a` to abort exit.",
+ default="a",
+ )
+ except click.Abort:
+ # Print newline if user aborts with `^C`, otherwise
+ # pgcli's prompt will be printed on the same line
+ # (just after the confirmation prompt).
+ click.echo(None, err=False)
+ choice = "a"
+ choice = choice.lower()
+ if choice == "a":
+ return False # do not quit
+ if choice == "c":
+ query = self.execute_command("commit")
+ return query.successful # quit only if query is successful
+ if choice == "r":
+ query = self.execute_command("rollback")
+ return query.successful # quit only if query is successful
+
def run_cli(self):
logger = self.logger
@@ -751,6 +870,10 @@ class PGCli:
text = self.prompt_app.prompt()
except KeyboardInterrupt:
continue
+ except EOFError:
+ if not self._check_ongoing_transaction_and_allow_quitting():
+ continue
+ raise
try:
text = self.handle_editor_command(text)
@@ -760,18 +883,12 @@ class PGCli:
click.secho(str(e), err=True, fg="red")
continue
- # Initialize default metaquery in case execution fails
- self.watch_command, timing = special.get_watch_command(text)
- if self.watch_command:
- while self.watch_command:
- try:
- query = self.execute_command(self.watch_command)
- click.echo(f"Waiting for {timing} seconds before repeating")
- sleep(timing)
- except KeyboardInterrupt:
- self.watch_command = None
- else:
- query = self.execute_command(text)
+ try:
+ self.handle_watch_command(text)
+ except PgCliQuitError:
+ if not self._check_ongoing_transaction_and_allow_quitting():
+ continue
+ raise
self.now = dt.datetime.today()
@@ -779,12 +896,40 @@ class PGCli:
with self._completer_lock:
self.completer.extend_query_history(text)
- self.query_history.append(query)
-
except (PgCliQuitError, EOFError):
if not self.less_chatty:
print("Goodbye!")
+ def handle_watch_command(self, text):
+ # Initialize default metaquery in case execution fails
+ self.watch_command, timing = special.get_watch_command(text)
+
+ # If we run \watch without a command, apply it to the last query run.
+ if self.watch_command is not None and not self.watch_command.strip():
+ try:
+ self.watch_command = self.query_history[-1].query
+ except IndexError:
+ click.secho(
+ "\\watch cannot be used with an empty query", err=True, fg="red"
+ )
+ self.watch_command = None
+
+ # If there's a command to \watch, run it in a loop.
+ if self.watch_command:
+ while self.watch_command:
+ try:
+ query = self.execute_command(self.watch_command)
+ click.echo(f"Waiting for {timing} seconds before repeating")
+ sleep(timing)
+ except KeyboardInterrupt:
+ self.watch_command = None
+
+ # Otherwise, execute it as a regular command.
+ else:
+ query = self.execute_command(text)
+
+ self.query_history.append(query)
+
def _build_cli(self, history):
key_bindings = pgcli_bindings(self)
@@ -857,6 +1002,8 @@ class PGCli:
def _should_limit_output(self, sql, cur):
"""returns True if the output should be truncated, False otherwise."""
+ if self.explain_mode:
+ return False
if not is_select(sql):
return False
@@ -889,6 +1036,8 @@ class PGCli:
logger = self.logger
logger.debug("sql: %r", text)
+ # set query to formatter in order to parse table name
+ self.formatter.query = text
all_success = True
meta_changed = False # CREATE, ALTER, DROP, etc
mutated = False # INSERT, DELETE, etc
@@ -902,7 +1051,11 @@ class PGCli:
start = time()
on_error_resume = self.on_error == "RESUME"
res = self.pgexecute.run(
- text, self.pgspecial, exception_formatter, on_error_resume
+ text,
+ self.pgspecial,
+ exception_formatter,
+ on_error_resume,
+ explain_mode=self.explain_mode,
)
is_special = None
@@ -934,9 +1087,12 @@ class PGCli:
else lambda x: x
),
style_output=self.style_output,
+ max_field_width=self.max_field_width,
)
execution = time() - start
- formatted = format_output(title, cur, headers, status, settings)
+ formatted = format_output(
+ title, cur, headers, status, settings, self.explain_mode
+ )
output.extend(formatted)
total = time() - start
@@ -971,10 +1127,17 @@ class PGCli:
click.secho("Reconnecting...", fg="green")
self.pgexecute.connect()
click.secho("Reconnected!", fg="green")
- self.execute_com