From 6b27b2c5ae679751cd8e4197618fd766c2db2bf8 Mon Sep 17 00:00:00 2001 From: Claus Holbech Date: Tue, 17 Dec 2024 07:19:45 +0100 Subject: [PATCH] Refactor(plugins): Improve schema models (#4795) --- .../plugins/plugin_utils/pyavd_wrappers.py | 5 + .../roles/eos_designs/docs/input-variables.md | 2 +- .../shared_utils/link_tracking_groups.py | 4 +- .../pyavd/_eos_designs/shared_utils/misc.py | 34 ++-- .../structured_config/base/__init__.py | 2 +- .../ethernet_interfaces.py | 2 +- .../port_channel_interfaces.py | 2 +- .../__init__.py | 6 +- .../structured_config/mlag/__init__.py | 10 +- .../network_services/ethernet_interfaces.py | 2 +- .../network_services/router_bgp.py | 6 +- .../network_services/struct_cfgs.py | 2 +- .../network_services/vlan_interfaces.py | 2 +- .../structured_config/overlay/router_bgp.py | 2 +- .../structured_config/underlay/router_bgp.py | 2 +- .../structured_config/underlay/utils.py | 4 +- python-avd/pyavd/_schema/coerce_type.py | 58 +------ python-avd/pyavd/_schema/models/avd_base.py | 33 +++- .../pyavd/_schema/models/avd_indexed_list.py | 95 +++++++++--- python-avd/pyavd/_schema/models/avd_list.py | 65 ++++++-- python-avd/pyavd/_schema/models/avd_model.py | 145 +++++++++++++----- .../test_eos_designs_class.py | 4 +- 22 files changed, 305 insertions(+), 182 deletions(-) diff --git a/ansible_collections/arista/avd/plugins/plugin_utils/pyavd_wrappers.py b/ansible_collections/arista/avd/plugins/plugin_utils/pyavd_wrappers.py index 58f98a98aad..63ae2dcf886 100644 --- a/ansible_collections/arista/avd/plugins/plugin_utils/pyavd_wrappers.py +++ b/ansible_collections/arista/avd/plugins/plugin_utils/pyavd_wrappers.py @@ -32,6 +32,11 @@ def __init__(self, exception: Exception) -> None: def __call__(self, *_args: Any, **_kwargs: Any) -> NoReturn: raise self.exception + def __getattr__(self, name: str) -> Any: + if not name.startswith("__"): + raise self.exception + return self.__getattribute__(name) + def wrap_plugin(plugin_type: Literal["filter", "test"], name: str) -> Callable: plugin_map = { diff --git a/ansible_collections/arista/avd/roles/eos_designs/docs/input-variables.md b/ansible_collections/arista/avd/roles/eos_designs/docs/input-variables.md index e2b0f2da452..f2b658a8152 100644 --- a/ansible_collections/arista/avd/roles/eos_designs/docs/input-variables.md +++ b/ansible_collections/arista/avd/roles/eos_designs/docs/input-variables.md @@ -1480,7 +1480,7 @@ This feature currently provides the following configurations based on the given `max_uplink_switches` and `max_parallel_uplinks` to ensure consistent IP addressing. ??? example "`cv_topology` example" - To use this feature set `default_interfaces` according to the intended design (see [default_intefaces](#default-interface-settings) for details) and set `use_cv_topology` to `true`. + To use this feature set `default_interfaces` according to the intended design (see [default_interfaces](#default-interface-settings) for details) and set `use_cv_topology` to `true`. Provide a full topology under `cv_topology` like this example: ```yaml diff --git a/python-avd/pyavd/_eos_designs/shared_utils/link_tracking_groups.py b/python-avd/pyavd/_eos_designs/shared_utils/link_tracking_groups.py index 0782dd521af..8849a5e90fc 100644 --- a/python-avd/pyavd/_eos_designs/shared_utils/link_tracking_groups.py +++ b/python-avd/pyavd/_eos_designs/shared_utils/link_tracking_groups.py @@ -6,7 +6,7 @@ from functools import cached_property from typing import TYPE_CHECKING -from pyavd._utils import default +from pyavd._utils import default, strip_empties_from_list if TYPE_CHECKING: from . import SharedUtils @@ -33,6 +33,6 @@ def link_tracking_groups(self: SharedUtils) -> list | None: else: link_tracking_groups.append({"name": "LT_GROUP1", "recovery_delay": default_recovery_delay}) - return link_tracking_groups + return strip_empties_from_list(link_tracking_groups) return None diff --git a/python-avd/pyavd/_eos_designs/shared_utils/misc.py b/python-avd/pyavd/_eos_designs/shared_utils/misc.py index bf7a7b88ac7..7cf27721407 100644 --- a/python-avd/pyavd/_eos_designs/shared_utils/misc.py +++ b/python-avd/pyavd/_eos_designs/shared_utils/misc.py @@ -188,6 +188,7 @@ def get_ipv4_acl( "interface_ip": interface_ip, "peer_ip": peer_ip, } + changed = False for index, entry in enumerate(ipv4_acl.entries): if entry._get("remark"): continue @@ -202,13 +203,15 @@ def get_ipv4_acl( entry.source = self._get_ipv4_acl_field_with_substitution(entry.source, ip_replacements, f"{err_context}.source", interface_name) entry.destination = self._get_ipv4_acl_field_with_substitution(entry.destination, ip_replacements, f"{err_context}.destination", interface_name) + if entry.source != org_ipv4_acl.entries[index].source or entry.destination != org_ipv4_acl.entries[index].destination: + changed = True - if ipv4_acl != org_ipv4_acl: + if changed: ipv4_acl.name += f"_{self.sanitize_interface_name(interface_name)}" return ipv4_acl @staticmethod - def _get_ipv4_acl_field_with_substitution(field_value: str, replacements: dict[str, str], field_context: str, interface_name: str) -> str: + def _get_ipv4_acl_field_with_substitution(field_value: str, replacements: dict[str, str | None], field_context: str, interface_name: str) -> str: """ Checks one field if the value can be substituted. @@ -218,18 +221,15 @@ def _get_ipv4_acl_field_with_substitution(field_value: str, replacements: dict[s If a replacement is done, but the value is None, an error will be raised. """ - for key, value in replacements.items(): - if field_value != key: - continue - - if value is None: - msg = ( - f"Unable to perform substitution of the value '{key}' defined under '{field_context}', " - f"since no substitution value was found for interface '{interface_name}'. " - "Make sure to set the appropriate fields on the interface." - ) - raise AristaAvdError(msg) - - return value - - return field_value + if field_value not in replacements: + return field_value + + if (replacement_value := replacements[field_value]) is None: + msg = ( + f"Unable to perform substitution of the value '{field_value}' defined under '{field_context}', " + f"since no substitution value was found for interface '{interface_name}'. " + "Make sure to set the appropriate fields on the interface." + ) + raise AristaAvdError(msg) + + return replacement_value diff --git a/python-avd/pyavd/_eos_designs/structured_config/base/__init__.py b/python-avd/pyavd/_eos_designs/structured_config/base/__init__.py index da090f15922..296cb9787d3 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/base/__init__.py +++ b/python-avd/pyavd/_eos_designs/structured_config/base/__init__.py @@ -780,6 +780,6 @@ def route_maps(self) -> list | None: @cached_property def struct_cfgs(self) -> list | None: if self.shared_utils.platform_settings.structured_config: - return [self.shared_utils.platform_settings.structured_config._as_dict(strip_values=())] + return [self.shared_utils.platform_settings.structured_config._as_dict()] return None diff --git a/python-avd/pyavd/_eos_designs/structured_config/connected_endpoints/ethernet_interfaces.py b/python-avd/pyavd/_eos_designs/structured_config/connected_endpoints/ethernet_interfaces.py index 4b1bc4f2e2d..c8161251a87 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/connected_endpoints/ethernet_interfaces.py +++ b/python-avd/pyavd/_eos_designs/structured_config/connected_endpoints/ethernet_interfaces.py @@ -182,7 +182,7 @@ def _get_ethernet_interface_cfg( "dot1x": adapter.dot1x._as_dict() or None, "poe": self._get_adapter_poe(adapter), "eos_cli": adapter.raw_eos_cli, - "struct_cfg": adapter.structured_config._as_dict(strip_values=()), + "struct_cfg": adapter.structured_config._as_dict(), } # Port-channel member diff --git a/python-avd/pyavd/_eos_designs/structured_config/connected_endpoints/port_channel_interfaces.py b/python-avd/pyavd/_eos_designs/structured_config/connected_endpoints/port_channel_interfaces.py index 0ceeff3d2b1..3a8160e2a97 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/connected_endpoints/port_channel_interfaces.py +++ b/python-avd/pyavd/_eos_designs/structured_config/connected_endpoints/port_channel_interfaces.py @@ -158,7 +158,7 @@ def _get_port_channel_interface_cfg( "validate_state": None if (adapter.validate_state if adapter.validate_state is not None else True) else False, "validate_lldp": None if (adapter.validate_lldp if adapter.validate_lldp is not None else True) else False, "eos_cli": adapter.port_channel.raw_eos_cli, - "struct_cfg": adapter.port_channel.structured_config._as_dict(strip_values=()) or None, + "struct_cfg": adapter.port_channel.structured_config._as_dict() or None, } if adapter.port_channel.subinterfaces: diff --git a/python-avd/pyavd/_eos_designs/structured_config/custom_structured_configuration/__init__.py b/python-avd/pyavd/_eos_designs/structured_config/custom_structured_configuration/__init__.py index 6dc0d62b595..7965dd82000 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/custom_structured_configuration/__init__.py +++ b/python-avd/pyavd/_eos_designs/structured_config/custom_structured_configuration/__init__.py @@ -39,7 +39,7 @@ def _extract_and_apply_struct_cfg_from_list_of_dicts(self, list_of_dicts: list, return struct_cfgs def _struct_cfg(self) -> list: - if struct_cfg := self.shared_utils.node_config.structured_config._as_dict(strip_values=()): + if struct_cfg := self.shared_utils.node_config.structured_config._as_dict(): return [struct_cfg] return [] @@ -114,9 +114,7 @@ def _router_bgp_vlans(self) -> list: ] def _custom_structured_configurations(self) -> list[dict]: - return [ - custom_structured_configuration.value._as_dict(strip_values=()) for custom_structured_configuration in self.inputs._custom_structured_configurations - ] + return [custom_structured_configuration.value._as_dict() for custom_structured_configuration in self.inputs._custom_structured_configurations] def render(self) -> list[dict]: """ diff --git a/python-avd/pyavd/_eos_designs/structured_config/mlag/__init__.py b/python-avd/pyavd/_eos_designs/structured_config/mlag/__init__.py index 90abdddb325..ec50862acff 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/mlag/__init__.py +++ b/python-avd/pyavd/_eos_designs/structured_config/mlag/__init__.py @@ -70,7 +70,7 @@ def vlan_interfaces(self) -> list | None: ), "shutdown": False, "no_autostate": True, - "struct_cfg": self.shared_utils.node_config.mlag_peer_vlan_structured_config._as_dict(strip_values=()) or None, + "struct_cfg": self.shared_utils.node_config.mlag_peer_vlan_structured_config._as_dict() or None, "mtu": self.shared_utils.p2p_uplinks_mtu, } @@ -82,7 +82,7 @@ def vlan_interfaces(self) -> list | None: return [strip_empties_from_dict(main_vlan_interface)] # Create L3 data which will go on either a dedicated l3 vlan or the main mlag vlan - l3_cfg = {"struct_cfg": self.shared_utils.node_config.mlag_peer_l3_vlan_structured_config._as_dict(strip_values=()) or None} + l3_cfg = {"struct_cfg": self.shared_utils.node_config.mlag_peer_l3_vlan_structured_config._as_dict() or None} if self.shared_utils.underlay_routing_protocol == "ospf": l3_cfg.update( { @@ -121,7 +121,7 @@ def vlan_interfaces(self) -> list | None: main_vlan_interface.update(l3_cfg) # Applying structured config again in the case it is set on both l3vlan and main vlan if self.shared_utils.node_config.mlag_peer_vlan_structured_config is not None: - main_vlan_interface["struct_cfg"] = self.shared_utils.node_config.mlag_peer_vlan_structured_config._as_dict(strip_values=()) + main_vlan_interface["struct_cfg"] = self.shared_utils.node_config.mlag_peer_vlan_structured_config._as_dict() return [strip_empties_from_dict(main_vlan_interface)] @@ -169,7 +169,7 @@ def port_channel_interfaces(self) -> list: }, "shutdown": False, "service_profile": self.inputs.p2p_uplinks_qos_profile, - "struct_cfg": self.shared_utils.node_config.mlag_port_channel_structured_config._as_dict(strip_values=()) or None, + "struct_cfg": self.shared_utils.node_config.mlag_port_channel_structured_config._as_dict() or None, "flow_tracker": self.shared_utils.get_flow_tracker(self.inputs.fabric_flow_tracking.mlag_interfaces), } @@ -350,7 +350,7 @@ def _router_bgp_mlag_peer_group(self) -> dict: "bfd": self.inputs.bgp_peer_groups.ipv4_underlay_peers.bfd or None, "maximum_routes": 12000, "send_community": "all", - "struct_cfg": self.inputs.bgp_peer_groups.mlag_ipv4_underlay_peer.structured_config._as_dict(strip_values=()) or None, + "struct_cfg": self.inputs.bgp_peer_groups.mlag_ipv4_underlay_peer.structured_config._as_dict() or None, } if self.shared_utils.node_config.mlag_ibgp_origin_incomplete: peer_group["route_map_in"] = "RM-MLAG-PEER-IN" diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/ethernet_interfaces.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/ethernet_interfaces.py index 49cab1e01c6..6fffe42a324 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/ethernet_interfaces.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/ethernet_interfaces.py @@ -70,7 +70,7 @@ def ethernet_interfaces(self: AvdStructuredConfigNetworkServices) -> list | None "shutdown": not l3_interface.enabled, "description": interface_description, "eos_cli": l3_interface.raw_eos_cli, - "struct_cfg": l3_interface.structured_config._as_dict(strip_values=()) or None, + "struct_cfg": l3_interface.structured_config._as_dict() or None, "flow_tracker": self.shared_utils.get_flow_tracker(l3_interface.flow_tracking), } diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/router_bgp.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/router_bgp.py index 8d26e868441..907d2eac546 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/router_bgp.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/router_bgp.py @@ -153,7 +153,7 @@ def _router_bgp_vrfs(self: AvdStructuredConfigNetworkServices) -> dict: bgp_vrf = strip_empties_from_dict( { "eos_cli": vrf.bgp.raw_eos_cli, - "struct_cfg": vrf.bgp.structured_config._as_dict(strip_values=()) or None, + "struct_cfg": vrf.bgp.structured_config._as_dict() or None, } ) @@ -493,7 +493,7 @@ def _router_bgp_vlans_vlan( "route_targets": {"both": [vlan_rt]}, "redistribute_routes": ["learned"], "eos_cli": vlan.bgp.raw_eos_cli, - "struct_cfg": vlan.bgp.structured_config._as_dict(strip_values=()) or None, + "struct_cfg": vlan.bgp.structured_config._as_dict() or None, } if self.shared_utils.node_config.evpn_gateway.evpn_l2.enabled and default( vlan.evpn_l2_multi_domain, vrf.evpn_l2_multi_domain, tenant.evpn_l2_multi_domain @@ -816,7 +816,7 @@ def _router_bgp_mlag_peer_group(self: AvdStructuredConfigNetworkServices) -> dic "password": self.inputs.bgp_peer_groups.mlag_ipv4_underlay_peer.password, "maximum_routes": 12000, "send_community": "all", - "struct_cfg": self.inputs.bgp_peer_groups.mlag_ipv4_underlay_peer.structured_config._as_dict(strip_values=()) or None, + "struct_cfg": self.inputs.bgp_peer_groups.mlag_ipv4_underlay_peer.structured_config._as_dict() or None, } if self.shared_utils.node_config.mlag_ibgp_origin_incomplete: diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/struct_cfgs.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/struct_cfgs.py index 4b34c21ceb0..48f2ccc8054 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/struct_cfgs.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/struct_cfgs.py @@ -32,7 +32,7 @@ def struct_cfgs(self: AvdStructuredConfigNetworkServices) -> list | None: for vrf in tenant.vrfs: if vrf.structured_config: # Inserting VRF into structured_config to perform duplication checks - vrf_struct_cfg = {"vrf": vrf.name, "struct_cfg": vrf.structured_config._as_dict(strip_values=())} + vrf_struct_cfg = {"vrf": vrf.name, "struct_cfg": vrf.structured_config._as_dict()} append_if_not_duplicate( list_of_dicts=vrf_struct_cfgs, primary_key="vrf", diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/vlan_interfaces.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/vlan_interfaces.py index d07d1fa2fee..92b6958923a 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/vlan_interfaces.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/vlan_interfaces.py @@ -103,7 +103,7 @@ def _check_virtual_router_mac_address(vlan_interface_config: dict, variables: li "access_group_out": get(self._svi_acls, f"{interface_name}.ipv4_acl_out.name"), "mtu": svi.mtu if self.shared_utils.platform_settings.feature_support.per_interface_mtu else None, "eos_cli": svi.raw_eos_cli, - "struct_cfg": svi.structured_config._as_dict(strip_values=()) or None, + "struct_cfg": svi.structured_config._as_dict() or None, } # Only set VARP if ip_address is set if vlan_interface_config["ip_address"] is not None: diff --git a/python-avd/pyavd/_eos_designs/structured_config/overlay/router_bgp.py b/python-avd/pyavd/_eos_designs/structured_config/overlay/router_bgp.py index 54196e460d3..89d90a33cad 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/overlay/router_bgp.py +++ b/python-avd/pyavd/_eos_designs/structured_config/overlay/router_bgp.py @@ -88,7 +88,7 @@ def _generate_base_peer_group( "password": peer_group.password, "send_community": "all", "maximum_routes": maximum_routes, - "struct_cfg": peer_group.structured_config._as_dict(strip_values=()) or None, + "struct_cfg": peer_group.structured_config._as_dict() or None, } def _peer_groups(self: AvdStructuredConfigOverlay) -> list | None: diff --git a/python-avd/pyavd/_eos_designs/structured_config/underlay/router_bgp.py b/python-avd/pyavd/_eos_designs/structured_config/underlay/router_bgp.py index 835147cb981..1c35a78678c 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/underlay/router_bgp.py +++ b/python-avd/pyavd/_eos_designs/structured_config/underlay/router_bgp.py @@ -36,7 +36,7 @@ def router_bgp(self: AvdStructuredConfigUnderlay) -> dict | None: "bfd": self.inputs.bgp_peer_groups.ipv4_underlay_peers.bfd or None, "maximum_routes": 12000, "send_community": "all", - "struct_cfg": self.inputs.bgp_peer_groups.ipv4_underlay_peers.structured_config._as_dict(strip_values=()) or None, + "struct_cfg": self.inputs.bgp_peer_groups.ipv4_underlay_peers.structured_config._as_dict() or None, } if self.shared_utils.overlay_routing_protocol == "ibgp" and self.shared_utils.is_cv_pathfinder_router: diff --git a/python-avd/pyavd/_eos_designs/structured_config/underlay/utils.py b/python-avd/pyavd/_eos_designs/structured_config/underlay/utils.py index 556bc1d8dc9..20da753bfba 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/underlay/utils.py +++ b/python-avd/pyavd/_eos_designs/structured_config/underlay/utils.py @@ -182,7 +182,7 @@ def _get_l3_interface_cfg( "access_group_in": get(self._l3_interface_acls, f"{l3_interface.name}..ipv4_acl_in..name", separator=".."), "access_group_out": get(self._l3_interface_acls, f"{l3_interface.name}..ipv4_acl_out..name", separator=".."), "eos_cli": l3_interface.raw_eos_cli, - "struct_cfg": l3_interface.structured_config._as_dict(strip_values=()), + "struct_cfg": l3_interface.structured_config._as_dict(), "flow_tracker": self.shared_utils.get_flow_tracker(l3_interface.flow_tracking), } @@ -268,7 +268,7 @@ def _get_l2_as_subint( "ipv6_enable": svi.ipv6_enable, "mtu": svi.mtu if self.shared_utils.platform_settings.feature_support.per_interface_mtu else None, "eos_cli": svi.raw_eos_cli, - "struct_cfg": svi.structured_config._as_dict(strip_values=()) or None, + "struct_cfg": svi.structured_config._as_dict() or None, "flow_tracker": link.get("flow_tracker"), } if (mtu := subinterface["mtu"]) is not None and subinterface["mtu"] > self.shared_utils.p2p_uplinks_mtu: diff --git a/python-avd/pyavd/_schema/coerce_type.py b/python-avd/pyavd/_schema/coerce_type.py index d071a0f7539..34d9a3e2905 100644 --- a/python-avd/pyavd/_schema/coerce_type.py +++ b/python-avd/pyavd/_schema/coerce_type.py @@ -12,64 +12,9 @@ if TYPE_CHECKING: from typing import NoReturn, TypeVar - from typing_extensions import Self - - from pyavd._schema.models.type_vars import T_AvdBase - T = TypeVar("T") -def nullifiy_class(cls: type[T_AvdBase]) -> type: - """ - Returns a subclass of the given class with overrides for "null" values. - - This class is used when the input for a dict or a list is None/null/none, - to be able to signal to the deepmerge/inherit methods that this is not the same as an unset variable. - """ - - class NullifiedCls(cls): - def _get_defined_attr(self, name: str) -> T_AvdBase | None: - """ - Return the default values or None. - - This is required for the various merge / inheritance logic to always take from this if undefined. - """ - return getattr(self, name) - - def _as_dict(self, *_args: Any, **_kwargs: Any) -> None: - """Always None.""" - - def _as_list(self, *_args: Any, **_kwargs: Any) -> None: - """Always None.""" - - def __repr__(self) -> str: - return f"" - - def _deepinherited(self, *_args: Any, **_kwargs: Any) -> Self: - """Nothing to do since a NullifiedCls will override anything with None.""" - return self._deepcopy() - - def _deepinherit(self, *_args: Any, **_kwargs: Any) -> None: - """Nothing to do since a NullifiedCls will override anything with None.""" - - def _inherit(self, *_args: Any, **_kwargs: Any) -> None: - """Nothing to do since a NullifiedCls will override anything with None.""" - - def _deepmerge(self, *_args: Any, **_kwargs: Any) -> NoReturn: - msg = "A NullifiedCls cannot be inplace deepmerged. Use _deepmerged() instead." - raise NotImplementedError(msg) - - def _deepmerged(self, other: T_AvdBase, *_args: Any, **_kwargs: Any) -> T_AvdBase: - """Returning the other directly since NullifiedCls is empty.""" - return other._deepcopy() - - def _cast_as(self, new_type: type[T_AvdBase], *_args: Any, **_kwargs: Any) -> T_AvdBase: - """Wrap the new type in it's own NullifiedCls.""" - return nullifiy_class(new_type)() - - return NullifiedCls - - def coerce_type(value: Any, target_type: type[T]) -> T: """ Return a coerced variant of the given value to the target_type. @@ -81,8 +26,7 @@ def coerce_type(value: Any, target_type: type[T]) -> T: if value is None: if issubclass(target_type, AvdBase): # None values are sometimes used to overwrite inherited profiles. - # This ensures we still follow the type hint of the class. - return nullifiy_class(target_type)() + return target_type._from_null() # Other None values are left untouched. elif target_type is Any or isinstance(value, target_type): diff --git a/python-avd/pyavd/_schema/models/avd_base.py b/python-avd/pyavd/_schema/models/avd_base.py index 22beb558607..f9dd8bf714f 100644 --- a/python-avd/pyavd/_schema/models/avd_base.py +++ b/python-avd/pyavd/_schema/models/avd_base.py @@ -18,11 +18,21 @@ class AvdBase(ABC): """Base class used for schema-based data classes holding data loaded from AVD inputs.""" - def __eq__(self, other: object) -> bool: - """Compare two instances of AvdBase by comparing their repr.""" - if isinstance(other, self.__class__): - return repr(self) == repr(other) - return False + _created_from_null: bool = False + """ + Flag to say if this data was loaded from a ': null' value in YAML. + + This is used to handle inheritance and merging correctly. + When _created_from_null we inherit nothing (we win!). + When _created_from_null we take anything in when merging and clear the flag. + TODO: Stop changing data in-place. + + The flag is not carried across between classes, so it should not affect anything outside the loaded inputs. + Only exception is on _cast_as, where the flag is carried over. + """ + + _block_inheritance: bool = False + """Flag to block inheriting further if we at some point inherited from a class with _created_from_null set.""" def _deepcopy(self) -> Self: """Return a copy including all nested models.""" @@ -33,8 +43,19 @@ def _deepcopy(self) -> Self: def _load(cls, data: Sequence | Mapping) -> Self: """Returns a new instance loaded with the given data.""" + @classmethod + def _from_null(cls) -> Self: + """Returns a new instance with all attributes set to None. This represents the YAML input ': null'.""" + new_instance = cls() + new_instance._created_from_null = True + return new_instance + + @abstractmethod + def _strip_empties(self) -> None: + """In-place update the instance to remove data matching the given strip_values.""" + @abstractmethod - def _dump(self, include_default_values: bool = False, strip_values: tuple = (None, [], {})) -> dict | list: + def _dump(self, include_default_values: bool = False) -> dict | list: """Dump data into native Python types with or without default values.""" @abstractmethod diff --git a/python-avd/pyavd/_schema/models/avd_indexed_list.py b/python-avd/pyavd/_schema/models/avd_indexed_list.py index b2e43f5e651..cfb6118519c 100644 --- a/python-avd/pyavd/_schema/models/avd_indexed_list.py +++ b/python-avd/pyavd/_schema/models/avd_indexed_list.py @@ -5,9 +5,9 @@ import re from collections.abc import Iterable, Iterator, Sequence -from copy import deepcopy -from typing import TYPE_CHECKING, ClassVar, Generic, Literal +from typing import TYPE_CHECKING, ClassVar, Generic, Literal, cast +from pyavd._errors import AristaAvdDuplicateDataError from pyavd._schema.coerce_type import coerce_type from pyavd._utils import Undefined, UndefinedType @@ -52,7 +52,7 @@ def _from_list(cls, data: Sequence) -> Self: msg = f"Expecting 'data' as a 'Sequence' when loading data into '{cls.__name__}'. Got '{type(data)}" raise TypeError(msg) - cls_items = [coerce_type(item, cls._item_type) for item in data] + cls_items = cast(Iterable[T_AvdModel], (coerce_type(item, cls._item_type) for item in data)) return cls(cls_items) def __init__(self, items: Iterable[T_AvdModel] | UndefinedType = Undefined) -> None: @@ -104,22 +104,52 @@ def keys(self) -> Iterable[T_PrimaryKey]: def values(self) -> Iterable[T_AvdModel]: return self._items.values() - def append(self, item: T_AvdModel) -> None: - self._items[getattr(item, self._primary_key)] = item + def obtain(self, key: T_PrimaryKey) -> T_AvdModel: + """Return item with given primary key, autocreating if missing.""" + if key not in self._items: + item_type = cast(T_AvdModel, self._item_type) + self._items[key] = item_type._from_dict({self._primary_key: key}) + return self._items[key] + + def append(self, item: T_AvdModel, ignore_fields: tuple[str, ...] = ()) -> None: + if (primary_key := getattr(item, self._primary_key)) in self._items: + # Found existing entry using the same primary key. Ignore if it is the exact same content. + if item._compare(existing_item := self._items[primary_key], ignore_fields): + # Ignore identical item. + return + raise AristaAvdDuplicateDataError(type(self).__name__, str(item), str(existing_item)) + + self._items[primary_key] = item + + if TYPE_CHECKING: + append_new: type[T_AvdModel] + + else: + + def append_new(self, *args: Any, **kwargs: Any) -> T_AvdModel: + """ + Create a new instance with the given arguments and append to the list. + + Returns the new item, or in case of an identical duplicate item it returns the existing item. + """ + new_item = self._item_type(*args, **kwargs) + self.append(new_item) + return self._items[kwargs[self._primary_key]] def extend(self, items: Iterable[T_AvdModel]) -> None: self._items.update({getattr(item, self._primary_key): item for item in items}) - def _as_list(self, include_default_values: bool = False, strip_values: tuple = (None, [], {})) -> list[dict]: + def _strip_empties(self) -> None: + """In-place update the instance to remove data matching the given strip_values.""" + [item._strip_empties() for item in self._items.values()] + self._items = {primary_key: item for primary_key, item in self._items.items() if item} + + def _as_list(self, include_default_values: bool = False) -> list[dict]: """Returns a list with all the data from this model and any nested models.""" - return [ - value - for item in self._items.values() - if (value := item._as_dict(include_default_values=include_default_values, strip_values=strip_values)) not in strip_values - ] + return [item._as_dict(include_default_values=include_default_values) for item in self._items.values()] - def _dump(self, include_default_values: bool = False, strip_values: tuple = (None, [], {})) -> list[dict]: - return self._as_list(include_default_values=include_default_values, strip_values=strip_values) + def _dump(self, include_default_values: bool = False) -> list[dict]: + return self._as_list(include_default_values=include_default_values) def _natural_sorted(self, ignore_case: bool = True) -> Self: """Return new instance where the items are natural sorted by primary key.""" @@ -151,15 +181,24 @@ def _deepmerge(self, other: Self, list_merge: Literal["append", "replace"] = "ap msg = f"Unable to merge type '{type(other)}' into '{cls}'" raise TypeError(msg) + if self._created_from_null or other._created_from_null: + # Clear the flag and set list_merge to replace so we overwrite with data from other below. + self._created_from_null = False + list_merge = "replace" + if list_merge == "replace": - self._items = deepcopy(other._items) + self._items = other._items.copy() return for primary_key, new_item in other.items(): - old_value = self.get(primary_key) - if old_value is Undefined or not isinstance(old_value, type(new_item)): + if new_item._created_from_null: + # Remove the complete item when merging in a Null item. + self._items.pop(primary_key, None) + continue + + if (old_value := self.get(primary_key)) is Undefined or not isinstance(old_value, type(new_item)): # New item or different type so we can just replace - self[primary_key] = deepcopy(new_item) + self[primary_key] = new_item continue # Existing item of same type, so deepmerge. @@ -172,11 +211,19 @@ def _deepinherit(self, other: Self) -> None: msg = f"Unable to inherit from type '{type(other)}' into '{cls}'" raise TypeError(msg) + if self._created_from_null or self._block_inheritance: + # Null always wins, so no inheritance. + return + + if other._created_from_null: + # Nothing to inherit, and we set the special block flag to prevent inheriting from something else later. + self._block_inheritance = True + return + for primary_key, new_item in other.items(): - old_value = self.get(primary_key) - if old_value is Undefined: + if self.get(primary_key) is Undefined: # New item so we can just append - self[primary_key] = deepcopy(new_item) + self[primary_key] = new_item continue # Existing item, so deepinherit. @@ -195,4 +242,10 @@ def _cast_as(self, new_type: type[T_AvdIndexedList], ignore_extra_keys: bool = F msg = f"Unable to cast '{cls}' as type '{new_type}' since '{new_type}' is not an AvdIndexedList subclass." raise TypeError(msg) - return new_type([item._cast_as(new_type._item_type, ignore_extra_keys=ignore_extra_keys) for item in self]) + new_instance = new_type([item._cast_as(new_type._item_type, ignore_extra_keys=ignore_extra_keys) for item in self]) + + # Pass along the internal flags + new_instance._created_from_null = self._created_from_null + new_instance._block_inheritance = self._block_inheritance + + return new_instance diff --git a/python-avd/pyavd/_schema/models/avd_list.py b/python-avd/pyavd/_schema/models/avd_list.py index 51266627969..490bf725d11 100644 --- a/python-avd/pyavd/_schema/models/avd_list.py +++ b/python-avd/pyavd/_schema/models/avd_list.py @@ -5,8 +5,7 @@ import re from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence -from copy import deepcopy -from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, cast from pyavd._schema.coerce_type import coerce_type from pyavd._utils import Undefined, UndefinedType @@ -57,10 +56,10 @@ def _from_list(cls, data: Sequence) -> Self: def __init__(self, items: Iterable[T_ItemType] | UndefinedType = Undefined) -> None: """ - AvdIndexedList subclass. + AvdList subclass. Args: - items: Iterable holding items of the correct type to be loaded into the indexed list. + items: Iterable holding items of the correct type to be loaded into the list. """ if isinstance(items, UndefinedType): self._items = [] @@ -98,20 +97,44 @@ def get(self, index: int, default: T | UndefinedType = Undefined) -> T_ItemType def append(self, item: T_ItemType) -> None: self._items.append(item) + def append_unique(self, item: T_ItemType) -> None: + """Append the item if not there already. Otherwise ignore.""" + if item not in self._items: + self._items.append(item) + + if TYPE_CHECKING: + append_new: type[T_ItemType] + else: + + def append_new(self, *args: Any, **kwargs: Any) -> T_ItemType: + """Create a new instance with the given arguments and append to the list. Returns the new item.""" + new_item = self._item_type(*args, **kwargs) + self.append(new_item) + return new_item + def extend(self, items: Iterable[T_ItemType]) -> None: self._items.extend(items) - def _as_list(self, include_default_values: bool = False, strip_values: tuple = (None, [], {})) -> list: + def _strip_empties(self) -> None: + """In-place update the instance to remove data matching the given strip_values.""" + if issubclass(self._item_type, AvdBase): + items = cast(list[AvdBase], self._items) + [item._strip_empties() for item in items] + self._items = [item for item in self._items if item] + return + + self._items = [item for item in self._items if item is not None] + + def _as_list(self, include_default_values: bool = False) -> list: """Returns a list with all the data from this model and any nested models.""" if issubclass(self._item_type, AvdBase): - items: list[AvdBase] = self._items - return [ - value for item in items if (value := item._dump(include_default_values=include_default_values, strip_values=strip_values)) not in strip_values - ] - return [item for item in self._items if item not in strip_values] + items = cast(list[AvdBase], self._items) + return [item._dump(include_default_values=include_default_values) for item in items] + + return list(self._items) - def _dump(self, include_default_values: bool = False, strip_values: tuple = (None, [], {})) -> list: - return self._as_list(include_default_values=include_default_values, strip_values=strip_values) + def _dump(self, include_default_values: bool = False) -> list: + return self._as_list(include_default_values=include_default_values) def _natural_sorted(self, sort_key: str | None = None, ignore_case: bool = True) -> Self: """Return new instance where the items are natural sorted by the given sort key or by the item itself.""" @@ -153,12 +176,17 @@ def _deepmerge(self, other: Self, list_merge: Literal["append", "replace"] = "ap msg = f"Unable to merge type '{type(other)}' into '{cls}'" raise TypeError(msg) + if self._created_from_null: + # Overwrite all data from other and clear the flag. + self._created_from_null = False + list_merge = "replace" + if list_merge == "replace": - self._items = deepcopy(other._items) + self._items = other._items.copy() return # Append non-existing items. - self._items.extend(deepcopy([new_item for new_item in other._items if new_item not in self._items])) + self._items.extend(new_item for new_item in other._items if new_item not in self._items) def _cast_as(self, new_type: type[T_AvdList], ignore_extra_keys: bool = False) -> T_AvdList: """ @@ -174,11 +202,16 @@ def _cast_as(self, new_type: type[T_AvdList], ignore_extra_keys: bool = False) - raise TypeError(msg) if issubclass(self._item_type, AvdBase): - items: list[AvdBase] = self._items + items = cast(list[AvdBase], self._items) return new_type([item._cast_as(new_type._item_type, ignore_extra_keys=ignore_extra_keys) for item in items]) if self._item_type != new_type._item_type: msg = f"Unable to cast '{cls}' as type '{new_type}' since they have incompatible item types." raise TypeError(msg) - return new_type(self._items) + new_instance = new_type(self._items) + + # Pass along the _created_from_null flag + new_instance._created_from_null = self._created_from_null + + return new_instance diff --git a/python-avd/pyavd/_schema/models/avd_model.py b/python-avd/pyavd/_schema/models/avd_model.py index c425e51708d..cca18716d25 100644 --- a/python-avd/pyavd/_schema/models/avd_model.py +++ b/python-avd/pyavd/_schema/models/avd_model.py @@ -5,7 +5,8 @@ from collections.abc import Mapping from copy import deepcopy -from typing import TYPE_CHECKING, Any, ClassVar, Literal +from logging import getLogger +from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast from pyavd._schema.coerce_type import coerce_type from pyavd._utils import Undefined, UndefinedType, merge @@ -18,6 +19,8 @@ from .type_vars import T_AvdModel +LOGGER = getLogger(__name__) + class AvdModel(AvdBase): """Base class used for schema-based data classes holding dictionaries loaded from AVD inputs.""" @@ -131,6 +134,8 @@ def _get_defined_attr(self, name: str) -> Any | UndefinedType: Get attribute or Undefined. Avoids the overridden __getattr__ to avoid default values. + + Falls back to __getattr__ in case of _created_from_null to always insert None or default value. """ if name not in self._fields: msg = f"'{type(self).__name__}' object has no attribute '{name}'" @@ -164,39 +169,57 @@ def __bool__(self) -> bool: for field in self._fields ) - def _as_dict(self, include_default_values: bool = False, strip_values: tuple = (None, [], {})) -> dict: + def _strip_empties(self) -> None: + """In-place update the instance to remove data matching the given strip_values.""" + for field, field_info in self._fields.items(): + if (value := self._get_defined_attr(field)) is Undefined or field == "_custom_data": + continue + + if issubclass(field_info["type"], AvdBase): + value = cast(AvdBase, value) + value._strip_empties() + if not value: + delattr(self, field) + continue + + if value is None: + delattr(self, field) + + def _as_dict(self, include_default_values: bool = False) -> dict: """ Returns a dict with all the data from this model and any nested models. Filtered for nested None, {} and [] values. """ as_dict = {} - for field, field_info in self._fields.items() or (): - if (value := self._get_defined_attr(field)) is Undefined: + for field, field_info in self._fields.items(): + value = self._get_defined_attr(field) + + if field == "_custom_data": + if value: + value = cast(dict[str, Any], value) + as_dict.update(value) + continue + + if value is Undefined: if not include_default_values: continue value = self._get_field_default_value(field) - if field == "_custom_data" and isinstance(value, dict) and value: - as_dict.update(value) - continue - # Removing field_ prefix if needed. key = self._field_to_key_map.get(field, field) - if issubclass(field_info["type"], AvdBase) and isinstance(value, AvdBase): - value = value._dump(include_default_values=include_default_values, strip_values=strip_values) - - if value in strip_values: - continue + if issubclass(field_info["type"], AvdBase): + value = cast(AvdBase, value) + value = None if value._created_from_null else value._dump(include_default_values=include_default_values) as_dict[key] = value return as_dict - def _dump(self, include_default_values: bool = False, strip_values: tuple = (None, [], {})) -> dict: - return self._as_dict(include_default_values=include_default_values, strip_values=strip_values) + def _dump(self, include_default_values: bool = False) -> dict: + return self._as_dict(include_default_values=include_default_values) def _get(self, name: str, default: Any = None) -> Any: """ @@ -208,20 +231,14 @@ def _get(self, name: str, default: Any = None) -> Any: return default return value - def _update(self, other: Self) -> None: - """Update instance by shallow merging the other instance in.""" - cls = type(self) - if not isinstance(other, cls): - msg = f"Unable to merge type '{type(other)}' into '{cls}'" - raise TypeError(msg) + if TYPE_CHECKING: + _update: type[Self] + else: - for field in cls._fields: - if new_value := other._get_defined_attr(field) is Undefined: - continue - old_value = self._get_defined_attr(field) - if old_value == new_value: - continue - setattr(self, field, new_value) + def _update(self, *args: Any, **kwargs: Any) -> Self: + """Update instance with the given kwargs. Reuses __init__.""" + self.__init__(*args, **kwargs) + return self def _deepmerge(self, other: Self, list_merge: Literal["append", "replace"] = "append") -> None: """ @@ -239,6 +256,10 @@ def _deepmerge(self, other: Self, list_merge: Literal["append", "replace"] = "ap raise TypeError(msg) for field, field_info in cls._fields.items(): + if other._created_from_null and self._get_defined_attr(field) is not Undefined: + # Force the field back to unset if other is a "null" class. + delattr(self, field) + if (new_value := other._get_defined_attr(field)) is Undefined: continue old_value = self._get_defined_attr(field) @@ -246,25 +267,34 @@ def _deepmerge(self, other: Self, list_merge: Literal["append", "replace"] = "ap continue if not isinstance(old_value, type(new_value)): - # Different type so we can just replace - setattr(self, field, deepcopy(new_value)) + # Different types so we can just replace with the new value. + setattr(self, field, new_value) continue # Merge new value field_type = field_info["type"] - if issubclass(field_type, AvdBase) and isinstance(old_value, field_type): + if issubclass(field_type, AvdBase): # Merge in to the existing object + old_value = cast(AvdBase, old_value) + new_value = cast(AvdBase, new_value) old_value._deepmerge(new_value, list_merge=list_merge) continue if field_type is dict: # In-place deepmerge in to the existing dict without schema. # Deepcopying since merge() does not copy. - merge(old_value, deepcopy(new_value), list_merge=list_merge) + merge(old_value, new_value, list_merge=list_merge) continue setattr(self, field, new_value) + if other._created_from_null: + # Inherit the _created_from_null attribute to make sure we output null values instead of empty dicts. + self._created_from_null = True + elif self._created_from_null: + # We merged into a "null" class, but since we now have proper data, we clear the flag. + self._created_from_null = False + def _inherit(self, other: Self) -> None: """Update unset fields on this instance with fields from other instance. No merging.""" cls = type(self) @@ -272,6 +302,15 @@ def _inherit(self, other: Self) -> None: msg = f"Unable to inherit from type '{type(other)}' into '{cls}'" raise TypeError(msg) + if self._created_from_null: + # Null always wins, so no inheritance. + return + + if other._created_from_null: + # Nothing to inherit, but we set the flag to prevent inheriting from something else later. + self._created_from_null = True + return + for field in cls._fields: if self._get_defined_attr(field) is not Undefined: continue @@ -287,6 +326,15 @@ def _deepinherit(self, other: Self) -> None: msg = f"Unable to inherit from type '{type(other)}' into '{cls}'" raise TypeError(msg) + if self._created_from_null or self._block_inheritance: + # Null always wins, so no inheritance. + return + + if other._created_from_null: + # Nothing to inherit, and we set the special block flag to prevent inheriting from something else later. + self._block_inheritance = True + return + for field, field_info in cls._fields.items(): if (new_value := other._get_defined_attr(field)) is Undefined: continue @@ -296,19 +344,26 @@ def _deepinherit(self, other: Self) -> None: # Inherit the field only if the old value is Undefined. if old_value is Undefined: - setattr(self, field, deepcopy(new_value)) + setattr(self, field, new_value) continue # Merge new value if it is a class with inheritance support. field_type = field_info["type"] - if issubclass(field_type, (AvdModel, AvdIndexedList)) and isinstance(old_value, field_type): + if issubclass(field_type, AvdModel): # Inherit into the existing object. + old_value = cast(AvdModel, old_value) + new_value = cast(AvdModel, new_value) + old_value._deepinherit(new_value) + continue + if issubclass(field_type, AvdIndexedList): + # Inherit into the existing object. + old_value = cast(AvdIndexedList, old_value) + new_value = cast(AvdIndexedList, new_value) old_value._deepinherit(new_value) continue if field_type is dict: # In-place deepmerge in to the existing dict without schema. - # Deepcopying since merge() does not copy. merge(old_value, deepcopy(new_value), list_merge="replace") def _deepinherited(self, other: Self) -> Self: @@ -341,9 +396,10 @@ def _cast_as(self, new_type: type[T_AvdModel], ignore_extra_keys: bool = False) msg = f"Unable to cast '{cls}' as type '{new_type}' since the field '{field}' is missing from the new class. " raise TypeError(msg) if field_info != new_type._fields[field]: - if issubclass(field_info["type"], (AvdBase)) and isinstance(value, (AvdBase)): + if issubclass(field_info["type"], AvdBase): # TODO: Consider using the TypeError we raise below to ensure we know the outer type. # TODO: with suppress(TypeError): + value = cast(AvdBase, value) new_args[field] = value._cast_as(new_type._fields[field]["type"], ignore_extra_keys=ignore_extra_keys) continue @@ -351,6 +407,19 @@ def _cast_as(self, new_type: type[T_AvdModel], ignore_extra_keys: bool = False) raise TypeError(msg) new_args[field] = value - continue - return new_type(**new_args) + new_instance = new_type(**new_args) + + # Pass along the internal flags + new_instance._created_from_null = self._created_from_null + new_instance._block_inheritance = self._block_inheritance + + return new_instance + + def _compare(self, other: Self, ignore_fields: tuple[str, ...] = ()) -> bool: + cls = type(self) + if not isinstance(other, cls): + msg = f"Unable to compare '{cls}' with a '{type(other)}' class." + raise TypeError(msg) + + return all(self._get_defined_attr(field) == other._get_defined_attr(field) for field in self._fields if field not in ignore_fields) diff --git a/python-avd/tests/pyavd/molecule_scenarios/test_eos_designs_class.py b/python-avd/tests/pyavd/molecule_scenarios/test_eos_designs_class.py index 40b7eb0b4b2..823c0756b21 100644 --- a/python-avd/tests/pyavd/molecule_scenarios/test_eos_designs_class.py +++ b/python-avd/tests/pyavd/molecule_scenarios/test_eos_designs_class.py @@ -46,7 +46,7 @@ @pytest.mark.parametrize(("prefix", "expected_data"), CSC_TESTS) -def test_eos_designs_custom_structured_configuration(prefix: str | None, expected_data: dict) -> None: +def test_eos_designs_custom_structured_configuration(prefix: str | None, expected_data: EosDesigns._CustomStructuredConfigurations) -> None: data = CSC_DATA.copy() if prefix: data.update({"custom_structured_configuration_prefix": prefix}) @@ -57,7 +57,7 @@ def test_eos_designs_custom_structured_configuration(prefix: str | None, expecte for entry in loaded_model._custom_structured_configurations: assert isinstance(entry, EosDesigns._CustomStructuredConfigurationsItem) - assert loaded_model._custom_structured_configurations == expected_data + assert repr(loaded_model._custom_structured_configurations) == repr(expected_data) # eos_cli_config_gen inputs are validated by `validate_structured_config` in another file.