summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAmjith Ramanujam <amjith.r@gmail.com>2021-02-23 10:56:30 -0800
committerGitHub <noreply@github.com>2021-02-23 10:56:30 -0800
commit6e5f0e107e226c478120af702d88316b04601542 (patch)
tree0d6dc0afc58b5a509bede5ef0a4416da07b5b9a7
parentd1c8699d894a2aa6cafb97c4dcbff7a8d7d6569c (diff)
parentf90b36bdcb6662b9b75ac95527d5ec2e7bd43b22 (diff)
Merge branch 'master' into pasenor/resources
-rw-r--r--.github/workflows/ci.yml1
-rw-r--r--changelog.md6
-rw-r--r--mycli/config.py17
-rwxr-xr-xmycli/main.py8
-rw-r--r--mycli/sqlexecute.py101
-rwxr-xr-xsetup.py4
-rw-r--r--test/features/connection.feature23
-rw-r--r--test/features/steps/auto_vertical.py3
-rw-r--r--test/features/steps/connection.py27
-rw-r--r--test/features/steps/utils.py12
-rw-r--r--test/features/steps/wrappers.py51
-rw-r--r--test/test_main.py2
-rw-r--r--test/test_sqlexecute.py22
13 files changed, 205 insertions, 72 deletions
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index ccb15aa..0a14472 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -49,6 +49,7 @@ jobs:
- name: Pytest / behave
env:
PYTEST_PASSWORD: root
+ PYTEST_HOST: 127.0.0.1
run: |
./setup.py test --pytest-args="--cov-report= --cov=mycli"
diff --git a/changelog.md b/changelog.md
index f0d835b..9253cff 100644
--- a/changelog.md
+++ b/changelog.md
@@ -1,9 +1,11 @@
TBD
-=======
+===
Bug Fixes:
----------
* Allow `FileNotFound` exception for SSH config files.
+* Fix startup error on MySQL < 5.0.22
+* Check error code rather than message for Access Denied error
Features:
---------
@@ -13,6 +15,8 @@ Internal:
---------
* Remove unused function is_open_quote()
* Use importlib, instead of file links, to locate resources
+* Test various host-port combinations in command line arguments
+* Switched from Cryptography to pyaes for decrypting mylogin.cnf
1.23.2
diff --git a/mycli/config.py b/mycli/config.py
index 5111288..55d230d 100644
--- a/mycli/config.py
+++ b/mycli/config.py
@@ -9,8 +9,7 @@ import sys
from typing import Union
from configobj import ConfigObj, ConfigObjError
-from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
-from cryptography.hazmat.backends import default_backend
+import pyaes
try:
import importlib.resources as resources
@@ -215,11 +214,9 @@ def read_and_decrypt_mylogin_cnf(f):
return None
rkey = struct.pack('16B', *rkey)
- # Create a decryptor object using the key.
- decryptor = _get_decryptor(rkey)
-
# Create a bytes buffer to hold the plaintext.
plaintext = BytesIO()
+ aes = pyaes.AESModeOfOperationECB(rkey)
while True:
# Read the length of the ciphertext.
@@ -230,7 +227,9 @@ def read_and_decrypt_mylogin_cnf(f):
# Read cipher_len bytes from the file and decrypt.
cipher = f.read(cipher_len)
- plain = _remove_pad(decryptor.update(cipher))
+ plain = _remove_pad(
+ b''.join([aes.decrypt(cipher[i: i + 16]) for i in range(0, cipher_len, 16)])
+ )
if plain is False:
continue
plaintext.write(plain)
@@ -274,12 +273,6 @@ def strip_matching_quotes(s):
return s
-def _get_decryptor(key):
- """Get the AES decryptor."""
- c = Cipher(algorithms.AES(key), modes.ECB(), backend=default_backend())
- return c.decryptor()
-
-
def _remove_pad(line):
"""Remove the pad from the *line*."""
try:
diff --git a/mycli/main.py b/mycli/main.py
index 2eb3812..6ab0f49 100755
--- a/mycli/main.py
+++ b/mycli/main.py
@@ -43,7 +43,7 @@ from .packages.special.favoritequeries import FavoriteQueries
from .sqlcompleter import SQLCompleter
from .clitoolbar import create_toolbar_tokens_func
from .clistyle import style_factory, style_factory_output
-from .sqlexecute import FIELD_TYPES, SQLExecute
+from .sqlexecute import FIELD_TYPES, SQLExecute, ERROR_CODE_ACCESS_DENIED
from .clibuffer import cli_is_multiline
from .completion_refresher import CompletionRefresher
from .config import (write_default_config, get_mylogin_cnf_path,
@@ -434,7 +434,7 @@ class MyCli(object):
ssh_password, ssh_key_filename, init_command
)
except OperationalError as e:
- if ('Access denied for user' in e.args[1]):
+ if e.args[0] == ERROR_CODE_ACCESS_DENIED:
new_passwd = click.prompt('Password', hide_input=True,
show_default=False, type=str, err=True)
self.sqlexecute = SQLExecute(
@@ -563,7 +563,7 @@ class MyCli(object):
key_bindings = mycli_bindings(self)
if not self.less_chatty:
- print(' '.join(sqlexecute.server_type()))
+ print(sqlexecute.server_info)
print('mycli', __version__)
print(SUPPORT_INFO)
print('Thanks to the contributor -', thanks_picker([author_file, sponsor_file]))
@@ -935,7 +935,7 @@ class MyCli(object):
string = string.replace('\\u', sqlexecute.user or '(none)')
string = string.replace('\\h', host or '(none)')
string = string.replace('\\d', sqlexecute.dbname or '(none)')
- string = string.replace('\\t', sqlexecute.server_type()[0] or 'mycli')
+ string = string.replace('\\t', sqlexecute.server_info.species.name)
string = string.replace('\\n', "\n")
string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y'))
string = string.replace('\\m', now.strftime('%M'))
diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py
index 46cf07c..9461438 100644
--- a/mycli/sqlexecute.py
+++ b/mycli/sqlexecute.py
@@ -1,4 +1,7 @@
+import enum
import logging
+import re
+
import pymysql
from .packages import special
from pymysql.constants import FIELD_TYPE
@@ -17,17 +20,71 @@ FIELD_TYPES.update({
FIELD_TYPE.NULL: type(None)
})
+
+ERROR_CODE_ACCESS_DENIED = 1045
+
+
+class ServerSpecies(enum.Enum):
+ MySQL = 'MySQL'
+ MariaDB = 'MariaDB'
+ Percona = 'Percona'
+ Unknown = 'MySQL'
+
+
+class ServerInfo:
+ def __init__(self, species, version_str):
+ self.species = species
+ self.version_str = version_str
+ self.version = self.calc_mysql_version_value(version_str)
+
+ @staticmethod
+ def calc_mysql_version_value(version_str) -> int:
+ if not version_str or not isinstance(version_str, str):
+ return 0
+ try:
+ major, minor, patch = version_str.split('.')
+ except ValueError:
+ return 0
+ else:
+ return int(major) * 10_000 + int(minor) * 100 + int(patch)
+
+ @classmethod
+ def from_version_string(cls, version_string):
+ if not version_string:
+ return cls(ServerSpecies.Unknown, '')
+
+ re_species = (
+ (r'(?P<version>[0-9\.]+)-MariaDB', ServerSpecies.MariaDB),
+ (r'(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[0-9]+$)',
+ ServerSpecies.Percona),
+ (r'(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[A-Za-z0-9_]+)',
+ ServerSpecies.MySQL),
+ )
+ for regexp, species in re_species:
+ match = re.search(regexp, version_string)
+ if match is not None:
+ parsed_version = match.group('version')
+ detected_species = species
+ break
+ else:
+ detected_species = ServerSpecies.Unknown
+ parsed_version = ''
+
+ return cls(detected_species, parsed_version)
+
+ def __str__(self):
+ if self.species:
+ return f'{self.species.value} {self.version_str}'
+ else:
+ return self.version_str
+
+
class SQLExecute(object):
databases_query = '''SHOW DATABASES'''
tables_query = '''SHOW TABLES'''
- version_query = '''SELECT @@VERSION'''
-
- version_comment_query = '''SELECT @@VERSION_COMMENT'''
- version_comment_query_mysql4 = '''SHOW VARIABLES LIKE "version_comment"'''
-
show_candidates_query = '''SELECT name from mysql.help_topic WHERE name like "SHOW %"'''
users_query = '''SELECT CONCAT("'", user, "'@'",host,"'") FROM mysql.user'''
@@ -51,7 +108,7 @@ class SQLExecute(object):
self.charset = charset
self.local_infile = local_infile
self.ssl = ssl
- self._server_type = None
+ self.server_info = None
self.connection_id = None
self.ssh_user = ssh_user
self.ssh_host = ssh_host
@@ -156,6 +213,7 @@ class SQLExecute(object):
self.init_command = init_command
# retrieve connection id
self.reset_connection_id()
+ self.server_info = ServerInfo.from_version_string(conn.server_version)
def run(self, statement):
"""Execute the sql in the database and return the results. The results
@@ -272,37 +330,6 @@ class SQLExecute(object):
for row in cur:
yield row
- def server_type(self):
- if self._server_type:
- return self._server_type
- with self.conn.cursor() as cur:
- _logger.debug('Version Query. sql: %r', self.version_query)
- cur.execute(self.version_query)
- version = cur.fetchone()[0]
- if version[0] == '4':
- _logger.debug('Version Comment. sql: %r',
- self.version_comment_query_mysql4)
- cur.execute(self.version_comment_query_mysql4)
- version_comment = cur.fetchone()[1].lower()
- if isinstance(version_comment, bytes):
- # with python3 this query returns bytes
- version_comment = version_comment.decode('utf-8')
- else:
- _logger.debug('Version Comment. sql: %r',
- self.version_comment_query)
- cur.execute(self.version_comment_query)
- version_comment = cur.fetchone()[0].lower()
-
- if 'mariadb' in version_comment:
- product_type = 'mariadb'
- elif 'percona' in version_comment:
- product_type = 'percona'
- else:
- product_type = 'mysql'
-
- self._server_type = (product_type, version)
- return self._server_type
-
def get_connection_id(self):
if not self.connection_id:
self.reset_connection_id()
diff --git a/setup.py b/setup.py
index ce32977..5acbae7 100755
--- a/setup.py
+++ b/setup.py
@@ -23,9 +23,9 @@ install_requirements = [
'PyMySQL >= 0.9.2',
'sqlparse>=0.3.0,<0.4.0',
'configobj >= 5.0.5',
- 'cryptography >= 1.0.0',
'cli_helpers[styles] >= 2.0.1',
- 'pyperclip >= 1.8.1'
+ 'pyperclip >= 1.8.1',
+ 'pyaes >= 1.6.1'
]
if sys.version_info.minor < 9:
diff --git a/test/features/connection.feature b/test/features/connection.feature
new file mode 100644
index 0000000..04d041d
--- /dev/null
+++ b/test/features/connection.feature
@@ -0,0 +1,23 @@
+Feature: connect to a database:
+
+ @requires_local_db
+ Scenario: run mycli on localhost without port
+ When we run mycli with arguments "host=localhost" without arguments "port"
+ When we query "status"
+ Then status contains "via UNIX socket"
+
+ Scenario: run mycli on TCP host without port
+ When we run mycli without arguments "port"
+ When we query "status"
+ Then status contains "via TCP/IP"
+
+ Scenario: run mycli with port but without host
+ When we run mycli without arguments "host"
+ When we query "status"
+ Then status contains "via TCP/IP"
+
+ @requires_local_db
+ Scenario: run mycli without host and port
+ When we run mycli without arguments "host port"
+ When we query "status"
+ Then status contains "via UNIX socket"
diff --git a/test/features/steps/auto_vertical.py b/test/features/steps/auto_vertical.py
index 974740d..e1cb26f 100644
--- a/test/features/steps/auto_vertical.py
+++ b/test/features/steps/auto_vertical.py
@@ -3,11 +3,12 @@ from textwrap import dedent
from behave import then, when
import wrappers
+from utils import parse_cli_args_to_dict
@when('we run dbcli with {arg}')
def step_run_cli_with_arg(context, arg):
- wrappers.run_cli(context, run_args=arg.split('='))
+ wrappers.run_cli(context, run_args=parse_cli_args_to_dict(arg))
@when('we execute a small query')
diff --git a/test/features/steps/connection.py b/test/features/steps/connection.py
new file mode 100644
index 0000000..f4a4929
--- /dev/null
+++ b/test/features/steps/connection.py
@@ -0,0 +1,27 @@
+import shlex
+from behave import when, then
+
+import wrappers
+from test.features.steps.utils import parse_cli_args_to_dict
+
+
+@when('we run mycli with arguments "{exact_args}" without arguments "{excluded_args}"')
+@when('we run mycli without arguments "{excluded_args}"')
+def step_run_cli_without_args(context, excluded_args, exact_args=''):
+ wrappers.run_cli(
+ context,
+ run_args=parse_cli_args_to_dict(exact_args),
+ exclude_args=parse_cli_args_to_dict(excluded_args).keys()
+ )
+
+
+@then('status contains "{expression}"')
+def status_contains(context, expression):
+ wrappers.expect_exact(context, f'{expression}', timeout=5)
+
+ # Normally, the shutdown after scenario waits for the prompt.
+ # But we may have changed the prompt, depending on parameters,
+ # so let's wait for its last character
+ context.cli.expect_exact('>')
+ context.atprompt = True
+
diff --git a/test/features/steps/utils.py b/test/features/steps/utils.py
new file mode 100644
index 0000000..1ae63d2
--- /dev/null
+++ b/test/features/steps/utils.py
@@ -0,0 +1,12 @@
+import shlex
+
+
+def parse_cli_args_to_dict(cli_args: str):
+ args_dict = {}
+ for arg in shlex.split(cli_args):
+ if '=' in arg:
+ key, value = arg.split('=')
+ args_dict[key] = value
+ else:
+ args_dict[arg] = None
+ return args_dict
diff --git a/test/features/steps/wrappers.py b/test/features/steps/wrappers.py
index de833dd..780a1c7 100644
--- a/test/features/steps/wrappers.py
+++ b/test/features/steps/wrappers.py
@@ -3,6 +3,7 @@ import pexpect
import sys
import textwrap
+
try:
from StringIO import StringIO
except ImportError:
@@ -46,21 +47,41 @@ def expect_pager(context, expected, timeout):
context.conf['pager_boundary'], expected), timeout=timeout)
-def run_cli(context, run_args=None):
+def run_cli(context, run_args=None, exclude_args=None):
"""Run the process using pexpect."""
- run_args = run_args or []
- if context.conf.get('host', None):
- run_args.extend(('-h', context.conf['host']))
- if context.conf.get('user', None):
- run_args.extend(('-u', context.conf['user']))
- if context.conf.get('pass', None):
- run_args.extend(('-p', context.conf['pass']))
- if context.conf.get('dbname', None):
- run_args.extend(('-D', context.conf['dbname']))
- if context.conf.get('defaults-file', None):
- run_args.extend(('--defaults-file', context.conf['defaults-file']))
- if context.conf.get('myclirc', None):
- run_args.extend(('--myclirc', context.conf['myclirc']))
+ run_args = run_args or {}
+ rendered_args = []
+ exclude_args = set(exclude_args) if exclude_args else set()
+
+ conf = dict(**context.conf)
+ conf.update(run_args)
+
+ def add_arg(name, key, value):
+ if name not in exclude_args:
+ if value is not None:
+ rendered_args.extend((key, value))
+ else:
+ rendered_args.append(key)
+
+ if conf.get('host', None):
+ add_arg('host', '-h', conf['host'])
+ if conf.get('user', None):
+ add_arg('user', '-u', conf['user'])
+ if conf.get('pass', None):
+ add_arg('pass', '-p', conf['pass'])
+ if conf.get('port', None):
+ add_arg('port', '-P', str(conf['port']))
+ if conf.get('dbname', None):
+ add_arg('dbname', '-D', conf['dbname'])
+ if conf.get('defaults-file', None):
+ add_arg('defaults_file', '--defaults-file', conf['defaults-file'])
+ if conf.get('myclirc', None):
+ add_arg('myclirc', '--myclirc', conf['myclirc'])
+
+ for arg_name, arg_value in conf.items():
+ if arg_name.startswith('-'):
+ add_arg(arg_name, arg_name, arg_value)
+
try:
cli_cmd = context.conf['cli_command']
except KeyError:
@@ -73,7 +94,7 @@ def run_cli(context, run_args=None):
'"'
).format(sys.executable)
- cmd_parts = [cli_cmd] + run_args
+ cmd_parts = [cli_cmd] + rendered_args
cmd = ' '.join(cmd_parts)
context.cli = pexpect.spawnu(cmd, cwd=context.package_root)
context.logfile = StringIO()
diff --git a/test/test_main.py b/test/test_main.py
index 91a366b..00fdc1b 100644
--- a/test/test_main.py
+++ b/test/test_main.py
@@ -5,6 +5,7 @@ from click.testing import CliRunner
from mycli.main import MyCli, cli, thanks_picker
from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS
+from mycli.sqlexecute import ServerInfo
from .utils import USER, HOST, PORT, PASSWORD, dbtest, run
from textwrap import dedent
@@ -174,6 +175,7 @@ def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager):
host = 'test'
user = 'test'
dbname = 'test'
+ server_info = ServerInfo.from_version_string('unknown')
port = 0
def server_type(self):
diff --git a/test/test_sqlexecute.py b/test/test_sqlexecute.py
index 5168bf6..0f38a97 100644
--- a/test/test_sqlexecute.py
+++ b/test/test_sqlexecute.py
@@ -3,6 +3,7 @@ import os
import pytest
import pymysql
+from mycli.sqlexecute import ServerInfo, ServerSpecies
from .utils import run, dbtest, set_expanded_output, is_expanded_output
@@ -270,3 +271,24 @@ def test_multiple_results(executor):
'status': '1 row in set'}
]
assert results == expected
+
+
+@pytest.mark.parametrize(
+ 'version_string, species, parsed_version_string, version',
+ (
+ ('5.7.32-35', 'Percona', '5.7.32', 50732),
+ ('5.7.32-0ubuntu0.18.04.1', 'MySQL', '5.7.32', 50732),
+ ('10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508),
+ ('5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508),
+ ('5.0.16-pro-nt-log', 'MySQL', '5.0.16', 50016),
+ ('5.1.5a-alpha', 'MySQL', '5.1.5', 50105),
+ ('unexpected version string', None, '', 0),
+ ('', None, '', 0),
+ (None, None, '', 0),
+ )
+)
+def test_version_parsing(version_string, species, parsed_version_string, version):
+ server_info = ServerInfo.from_version_string(version_string)
+ assert (server_info.species and server_info.species.name) == species or ServerSpecies.Unknown
+ assert server_info.version_str == parsed_version_string
+ assert server_info.version == version