diff --git a/test/vmtests/README.md b/test/vmtests/README.md index 8bfc45b15..b82b54ea0 100644 --- a/test/vmtests/README.md +++ b/test/vmtests/README.md @@ -1,13 +1,15 @@ # VM Tests -A test suite that runs the containerized version of the image customizer tool. +A test suite that runs the containerized version of the image customizer tool and then +boots the customized images. ## How to run Requirements: - Python3 -- Docker +- QEMU/KVM +- libvirt Steps: diff --git a/test/vmtests/requirements.txt b/test/vmtests/requirements.txt index 9e58009c8..ba91c29f4 100644 --- a/test/vmtests/requirements.txt +++ b/test/vmtests/requirements.txt @@ -1,2 +1,3 @@ docker == 7.1.0 +libvirt-python == 10.9.0 pytest == 8.3.3 diff --git a/test/vmtests/vmtests/conftest.py b/test/vmtests/vmtests/conftest.py index 7bd91527e..750b3ab76 100644 --- a/test/vmtests/vmtests/conftest.py +++ b/test/vmtests/vmtests/conftest.py @@ -7,12 +7,15 @@ import string import tempfile from pathlib import Path -from typing import Generator +from typing import Generator, List import docker +import libvirt # type: ignore import pytest from docker import DockerClient +from .utils.closeable import Closeable + SCRIPT_PATH = Path(__file__).parent TEST_CONFIGS_DIR = SCRIPT_PATH.joinpath("../../../toolkit/tools/pkg/imagecustomizerlib/testdata") @@ -91,3 +94,35 @@ def docker_client() -> Generator[DockerClient, None, None]: yield client client.close() # type: ignore + + +@pytest.fixture(scope="session") +def libvirt_conn() -> Generator[libvirt.virConnect, None, None]: + # Connect to libvirt. + libvirt_conn_str = f"qemu:///system" + libvirt_conn = libvirt.open(libvirt_conn_str) + + yield libvirt_conn + + libvirt_conn.close() + + +# Fixture that will close resources after a test has run, so long as the '--keep-environment' flag is not specified. +@pytest.fixture(scope="function") +def close_list(keep_environment: bool) -> Generator[List[Closeable], None, None]: + vm_delete_list: List[Closeable] = [] + + yield vm_delete_list + + if keep_environment: + return + + exceptions = [] + for vm in reversed(vm_delete_list): + try: + vm.close() + except Exception as ex: + exceptions.append(ex) + + if len(exceptions) > 0: + raise ExceptionGroup("failed to close resources", exceptions) diff --git a/test/vmtests/vmtests/test_no_change.py b/test/vmtests/vmtests/test_no_change.py index 7826eb9ff..d3fbf40d5 100644 --- a/test/vmtests/vmtests/test_no_change.py +++ b/test/vmtests/vmtests/test_no_change.py @@ -1,12 +1,19 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import os from pathlib import Path +from typing import List +import libvirt # type: ignore from docker import DockerClient from .conftest import TEST_CONFIGS_DIR +from .utils import local_client +from .utils.closeable import Closeable from .utils.imagecustomizer import run_image_customizer +from .utils.libvirt_utils import VmSpec, create_libvirt_domain_xml +from .utils.libvirt_vm import LibvirtVm def test_no_change( @@ -14,9 +21,13 @@ def test_no_change( image_customizer_container_url: str, core_efi_azl2: Path, test_temp_dir: Path, + test_instance_name: str, + libvirt_conn: libvirt.virConnect, + close_list: List[Closeable], ) -> None: config_path = TEST_CONFIGS_DIR.joinpath("nochange-config.yaml") output_image_path = test_temp_dir.joinpath("image.qcow2") + diff_image_path = test_temp_dir.joinpath("image-diff.qcow2") run_image_customizer( docker_client, @@ -26,3 +37,25 @@ def test_no_change( "qcow2", output_image_path, ) + + # Create a differencing disk for the VM. + # This will make it easier to manually debug what is in the image itself and what was set during first boot. + local_client.run( + ["qemu-img", "create", "-F", "qcow2", "-f", "qcow2", "-b", str(output_image_path), str(diff_image_path)], + ).check_exit_code() + + # Ensure VM can write to the disk file. + os.chmod(diff_image_path, 0o666) + + # Create VM. + vm_name = test_instance_name + domain_xml = create_libvirt_domain_xml(VmSpec(vm_name, 4096, 4, diff_image_path)) + + vm = LibvirtVm(vm_name, domain_xml, libvirt_conn) + close_list.append(vm) + + # Start VM. + vm.start() + + # Wait for VM to boot by waiting for it to request an IP address from the DHCP server. + vm.get_vm_ip_address(timeout=30) diff --git a/test/vmtests/vmtests/utils/closeable.py b/test/vmtests/vmtests/utils/closeable.py new file mode 100644 index 000000000..192669f13 --- /dev/null +++ b/test/vmtests/vmtests/utils/closeable.py @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import Protocol + + +# Interface for classes that have a 'close' method. +class Closeable(Protocol): + def close(self) -> None: + pass diff --git a/test/vmtests/vmtests/utils/libvirt_utils.py b/test/vmtests/vmtests/utils/libvirt_utils.py new file mode 100644 index 000000000..1f0ac6ebd --- /dev/null +++ b/test/vmtests/vmtests/utils/libvirt_utils.py @@ -0,0 +1,156 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import xml.etree.ElementTree as ET # noqa: N817 +from pathlib import Path +from typing import Dict + + +class VmSpec: + def __init__(self, name: str, memory_mib: int, core_count: int, os_disk_path: Path): + self.name: str = name + self.memory_mib: int = memory_mib + self.core_count: int = core_count + self.os_disk_path: Path = os_disk_path + + +# Create XML definition for a VM. +def create_libvirt_domain_xml(vm_spec: VmSpec) -> str: + domain = ET.Element("domain") + domain.attrib["type"] = "kvm" + + name = ET.SubElement(domain, "name") + name.text = vm_spec.name + + memory = ET.SubElement(domain, "memory") + memory.attrib["unit"] = "MiB" + memory.text = str(vm_spec.memory_mib) + + vcpu = ET.SubElement(domain, "vcpu") + vcpu.text = str(vm_spec.core_count) + + os_tag = ET.SubElement(domain, "os") + os_tag.attrib["firmware"] = "efi" + + os_type = ET.SubElement(os_tag, "type") + os_type.text = "hvm" + + firmware = ET.SubElement(domain, "firmware") + firmware.attrib["secure-boot"] = "yes" + firmware.attrib["enrolled-keys"] = "yes" + + features = ET.SubElement(domain, "features") + + ET.SubElement(features, "acpi") + + ET.SubElement(features, "apic") + + cpu = ET.SubElement(domain, "cpu") + cpu.attrib["mode"] = "host-passthrough" + + clock = ET.SubElement(domain, "clock") + clock.attrib["offset"] = "utc" + + on_poweroff = ET.SubElement(domain, "on_poweroff") + on_poweroff.text = "destroy" + + on_reboot = ET.SubElement(domain, "on_reboot") + on_reboot.text = "restart" + + on_crash = ET.SubElement(domain, "on_crash") + on_crash.text = "destroy" + + devices = ET.SubElement(domain, "devices") + + serial = ET.SubElement(devices, "serial") + serial.attrib["type"] = "pty" + + serial_target = ET.SubElement(serial, "target") + serial_target.attrib["type"] = "isa-serial" + serial_target.attrib["port"] = "0" + + serial_target_model = ET.SubElement(serial_target, "model") + serial_target_model.attrib["name"] = "isa-serial" + + console = ET.SubElement(devices, "console") + console.attrib["type"] = "pty" + + console_target = ET.SubElement(console, "target") + console_target.attrib["type"] = "serial" + console_target.attrib["port"] = "0" + + video = ET.SubElement(devices, "video") + + video_model = ET.SubElement(video, "model") + video_model.attrib["type"] = "qxl" + + graphics = ET.SubElement(devices, "graphics") + graphics.attrib["type"] = "spice" + + network_interface = ET.SubElement(devices, "interface") + network_interface.attrib["type"] = "network" + + network_interface_source = ET.SubElement(network_interface, "source") + network_interface_source.attrib["network"] = "default" + + network_interface_model = ET.SubElement(network_interface, "model") + network_interface_model.attrib["type"] = "virtio" + + next_disk_indexes: Dict[str, int] = {} + _add_disk_xml( + devices, + str(vm_spec.os_disk_path), + "disk", + "qcow2", + "virtio", + next_disk_indexes, + ) + + xml = ET.tostring(domain, "unicode") + return xml + + +# Adds a disk to a libvirt domain XML document. +def _add_disk_xml( + devices: ET.Element, + file_path: str, + device_type: str, + image_type: str, + bus_type: str, + next_disk_indexes: Dict[str, int], +) -> None: + device_name = _gen_disk_device_name("vd", next_disk_indexes) + + disk = ET.SubElement(devices, "disk") + disk.attrib["type"] = "file" + disk.attrib["device"] = device_type + + disk_driver = ET.SubElement(disk, "driver") + disk_driver.attrib["name"] = "qemu" + disk_driver.attrib["type"] = image_type + + disk_target = ET.SubElement(disk, "target") + disk_target.attrib["dev"] = device_name + disk_target.attrib["bus"] = bus_type + + disk_source = ET.SubElement(disk, "source") + disk_source.attrib["file"] = file_path + + +def _gen_disk_device_name(prefix: str, next_disk_indexes: Dict[str, int]) -> str: + disk_index = next_disk_indexes.get(prefix, 0) + next_disk_indexes[prefix] = disk_index + 1 + + match prefix: + case "vd" | "sd": + # The disk device name is required to follow the standard Linux device naming + # scheme. That is: [ sda, sdb, ..., sdz, sdaa, sdab, ... ]. However, it is + # unlikely that someone will ever need more than 26 disks. So, keep it simple + # for now. + if disk_index < 0 or disk_index > 25: + raise Exception(f"Unsupported disk index: {disk_index}.") + suffix = chr(ord("a") + disk_index) + return f"{prefix}{suffix}" + + case _: + return f"{prefix}{disk_index}" diff --git a/test/vmtests/vmtests/utils/libvirt_vm.py b/test/vmtests/vmtests/utils/libvirt_vm.py new file mode 100644 index 000000000..f3854f299 --- /dev/null +++ b/test/vmtests/vmtests/utils/libvirt_vm.py @@ -0,0 +1,101 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import logging +import time +from typing import Any, Optional + +import libvirt # type: ignore + + +# Assists with creating and destroying a libvirt VM. +class LibvirtVm: + def __init__(self, vm_name: str, domain_xml: str, libvirt_conn: libvirt.virConnect): + self.vm_name: str = vm_name + self.domain: libvirt.virDomain = None + + self.domain = libvirt_conn.defineXML(domain_xml) + + def start(self) -> None: + # Start the VM in the paused state. + # This gives the console logger a chance to connect before the VM starts. + self.domain.createWithFlags(libvirt.VIR_DOMAIN_START_PAUSED) + + # PLACEHOLDER + # Attach the console logger + # self.console_logger = LibvirtConsoleLogger() + # self.console_logger.attach(domain, console_log_file_path) + + # Start the VM. + self.domain.resume() + + # Wait for the VM to boot and then get the IP address. + def get_vm_ip_address(self, timeout: float = 30) -> str: + start_time = time.time() + timeout_time = start_time + timeout + + while True: + addr = self.try_get_vm_ip_address() + if addr: + total_wait_time = time.time() - start_time + logging.debug(f"Wait for VM ({self.vm_name}) boot / request IP address: {total_wait_time:.0f}s") + return addr + + if time.time() > timeout_time: + raise Exception(f"No IP addresses found for '{self.vm_name}'. OS might have failed to boot.") + + time.sleep(1) + + # Try to get the IP address of the VM. + def try_get_vm_ip_address(self) -> Optional[str]: + assert self.domain + + # Acquire IP address from libvirt's DHCP server. + interfaces = self.domain.interfaceAddresses(libvirt.VIR_DOMAIN_INTERFACE_ADDRESSES_SRC_LEASE) + if len(interfaces) < 1: + return None + + interface_name = next(iter(interfaces)) + addrs = interfaces[interface_name]["addrs"] + if len(addrs) < 1: + return None + + addr = addrs[0]["addr"] + assert isinstance(addr, str) + return addr + + def close(self) -> None: + # Stop the VM. + logging.debug(f"Stop VM: {self.vm_name}") + try: + # In the libvirt API, "destroy" means "stop". + self.domain.destroy() + except libvirt.libvirtError as ex: + logging.warning(f"VM stop failed. {ex}") + + # PLACEHOLDER + # Wait for console log to close. + # Note: libvirt can deadlock if you try to undefine the VM while the stream + # is trying to close. + # if self.console_logger: + # log.debug(f"Close VM console log: {vm_name}") + # self.console_logger.close() + # self.console_logger = None + + # Undefine the VM. + logging.debug(f"Delete VM: {self.vm_name}") + try: + self.domain.undefineFlags( + libvirt.VIR_DOMAIN_UNDEFINE_MANAGED_SAVE + | libvirt.VIR_DOMAIN_UNDEFINE_SNAPSHOTS_METADATA + | libvirt.VIR_DOMAIN_UNDEFINE_NVRAM + | libvirt.VIR_DOMAIN_UNDEFINE_CHECKPOINTS_METADATA + ) + except libvirt.libvirtError as ex: + logging.warning(f"VM delete failed. {ex}") + + def __enter__(self) -> "LibvirtVm": + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + self.close() diff --git a/test/vmtests/vmtests/utils/local_client.py b/test/vmtests/vmtests/utils/local_client.py new file mode 100644 index 000000000..acf9a9497 --- /dev/null +++ b/test/vmtests/vmtests/utils/local_client.py @@ -0,0 +1,194 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Mostly just a wrapper around the subprocess library but can +# live log stdout and stderr, while also collecting them as +# strings. + +import logging +import subprocess +import time +from io import StringIO +from pathlib import Path +from threading import Thread +from typing import IO, Any, Dict, List, Optional, Union + + +# The result of a subprocess execution. +class LocalExecutableResult: + def __init__( + self, + stdout: str, + stderr: str, + exit_code: int, + cmd: Union[str, List[str]], + elapsed: float, + is_timeout: bool, + ) -> None: + self.stdout = stdout + self.stderr = stderr + self.exit_code = exit_code + self.cmd = cmd + self.elapsed = elapsed + self.is_timeout = is_timeout + + def check_exit_code(self) -> None: + if self.is_timeout: + raise Exception("Process timed out") + + elif self.exit_code != 0: + raise Exception(f"Process failed with exit code: {self.exit_code}") + + +# Handles reading a pipe (stdout or stderr). +# The contents are both collected as a string and logged. +class _PipeReader: + def __init__(self, pipe: IO[str], log_level: int, log_name: str) -> None: + self._pipe = pipe + self._log_level = log_level + self._log_name = log_name + self._output: Optional[str] = None + + self._thread: Thread = Thread(target=self._read_thread) + self._thread.start() + + def close(self) -> None: + self._thread.join() + + def __enter__(self) -> "_PipeReader": + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + self.close() + + def wait_for_output(self) -> str: + self._thread.join() + + assert self._output is not None + return self._output + + def _read_thread(self) -> None: + log_enabled = logging.getLogger().isEnabledFor(self._log_level) + + with StringIO() as output: + while True: + # Read output one line at a time. + line = self._pipe.readline() + if not line: + break + + # Store the line. + output.write(line) + + # Log the line. + if log_enabled: + line_strip_newline = line[:-1] if line.endswith("\n") else line + logging.log(self._log_level, "%s: %s", self._log_name, line_strip_newline) + + self._pipe.close() + self._output = output.getvalue() + + +class LocalProcess: + def __init__( + self, + cmd: Union[str, List[str]], + proc: subprocess.Popen[str], + stdout_log_level: int, + stderr_log_level: int, + ) -> None: + self.cmd = cmd + self._proc = proc + self._result: Optional[LocalExecutableResult] = None + + self._start_time = time.monotonic() + + logging.debug("[%d][cmd]: %s", proc.pid, cmd) + + assert proc.stdout + assert proc.stderr + + self._stdout_reader = _PipeReader(proc.stdout, stdout_log_level, f"[{proc.pid}][stdout]") + self._stderr_reader = _PipeReader(proc.stderr, stderr_log_level, f"[{proc.pid}][stderr]") + + def close(self) -> None: + self._proc.kill() + self._stdout_reader.close() + self._stderr_reader.close() + + def __enter__(self) -> "LocalProcess": + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + self.close() + + # Wait for the subprocess to exit. + # If the timeout expires, the subprocess is killed. + def wait( + self, + timeout: float = 600, + ) -> LocalExecutableResult: + result = self._result + if result is None: + # Wait for the process to exit. + completed = False + try: + exit_code = self._proc.wait(timeout) + completed = True + + except subprocess.TimeoutExpired: + self._proc.kill() + exit_code = self._proc.wait() + + # Get the process's output. + stdout = self._stdout_reader.wait_for_output() + stderr = self._stderr_reader.wait_for_output() + + elapsed_time = time.monotonic() - self._start_time + + logging.debug("[%d][cmd]: execution time: %f, exit code: %d", self._proc.pid, elapsed_time, exit_code) + + result = LocalExecutableResult(stdout, stderr, exit_code, self.cmd, elapsed_time, not completed) + self._result = result + + return result + + +# Runs a subprocess and wait for the result. +# Stdout and stderr are both logged and collected as strings that are returned as part of the result. +def run( + cmd: Union[str, List[str]], + shell: bool = False, + cwd: Optional[Path] = None, + env: Optional[Dict[str, str]] = None, + stdout_log_level: int = logging.DEBUG, + stderr_log_level: int = logging.DEBUG, + timeout: float = 600, +) -> LocalExecutableResult: + with popen( + cmd, + shell=shell, + cwd=cwd, + env=env, + stdout_log_level=stdout_log_level, + stderr_log_level=stderr_log_level, + ) as process: + return process.wait( + timeout=timeout, + ) + + +# Runs a subprocess. +# Stdout and stderr are both logged and collected as strings that are returned as part of the result. +def popen( + cmd: Union[str, List[str]], + shell: bool = False, + cwd: Optional[Path] = None, + env: Optional[Dict[str, str]] = None, + stdout_log_level: int = logging.DEBUG, + stderr_log_level: int = logging.DEBUG, +) -> LocalProcess: + proc = subprocess.Popen( + cmd, shell=shell, env=env, cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8" + ) + return LocalProcess(cmd, proc, stdout_log_level, stderr_log_level)