diff options
author | Georgy Frolov <gosha@fro.lv> | 2021-02-28 00:55:53 +0300 |
---|---|---|
committer | Georgy Frolov <gosha@fro.lv> | 2021-02-28 01:19:32 +0300 |
commit | 51b8ee3c1146026b2c58b897a7ef5ad0c10a0836 (patch) | |
tree | 9db5f0a0a18e57fe1c7f2161f589fa58784d8c75 | |
parent | 1310b2f739967e6db01039b626ecfd1d6a77ed20 (diff) |
fixed login with .my.cnf and .mylogin.cnf
-rw-r--r-- | changelog.md | 1 | ||||
-rw-r--r-- | mycli/config.py | 58 | ||||
-rwxr-xr-x | mycli/main.py | 48 | ||||
-rw-r--r-- | test/features/connection.feature | 12 | ||||
-rw-r--r-- | test/features/environment.py | 48 | ||||
-rw-r--r-- | test/features/steps/connection.py | 44 | ||||
-rw-r--r-- | test/features/steps/wrappers.py | 4 |
7 files changed, 191 insertions, 24 deletions
diff --git a/changelog.md b/changelog.md index 9253cff..136145a 100644 --- a/changelog.md +++ b/changelog.md @@ -6,6 +6,7 @@ Bug Fixes: * Allow `FileNotFound` exception for SSH config files. * Fix startup error on MySQL < 5.0.22 * Check error code rather than message for Access Denied error +* Fix login with ~/.my.cnf files Features: --------- diff --git a/mycli/config.py b/mycli/config.py index 0b67bd4..5d71109 100644 --- a/mycli/config.py +++ b/mycli/config.py @@ -5,7 +5,7 @@ import os from os.path import exists import struct import sys -from typing import Union +from typing import Union, IO from configobj import ConfigObj, ConfigObjError import pyaes @@ -52,9 +52,9 @@ def read_config_file(f, list_values=True): config = ConfigObj(f, interpolation=False, encoding='utf8', list_values=list_values) except ConfigObjError as e: - log(logger, logging.ERROR, "Unable to parse line {0} of config file " + log(logger, logging.WARNING, "Unable to parse line {0} of config file " "'{1}'.".format(e.line_number, f)) - log(logger, logging.ERROR, "Using successfully parsed config values.") + log(logger, logging.WARNING, "Using successfully parsed config values.") return e.config except (IOError, OSError) as e: log(logger, logging.WARNING, "You don't have permission to read " @@ -172,6 +172,58 @@ def open_mylogin_cnf(name): return TextIOWrapper(plaintext) +# TODO reuse code between encryption an decryption +def encrypt_mylogin_cnf(plaintext: IO[str]): + """Encryption of .mylogin.cnf file, analogous to calling + mysql_config_editor. + + Code is based on the python implementation by Kristian Koehntopp + https://github.com/isotopp/mysql-config-coder + + """ + def realkey(key): + """Create the AES key from the login key.""" + rkey = bytearray(16) + for i in range(len(key)): + rkey[i % 16] ^= key[i] + return bytes(rkey) + + def encode_line(plaintext, real_key, buf_len): + aes = pyaes.AESModeOfOperationECB(real_key) + text_len = len(plaintext) + pad_len = buf_len - text_len + pad_chr = bytes(chr(pad_len), "utf8") + plaintext = plaintext.encode() + pad_chr * pad_len + encrypted_text = b''.join( + [aes.encrypt(plaintext[i: i + 16]) + for i in range(0, len(plaintext), 16)] + ) + return encrypted_text + + LOGIN_KEY_LENGTH = 20 + key = os.urandom(LOGIN_KEY_LENGTH) + real_key = realkey(key) + + outfile = BytesIO() + + outfile.write(struct.pack("i", 0)) + outfile.write(key) + + while True: + line = plaintext.readline() + if not line: + break + real_len = len(line) + pad_len = (int(real_len / 16) + 1) * 16 + + outfile.write(struct.pack("i", pad_len)) + x = encode_line(line, real_key, pad_len) + outfile.write(x) + + outfile.seek(0) + return outfile + + def read_and_decrypt_mylogin_cnf(f): """Read and decrypt the contents of .mylogin.cnf. diff --git a/mycli/main.py b/mycli/main.py index c1685a2..2e2b842 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -1,3 +1,4 @@ +from collections import defaultdict from io import open import os import sys @@ -332,20 +333,33 @@ class MyCli(object): cnf = read_config_files(files, list_values=False) sections = ['client', 'mysqld'] + key_transformations = { + 'mysqld': { + 'socket': 'default_socket', + 'port': 'default_port', + }, + } + if self.login_path and self.login_path != 'client': sections.append(self.login_path) if self.defaults_suffix: sections.extend([sect + self.defaults_suffix for sect in sections]) - def get(key): - result = None - for sect in cnf: - if sect in sections and key in cnf[sect]: - result = strip_matching_quotes(cnf[sect][key]) - return result + configuration = defaultdict(lambda: None) + for key in keys: + for section in cnf: + if ( + section not in sections or + key not in cnf[section] + ): + continue + new_key = key_transformations.get(section, {}).get(key) or key + configuration[new_key] = strip_matching_quotes( + cnf[section][key]) + + return configuration - return {x: get(x) for x in keys} def merge_ssl_with_cnf(self, ssl, cnf): """Merge SSL configuration dict with cnf dict""" @@ -381,6 +395,7 @@ class MyCli(object): 'host': None, 'port': None, 'socket': None, + 'default_socket': None, 'default-character-set': None, 'local-infile': None, 'loose-local-infile': None, @@ -394,18 +409,23 @@ class MyCli(object): cnf = self.read_my_cnf_files(self.cnf_files, cnf.keys()) # Fall back to config values only if user did not specify a value. - database = database or cnf['database'] - # Socket interface not supported for SSH connections - if port or (host and host != 'localhost') or (ssh_host and ssh_port): - socket = '' - else: - socket = socket or cnf['socket'] or guess_socket_location() user = user or cnf['user'] or os.getenv('USER') host = host or cnf['host'] - port = int(port or cnf['port'] or 3306) + port = port or cnf['port'] ssl = ssl or {} + port = port and int(port) + if not port: + port = 3306 + if not host or host == 'localhost': + socket = ( + cnf['socket'] or + cnf['default_socket'] or + guess_socket_location() + ) + + passwd = passwd if isinstance(passwd, str) else cnf['password'] charset = charset or cnf['default-character-set'] or 'utf8' diff --git a/test/features/connection.feature b/test/features/connection.feature index 04d041d..b06935e 100644 --- a/test/features/connection.feature +++ b/test/features/connection.feature @@ -21,3 +21,15 @@ Feature: connect to a database: When we run mycli without arguments "host port" When we query "status" Then status contains "via UNIX socket" + + Scenario: run mycli with my.cnf configuration + When we create my.cnf file + When we run mycli without arguments "host port user pass defaults_file" + Then we are logged in + + Scenario: run mycli with mylogin.cnf configuration + When we create mylogin.cnf file + When we run mycli with arguments "login_path=test_login_path" without arguments "host port user pass defaults_file" + Then we are logged in + + diff --git a/test/features/environment.py b/test/features/environment.py index 98c2004..1ea0f08 100644 --- a/test/features/environment.py +++ b/test/features/environment.py @@ -1,4 +1,5 @@ import os +import shutil import sys from tempfile import mkstemp @@ -11,6 +12,24 @@ from steps.wrappers import run_cli, wait_prompt test_log_file = os.path.join(os.environ['HOME'], '.mycli.test.log') +SELF_CONNECTING_FEATURES = ( + 'test/features/connection.feature', +) + + +MY_CNF_PATH = os.path.expanduser('~/.my.cnf') +MY_CNF_BACKUP_PATH = f'{MY_CNF_PATH}.backup' +MYLOGIN_CNF_PATH = os.path.expanduser('~/.mylogin.cnf') +MYLOGIN_CNF_BACKUP_PATH = f'{MYLOGIN_CNF_PATH}.backup' + + +def get_db_name_from_context(context): + return context.config.userdata.get( + 'my_test_db', None + ) or "mycli_behave_tests" + + + def before_all(context): """Set env parameters.""" os.environ['LINES'] = "100" @@ -22,7 +41,7 @@ def before_all(context): test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) login_path_file = os.path.join(test_dir, 'mylogin.cnf') - os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file +# os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file context.package_root = os.path.abspath( os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) @@ -33,8 +52,7 @@ def before_all(context): context.exit_sent = False vi = '_'.join([str(x) for x in sys.version_info[:3]]) - db_name = context.config.userdata.get( - 'my_test_db', None) or "mycli_behave_tests" + db_name = get_db_name_from_context(context) db_name_full = '{0}_{1}'.format(db_name, vi) # Store get params from config/environment variables @@ -104,11 +122,18 @@ def before_step(context, _): context.atprompt = False -def before_scenario(context, _): +def before_scenario(context, arg): with open(test_log_file, 'w') as f: f.write('') - run_cli(context) - wait_prompt(context) + if arg.location.filename not in SELF_CONNECTING_FEATURES: + run_cli(context) + wait_prompt(context) + + if os.path.exists(MY_CNF_PATH): + shutil.move(MY_CNF_PATH, MY_CNF_BACKUP_PATH) + + if os.path.exists(MYLOGIN_CNF_PATH): + shutil.move(MYLOGIN_CNF_PATH, MYLOGIN_CNF_BACKUP_PATH) def after_scenario(context, _): @@ -134,6 +159,17 @@ def after_scenario(context, _): context.cli.sendcontrol('d') context.cli.expect_exact(pexpect.EOF, timeout=5) + if os.path.exists(MY_CNF_BACKUP_PATH): + shutil.move(MY_CNF_BACKUP_PATH, MY_CNF_PATH) + + if os.path.exists(MYLOGIN_CNF_BACKUP_PATH): + shutil.move(MYLOGIN_CNF_BACKUP_PATH, MYLOGIN_CNF_PATH) + elif os.path.exists(MYLOGIN_CNF_PATH): + # This file was moved in `before_scenario`. + # If it exists now, it has been created during a test + os.remove(MYLOGIN_CNF_PATH) + + # TODO: uncomment to debug a failure # def after_step(context, step): # if step.status == "failed": diff --git a/test/features/steps/connection.py b/test/features/steps/connection.py index f4a4929..e16dd86 100644 --- a/test/features/steps/connection.py +++ b/test/features/steps/connection.py @@ -1,8 +1,18 @@ +import io +import os import shlex + from behave import when, then +import pexpect import wrappers from test.features.steps.utils import parse_cli_args_to_dict +from test.features.environment import MY_CNF_PATH, MYLOGIN_CNF_PATH, get_db_name_from_context +from test.utils import HOST, PORT, USER, PASSWORD +from mycli.config import encrypt_mylogin_cnf + + +TEST_LOGIN_PATH = 'test_login_path' @when('we run mycli with arguments "{exact_args}" without arguments "{excluded_args}"') @@ -25,3 +35,37 @@ def status_contains(context, expression): context.cli.expect_exact('>') context.atprompt = True + +@when('we create my.cnf file') +def step_create_my_cnf_file(context): + my_cnf = ( + '[client]\n' + f'host = {HOST}\n' + f'port = {PORT}\n' + f'user = {USER}\n' + f'password = {PASSWORD}\n' + ) + with open(MY_CNF_PATH, 'w') as f: + f.write(my_cnf) + + +@when('we create mylogin.cnf file') +def step_create_mylogin_cnf_file(context): + os.environ.pop('MYSQL_TEST_LOGIN_FILE', None) + mylogin_cnf = ( + f'[{TEST_LOGIN_PATH}]\n' + f'host = {HOST}\n' + f'port = {PORT}\n' + f'user = {USER}\n' + f'password = {PASSWORD}\n' + ) + with open(MYLOGIN_CNF_PATH, 'wb') as f: + input_file = io.StringIO(mylogin_cnf) + f.write(encrypt_mylogin_cnf(input_file).read()) + + +@then('we are logged in') +def we_are_logged_in(context): + db_name = get_db_name_from_context(context) + context.cli.expect_exact(f'{db_name}>', timeout=5) + context.atprompt = True diff --git a/test/features/steps/wrappers.py b/test/features/steps/wrappers.py index 780a1c7..6408f23 100644 --- a/test/features/steps/wrappers.py +++ b/test/features/steps/wrappers.py @@ -14,7 +14,7 @@ def expect_exact(context, expected, timeout): timedout = False try: context.cli.expect_exact(expected, timeout=timeout) - except pexpect.exceptions.TIMEOUT: + except pexpect.TIMEOUT: timedout = True if timedout: # Strip color codes out of the output. @@ -77,6 +77,8 @@ def run_cli(context, run_args=None, exclude_args=None): add_arg('defaults_file', '--defaults-file', conf['defaults-file']) if conf.get('myclirc', None): add_arg('myclirc', '--myclirc', conf['myclirc']) + if conf.get('login_path'): + add_arg('login_path', '--login-path', conf['login_path']) for arg_name, arg_value in conf.items(): if arg_name.startswith('-'): |