summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDavid Arnold <dar@xoe.solutions>2021-06-12 17:47:25 -0500
committerDavid Arnold <david.arnold@iohk.io>2021-10-05 14:38:48 -0500
commitb0fc9da879812e47c1ed3438fb0fd51db00a3494 (patch)
treec238d3e8ce9c6ad17c47e8414001a29e137d8e52
parent3069ba0dd1dec75c5dc4f6a1ee238a4fab9828cd (diff)
nixos/test/test-driver: Class-ify the test driver
This commit encapsulates the involved domain into classes and defines explicit and typed arguments where untyped dicts where used. It preserves backwards compatibility through legacy wrappers.
-rwxr-xr-xnixos/lib/test-driver/test-driver.py804
-rw-r--r--nixos/lib/testing-python.nix11
-rw-r--r--nixos/modules/installer/tools/nixos-build-vms/build-vms.nix19
3 files changed, 527 insertions, 307 deletions
diff --git a/nixos/lib/test-driver/test-driver.py b/nixos/lib/test-driver/test-driver.py
index f8502188bde8..fdc440a896a0 100755
--- a/nixos/lib/test-driver/test-driver.py
+++ b/nixos/lib/test-driver/test-driver.py
@@ -21,7 +21,6 @@ import shutil
import socket
import subprocess
import sys
-import telnetlib
import tempfile
import time
import unicodedata
@@ -89,55 +88,6 @@ CHAR_TO_KEY = {
")": "shift-0x0B",
}
-global log, machines, test_script
-
-
-def eprint(*args: object, **kwargs: Any) -> None:
- print(*args, file=sys.stderr, **kwargs)
-
-
-def make_command(args: list) -> str:
- return " ".join(map(shlex.quote, (map(str, args))))
-
-
-def create_vlan(vlan_nr: str) -> Tuple[str, str, "subprocess.Popen[bytes]", Any]:
- log.log("starting VDE switch for network {}".format(vlan_nr))
- vde_socket = tempfile.mkdtemp(
- prefix="nixos-test-vde-", suffix="-vde{}.ctl".format(vlan_nr)
- )
- pty_master, pty_slave = pty.openpty()
- vde_process = subprocess.Popen(
- ["vde_switch", "-s", vde_socket, "--dirmode", "0700"],
- stdin=pty_slave,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- shell=False,
- )
- fd = os.fdopen(pty_master, "w")
- fd.write("version\n")
- # TODO: perl version checks if this can be read from
- # an if not, dies. we could hang here forever. Fix it.
- assert vde_process.stdout is not None
- vde_process.stdout.readline()
- if not os.path.exists(os.path.join(vde_socket, "ctl")):
- raise Exception("cannot start vde_switch")
-
- return (vlan_nr, vde_socket, vde_process, fd)
-
-
-def retry(fn: Callable, timeout: int = 900) -> None:
- """Call the given function repeatedly, with 1 second intervals,
- until it returns True or a timeout is reached.
- """
-
- for _ in range(timeout):
- if fn(False):
- return
- time.sleep(1)
-
- if not fn(True):
- raise Exception(f"action timed out after {timeout} seconds")
-
class Logger:
def __init__(self) -> None:
@@ -151,6 +101,10 @@ class Logger:
self._print_serial_logs = True
+ @staticmethod
+ def _eprint(*args: object, **kwargs: Any) -> None:
+ print(*args, file=sys.stderr, **kwargs)
+
def close(self) -> None:
self.xml.endElement("logfile")
self.xml.endDocument()
@@ -169,15 +123,27 @@ class Logger:
self.xml.characters(message)
self.xml.endElement("line")
+ def info(self, *args, **kwargs) -> None: # type: ignore
+ self.log(*args, **kwargs)
+
+ def warning(self, *args, **kwargs) -> None: # type: ignore
+ self.log(*args, **kwargs)
+
+ def error(self, *args, **kwargs) -> None: # type: ignore
+ self.log(*args, **kwargs)
+ sys.exit(1)
+
def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
- eprint(self.maybe_prefix(message, attributes))
+ self._eprint(self.maybe_prefix(message, attributes))
self.drain_log_queue()
self.log_line(message, attributes)
def log_serial(self, message: str, machine: str) -> None:
self.enqueue({"msg": message, "machine": machine, "type": "serial"})
if self._print_serial_logs:
- eprint(Style.DIM + "{} # {}".format(machine, message) + Style.RESET_ALL)
+ self._eprint(
+ Style.DIM + "{} # {}".format(machine, message) + Style.RESET_ALL
+ )
def enqueue(self, item: Dict[str, str]) -> None:
self.queue.put(item)
@@ -194,7 +160,7 @@ class Logger:
@contextmanager
def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
- eprint(self.maybe_prefix(message, attributes))
+ self._eprint(self.maybe_prefix(message, attributes))
self.xml.startElement("nest", attrs={})
self.xml.startElement("head", attributes)
@@ -211,6 +177,27 @@ class Logger:
self.xml.endElement("nest")
+rootlog = Logger()
+
+
+def make_command(args: list) -> str:
+ return " ".join(map(shlex.quote, (map(str, args))))
+
+
+def retry(fn: Callable, timeout: int = 900) -> None:
+ """Call the given function repeatedly, with 1 second intervals,
+ until it returns True or a timeout is reached.
+ """
+
+ for _ in range(timeout):
+ if fn(False):
+ return
+ time.sleep(1)
+
+ if not fn(True):
+ raise Exception(f"action timed out after {timeout} seconds")
+
+
def _perform_ocr_on_screenshot(
screenshot_path: str, model_ids: Iterable[int]
) -> List[str]:
@@ -242,113 +229,256 @@ def _perform_ocr_on_screenshot(
return model_results
-class Machine:
- def __repr__(self) -> str:
- return f"<Machine '{self.name}'>"
-
- def __init__(self, args: Dict[str, Any]) -> None:
- if "name" in args:
- self.name = args["name"]
- else:
- self.name = "machine"
- cmd = args.get("startCommand", None)
- if cmd:
- match = re.search("run-(.+)-vm$", cmd)
- if match:
- self.name = match.group(1)
- self.logger = args["log"]
- self.script = args.get("startCommand", self.create_startcommand(args))
-
- tmp_dir = os.environ.get("TMPDIR", tempfile.gettempdir())
-
- def create_dir(name: str) -> str:
- path = os.path.join(tmp_dir, name)
- os.makedirs(path, mode=0o700, exist_ok=True)
- return path
+class StartCommand:
+ """The Base Start Command knows how to append the necesary
+ runtime qemu options as determined by a particular test driver
+ run. Any such start command is expected to happily receive and
+ append additional qemu args.
+ """
- self.state_dir = os.path.join(tmp_dir, f"vm-state-{self.name}")
- if not args.get("keepVmState", False):
- self.cleanup_statedir()
- os.makedirs(self.state_dir, mode=0o700, exist_ok=True)
- self.shared_dir = create_dir("shared-xchg")
+ _cmd: str
- self.booted = False
- self.connected = False
- self.pid: Optional[int] = None
- self.socket = None
- self.monitor: Optional[socket.socket] = None
- self.allow_reboot = args.get("allowReboot", False)
+ def cmd(
+ self,
+ monitor_socket_path: pathlib.Path,
+ shell_socket_path: pathlib.Path,
+ allow_reboot: bool = False, # TODO: unused, legacy?
+ ) -> str:
+ display_opts = ""
+ display_available = any(x in os.environ for x in ["DISPLAY", "WAYLAND_DISPLAY"])
+ if display_available:
+ display_opts += " -nographic"
+
+ # qemu options
+ qemu_opts = ""
+ qemu_opts += (
+ ""
+ if allow_reboot
+ else " -no-reboot"
+ " -device virtio-serial"
+ " -device virtconsole,chardev=shell"
+ " -device virtio-rng-pci"
+ " -serial stdio"
+ )
+ # TODO: qemu script already catpures this env variable, legacy?
+ qemu_opts += " " + os.environ.get("QEMU_OPTS", "")
+
+ return (
+ f"{self._cmd}"
+ f" -monitor unix:{monitor_socket_path}"
+ f" -chardev socket,id=shell,path={shell_socket_path}"
+ f"{qemu_opts}"
+ f"{display_opts}"
+ )
@staticmethod
- def create_startcommand(args: Dict[str, str]) -> str:
- net_backend = "-netdev user,id=net0"
- net_frontend = "-device virtio-net-pci,netdev=net0"
+ def build_environment(
+ state_dir: pathlib.Path,
+ shared_dir: pathlib.Path,
+ ) -> dict:
+ # We make a copy to not update the current environment
+ env = dict(os.environ)
+ env.update(
+ {
+ "TMPDIR": str(state_dir),
+ "SHARED_DIR": str(shared_dir),
+ "USE_TMPDIR": "1",
+ }
+ )
+ return env
+
+ def run(
+ self,
+ state_dir: pathlib.Path,
+ shared_dir: pathlib.Path,
+ monitor_socket_path: pathlib.Path,
+ shell_socket_path: pathlib.Path,
+ ) -> subprocess.Popen:
+ return subprocess.Popen(
+ self.cmd(monitor_socket_path, shell_socket_path),
+ stdin=subprocess.DEVNULL,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ shell=True,
+ cwd=state_dir,
+ env=self.build_environment(state_dir, shared_dir),
+ )
+
- if "netBackendArgs" in args:
- net_backend += "," + args["netBackendArgs"]
+class NixStartScript(StartCommand):
+ """A start script from nixos/modules/virtualiation/qemu-vm.nix
+ that also satisfies the requirement of the BaseStartCommand.
+ These Nix commands have the particular charactersitic that the
+ machine name can be extracted out of them via a regex match.
+ (Admittedly a _very_ implicit contract, evtl. TODO fix)
+ """
- if "netFrontendArgs" in args:
- net_frontend += "," + args["netFrontendArgs"]
+ def __init__(self, script: str):
+ self._cmd = script
- start_command = (
- args.get("qemuBinary", "qemu-kvm")
- + " -m 384 "
- + net_backend
- + " "
- + net_frontend
- + " $QEMU_OPTS "
- )
+ @property
+ def machine_name(self) -> str:
+ match = re.search("run-(.+)-vm$", self._cmd)
+ name = "machine"
+ if match:
+ name = match.group(1)
+ return name
- if "hda" in args:
- hda_path = os.path.abspath(args["hda"])
- if args.get("hdaInterface", "") == "scsi":
- start_command += (
- "-drive id=hda,file="
- + hda_path
- + ",werror=report,if=none "
- + "-device scsi-hd,drive=hda "
+
+class LegacyStartCommand(StartCommand):
+ """Used in some places to create an ad-hoc machine instead of
+ using nix test instrumentation + module system for that purpose.
+ Legacy.
+ """
+
+ def __init__(
+ self,
+ netBackendArgs: Optional[str] = None,
+ netFrontendArgs: Optional[str] = None,
+ hda: Optional[Tuple[pathlib.Path, str]] = None,
+ cdrom: Optional[str] = None,
+ usb: Optional[str] = None,
+ bios: Optional[str] = None,
+ qemuFlags: Optional[str] = None,
+ ):
+ self._cmd = "qemu-kvm -m 384"
+
+ # networking
+ net_backend = "-netdev user,id=net0"
+ net_frontend = "-device virtio-net-pci,netdev=net0"
+ if netBackendArgs is not None:
+ net_backend += "," + netBackendArgs
+ if netFrontendArgs is not None:
+ net_frontend += "," + netFrontendArgs
+ self._cmd += f" {net_backend} {net_frontend}"
+
+ # hda
+ hda_cmd = ""
+ if hda is not None:
+ hda_path = hda[0].resolve()
+ hda_interface = hda[1]
+ if hda_interface == "scsi":
+ hda_cmd += (
+ f" -drive id=hda,file={hda_path},werror=report,if=none"
+ " -device scsi-hd,drive=hda"
)
else:
- start_command += (
- "-drive file="
- + hda_path
- + ",if="
- + args["hdaInterface"]
- + ",werror=report "
- )
+ hda_cmd += f" -drive file={hda_path},if={hda_interface},werror=report"
+ self._cmd += hda_cmd
- if "cdrom" in args:
- start_command += "-cdrom " + args["cdrom"] + " "
+ # cdrom
+ if cdrom is not None:
+ self._cmd += f" -cdrom {cdrom}"
- if "usb" in args:
+ # usb
+ usb_cmd = ""
+ if usb is not None:
# https://github.com/qemu/qemu/blob/master/docs/usb2.txt
- start_command += (
- "-device usb-ehci -drive "
- + "id=usbdisk,file="
- + args["usb"]
- + ",if=none,readonly "
- + "-device usb-storage,drive=usbdisk "
+ usb_cmd += (
+ " -device usb-ehci"
+ f" -drive id=usbdisk,file={usb},if=none,readonly"
+ " -device usb-storage,drive=usbdisk "
)
- if "bios" in args:
- start_command += "-bios " + args["bios"] + " "
+ self._cmd += usb_cmd
+
+ # bios
+ if bios is not None:
+ self._cmd += f" -bios {bios}"
+
+ # qemu flags
+ if qemuFlags is not None:
+ self._cmd += f" {qemuFlags}"
+
+
+class Machine:
+ """A handle to the machine with this name, that also knows how to manage
+ the machine lifecycle with the help of a start script / command."""
+
+ name: str
+ tmp_dir: pathlib.Path
+ shared_dir: pathlib.Path
+ state_dir: pathlib.Path
+ monitor_path: pathlib.Path
+ shell_path: pathlib.Path
+
+ start_command: StartCommand
+ keep_vm_state: bool
+ allow_reboot: bool
+
+ process: Optional[subprocess.Popen] = None
+ pid: Optional[int] = None
+ monitor: Optional[socket.socket] = None
+ shell: Optional[socket.socket] = None
+
+ booted: bool = False
+ connected: bool = False
+ # Store last serial console lines for use
+ # of wait_for_console_text
+ last_lines: Queue = Queue()
- start_command += args.get("qemuFlags", "")
+ def __repr__(self) -> str:
+ return f"<Machine '{self.name}'>"
+
+ def __init__(
+ self,
+ tmp_dir: pathlib.Path,
+ start_command: StartCommand,
+ name: str = "machine",
+ keep_vm_state: bool = False,
+ allow_reboot: bool = False,
+ ) -> None:
+ self.tmp_dir = tmp_dir
+ self.keep_vm_state = keep_vm_state
+ self.allow_reboot = allow_reboot
+ self.name = name
+ self.start_command = start_command
+
+ # set up directories
+ self.shared_dir = self.tmp_dir / "shared-xchg"
+ self.shared_dir.mkdir(mode=0o700, exist_ok=True)
+
+ self.state_dir = self.tmp_dir / f"vm-state-{self.name}"
+ self.monitor_path = self.state_dir / "monitor"
+ self.shell_path = self.state_dir / "shell"
+ if (not self.keep_vm_state) and self.state_dir.exists():
+ self.cleanup_statedir()
+ self.state_dir.mkdir(mode=0o700, exist_ok=True)
- return start_command
+ @staticmethod
+ def create_startcommand(args: Dict[str, str]) -> StartCommand:
+ rootlog.warning(
+ "Using legacy create_startcommand(),"
+ "please use proper nix test vm instrumentation, instead"
+ "to generate the appropriate nixos test vm qemu startup script"
+ )
+ hda = None
+ if args.get("hda"):
+ hda_arg: str = args.get("hda", "")
+ hda_arg_path: pathlib.Path = pathlib.Path(hda_arg)
+ hda = (hda_arg_path, args.get("hdaInterface", ""))
+ return LegacyStartCommand(
+ netBackendArgs=args.get("netBackendArgs"),
+ netFrontendArgs=args.get("netFrontendArgs"),
+ hda=hda,
+ cdrom=args.get("cdrom"),
+ usb=args.get("usb"),
+ bios=args.get("bios"),
+ qemuFlags=args.get("qemuFlags"),
+ )
def is_up(self) -> bool:
return self.booted and self.connected
def log(self, msg: str) -> None:
- self.logger.log(msg, {"machine": self.name})
+ rootlog.log(msg, {"machine": self.name})
def log_serial(self, msg: str) -> None:
- self.logger.log_serial(msg, self.name)
+ rootlog.log_serial(msg, self.name)
def nested(self, msg: str, attrs: Dict[str, str] = {}) -> _GeneratorContextManager:
my_attrs = {"machine": self.name}
my_attrs.update(attrs)
- return self.logger.nested(msg, my_attrs)
+ return rootlog.nested(msg, my_attrs)
def wait_for_monitor_prompt(self) -> str:
assert self.monitor is not None
@@ -446,6 +576,7 @@ class Machine:
self.connect()
out_command = "( set -euo pipefail; {} ); echo '|!=EOF' $?\n".format(command)
+ assert self.shell
self.shell.send(out_command.encode())
output = ""
@@ -466,6 +597,8 @@ class Machine:
Should only be used during test development, not in the production test."""
self.connect()
self.log("Terminal is ready (there is no prompt):")
+
+ assert self.shell
subprocess.run(
["socat", "READLINE", f"FD:{self.shell.fileno()}"],
pass_fds=[self.shell.fileno()],
@@ -534,6 +667,7 @@ class Machine:
with self.nested("waiting for the VM to power off"):
sys.stdout.flush()
+ assert self.process
self.process.wait()
self.pid = None
@@ -611,6 +745,8 @@ class Machine:
with self.nested("waiting for the VM to finish booting"):
self.start()
+ assert self.shell
+
tic = time.time()
self.shell.recv(1024)
# TODO: Timeout
@@ -750,65 +886,35 @@ class Machine:
self.log("starting vm")
- def create_socket(path: str) -> socket.socket:
- if os.path.exists(path):
- os.unlink(path)
+ def clear(path: pathlib.Path) -> pathlib.Path:
+ if path.exists():
+ path.unlink()
+ return path
+
+ def create_socket(path: pathlib.Path) -> socket.socket:
s = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
- s.bind(path)
+ s.bind(str(path))
s.listen(1)
return s
- monitor_path = os.path.join(self.state_dir, "monitor")
- self.monitor_socket = create_socket(monitor_path)
-
- shell_path = os.path.join(self.state_dir, "shell")
- self.shell_socket = create_socket(shell_path)
-
- display_available = any(x in os.environ for x in ["DISPLAY", "WAYLAND_DISPLAY"])
- qemu_options = (
- " ".join(
- [
- "" if self.allow_reboot else "-no-reboot",
- "-monitor unix:{}".format(monitor_path),
- "-chardev socket,id=shell,path={}".format(shell_path),
- "-device virtio-serial",
- "-device virtconsole,chardev=shell",
- "-device virtio-rng-pci",
- "-serial stdio" if display_available else "-nographic",
- ]
- )
- + " "
- + os.environ.get("QEMU_OPTS", "")
+ monitor_socket = create_socket(clear(self.monitor_path))
+ shell_socket = create_socket(clear(self.shell_path))
+ self.process = self.start_command.run(
+ self.state_dir,
+ self.shared_dir,
+ self.monitor_path,
+ self.shell_path,
)
-
- environment = dict(os.environ)
- environment.update(
- {
- "TMPDIR": self.state_dir,
- "SHARED_DIR": self.shared_dir,
- "USE_TMPDIR": "1",
- "QEMU_OPTS": qemu_options,
- }
- )
-
- self.process = subprocess.Popen(
- self.script,
- stdin=subprocess.DEVNULL,
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT,
- shell=True,
- cwd=self.state_dir,
- env=environment,
- )
- self.monitor, _ = self.monitor_socket.accept()
- self.shell, _ = self.shell_socket.accept()
+ self.monitor, _ = monitor_socket.accept()
+ self.shell, _ = shell_socket.accept()
# Store last serial console lines for use
# of wait_for_console_text
self.last_lines: Queue = Queue()
def process_serial_output() -> None:
- assert self.process.stdout is not None
+ assert self.process
+ assert self.process.stdout
for _line in self.process.stdout:
# Ignore undecodable bytes that may occur in boot menus
line = _line.decode(errors="ignore").replace("\r", "").rstrip()
@@ -825,15 +931,15 @@ class Machine:
self.log("QEMU running (pid {})".format(self.pid))
def cleanup_statedir(self) -> None:
- if os.path.isdir(self.state_dir):
- shutil.rmtree(self.state_dir)
- self.logger.log(f"deleting VM state directory {self.state_dir}")
- self.logger.log("if you want to keep the VM state, pass --keep-vm-state")
+ shutil.rmtree(self.state_dir)
+ rootlog.log(f"deleting VM state directory {self.state_dir}")
+ rootlog.log("if you want to keep the VM state, pass --keep-vm-state")
def shutdown(self) -> None:
if not self.booted:
return
+ assert self.shell
self.shell.send("poweroff\n".encode())
self.wait_for_shutdown()
@@ -908,41 +1014,225 @@ class Machine:
"""Make the machine reachable."""
self.send_monitor_command("set_link virtio-net-pci.1 on")
+ def release(self) -> None:
+ if self.pid is None:
+ return
+ rootlog.info(f"kill machine (pid {self.pid})")
+ assert self.process
+ assert self.shell
+ assert self.monitor
+ self.process.terminate()
+ self.shell.close()
+ self.monitor.close()
+
+
+class VLan:
+ """A handle to the vlan with this number, that also knows how to manage
+ it's lifecycle.
+ """
-def create_machine(args: Dict[str, Any]) -> Machine:
- args["log"] = log
- return Machine(args)
+ nr: int
+ socket_dir: pathlib.Path
+ process: Optional[subprocess.Popen]
+ pid: Optional[int]
+ fd: Optional[io.TextIOBase]
-def start_all() -> None:
- with log.nested("starting all VMs"):
- for machine in machines:
- machine.start()
+ def __repr__(self) -> str:
+ return f"<Vlan Nr. {self.nr}>"
+ def __init__(self, nr: int, tmp_dir: pathlib.Path):
+ self.nr = nr
+ self.socket_dir = tmp_dir / f"vde{self.nr}.ctl"
-def join_all() -> None:
- with log.nested("waiting for all VMs to finish"):
- for machine in machines:
- machine.wait_for_shutdown()
+ # TODO: don't side-effect environment here
+ os.environ[f"QEMU_VDE_SOCKET_{self.nr}"] = str(self.socket_dir)
+ def start(self) -> None:
-def run_tests(interactive: bool = False) -> None:
- if interactive:
- ptpython.repl.embed(test_symbols(), {})
- else:
- test_script()
+ rootlog.info("start vlan")
+ pty_master, pty_slave = pty.openpty()
+
+ self.process = subprocess.Popen(
+ ["vde_switch", "-s", self.socket_dir, "--dirmode", "0700"],
+ stdin=pty_slave,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ shell=False,
+ )
+ self.pid = self.process.pid
+ self.fd = os.fdopen(pty_master, "w")
+ self.fd.write("version\n")
+
+ # TODO: perl version checks if this can be read from
+ # an if not, dies. we could hang here forever. Fix it.
+ assert self.process.stdout is not None
+ self.process.stdout.readline()
+ if not (self.socket_dir / "ctl").exists():
+ rootlog.error("cannot start vde_switch")
+
+ rootlog.info(f"running vlan (pid {self.pid})")
+
+ def release(self) -> None:
+ if self.pid is None:
+ return
+ rootlog.info(f"kill vlan (pid {self.pid})")
+ assert self.fd
+ assert self.process
+ self.fd.close()
+ self.process.terminate()
+
+
+class Driver:
+ """A handle to the driver that sets up the environment
+ and runs the tests"""
+
+ tests: str
+ vlans: List[VLan]
+ machines: List[Machine]
+
+ def __init__(
+ self,
+ start_scripts: List[str],
+ vlans: List[int],
+ tests: str,
+ keep_vm_state: bool = False,
+ ):
+ self.tests = tests
+
+ tmp_dir = pathlib.Path(os.environ.get("TMPDIR", tempfile.gettempdir()))
+ tmp_dir.mkdir(mode=0o700, exist_ok=True)
+
+ self.vlans = [VLan(nr, tmp_dir) for nr in vlans]
+ with rootlog.nested("start all VLans"):
+ for vlan in self.vlans:
+ vlan.start()
+
+ def cmd(scripts: List[str]) -> Iterator[NixStartScript]:
+ for s in scripts:
+ yield NixStartScript(s)
+
+ self.machines = [
+ Machine(
+ start_command=cmd,
+ keep_vm_state=keep_vm_state,
+ name=cmd.machine_name,
+ tmp_dir=tmp_dir,
+ )
+ for cmd in cmd(start_scripts)
+ ]
+
+ @atexit.register
+ def clean_up() -> None:
+ with rootlog.nested("clean up"):
+ for machine in self.machines:
+ machine.release()
+ for vlan in self.vlans:
+ vlan.release()
+
+ def subtest(self, name: str) -> Iterator[None]:
+ """Group logs under a given test name"""
+ with rootlog.nested(name):
+ try:
+ yield
+ return True
+ except:
+ rootlog.error(f'Test "{name}" failed with error:')
+ raise
+
+ def test_symbols(self) -> Dict[str, Any]:
+ @contextmanager
+ def subtest(name: str) -> Iterator[None]:
+ return self.subtest(name)
+
+ general_symbols = dict(
+ start_all=self.start_all,
+ test_script=self.test_script,
+ machines=self.machines,
+ vlans=self.vlans,
+ driver=self,
+ log=rootlog,
+ os=os,
+ create_machine=self.create_machine,
+ subtest=subtest,
+ run_tests=self.run_tests,
+ join_all=self.join_all,
+ retry=retry,
+ serial_stdout_off=self.serial_stdout_off,
+ serial_stdout_on=self.serial_stdout_on,
+ Machine=Machine, # for typing
+ )
+ machine_symbols = {
+ m.name: self.machines[idx] for idx, m in enumerate(self.machines)
+ }
+ vlan_symbols = {
+ f"vlan{v.nr}": self.vlans[idx] for idx, v in enumerate(self.vlans)
+ }
+ print(
+ "additionally exposed symbols:\n "
+ + ", ".join(map(lambda m: m.name, self.machines))
+ + ",\n "
+ + ", ".join(map(lambda v: f"vlan{v.nr}", self.vlans))
+ + ",\n "
+ + ", ".join(list(general_symbols.keys()))
+ )
+ return {**general_symbols, **machine_symbols, **vlan_symbols}
+
+ def test_script(self) -> None:
+ """Run the test script"""
+ with rootlog.nested("run the VM test script"):
+ symbols = self.test_symbols() # call eagerly
+ exec(self.tests, symbols, None)
+
+ def run_tests(self) -> None:
+ """Run the test script (for non-interactive test runs)"""
+ self.test_script()
# TODO: Collect coverage data
- for machine in machines:
+ for machine in self.machines:
if machine.is_up():
machine.execute("sync")
+ def start_all(self) -> None:
+ """Start all machines"""
+ with rootlog.nested("start all VMs"):
+ for machine in self.machines:
+ machine.start()
+
+ def join_all(self) -> None:
+ """Wait for all machines to shut down"""
+ with rootlog.nested("wait for all VMs to finish"):
+ for machine in self.machines:
+ machine.wait_for_shutdown()
+
+ def create_machine(self, args: Dict[str, Any]) -> Machine:
+ rootlog.warning(
+ "Using legacy create_machine(), please instantiate the"
+ "Machine class directly, instead"
+ )
+ tmp_dir = pathlib.Path(os.environ.get("TMPDIR", tempfile.gettempdir()))
+ tmp_dir.mkdir(mode=0o700, exist_ok=True)
-def serial_stdout_on() -> None:
- log._print_serial_logs = True
+ if args.get("startCommand"):
+ start_command: str = args.get("startCommand", "")
+ cmd = NixStartScript(start_command)
+ name = args.get("name", cmd.machine_name)
+ else:
+ cmd = Machine.create_startcommand(args) # type: ignore
+ name = args.get("name", "machine")
+
+ return Machine(
+ tmp_dir=tmp_dir,
+ start_command=cmd,
+ name=name,
+ keep_vm_state=args.get("keep_vm_state", False),
+ allow_reboot=args.get("allow_reboot", False),
+ )
+ def serial_stdout_on(self) -> None:
+ rootlog._print_serial_logs = True
-def serial_stdout_off() -> None:
- log._print_serial_logs = False
+ def serial_stdout_off(self) -> None:
+ rootlog._print_serial_logs = False
class EnvDefault(argparse.Action):
@@ -970,52 +1260,6 @@ class EnvDefault(argparse.Action):
setattr(namespace, self.dest, values)
-@contextmanager
-def subtest(name: str) -> Iterator[None]:
- with log.nested(name):
- try:
- yield
- return True
- except Exception as e:
- log.log(f'Test "{name}" failed with error: "{e}"')
- raise e
-
- return False
-
-
-def _test_symbols() -> Dict[str, Any]:
- general_symbols = dict(
- start_all=start_all,
- test_script=globals().get("test_script"), # same
- machines=globals().get("machines"), # without being initialized
- log=globals().get("log"), # extracting those symbol keys
- os=os,
- create_machine=create_machine,
- subtest=subtest,
- run_tests=run_tests,
- join_all=join_all,
- retry=retry,
- serial_stdout_off=serial_stdout_off,
- serial_stdout_on=serial_stdout_on,
- Machine=Machine, # for typing
- )
- return general_symbols
-
-
-def test_symbols() -> Dict[str, Any]:
-
- general_symbols = _test_symbols()
-
- machine_symbols = {m.name: machines[idx] for idx, m in enumerate(machines)}
- print(
- "additionally exposed symbols:\n "
- + ", ".join(map(lambda m: m.name, machines))
- + ",\n "
- + ", ".join(list(general_symbols.keys()))
- )
- return {**general_symbols, **machine_symbols}
-
-
if __name__ == "__main__":
arg_parser = argparse.ArgumentParser(prog="nixos-test-driver")
arg_parser.add_argument(
@@ -1055,44 +1299,18 @@ if __name__ == "__main__":
)
args = arg_parser.parse_args()
- testscript = pathlib.Path(args.testscript).read_text()
-
- global log, machines, test_script
-
- log = Logger()
-
- vde_sockets = [create_vlan(v) for v in args.vlans]
- for nr, vde_socket, _, _ in vde_sockets:
- os.environ["QEMU_VDE_SOCKET_{}".format(nr)] = vde_socket
-
- machines = [
- create_machine({"startCommand": s, "keepVmState": args.keep_vm_state})
- for s in args.start_scripts
- ]
- machine_eval = [
- "{0} = machines[{1}]".format(m.name, idx) for idx, m in enumerate(machines)
- ]
- exec("\n".join(machine_eval))
-
- @atexit.register
- def clean_up() -> None:
- with log.nested("cleaning up"):
- for machine in machines:
- if machine.pid is None:
- continue
- log.log("killing {} (pid {})".format(machine.name, machine.pid))
- machine.process.kill()
- for _, _, process, _ in vde_sockets:
- process.terminate()
- log.close()
-
- def test_script() -> None:
- with log.nested("running the VM test script"):
- symbols = test_symbols() # call eagerly
- exec(testscript, symbols, None)
-
- interactive = args.interactive or (not bool(testscript))
- tic = time.time()
- run_tests(interactive)
- toc = time.time()
- print("test script finished in {:.2f}s".format(toc - tic))
+
+ if not args.keep_vm_state:
+ rootlog.info("Machine state will be reset. To keep it, pass --keep-vm-state")
+
+ driver = Driver(
+ args.start_scripts, args.vlans, args.testscript.read_text(), args.keep_vm_state
+ )
+
+ if args.interactive:
+ ptpython.repl.embed(driver.test_symbols(), {})
+ else:
+ tic = time.time()
+ driver.run_tests()
+ toc = time.time()
+ rootlog.info(f"test script finished in {(toc-tic):.2f}s")
diff --git a/nixos/lib/testing-python.nix b/nixos/lib/testing-python.nix
index 43b4f9b159b2..1969f40edb6b 100644
--- a/nixos/lib/testing-python.nix
+++ b/nixos/lib/testing-python.nix
@@ -43,7 +43,8 @@ rec {
from pydoc import importfile
with open('driver-symbols', 'w') as fp:
t = importfile('${testDriverScript}')
- test_symbols = t._test_symbols()
+ d = t.Driver([],[],"")
+ test_symbols = d.test_symbols()
fp.write(','.join(test_symbols.keys()))
EOF
'';
@@ -188,14 +189,6 @@ rec {
--set startScripts "''${vmStartScripts[*]}" \
--set testScript "$out/test-script" \
--set vlans '${toString vlans}'
-
- ${lib.opti