summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAmjith Ramanujam <amjith.r@gmail.com>2015-10-28 12:42:52 -0700
committerAmjith Ramanujam <amjith.r@gmail.com>2015-10-28 12:42:52 -0700
commit00df5b44c9f0291cfadcc0e8e101e8208578350a (patch)
tree79517470a28a08db9344811b1735a9a8a4bd93a1
parent7e59c06568b38c092804e9054315b3e76e792365 (diff)
parent7adaa498b294059ea9e2e7aa7070feb9876bf161 (diff)
Merge pull request #395 from dbcli/darikg/refactor-on-error-options
Tweak on error options
-rw-r--r--pgcli/magic.py39
-rwxr-xr-xpgcli/main.py328
-rw-r--r--pgcli/pgexecute.py38
-rw-r--r--tests/conftest.py9
-rw-r--r--tests/test_main.py9
-rw-r--r--tests/test_pgexecute.py30
-rw-r--r--tests/utils.py19
7 files changed, 279 insertions, 193 deletions
diff --git a/pgcli/magic.py b/pgcli/magic.py
index 898e0092..4a52446e 100644
--- a/pgcli/magic.py
+++ b/pgcli/magic.py
@@ -5,30 +5,31 @@ import logging
_logger = logging.getLogger(__name__)
-def load_ipython_extension(ipython):
- #This is called via the ipython command '%load_ext pgcli.magic'
+def load_ipython_extension(ipython):
+ """This is called via the ipython command '%load_ext pgcli.magic'"""
- #first, load the sql magic if it isn't already loaded
+ # first, load the sql magic if it isn't already loaded
if not ipython.find_line_magic('sql'):
ipython.run_line_magic('load_ext', 'sql')
- #register our own magic
+ # register our own magic
ipython.register_magic_function(pgcli_line_magic, 'line', 'pgcli')
+
def pgcli_line_magic(line):
_logger.debug('pgcli magic called: %r', line)
parsed = sql.parse.parse(line, {})
conn = sql.connection.Connection.get(parsed['connection'])
try:
- #A corresponding pgcli object already exists
+ # A corresponding pgcli object already exists
pgcli = conn._pgcli
_logger.debug('Reusing existing pgcli')
except AttributeError:
- #I can't figure out how to get the underylying psycopg2 connection
- #from the sqlalchemy connection, so just grab the url and make a
- #new connection
+ # I can't figure out how to get the underylying psycopg2 connection
+ # from the sqlalchemy connection, so just grab the url and make a
+ # new connection
pgcli = PGCli()
u = conn.session.engine.url
_logger.debug('New pgcli: %r', str(u))
@@ -36,7 +37,7 @@ def pgcli_line_magic(line):
pgcli.connect(u.database, u.host, u.username, u.port, u.password)
conn._pgcli = pgcli
- #For convenience, print the connection alias
+ # For convenience, print the connection alias
print('Connected: {}'.format(conn.name))
try:
@@ -48,11 +49,21 @@ def pgcli_line_magic(line):
return
q = pgcli.query_history[-1]
- if q.mutating:
- _logger.debug('Mutating query detected -- ignoring')
+
+ if not q.successful:
+ _logger.debug('Unsuccessful query - ignoring')
+ return
+
+ if q.meta_changed or q.db_changed or q.path_changed:
+ _logger.debug('Dangerous query detected -- ignoring')
return
- if q.successful:
- ipython = get_ipython()
- return ipython.run_cell_magic('sql', line, q.query)
+
+ ipython = get_ipython()
+ return ipython.run_cell_magic('sql', line, q.query)
+
+
+
+
+
diff --git a/pgcli/main.py b/pgcli/main.py
index d1f763e3..4cdd3169 100755
--- a/pgcli/main.py
+++ b/pgcli/main.py
@@ -36,7 +36,7 @@ import pgspecial as special
from .pgcompleter import PGCompleter
from .pgtoolbar import create_toolbar_tokens_func
from .pgstyle import style_factory
-from .pgexecute import PGExecute, ON_ERROR_RESUME, ON_ERROR_STOP
+from .pgexecute import PGExecute
from .pgbuffer import PGBuffer
from .completion_refresher import CompletionRefresher
from .config import write_default_config, load_config, config_location
@@ -57,8 +57,18 @@ from psycopg2 import OperationalError
from collections import namedtuple
# Query tuples are used for maintaining history
-Query = namedtuple('Query', ['query', 'successful', 'mutating'])
-
+MetaQuery = namedtuple(
+ 'Query',
+ [
+ 'query', # The entire text of the command
+ 'successful', # True If all subqueries were successful
+ 'total_time', # Time elapsed executing the query
+ 'meta_changed', # True if any subquery executed create/alter/drop
+ 'db_changed', # True if any subquery changed the database
+ 'path_changed', # True if any subquery changed the search path
+ 'mutated', # True if any subquery executed insert/update/delete
+ ])
+MetaQuery.__new__.__defaults__ = ('', False, 0, False, False, False, False)
class PGCli(object):
@@ -88,8 +98,7 @@ class PGCli(object):
self.cli_style = c['colors']
self.wider_completion_menu = c['main'].as_bool('wider_completion_menu')
- on_error_modes = {'STOP': ON_ERROR_STOP, 'RESUME': ON_ERROR_RESUME}
- self.on_error = on_error_modes[c['main']['on_error'].upper()]
+ self.on_error = c['main']['on_error'].upper()
self.completion_refresher = CompletionRefresher()
@@ -257,56 +266,18 @@ class PGCli(object):
return document
def run_cli(self):
- pgexecute = self.pgexecute
logger = self.logger
original_less_opts = self.adjust_less_opts()
self.refresh_completions()
- def set_vi_mode(value):
- self.vi_mode = value
-
- key_binding_manager = pgcli_bindings(
- get_vi_mode_enabled=lambda: self.vi_mode,
- set_vi_mode_enabled=set_vi_mode)
+ self.cli = self._build_cli()
print('Version:', __version__)
print('Chat: https://gitter.im/dbcli/pgcli')
print('Mail: https://groups.google.com/forum/#!forum/pgcli')
print('Home: http://pgcli.com')
- def prompt_tokens(cli):
- return [(Token.Prompt, '%s> ' % pgexecute.dbname)]
-
- get_toolbar_tokens = create_toolbar_tokens_func(lambda: self.vi_mode,
- self.completion_refresher.is_refreshing)
-
- layout = create_default_layout(lexer=PostgresLexer,
- reserve_space_for_menu=True,
- get_prompt_tokens=prompt_tokens,
- get_bottom_toolbar_tokens=get_toolbar_tokens,
- display_completions_in_columns=self.wider_completion_menu,
- multiline=True,
- extra_input_processors=[
- # Highlight matching brackets while editing.
- ConditionalProcessor(
- processor=HighlightMatchingBracketProcessor(chars='[](){}'),
- filter=HasFocus(DEFAULT_BUFFER) & ~IsDone()),
- ])
- history_file = self.config['main']['history_file']
- with self._completer_lock:
- buf = PGBuffer(always_multiline=self.multi_line, completer=self.completer,
- history=FileHistory(os.path.expanduser(history_file)),
- complete_while_typing=Always())
-
- application = Application(style=style_factory(self.syntax_style, self.cli_style),
- layout=layout, buffer=buf,
- key_bindings_registry=key_binding_manager.registry,
- on_exit=AbortAction.RAISE_EXCEPTION,
- ignore_case=True)
- self.cli = CommandLineInterface(application=application,
- eventloop=create_eventloop())
-
try:
while True:
document = self.cli.run()
@@ -326,69 +297,22 @@ class PGCli(object):
click.secho(str(e), err=True, fg='red')
continue
- # Keep track of whether or not the query is mutating. In case
- # of a multi-statement query, the overall query is considered
- # mutating if any one of the component statements is mutating
- mutating = False
+ # Initialize default metaquery in case execution fails
+ query = MetaQuery(query=document.text, successful=False)
try:
- logger.debug('sql: %r', document.text)
- successful = False
- # Initialized to [] because res might never get initialized
- # if an exception occurs in pgexecute.run(). Which causes
- # finally clause to fail.
- res = []
- # Run the query.
- start = time()
- res = pgexecute.run(document.text, self.pgspecial,
- on_error=self.on_error)
- output = []
- total = 0
- for title, cur, headers, status in res:
- logger.debug("headers: %r", headers)
- logger.debug("rows: %r", cur)
- logger.debug("status: %r", status)
- threshold = 1000
- if (is_select(status) and
- cur and cur.rowcount > threshold):
- click.secho('The result set has more than %s rows.'
- % threshold, fg='red')
- if not click.confirm('Do you want to continue?'):
- click.secho("Aborted!", err=True, fg='red')
- break
-
- if self.pgspecial.auto_expand:
- max_width = self.cli.output.get_size().columns
- else:
- max_width = None
-
- formatted = format_output(title, cur, headers, status,
- self.table_format,
- self.pgspecial.expanded_output,
- max_width)
- output.extend(formatted)
- end = time()
- total += end - start
- mutating = mutating or is_mutating(status)
-
+ output, query = self._evaluate_command(document.text)
except KeyboardInterrupt:
# Restart connection to the database
- pgexecute.connect()
+ self.pgexecute.connect()
logger.debug("cancelled query, sql: %r", document.text)
click.secho("cancelled query", err=True, fg='red')
except NotImplementedError:
click.secho('Not Yet Implemented.', fg="yellow")
except OperationalError as e:
- reconnect = True
- if ('server closed the connection' in utf8tounicode(e.args[0])):
- reconnect = click.prompt('Connection reset. Reconnect (Y/n)',
- show_default=False, type=bool, default=True)
- if reconnect:
- try:
- pgexecute.connect()
- click.secho('Reconnected!\nTry the command again.', fg='green')
- except OperationalError as e:
- click.secho(str(e), err=True, fg='red')
+ if ('server closed the connection'
+ in utf8tounicode(e.args[0])):
+ self._handle_server_closed_connection()
else:
logger.error("sql: %r, error: %r", document.text, e)
logger.error("traceback: %r", traceback.format_exc())
@@ -398,26 +322,28 @@ class PGCli(object):
logger.error("traceback: %r", traceback.format_exc())
click.secho(str(e), err=True, fg='red')
else:
- successful = True
try:
click.echo_via_pager('\n'.join(output))
except KeyboardInterrupt:
pass
- if self.pgspecial.timing_enabled:
- print('Time: %0.03fs' % total)
-
- # Refresh the table names and column names if necessary.
- if need_completion_refresh(document.text):
- self.refresh_completions(need_completion_reset(document.text))
- # Refresh search_path to set default schema.
- if need_search_path_refresh(document.text):
+ if self.pgspecial.timing_enabled:
+ print('Time: %0.03fs' % query.total_time)
+
+ # Check if we need to update completions, in order of most
+ # to least drastic changes
+ if query.db_changed:
+ self.refresh_completions(reset=True)
+ elif query.meta_changed:
+ self.refresh_completions(reset=False)
+ elif query.path_changed:
logger.debug('Refreshing search path')
with self._completer_lock:
- self.completer.set_search_path(pgexecute.search_path())
- logger.debug('Search path: %r', self.completer.search_path)
+ self.completer.set_search_path(
+ self.pgexecute.search_path())
+ logger.debug('Search path: %r',
+ self.completer.search_path)
- query = Query(document.text, successful, mutating)
self.query_history.append(query)
except EOFError:
@@ -426,6 +352,133 @@ class PGCli(object):
logger.debug('Restoring env var LESS to %r.', original_less_opts)
os.environ['LESS'] = original_less_opts
+ def _build_cli(self):
+
+ def set_vi_mode(value):
+ self.vi_mode = value
+
+ key_binding_manager = pgcli_bindings(
+ get_vi_mode_enabled=lambda: self.vi_mode,
+ set_vi_mode_enabled=set_vi_mode)
+
+ def prompt_tokens(_):
+ return [(Token.Prompt, '%s> ' % self.pgexecute.dbname)]
+
+ get_toolbar_tokens = create_toolbar_tokens_func(
+ lambda: self.vi_mode, self.completion_refresher.is_refreshing)
+
+ layout = create_default_layout(
+ lexer=PostgresLexer,
+ reserve_space_for_menu=True,
+ get_prompt_tokens=prompt_tokens,
+ get_bottom_toolbar_tokens=get_toolbar_tokens,
+ display_completions_in_columns=self.wider_completion_menu,
+ multiline=True,
+ extra_input_processors=[
+ # Highlight matching brackets while editing.
+ ConditionalProcessor(
+ processor=HighlightMatchingBracketProcessor(chars='[](){}'),
+ filter=HasFocus(DEFAULT_BUFFER) & ~IsDone()),
+ ])
+
+ history_file = self.config['main']['history_file']
+ with self._completer_lock:
+ buf = PGBuffer(
+ always_multiline=self.multi_line,
+ completer=self.completer,
+ history=FileHistory(os.path.expanduser(history_file)),
+ complete_while_typing=Always())
+
+ application = Application(
+ style=style_factory(self.syntax_style, self.cli_style),
+ layout=layout,
+ buffer=buf,
+ key_bindings_registry=key_binding_manager.registry,
+ on_exit=AbortAction.RAISE_EXCEPTION,
+ ignore_case=True)
+
+ cli = CommandLineInterface(
+ application=application,
+ eventloop=create_eventloop())
+
+ return cli
+
+ def _evaluate_command(self, text):
+ """Used to run a command entered by the user during CLI operation
+ (Puts the E in REPL)
+
+ returns (results, MetaQuery)
+ """
+ logger = self.logger
+ logger.debug('sql: %r', text)
+
+ all_success = True
+ meta_changed = False # CREATE, ALTER, DROP, etc
+ mutated = False # INSERT, DELETE, etc
+ db_changed = False
+ path_changed = False
+ output = []
+ total = 0
+
+ # Run the query.
+ start = time()
+ on_error_resume = self.on_error == 'RESUME'
+ res = self.pgexecute.run(text, self.pgspecial,
+ exception_formatter, on_error_resume)
+
+ for title, cur, headers, status, sql, success in res:
+ logger.debug("headers: %r", headers)
+ logger.debug("rows: %r", cur)
+ logger.debug("status: %r", status)
+ threshold = 1000
+ if (is_select(status) and
+ cur and cur.rowcount > threshold):
+ click.secho('The result set has more than %s rows.'
+ % threshold, fg='red')
+ if not click.confirm('Do you want to continue?'):
+ click.secho("Aborted!", err=True, fg='red')
+ break
+
+ if self.pgspecial.auto_expand:
+ max_width = self.cli.output.get_size().columns
+ else:
+ max_width = None
+
+ formatted = format_output(
+ title, cur, headers, status, self.table_format,
+ self.pgspecial.expanded_output, max_width)
+
+ output.extend(formatted)
+ end = time()
+ total += end - start
+
+ # Keep track of whether any of the queries are mutating or changing
+ # the database
+ if success:
+ mutated = mutated or is_mutating(status)
+ db_changed = db_changed or has_change_db_cmd(sql)
+ meta_changed = meta_changed or has_meta_cmd(sql)
+ path_changed = path_changed or has_change_path_cmd(sql)
+ else:
+ all_success = False
+
+ meta_query = MetaQuery(text, all_success, total, meta_changed,
+ db_changed, path_changed, mutated)
+
+ return output, meta_query
+
+ def _handle_server_closed_connection(self):
+ """Used during CLI execution"""
+ reconnect = click.prompt(
+ 'Connection reset. Reconnect (Y/n)',
+ show_default=False, type=bool, default=True)
+ if reconnect:
+ try:
+ self.pgexecute.connect()
+ click.secho('Reconnected!\nTry the command again.', fg='green')
+ except OperationalError as e:
+ click.secho(str(e), err=True, fg='red')
+
def adjust_less_opts(self):
less_opts = os.environ.get('LESS', '')
self.logger.debug('Original value for LESS env var: %r', less_opts)
@@ -534,6 +587,7 @@ def cli(database, user, host, port, prompt_passwd, never_prompt, dbname,
pgcli.run_cli()
+
def obfuscate_process_password():
process_title = setproctitle.getproctitle()
if '://' in process_title:
@@ -564,58 +618,64 @@ def format_output(title, cur, headers, status, table_format, expanded=False, max
output.append(status)
return output
-def need_completion_refresh(queries):
+
+def has_meta_cmd(query):
"""Determines if the completion needs a refresh by checking if the sql
- statement is an alter, create, drop or change db."""
- for query in sqlparse.split(queries):
- try:
- first_token = query.split()[0]
- if first_token.lower() in ('alter', 'create', 'use', '\\c',
- '\\connect', 'drop'):
- return True
- except Exception:
- return False
+ statement is an alter, create, or drop"""
+ try:
+ first_token = query.split()[0]
+ if first_token.lower() in ('alter', 'create', 'drop'):
+ return True
+ except Exception:
+ return False
return False
-def need_completion_reset(queries):
- """Determines if the statement is a database switch such as 'use' or '\\c'.
- When a database is changed the existing completions must be reset before we
- start the completion refresh for the new database.
- """
- for query in sqlparse.split(queries):
- try:
- first_token = query.split()[0]
- if first_token.lower() in ('use', '\\c', '\\connect'):
- return True
- except Exception:
- return False
+
+def has_change_db_cmd(query):
+ """Determines if the statement is a database switch such as 'use' or '\\c'"""
+ try:
+ first_token = query.split()[0]
+ if first_token.lower() in ('use', '\\c', '\\connect'):
+ return True
+ except Exception:
+ return False
+
+ return False
-def need_search_path_refresh(sql):
+def has_change_path_cmd(sql):
"""Determines if the search_path should be refreshed by checking if the
sql has 'set search_path'."""
return 'set search_path' in sql.lower()
+
def is_mutating(status):
"""Determines if the statement is mutating based on the status."""
if not status:
return False
- mutating = set(['insert', 'update', 'delete', 'alter', 'create', 'drop'])
+ mutating = set(['insert', 'update', 'delete'])
return status.split(None, 1)[0].lower() in mutating
+
def is_select(status):
"""Returns true if the first word in status is 'select'."""
if not status:
return False
return status.split(None, 1)[0].lower() == 'select'
+
def quit_command(sql):
return (sql.strip().lower() == 'exit'
or sql.strip().lower() == 'quit'
or sql.strip() == '\q'
or sql.strip() == ':q')
+
+def exception_formatter(e):
+ return click.style(utf8tounicode(str(e)), fg='red')
+
+
if __name__ == "__main__":
cli()
diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py
index 041b1b58..2210058b 100644
--- a/pgcli/pgexecute.py
+++ b/pgcli/pgexecute.py
@@ -6,8 +6,7 @@ import psycopg2.extensions as ext
import sqlparse
import pgspecial as special
from .packages.function_metadata import FunctionMetadata
-from .encodingutils import unicode2utf8, PY2, utf8tounicode
-import click
+from .encodingutils import unicode2utf8, PY2
_logger = logging.getLogger(__name__)
@@ -27,11 +26,6 @@ ext.register_type(ext.new_type((17,), 'BYTEA_TEXT', psycopg2.STRING))
ext.set_wait_callback(psycopg2.extras.wait_select)
-ON_ERROR_RAISE = 0
-ON_ERROR_RESUME = 1
-ON_ERROR_STOP = 2
-
-
def register_json_typecasters(conn, loads_fn):
"""Set the function for converting JSON data for a connection.
@@ -58,6 +52,7 @@ def register_json_typecasters(conn, loads_fn):
return available
+
def register_hstore_typecaster(conn):
"""
Instead of using register_hstore() which converts hstore into a python
@@ -73,6 +68,7 @@ def register_hstore_typecaster(conn):
except Exception:
pass
+
class PGExecute(object):
# The boolean argument to the current_schemas function indicates whether
@@ -226,18 +222,28 @@ class PGExecute(object):
else:
return json_data
- def run(self, statement, pgspecial=None, on_error=ON_ERROR_RESUME):
+ def run(self, statement, pgspecial=None, exception_formatter=None,
+ on_error_resume=False):
"""Execute the sql in the database and return the results.
:param statement: A string containing one or more sql statements
:param pgspecial: PGSpecial object
- :return: List of tuples containing (title, rows, headers, status)
+ :param exception_formatter: A callable that accepts an Exception and
+ returns a formatted (title, rows, headers, status) tuple that can
+ act as a query result. If an exception_formatter is not supplied,
+ psycopg2 exceptions are always raised.
+ :param on_error_resume: Bool. If true, queries following an exception
+ (assuming exception_formatter has been supplied) continue to
+ execute.
+
+ :return: Generator yielding tuples containing
+ (title, rows, headers, status, query, success)
"""
# Remove spaces and EOL
statement = statement.strip()
if not statement: # Empty string
- yield (None, None, None, None)
+ yield (None, None, None, None, statement, False)
# Split the sql into separate queries and run each one.
for sql in sqlparse.split(statement):
@@ -251,30 +257,30 @@ class PGExecute(object):
cur = self.conn.cursor()
try:
for result in pgspecial.execute(cur, sql):
- yield result
+ yield result + (sql, True)
continue
except special.CommandNotFound:
pass
# Not a special command, so execute as normal sql
- yield self.execute_normal_sql(sql)
+ yield self.execute_normal_sql(sql) + (sql, True)
except psycopg2.DatabaseError as e:
_logger.error("sql: %r, error: %r", sql, e)
_logger.error("traceback: %r", traceback.format_exc())
if (isinstance(e, psycopg2.OperationalError)
- or on_error == ON_ERROR_RAISE):
+ or not exception_formatter):
# Always raise operational errors, regardless of on_error
# specification
raise
- result = click.style(utf8tounicode(str(e)), fg='red')
- yield None, None, None, result
+ yield None, None, None, exception_formatter(e), sql, False
- if on_error == ON_ERROR_STOP:
+ if not on_error_resume:
break
def execute_normal_sql(self, split_sql):
+ """Returns tuple (title, rows, headers, status)"""
_logger.debug('Regular sql statement. sql: %r', split_sql)
cur = self.conn.cursor()
cur.execute(split_sql)
diff --git a/tests/conftest.py b/tests/conftest.py
index 0cff635a..06b61462 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,6 +1,6 @@
import pytest
from utils import (POSTGRES_HOST, POSTGRES_USER, create_db, db_connection,
-drop_tables)
+ drop_tables)
import pgcli.pgexecute
@@ -24,3 +24,10 @@ def cursor(connection):
def executor(connection):
return pgcli.pgexecute.PGExecute(database='_test_db', user=POSTGRES_USER,
host=POSTGRES_HOST, password=None, port=None, dsn=None)
+
+
+@pytest.fixture
+def exception_formatter():
+ return lambda e: str(e)
+
+
diff --git a/tests/test_main.py b/tests/test_main.py
index f98aae5e..0d026994 100644
--- a/tests/test_main.py
+++ b/tests/test_main.py
@@ -1,19 +1,12 @@
import pytest
import platform
-from pgcli.main import need_completion_refresh, obfuscate_process_password
try:
import setproctitle
except ImportError:
setproctitle = None
+from pgcli.main import obfuscate_process_password
-@pytest.mark.parametrize('sql', [
- 'DROP TABLE foo',
- 'SELECT * FROM foo; DROP TABLE foo',
-])
-def test_need_completion_refresh(sql):
- assert need_completion_refresh(sql)
-
@pytest.mark.skipif(platform.system() == 'Windows',
reason='Not applicable in windows')
@pytest.mark.skipif(not setproctitle,
diff --git a/tests/test_pgexecute.py b/tests/test_pgexecute.py
index fb830836..bec40c7f 100644
--- a/tests/test_pgexecute.py
+++ b/tests/test_pgexecute.py
@@ -6,7 +6,6 @@ from pgspecial.main import PGSpecial
from pgcli.packages.function_metadata import FunctionMetadata
from textwrap import dedent
from utils import run, dbtest, requires_json, requires_jsonb
-from pgcli.pgexecute import (ON_ERROR_STOP, ON_ERROR_RAISE, ON_ERROR_RESUME)
@dbtest
@@ -105,13 +104,15 @@ def test_database_list(executor):
assert '_test_db' in databases
@dbtest
-def test_invalid_syntax(executor):
- result = run(executor, 'invalid syntax!')
+def test_invalid_syntax(executor, exception_formatter):
+ result = run(executor, 'invalid syntax!',
+ exception_formatter=exception_formatter)
assert 'syntax error at or near "invalid"' in result[0]
@dbtest
-def test_invalid_column_name(executor):
- result = run(executor, 'select invalid command')
+def test_invalid_column_name(executor, exception_formatter):
+ result = run(executor, 'select invalid command',
+ exception_formatter=exception_formatter)
assert 'column "invalid" does not exist' in result[0]
@@ -146,8 +147,9 @@ def test_multiple_queries_with_special_command_same_line(executor, pgspecial):
assert "Schema" in result[2]
@dbtest
-def test_multiple_queries_same_line_syntaxerror(executor):
- result = run(executor, u"select 'fooé'; invalid syntax é")
+def test_multiple_queries_same_line_syntaxerror(executor, exception_formatter):
+ result = run(executor, u"select 'fooé'; invalid syntax é",
+ exception_formatter=exception_formatter)
assert u'fooé' in result[0]
assert 'syntax error at or near "invalid"' in result[-1]
@@ -224,20 +226,22 @@ def test_describe_special(executor, command, verbose, pattern):
'invalid sql',
'SELECT 1; select error;',
])
-def test_on_error_raises(executor, sql):
+def test_raises_with_no_formatter(executor, sql):
with pytest.raises(psycopg2.ProgrammingError):
- list(executor.run(sql, on_error=ON_ERROR_RAISE))
+ list(executor.run(sql))
@dbtest
-def test_on_error_resume(executor):
+def test_on_error_resume(executor, exception_formatter):
sql = 'select 1; error; select 1;'
- result = list(executor.run(sql, on_error=ON_ERROR_RESUME))
+ result = list(executor.run(sql, on_error_resume=True,
+ exception_formatter=exception_formatter))
assert len(result) == 3
@dbtest
-def test_on_error_stop(executor):
+def test_on_error_stop(executor, exception_formatter):
sql = 'select 1; error; select 1;'
- result = list(executor.run(sql, on_error=ON_ERROR_STOP))
+ result = list(executor.run(sql, on_error_resume=False,
+ exception_formatter=exception_formatter))
assert len(result) == 2
diff --git a/tests/utils.py b/tests/utils.py
index 93e9f258..1f515add 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -58,12 +58,17 @@ def drop_tables(conn):
DROP SCHEMA IF EXISTS schema2 CASCADE''')
-def run(executor, sql, join=False, expanded=False, pgspecial=None):
+def run(executor, sql, join=False, expanded=False, pgspecial=None,
+ exception_formatter=None):
" Return string output for the sql to be run "
- result = []
- for title, rows, headers, status in executor.run(sql, pgspecial):
- result.extend(format_output(title, rows, headers, status, 'psql',
- expanded=expanded))
+
+ results = executor.run(sql, pgspecial, exception_formatter)
+ formatted = []
+
+ for title, rows, headers, status, sql, success in results:
+ formatted.extend(format_output(title, rows, headers, status, 'psql',
+ expanded=expanded))
if join:
- result = '\n'.join(result)
- return result
+ formatted = '\n'.join(formatted)
+
+ return formatted