summaryrefslogtreecommitdiffstats
path: root/nixos
diff options
context:
space:
mode:
authorJacek Galowicz <jacek@galowicz.de>2021-08-20 11:25:52 +0200
committerGitHub <noreply@github.com>2021-08-20 11:25:52 +0200
commit85e131e51a5d6813230b1c63fc115a737f15de99 (patch)
treea19129c0088d6069c2dd16cfee309a557d07e2ea /nixos
parentf4ddae2ba55889d5f21bd4601bcf6cc256baf39c (diff)
parentdb614e11d672cf8e3c1268d34e74e0c9981ab5be (diff)
Merge pull request #125992 from blaggacao/nixos-test-ref/04-better-control-test-env-symbols
nixos test ref/04 better control test env symbols
Diffstat (limited to 'nixos')
-rwxr-xr-xnixos/lib/test-driver/test-driver.py57
-rw-r--r--nixos/lib/testing-python.nix4
2 files changed, 45 insertions, 16 deletions
diff --git a/nixos/lib/test-driver/test-driver.py b/nixos/lib/test-driver/test-driver.py
index 0372148cb33c..488789e119d0 100755
--- a/nixos/lib/test-driver/test-driver.py
+++ b/nixos/lib/test-driver/test-driver.py
@@ -89,9 +89,7 @@ CHAR_TO_KEY = {
")": "shift-0x0B",
}
-# Forward references
-log: "Logger"
-machines: "List[Machine]"
+global log, machines, test_script
def eprint(*args: object, **kwargs: Any) -> None:
@@ -103,7 +101,6 @@ def make_command(args: list) -> str:
def create_vlan(vlan_nr: str) -> Tuple[str, str, "subprocess.Popen[bytes]", Any]:
- global log
log.log("starting VDE switch for network {}".format(vlan_nr))
vde_socket = tempfile.mkdtemp(
prefix="nixos-test-vde-", suffix="-vde{}.ctl".format(vlan_nr)
@@ -246,6 +243,9 @@ def _perform_ocr_on_screenshot(
class Machine:
+ def __repr__(self) -> str:
+ return f"<Machine '{self.name}'>"
+
def __init__(self, args: Dict[str, Any]) -> None:
if "name" in args:
self.name = args["name"]
@@ -910,29 +910,25 @@ class Machine:
def create_machine(args: Dict[str, Any]) -> Machine:
- global log
args["log"] = log
return Machine(args)
def start_all() -> None:
- global machines
with log.nested("starting all VMs"):
for machine in machines:
machine.start()
def join_all() -> None:
- global machines
with log.nested("waiting for all VMs to finish"):
for machine in machines:
machine.wait_for_shutdown()
def run_tests(interactive: bool = False) -> None:
- global machines
if interactive:
- ptpython.repl.embed(globals(), locals())
+ ptpython.repl.embed(test_symbols(), {})
else:
test_script()
# TODO: Collect coverage data
@@ -942,12 +938,10 @@ def run_tests(interactive: bool = False) -> None:
def serial_stdout_on() -> None:
- global log
log._print_serial_logs = True
def serial_stdout_off() -> None:
- global log
log._print_serial_logs = False
@@ -989,6 +983,37 @@ def subtest(name: str) -> Iterator[None]:
return False
+def _test_symbols() -> Dict[str, Any]:
+ general_symbols = dict(
+ start_all=start_all,
+ test_script=globals().get("test_script"), # same
+ machines=globals().get("machines"), # without being initialized
+ log=globals().get("log"), # extracting those symbol keys
+ os=os,
+ create_machine=create_machine,
+ subtest=subtest,
+ run_tests=run_tests,
+ join_all=join_all,
+ serial_stdout_off=serial_stdout_off,
+ serial_stdout_on=serial_stdout_on,
+ )
+ return general_symbols
+
+
+def test_symbols() -> Dict[str, Any]:
+
+ general_symbols = _test_symbols()
+
+ machine_symbols = {m.name: machines[idx] for idx, m in enumerate(machines)}
+ print(
+ "additionally exposed symbols:\n "
+ + ", ".join(map(lambda m: m.name, machines))
+ + ",\n "
+ + ", ".join(list(general_symbols.keys()))
+ )
+ return {**general_symbols, **machine_symbols}
+
+
if __name__ == "__main__":
arg_parser = argparse.ArgumentParser(prog="nixos-test-driver")
arg_parser.add_argument(
@@ -1028,12 +1053,9 @@ if __name__ == "__main__":
)
args = arg_parser.parse_args()
- global test_script
testscript = pathlib.Path(args.testscript).read_text()
- def test_script() -> None:
- with log.nested("running the VM test script"):
- exec(testscript, globals())
+ global log, machines, test_script
log = Logger()
@@ -1062,6 +1084,11 @@ if __name__ == "__main__":
process.terminate()
log.close()
+ def test_script() -> None:
+ with log.nested("running the VM test script"):
+ symbols = test_symbols() # call eagerly
+ exec(testscript, symbols, None)
+
interactive = args.interactive or (not bool(testscript))
tic = time.time()
run_tests(interactive)
diff --git a/nixos/lib/testing-python.nix b/nixos/lib/testing-python.nix
index e95ebe16ecac..43b4f9b159b2 100644
--- a/nixos/lib/testing-python.nix
+++ b/nixos/lib/testing-python.nix
@@ -42,7 +42,9 @@ rec {
python <<EOF
from pydoc import importfile
with open('driver-symbols', 'w') as fp:
- fp.write(','.join(dir(importfile('${testDriverScript}'))))
+ t = importfile('${testDriverScript}')
+ test_symbols = t._test_symbols()
+ fp.write(','.join(test_symbols.keys()))
EOF
'';