diff options
author | Amjith Ramanujam <amjith.r@gmail.com> | 2021-02-23 10:56:30 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-02-23 10:56:30 -0800 |
commit | 6e5f0e107e226c478120af702d88316b04601542 (patch) | |
tree | 0d6dc0afc58b5a509bede5ef0a4416da07b5b9a7 | |
parent | d1c8699d894a2aa6cafb97c4dcbff7a8d7d6569c (diff) | |
parent | f90b36bdcb6662b9b75ac95527d5ec2e7bd43b22 (diff) |
Merge branch 'master' into pasenor/resources
-rw-r--r-- | .github/workflows/ci.yml | 1 | ||||
-rw-r--r-- | changelog.md | 6 | ||||
-rw-r--r-- | mycli/config.py | 17 | ||||
-rwxr-xr-x | mycli/main.py | 8 | ||||
-rw-r--r-- | mycli/sqlexecute.py | 101 | ||||
-rwxr-xr-x | setup.py | 4 | ||||
-rw-r--r-- | test/features/connection.feature | 23 | ||||
-rw-r--r-- | test/features/steps/auto_vertical.py | 3 | ||||
-rw-r--r-- | test/features/steps/connection.py | 27 | ||||
-rw-r--r-- | test/features/steps/utils.py | 12 | ||||
-rw-r--r-- | test/features/steps/wrappers.py | 51 | ||||
-rw-r--r-- | test/test_main.py | 2 | ||||
-rw-r--r-- | test/test_sqlexecute.py | 22 |
13 files changed, 205 insertions, 72 deletions
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ccb15aa..0a14472 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -49,6 +49,7 @@ jobs: - name: Pytest / behave env: PYTEST_PASSWORD: root + PYTEST_HOST: 127.0.0.1 run: | ./setup.py test --pytest-args="--cov-report= --cov=mycli" diff --git a/changelog.md b/changelog.md index f0d835b..9253cff 100644 --- a/changelog.md +++ b/changelog.md @@ -1,9 +1,11 @@ TBD -======= +=== Bug Fixes: ---------- * Allow `FileNotFound` exception for SSH config files. +* Fix startup error on MySQL < 5.0.22 +* Check error code rather than message for Access Denied error Features: --------- @@ -13,6 +15,8 @@ Internal: --------- * Remove unused function is_open_quote() * Use importlib, instead of file links, to locate resources +* Test various host-port combinations in command line arguments +* Switched from Cryptography to pyaes for decrypting mylogin.cnf 1.23.2 diff --git a/mycli/config.py b/mycli/config.py index 5111288..55d230d 100644 --- a/mycli/config.py +++ b/mycli/config.py @@ -9,8 +9,7 @@ import sys from typing import Union from configobj import ConfigObj, ConfigObjError -from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes -from cryptography.hazmat.backends import default_backend +import pyaes try: import importlib.resources as resources @@ -215,11 +214,9 @@ def read_and_decrypt_mylogin_cnf(f): return None rkey = struct.pack('16B', *rkey) - # Create a decryptor object using the key. - decryptor = _get_decryptor(rkey) - # Create a bytes buffer to hold the plaintext. plaintext = BytesIO() + aes = pyaes.AESModeOfOperationECB(rkey) while True: # Read the length of the ciphertext. @@ -230,7 +227,9 @@ def read_and_decrypt_mylogin_cnf(f): # Read cipher_len bytes from the file and decrypt. cipher = f.read(cipher_len) - plain = _remove_pad(decryptor.update(cipher)) + plain = _remove_pad( + b''.join([aes.decrypt(cipher[i: i + 16]) for i in range(0, cipher_len, 16)]) + ) if plain is False: continue plaintext.write(plain) @@ -274,12 +273,6 @@ def strip_matching_quotes(s): return s -def _get_decryptor(key): - """Get the AES decryptor.""" - c = Cipher(algorithms.AES(key), modes.ECB(), backend=default_backend()) - return c.decryptor() - - def _remove_pad(line): """Remove the pad from the *line*.""" try: diff --git a/mycli/main.py b/mycli/main.py index 2eb3812..6ab0f49 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -43,7 +43,7 @@ from .packages.special.favoritequeries import FavoriteQueries from .sqlcompleter import SQLCompleter from .clitoolbar import create_toolbar_tokens_func from .clistyle import style_factory, style_factory_output -from .sqlexecute import FIELD_TYPES, SQLExecute +from .sqlexecute import FIELD_TYPES, SQLExecute, ERROR_CODE_ACCESS_DENIED from .clibuffer import cli_is_multiline from .completion_refresher import CompletionRefresher from .config import (write_default_config, get_mylogin_cnf_path, @@ -434,7 +434,7 @@ class MyCli(object): ssh_password, ssh_key_filename, init_command ) except OperationalError as e: - if ('Access denied for user' in e.args[1]): + if e.args[0] == ERROR_CODE_ACCESS_DENIED: new_passwd = click.prompt('Password', hide_input=True, show_default=False, type=str, err=True) self.sqlexecute = SQLExecute( @@ -563,7 +563,7 @@ class MyCli(object): key_bindings = mycli_bindings(self) if not self.less_chatty: - print(' '.join(sqlexecute.server_type())) + print(sqlexecute.server_info) print('mycli', __version__) print(SUPPORT_INFO) print('Thanks to the contributor -', thanks_picker([author_file, sponsor_file])) @@ -935,7 +935,7 @@ class MyCli(object): string = string.replace('\\u', sqlexecute.user or '(none)') string = string.replace('\\h', host or '(none)') string = string.replace('\\d', sqlexecute.dbname or '(none)') - string = string.replace('\\t', sqlexecute.server_type()[0] or 'mycli') + string = string.replace('\\t', sqlexecute.server_info.species.name) string = string.replace('\\n', "\n") string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y')) string = string.replace('\\m', now.strftime('%M')) diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py index 46cf07c..9461438 100644 --- a/mycli/sqlexecute.py +++ b/mycli/sqlexecute.py @@ -1,4 +1,7 @@ +import enum import logging +import re + import pymysql from .packages import special from pymysql.constants import FIELD_TYPE @@ -17,17 +20,71 @@ FIELD_TYPES.update({ FIELD_TYPE.NULL: type(None) }) + +ERROR_CODE_ACCESS_DENIED = 1045 + + +class ServerSpecies(enum.Enum): + MySQL = 'MySQL' + MariaDB = 'MariaDB' + Percona = 'Percona' + Unknown = 'MySQL' + + +class ServerInfo: + def __init__(self, species, version_str): + self.species = species + self.version_str = version_str + self.version = self.calc_mysql_version_value(version_str) + + @staticmethod + def calc_mysql_version_value(version_str) -> int: + if not version_str or not isinstance(version_str, str): + return 0 + try: + major, minor, patch = version_str.split('.') + except ValueError: + return 0 + else: + return int(major) * 10_000 + int(minor) * 100 + int(patch) + + @classmethod + def from_version_string(cls, version_string): + if not version_string: + return cls(ServerSpecies.Unknown, '') + + re_species = ( + (r'(?P<version>[0-9\.]+)-MariaDB', ServerSpecies.MariaDB), + (r'(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[0-9]+$)', + ServerSpecies.Percona), + (r'(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[A-Za-z0-9_]+)', + ServerSpecies.MySQL), + ) + for regexp, species in re_species: + match = re.search(regexp, version_string) + if match is not None: + parsed_version = match.group('version') + detected_species = species + break + else: + detected_species = ServerSpecies.Unknown + parsed_version = '' + + return cls(detected_species, parsed_version) + + def __str__(self): + if self.species: + return f'{self.species.value} {self.version_str}' + else: + return self.version_str + + class SQLExecute(object): databases_query = '''SHOW DATABASES''' tables_query = '''SHOW TABLES''' - version_query = '''SELECT @@VERSION''' - - version_comment_query = '''SELECT @@VERSION_COMMENT''' - version_comment_query_mysql4 = '''SHOW VARIABLES LIKE "version_comment"''' - show_candidates_query = '''SELECT name from mysql.help_topic WHERE name like "SHOW %"''' users_query = '''SELECT CONCAT("'", user, "'@'",host,"'") FROM mysql.user''' @@ -51,7 +108,7 @@ class SQLExecute(object): self.charset = charset self.local_infile = local_infile self.ssl = ssl - self._server_type = None + self.server_info = None self.connection_id = None self.ssh_user = ssh_user self.ssh_host = ssh_host @@ -156,6 +213,7 @@ class SQLExecute(object): self.init_command = init_command # retrieve connection id self.reset_connection_id() + self.server_info = ServerInfo.from_version_string(conn.server_version) def run(self, statement): """Execute the sql in the database and return the results. The results @@ -272,37 +330,6 @@ class SQLExecute(object): for row in cur: yield row - def server_type(self): - if self._server_type: - return self._server_type - with self.conn.cursor() as cur: - _logger.debug('Version Query. sql: %r', self.version_query) - cur.execute(self.version_query) - version = cur.fetchone()[0] - if version[0] == '4': - _logger.debug('Version Comment. sql: %r', - self.version_comment_query_mysql4) - cur.execute(self.version_comment_query_mysql4) - version_comment = cur.fetchone()[1].lower() - if isinstance(version_comment, bytes): - # with python3 this query returns bytes - version_comment = version_comment.decode('utf-8') - else: - _logger.debug('Version Comment. sql: %r', - self.version_comment_query) - cur.execute(self.version_comment_query) - version_comment = cur.fetchone()[0].lower() - - if 'mariadb' in version_comment: - product_type = 'mariadb' - elif 'percona' in version_comment: - product_type = 'percona' - else: - product_type = 'mysql' - - self._server_type = (product_type, version) - return self._server_type - def get_connection_id(self): if not self.connection_id: self.reset_connection_id() @@ -23,9 +23,9 @@ install_requirements = [ 'PyMySQL >= 0.9.2', 'sqlparse>=0.3.0,<0.4.0', 'configobj >= 5.0.5', - 'cryptography >= 1.0.0', 'cli_helpers[styles] >= 2.0.1', - 'pyperclip >= 1.8.1' + 'pyperclip >= 1.8.1', + 'pyaes >= 1.6.1' ] if sys.version_info.minor < 9: diff --git a/test/features/connection.feature b/test/features/connection.feature new file mode 100644 index 0000000..04d041d --- /dev/null +++ b/test/features/connection.feature @@ -0,0 +1,23 @@ +Feature: connect to a database: + + @requires_local_db + Scenario: run mycli on localhost without port + When we run mycli with arguments "host=localhost" without arguments "port" + When we query "status" + Then status contains "via UNIX socket" + + Scenario: run mycli on TCP host without port + When we run mycli without arguments "port" + When we query "status" + Then status contains "via TCP/IP" + + Scenario: run mycli with port but without host + When we run mycli without arguments "host" + When we query "status" + Then status contains "via TCP/IP" + + @requires_local_db + Scenario: run mycli without host and port + When we run mycli without arguments "host port" + When we query "status" + Then status contains "via UNIX socket" diff --git a/test/features/steps/auto_vertical.py b/test/features/steps/auto_vertical.py index 974740d..e1cb26f 100644 --- a/test/features/steps/auto_vertical.py +++ b/test/features/steps/auto_vertical.py @@ -3,11 +3,12 @@ from textwrap import dedent from behave import then, when import wrappers +from utils import parse_cli_args_to_dict @when('we run dbcli with {arg}') def step_run_cli_with_arg(context, arg): - wrappers.run_cli(context, run_args=arg.split('=')) + wrappers.run_cli(context, run_args=parse_cli_args_to_dict(arg)) @when('we execute a small query') diff --git a/test/features/steps/connection.py b/test/features/steps/connection.py new file mode 100644 index 0000000..f4a4929 --- /dev/null +++ b/test/features/steps/connection.py @@ -0,0 +1,27 @@ +import shlex +from behave import when, then + +import wrappers +from test.features.steps.utils import parse_cli_args_to_dict + + +@when('we run mycli with arguments "{exact_args}" without arguments "{excluded_args}"') +@when('we run mycli without arguments "{excluded_args}"') +def step_run_cli_without_args(context, excluded_args, exact_args=''): + wrappers.run_cli( + context, + run_args=parse_cli_args_to_dict(exact_args), + exclude_args=parse_cli_args_to_dict(excluded_args).keys() + ) + + +@then('status contains "{expression}"') +def status_contains(context, expression): + wrappers.expect_exact(context, f'{expression}', timeout=5) + + # Normally, the shutdown after scenario waits for the prompt. + # But we may have changed the prompt, depending on parameters, + # so let's wait for its last character + context.cli.expect_exact('>') + context.atprompt = True + diff --git a/test/features/steps/utils.py b/test/features/steps/utils.py new file mode 100644 index 0000000..1ae63d2 --- /dev/null +++ b/test/features/steps/utils.py @@ -0,0 +1,12 @@ +import shlex + + +def parse_cli_args_to_dict(cli_args: str): + args_dict = {} + for arg in shlex.split(cli_args): + if '=' in arg: + key, value = arg.split('=') + args_dict[key] = value + else: + args_dict[arg] = None + return args_dict diff --git a/test/features/steps/wrappers.py b/test/features/steps/wrappers.py index de833dd..780a1c7 100644 --- a/test/features/steps/wrappers.py +++ b/test/features/steps/wrappers.py @@ -3,6 +3,7 @@ import pexpect import sys import textwrap + try: from StringIO import StringIO except ImportError: @@ -46,21 +47,41 @@ def expect_pager(context, expected, timeout): context.conf['pager_boundary'], expected), timeout=timeout) -def run_cli(context, run_args=None): +def run_cli(context, run_args=None, exclude_args=None): """Run the process using pexpect.""" - run_args = run_args or [] - if context.conf.get('host', None): - run_args.extend(('-h', context.conf['host'])) - if context.conf.get('user', None): - run_args.extend(('-u', context.conf['user'])) - if context.conf.get('pass', None): - run_args.extend(('-p', context.conf['pass'])) - if context.conf.get('dbname', None): - run_args.extend(('-D', context.conf['dbname'])) - if context.conf.get('defaults-file', None): - run_args.extend(('--defaults-file', context.conf['defaults-file'])) - if context.conf.get('myclirc', None): - run_args.extend(('--myclirc', context.conf['myclirc'])) + run_args = run_args or {} + rendered_args = [] + exclude_args = set(exclude_args) if exclude_args else set() + + conf = dict(**context.conf) + conf.update(run_args) + + def add_arg(name, key, value): + if name not in exclude_args: + if value is not None: + rendered_args.extend((key, value)) + else: + rendered_args.append(key) + + if conf.get('host', None): + add_arg('host', '-h', conf['host']) + if conf.get('user', None): + add_arg('user', '-u', conf['user']) + if conf.get('pass', None): + add_arg('pass', '-p', conf['pass']) + if conf.get('port', None): + add_arg('port', '-P', str(conf['port'])) + if conf.get('dbname', None): + add_arg('dbname', '-D', conf['dbname']) + if conf.get('defaults-file', None): + add_arg('defaults_file', '--defaults-file', conf['defaults-file']) + if conf.get('myclirc', None): + add_arg('myclirc', '--myclirc', conf['myclirc']) + + for arg_name, arg_value in conf.items(): + if arg_name.startswith('-'): + add_arg(arg_name, arg_name, arg_value) + try: cli_cmd = context.conf['cli_command'] except KeyError: @@ -73,7 +94,7 @@ def run_cli(context, run_args=None): '"' ).format(sys.executable) - cmd_parts = [cli_cmd] + run_args + cmd_parts = [cli_cmd] + rendered_args cmd = ' '.join(cmd_parts) context.cli = pexpect.spawnu(cmd, cwd=context.package_root) context.logfile = StringIO() diff --git a/test/test_main.py b/test/test_main.py index 91a366b..00fdc1b 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -5,6 +5,7 @@ from click.testing import CliRunner from mycli.main import MyCli, cli, thanks_picker from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS +from mycli.sqlexecute import ServerInfo from .utils import USER, HOST, PORT, PASSWORD, dbtest, run from textwrap import dedent @@ -174,6 +175,7 @@ def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager): host = 'test' user = 'test' dbname = 'test' + server_info = ServerInfo.from_version_string('unknown') port = 0 def server_type(self): diff --git a/test/test_sqlexecute.py b/test/test_sqlexecute.py index 5168bf6..0f38a97 100644 --- a/test/test_sqlexecute.py +++ b/test/test_sqlexecute.py @@ -3,6 +3,7 @@ import os import pytest import pymysql +from mycli.sqlexecute import ServerInfo, ServerSpecies from .utils import run, dbtest, set_expanded_output, is_expanded_output @@ -270,3 +271,24 @@ def test_multiple_results(executor): 'status': '1 row in set'} ] assert results == expected + + +@pytest.mark.parametrize( + 'version_string, species, parsed_version_string, version', + ( + ('5.7.32-35', 'Percona', '5.7.32', 50732), + ('5.7.32-0ubuntu0.18.04.1', 'MySQL', '5.7.32', 50732), + ('10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508), + ('5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508), + ('5.0.16-pro-nt-log', 'MySQL', '5.0.16', 50016), + ('5.1.5a-alpha', 'MySQL', '5.1.5', 50105), + ('unexpected version string', None, '', 0), + ('', None, '', 0), + (None, None, '', 0), + ) +) +def test_version_parsing(version_string, species, parsed_version_string, version): + server_info = ServerInfo.from_version_string(version_string) + assert (server_info.species and server_info.species.name) == species or ServerSpecies.Unknown + assert server_info.version_str == parsed_version_string + assert server_info.version == version |