summaryrefslogtreecommitdiffstats
path: root/nixos/lib/test-driver
diff options
context:
space:
mode:
authorDavid Arnold <dar@xoe.solutions>2021-06-06 12:00:12 -0500
committerDavid Arnold <dgx.arnold@gmail.com>2021-08-05 19:07:11 -0500
commit926fb9396881202e727e5ec1fbf609b64455b388 (patch)
treeed0325450bfa2982cbc008edc6b5b70a78458985 /nixos/lib/test-driver
parent077b2825cd3328f83dbc1774ce62751dae4cb719 (diff)
nixos/tests/test-driver: normalise test driver entrypoint(s)
Previously the driver was configured exclusively through convoluted environment variables. Now the driver's defaults are configured through env variables. Some additional concerns are in the github comments of this PR.
Diffstat (limited to 'nixos/lib/test-driver')
-rwxr-xr-x[-rw-r--r--]nixos/lib/test-driver/test-driver.py101
1 files changed, 73 insertions, 28 deletions
diff --git a/nixos/lib/test-driver/test-driver.py b/nixos/lib/test-driver/test-driver.py
index 2a3e4d94b948..1720e553d733 100644..100755
--- a/nixos/lib/test-driver/test-driver.py
+++ b/nixos/lib/test-driver/test-driver.py
@@ -24,7 +24,6 @@ import sys
import telnetlib
import tempfile
import time
-import traceback
import unicodedata
CHAR_TO_KEY = {
@@ -930,29 +929,16 @@ def join_all() -> None:
machine.wait_for_shutdown()
-def test_script() -> None:
- exec(os.environ["testScript"])
-
-
-def run_tests() -> None:
+def run_tests(interactive: bool = False) -> None:
global machines
- tests = os.environ.get("tests", None)
- if tests is not None:
- with log.nested("running the VM test script"):
- try:
- exec(tests, globals())
- except Exception as e:
- eprint("error: ")
- traceback.print_exc()
- sys.exit(1)
+ if interactive:
+ ptpython.repl.embed(globals(), locals())
else:
- ptpython.repl.embed(locals(), globals())
-
- # TODO: Collect coverage data
-
- for machine in machines:
- if machine.is_up():
- machine.execute("sync")
+ test_script()
+ # TODO: Collect coverage data
+ for machine in machines:
+ if machine.is_up():
+ machine.execute("sync")
def serial_stdout_on() -> None:
@@ -965,6 +951,31 @@ def serial_stdout_off() -> None:
log._print_serial_logs = False
+class EnvDefault(argparse.Action):
+ """An argpars Action that takes values from the specified
+ environment variable as the flags default value.
+ """
+
+ def __init__(self, envvar, required=False, default=None, nargs=None, **kwargs): # type: ignore
+ if not default and envvar:
+ if envvar in os.environ:
+ if nargs is not None and (nargs.isdigit() or nargs in ["*", "+"]):
+ default = os.environ[envvar].split()
+ else:
+ default = os.environ[envvar]
+ kwargs["help"] = (
+ kwargs["help"] + f" (default from environment: {default})"
+ )
+ if required and default:
+ required = False
+ super(EnvDefault, self).__init__(
+ default=default, required=required, nargs=nargs, **kwargs
+ )
+
+ def __call__(self, parser, namespace, values, option_string=None): # type: ignore
+ setattr(namespace, self.dest, values)
+
+
@contextmanager
def subtest(name: str) -> Iterator[None]:
with log.nested(name):
@@ -986,18 +997,52 @@ if __name__ == "__main__":
help="re-use a VM state coming from a previous run",
action="store_true",
)
- (cli_args, vm_scripts) = arg_parser.parse_known_args()
+ arg_parser.add_argument(
+ "-I",
+ "--interactive",
+ help="drop into a python repl and run the tests interactively",
+ action="store_true",
+ )
+ arg_parser.add_argument(
+ "--start-scripts",
+ metavar="START-SCRIPT",
+ action=EnvDefault,
+ envvar="startScripts",
+ nargs="*",
+ help="start scripts for participating virtual machines",
+ )
+ arg_parser.add_argument(
+ "--vlans",
+ metavar="VLAN",
+ action=EnvDefault,
+ envvar="vlans",
+ nargs="*",
+ help="vlans to span by the driver",
+ )
+ arg_parser.add_argument(
+ "testscript",
+ action=EnvDefault,
+ envvar="testScript",
+ help="the test script to run",
+ type=pathlib.Path,
+ )
+
+ args = arg_parser.parse_args()
+ global test_script
+
+ def test_script() -> None:
+ with log.nested("running the VM test script"):
+ exec(pathlib.Path(args.testscript).read_text(), globals())
log = Logger()
- vlan_nrs = list(dict.fromkeys(os.environ.get("VLANS", "").split()))
- vde_sockets = [create_vlan(v) for v in vlan_nrs]
+ vde_sockets = [create_vlan(v) for v in args.vlans]
for nr, vde_socket, _, _ in vde_sockets:
os.environ["QEMU_VDE_SOCKET_{}".format(nr)] = vde_socket
machines = [
- create_machine({"startCommand": s, "keepVmState": cli_args.keep_vm_state})
- for s in vm_scripts
+ create_machine({"startCommand": s, "keepVmState": args.keep_vm_state})
+ for s in args.start_scripts
]
machine_eval = [
"{0} = machines[{1}]".format(m.name, idx) for idx, m in enumerate(machines)
@@ -1017,6 +1062,6 @@ if __name__ == "__main__":
log.close()
tic = time.time()
- run_tests()
+ run_tests(args.interactive)
toc = time.time()
print("test script finished in {:.2f}s".format(toc - tic))