Skip to content

Commit

Permalink
Refactor(plugins): Improve schema models (aristanetworks#4795)
Browse files Browse the repository at this point in the history
  • Loading branch information
ClausHolbechArista authored Dec 17, 2024
1 parent 660b4b8 commit 6b27b2c
Show file tree
Hide file tree
Showing 22 changed files with 305 additions and 182 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ def __init__(self, exception: Exception) -> None:
def __call__(self, *_args: Any, **_kwargs: Any) -> NoReturn:
raise self.exception

def __getattr__(self, name: str) -> Any:
if not name.startswith("__"):
raise self.exception
return self.__getattribute__(name)


def wrap_plugin(plugin_type: Literal["filter", "test"], name: str) -> Callable:
plugin_map = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1480,7 +1480,7 @@ This feature currently provides the following configurations based on the given
`max_uplink_switches` and `max_parallel_uplinks` to ensure consistent IP addressing.

??? example "`cv_topology` example"
To use this feature set `default_interfaces` according to the intended design (see [default_intefaces](#default-interface-settings) for details) and set `use_cv_topology` to `true`.
To use this feature set `default_interfaces` according to the intended design (see [default_interfaces](#default-interface-settings) for details) and set `use_cv_topology` to `true`.
Provide a full topology under `cv_topology` like this example:

```yaml
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from functools import cached_property
from typing import TYPE_CHECKING

from pyavd._utils import default
from pyavd._utils import default, strip_empties_from_list

if TYPE_CHECKING:
from . import SharedUtils
Expand All @@ -33,6 +33,6 @@ def link_tracking_groups(self: SharedUtils) -> list | None:
else:
link_tracking_groups.append({"name": "LT_GROUP1", "recovery_delay": default_recovery_delay})

return link_tracking_groups
return strip_empties_from_list(link_tracking_groups)

return None
34 changes: 17 additions & 17 deletions python-avd/pyavd/_eos_designs/shared_utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def get_ipv4_acl(
"interface_ip": interface_ip,
"peer_ip": peer_ip,
}
changed = False
for index, entry in enumerate(ipv4_acl.entries):
if entry._get("remark"):
continue
Expand All @@ -202,13 +203,15 @@ def get_ipv4_acl(

entry.source = self._get_ipv4_acl_field_with_substitution(entry.source, ip_replacements, f"{err_context}.source", interface_name)
entry.destination = self._get_ipv4_acl_field_with_substitution(entry.destination, ip_replacements, f"{err_context}.destination", interface_name)
if entry.source != org_ipv4_acl.entries[index].source or entry.destination != org_ipv4_acl.entries[index].destination:
changed = True

if ipv4_acl != org_ipv4_acl:
if changed:
ipv4_acl.name += f"_{self.sanitize_interface_name(interface_name)}"
return ipv4_acl

@staticmethod
def _get_ipv4_acl_field_with_substitution(field_value: str, replacements: dict[str, str], field_context: str, interface_name: str) -> str:
def _get_ipv4_acl_field_with_substitution(field_value: str, replacements: dict[str, str | None], field_context: str, interface_name: str) -> str:
"""
Checks one field if the value can be substituted.
Expand All @@ -218,18 +221,15 @@ def _get_ipv4_acl_field_with_substitution(field_value: str, replacements: dict[s
If a replacement is done, but the value is None, an error will be raised.
"""
for key, value in replacements.items():
if field_value != key:
continue

if value is None:
msg = (
f"Unable to perform substitution of the value '{key}' defined under '{field_context}', "
f"since no substitution value was found for interface '{interface_name}'. "
"Make sure to set the appropriate fields on the interface."
)
raise AristaAvdError(msg)

return value

return field_value
if field_value not in replacements:
return field_value

if (replacement_value := replacements[field_value]) is None:
msg = (
f"Unable to perform substitution of the value '{field_value}' defined under '{field_context}', "
f"since no substitution value was found for interface '{interface_name}'. "
"Make sure to set the appropriate fields on the interface."
)
raise AristaAvdError(msg)

return replacement_value
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,6 @@ def route_maps(self) -> list | None:
@cached_property
def struct_cfgs(self) -> list | None:
if self.shared_utils.platform_settings.structured_config:
return [self.shared_utils.platform_settings.structured_config._as_dict(strip_values=())]
return [self.shared_utils.platform_settings.structured_config._as_dict()]

return None
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def _get_ethernet_interface_cfg(
"dot1x": adapter.dot1x._as_dict() or None,
"poe": self._get_adapter_poe(adapter),
"eos_cli": adapter.raw_eos_cli,
"struct_cfg": adapter.structured_config._as_dict(strip_values=()),
"struct_cfg": adapter.structured_config._as_dict(),
}

# Port-channel member
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def _get_port_channel_interface_cfg(
"validate_state": None if (adapter.validate_state if adapter.validate_state is not None else True) else False,
"validate_lldp": None if (adapter.validate_lldp if adapter.validate_lldp is not None else True) else False,
"eos_cli": adapter.port_channel.raw_eos_cli,
"struct_cfg": adapter.port_channel.structured_config._as_dict(strip_values=()) or None,
"struct_cfg": adapter.port_channel.structured_config._as_dict() or None,
}

if adapter.port_channel.subinterfaces:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _extract_and_apply_struct_cfg_from_list_of_dicts(self, list_of_dicts: list,
return struct_cfgs

def _struct_cfg(self) -> list:
if struct_cfg := self.shared_utils.node_config.structured_config._as_dict(strip_values=()):
if struct_cfg := self.shared_utils.node_config.structured_config._as_dict():
return [struct_cfg]

return []
Expand Down Expand Up @@ -114,9 +114,7 @@ def _router_bgp_vlans(self) -> list:
]

def _custom_structured_configurations(self) -> list[dict]:
return [
custom_structured_configuration.value._as_dict(strip_values=()) for custom_structured_configuration in self.inputs._custom_structured_configurations
]
return [custom_structured_configuration.value._as_dict() for custom_structured_configuration in self.inputs._custom_structured_configurations]

def render(self) -> list[dict]:
"""
Expand Down
10 changes: 5 additions & 5 deletions python-avd/pyavd/_eos_designs/structured_config/mlag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def vlan_interfaces(self) -> list | None:
),
"shutdown": False,
"no_autostate": True,
"struct_cfg": self.shared_utils.node_config.mlag_peer_vlan_structured_config._as_dict(strip_values=()) or None,
"struct_cfg": self.shared_utils.node_config.mlag_peer_vlan_structured_config._as_dict() or None,
"mtu": self.shared_utils.p2p_uplinks_mtu,
}

Expand All @@ -82,7 +82,7 @@ def vlan_interfaces(self) -> list | None:
return [strip_empties_from_dict(main_vlan_interface)]

# Create L3 data which will go on either a dedicated l3 vlan or the main mlag vlan
l3_cfg = {"struct_cfg": self.shared_utils.node_config.mlag_peer_l3_vlan_structured_config._as_dict(strip_values=()) or None}
l3_cfg = {"struct_cfg": self.shared_utils.node_config.mlag_peer_l3_vlan_structured_config._as_dict() or None}
if self.shared_utils.underlay_routing_protocol == "ospf":
l3_cfg.update(
{
Expand Down Expand Up @@ -121,7 +121,7 @@ def vlan_interfaces(self) -> list | None:
main_vlan_interface.update(l3_cfg)
# Applying structured config again in the case it is set on both l3vlan and main vlan
if self.shared_utils.node_config.mlag_peer_vlan_structured_config is not None:
main_vlan_interface["struct_cfg"] = self.shared_utils.node_config.mlag_peer_vlan_structured_config._as_dict(strip_values=())
main_vlan_interface["struct_cfg"] = self.shared_utils.node_config.mlag_peer_vlan_structured_config._as_dict()

return [strip_empties_from_dict(main_vlan_interface)]

Expand Down Expand Up @@ -169,7 +169,7 @@ def port_channel_interfaces(self) -> list:
},
"shutdown": False,
"service_profile": self.inputs.p2p_uplinks_qos_profile,
"struct_cfg": self.shared_utils.node_config.mlag_port_channel_structured_config._as_dict(strip_values=()) or None,
"struct_cfg": self.shared_utils.node_config.mlag_port_channel_structured_config._as_dict() or None,
"flow_tracker": self.shared_utils.get_flow_tracker(self.inputs.fabric_flow_tracking.mlag_interfaces),
}

Expand Down Expand Up @@ -350,7 +350,7 @@ def _router_bgp_mlag_peer_group(self) -> dict:
"bfd": self.inputs.bgp_peer_groups.ipv4_underlay_peers.bfd or None,
"maximum_routes": 12000,
"send_community": "all",
"struct_cfg": self.inputs.bgp_peer_groups.mlag_ipv4_underlay_peer.structured_config._as_dict(strip_values=()) or None,
"struct_cfg": self.inputs.bgp_peer_groups.mlag_ipv4_underlay_peer.structured_config._as_dict() or None,
}
if self.shared_utils.node_config.mlag_ibgp_origin_incomplete:
peer_group["route_map_in"] = "RM-MLAG-PEER-IN"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def ethernet_interfaces(self: AvdStructuredConfigNetworkServices) -> list | None
"shutdown": not l3_interface.enabled,
"description": interface_description,
"eos_cli": l3_interface.raw_eos_cli,
"struct_cfg": l3_interface.structured_config._as_dict(strip_values=()) or None,
"struct_cfg": l3_interface.structured_config._as_dict() or None,
"flow_tracker": self.shared_utils.get_flow_tracker(l3_interface.flow_tracking),
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def _router_bgp_vrfs(self: AvdStructuredConfigNetworkServices) -> dict:
bgp_vrf = strip_empties_from_dict(
{
"eos_cli": vrf.bgp.raw_eos_cli,
"struct_cfg": vrf.bgp.structured_config._as_dict(strip_values=()) or None,
"struct_cfg": vrf.bgp.structured_config._as_dict() or None,
}
)

Expand Down Expand Up @@ -493,7 +493,7 @@ def _router_bgp_vlans_vlan(
"route_targets": {"both": [vlan_rt]},
"redistribute_routes": ["learned"],
"eos_cli": vlan.bgp.raw_eos_cli,
"struct_cfg": vlan.bgp.structured_config._as_dict(strip_values=()) or None,
"struct_cfg": vlan.bgp.structured_config._as_dict() or None,
}
if self.shared_utils.node_config.evpn_gateway.evpn_l2.enabled and default(
vlan.evpn_l2_multi_domain, vrf.evpn_l2_multi_domain, tenant.evpn_l2_multi_domain
Expand Down Expand Up @@ -816,7 +816,7 @@ def _router_bgp_mlag_peer_group(self: AvdStructuredConfigNetworkServices) -> dic
"password": self.inputs.bgp_peer_groups.mlag_ipv4_underlay_peer.password,
"maximum_routes": 12000,
"send_community": "all",
"struct_cfg": self.inputs.bgp_peer_groups.mlag_ipv4_underlay_peer.structured_config._as_dict(strip_values=()) or None,
"struct_cfg": self.inputs.bgp_peer_groups.mlag_ipv4_underlay_peer.structured_config._as_dict() or None,
}

if self.shared_utils.node_config.mlag_ibgp_origin_incomplete:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def struct_cfgs(self: AvdStructuredConfigNetworkServices) -> list | None:
for vrf in tenant.vrfs:
if vrf.structured_config:
# Inserting VRF into structured_config to perform duplication checks
vrf_struct_cfg = {"vrf": vrf.name, "struct_cfg": vrf.structured_config._as_dict(strip_values=())}
vrf_struct_cfg = {"vrf": vrf.name, "struct_cfg": vrf.structured_config._as_dict()}
append_if_not_duplicate(
list_of_dicts=vrf_struct_cfgs,
primary_key="vrf",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _check_virtual_router_mac_address(vlan_interface_config: dict, variables: li
"access_group_out": get(self._svi_acls, f"{interface_name}.ipv4_acl_out.name"),
"mtu": svi.mtu if self.shared_utils.platform_settings.feature_support.per_interface_mtu else None,
"eos_cli": svi.raw_eos_cli,
"struct_cfg": svi.structured_config._as_dict(strip_values=()) or None,
"struct_cfg": svi.structured_config._as_dict() or None,
}
# Only set VARP if ip_address is set
if vlan_interface_config["ip_address"] is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _generate_base_peer_group(
"password": peer_group.password,
"send_community": "all",
"maximum_routes": maximum_routes,
"struct_cfg": peer_group.structured_config._as_dict(strip_values=()) or None,
"struct_cfg": peer_group.structured_config._as_dict() or None,
}

def _peer_groups(self: AvdStructuredConfigOverlay) -> list | None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def router_bgp(self: AvdStructuredConfigUnderlay) -> dict | None:
"bfd": self.inputs.bgp_peer_groups.ipv4_underlay_peers.bfd or None,
"maximum_routes": 12000,
"send_community": "all",
"struct_cfg": self.inputs.bgp_peer_groups.ipv4_underlay_peers.structured_config._as_dict(strip_values=()) or None,
"struct_cfg": self.inputs.bgp_peer_groups.ipv4_underlay_peers.structured_config._as_dict() or None,
}

if self.shared_utils.overlay_routing_protocol == "ibgp" and self.shared_utils.is_cv_pathfinder_router:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def _get_l3_interface_cfg(
"access_group_in": get(self._l3_interface_acls, f"{l3_interface.name}..ipv4_acl_in..name", separator=".."),
"access_group_out": get(self._l3_interface_acls, f"{l3_interface.name}..ipv4_acl_out..name", separator=".."),
"eos_cli": l3_interface.raw_eos_cli,
"struct_cfg": l3_interface.structured_config._as_dict(strip_values=()),
"struct_cfg": l3_interface.structured_config._as_dict(),
"flow_tracker": self.shared_utils.get_flow_tracker(l3_interface.flow_tracking),
}

Expand Down Expand Up @@ -268,7 +268,7 @@ def _get_l2_as_subint(
"ipv6_enable": svi.ipv6_enable,
"mtu": svi.mtu if self.shared_utils.platform_settings.feature_support.per_interface_mtu else None,
"eos_cli": svi.raw_eos_cli,
"struct_cfg": svi.structured_config._as_dict(strip_values=()) or None,
"struct_cfg": svi.structured_config._as_dict() or None,
"flow_tracker": link.get("flow_tracker"),
}
if (mtu := subinterface["mtu"]) is not None and subinterface["mtu"] > self.shared_utils.p2p_uplinks_mtu:
Expand Down
58 changes: 1 addition & 57 deletions python-avd/pyavd/_schema/coerce_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,64 +12,9 @@
if TYPE_CHECKING:
from typing import NoReturn, TypeVar

from typing_extensions import Self

from pyavd._schema.models.type_vars import T_AvdBase

T = TypeVar("T")


def nullifiy_class(cls: type[T_AvdBase]) -> type:
"""
Returns a subclass of the given class with overrides for "null" values.
This class is used when the input for a dict or a list is None/null/none,
to be able to signal to the deepmerge/inherit methods that this is not the same as an unset variable.
"""

class NullifiedCls(cls):
def _get_defined_attr(self, name: str) -> T_AvdBase | None:
"""
Return the default values or None.
This is required for the various merge / inheritance logic to always take from this if undefined.
"""
return getattr(self, name)

def _as_dict(self, *_args: Any, **_kwargs: Any) -> None:
"""Always None."""

def _as_list(self, *_args: Any, **_kwargs: Any) -> None:
"""Always None."""

def __repr__(self) -> str:
return f"<NullifiedCls[{cls.__name__}]>"

def _deepinherited(self, *_args: Any, **_kwargs: Any) -> Self:
"""Nothing to do since a NullifiedCls will override anything with None."""
return self._deepcopy()

def _deepinherit(self, *_args: Any, **_kwargs: Any) -> None:
"""Nothing to do since a NullifiedCls will override anything with None."""

def _inherit(self, *_args: Any, **_kwargs: Any) -> None:
"""Nothing to do since a NullifiedCls will override anything with None."""

def _deepmerge(self, *_args: Any, **_kwargs: Any) -> NoReturn:
msg = "A NullifiedCls cannot be inplace deepmerged. Use _deepmerged() instead."
raise NotImplementedError(msg)

def _deepmerged(self, other: T_AvdBase, *_args: Any, **_kwargs: Any) -> T_AvdBase:
"""Returning the other directly since NullifiedCls is empty."""
return other._deepcopy()

def _cast_as(self, new_type: type[T_AvdBase], *_args: Any, **_kwargs: Any) -> T_AvdBase:
"""Wrap the new type in it's own NullifiedCls."""
return nullifiy_class(new_type)()

return NullifiedCls


def coerce_type(value: Any, target_type: type[T]) -> T:
"""
Return a coerced variant of the given value to the target_type.
Expand All @@ -81,8 +26,7 @@ def coerce_type(value: Any, target_type: type[T]) -> T:
if value is None:
if issubclass(target_type, AvdBase):
# None values are sometimes used to overwrite inherited profiles.
# This ensures we still follow the type hint of the class.
return nullifiy_class(target_type)()
return target_type._from_null()

# Other None values are left untouched.
elif target_type is Any or isinstance(value, target_type):
Expand Down
33 changes: 27 additions & 6 deletions python-avd/pyavd/_schema/models/avd_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,21 @@
class AvdBase(ABC):
"""Base class used for schema-based data classes holding data loaded from AVD inputs."""

def __eq__(self, other: object) -> bool:
"""Compare two instances of AvdBase by comparing their repr."""
if isinstance(other, self.__class__):
return repr(self) == repr(other)
return False
_created_from_null: bool = False
"""
Flag to say if this data was loaded from a '<key>: null' value in YAML.
This is used to handle inheritance and merging correctly.
When _created_from_null we inherit nothing (we win!).
When _created_from_null we take anything in when merging and clear the flag.
TODO: Stop changing data in-place.
The flag is not carried across between classes, so it should not affect anything outside the loaded inputs.
Only exception is on _cast_as, where the flag is carried over.
"""

_block_inheritance: bool = False
"""Flag to block inheriting further if we at some point inherited from a class with _created_from_null set."""

def _deepcopy(self) -> Self:
"""Return a copy including all nested models."""
Expand All @@ -33,8 +43,19 @@ def _deepcopy(self) -> Self:
def _load(cls, data: Sequence | Mapping) -> Self:
"""Returns a new instance loaded with the given data."""

@classmethod
def _from_null(cls) -> Self:
"""Returns a new instance with all attributes set to None. This represents the YAML input '<key>: null'."""
new_instance = cls()
new_instance._created_from_null = True
return new_instance

@abstractmethod
def _strip_empties(self) -> None:
"""In-place update the instance to remove data matching the given strip_values."""

@abstractmethod
def _dump(self, include_default_values: bool = False, strip_values: tuple = (None, [], {})) -> dict | list:
def _dump(self, include_default_values: bool = False) -> dict | list:
"""Dump data into native Python types with or without default values."""

@abstractmethod
Expand Down
Loading

0 comments on commit 6b27b2c

Please sign in to comment.