From 0d5f35d15eb95ffd58c6b4de593b5ed5390537e8 Mon Sep 17 00:00:00 2001 From: Claus Holbech Date: Fri, 2 Aug 2024 18:42:03 +0200 Subject: [PATCH] CI: Add Ruff config and fix tons of linting issues (#4310) Co-authored-by: Guillaume Mulocher --- .github/generate_release.py | 42 +-- .../custom_interface_descriptions.py | 51 ++-- ...custom_interface_descriptions_with_data.py | 17 +- .../custom_modules/custom_ip_addressing.py | 64 ++-- .../arista/avd/plugins/action/cv_workflow.py | 76 +++-- .../avd/plugins/action/eos_cli_config_gen.py | 50 ++-- .../avd/plugins/action/eos_designs_facts.py | 22 +- .../action/eos_designs_structured_config.py | 15 +- .../action/eos_validate_state_reports.py | 19 +- .../action/eos_validate_state_runner.py | 19 +- .../plugins/action/inventory_to_container.py | 29 +- .../arista/avd/plugins/action/set_vars.py | 6 +- .../avd/plugins/action/verify_requirements.py | 98 +++---- .../arista/avd/plugins/filter/add_md_toc.py | 8 +- .../avd/plugins/filter/convert_dicts.py | 8 +- .../arista/avd/plugins/filter/decrypt.py | 6 +- .../arista/avd/plugins/filter/default.py | 8 +- .../arista/avd/plugins/filter/encrypt.py | 6 +- .../avd/plugins/filter/hide_passwords.py | 6 +- .../arista/avd/plugins/filter/is_in_filter.py | 8 +- .../avd/plugins/filter/list_compress.py | 8 +- .../arista/avd/plugins/filter/natural_sort.py | 8 +- .../arista/avd/plugins/filter/range_expand.py | 8 +- .../arista/avd/plugins/filter/snmp_hash.py | 7 +- .../avd/plugins/filter/status_render.py | 9 +- .../plugins/modules/configlet_build_config.py | 49 ++-- .../plugins/modules/inventory_to_container.py | 143 +++++---- .../ansible_eos_device.py | 15 +- .../eos_validate_state_utils/avdtestbase.py | 6 +- .../config_manager.py | 6 +- .../eos_validate_state_utils/csv_report.py | 5 +- .../get_anta_results.py | 6 +- .../eos_validate_state_utils/md_report.py | 5 +- .../eos_validate_state_utils/mixins.py | 6 +- .../results_manager.py | 4 +- .../plugin_utils/merge/mergecatalog.py | 2 +- .../plugins/plugin_utils/pyavd_wrappers.py | 28 +- .../plugin_utils/schema/avdschematools.py | 65 +++-- .../plugin_utils/utils/compile_searchpath.py | 13 +- .../plugin_utils/utils/cprofile_decorator.py | 5 +- .../plugins/plugin_utils/utils/get_templar.py | 4 +- .../plugins/plugin_utils/utils/log_message.py | 3 +- .../python_to_ansible_logging_handler.py | 23 +- .../plugins/plugin_utils/utils/yaml_dumper.py | 4 +- .../arista/avd/plugins/test/contains.py | 43 ++- .../arista/avd/plugins/test/defined.py | 45 ++- .../arista/avd/plugins/vars/global_vars.py | 40 +-- .../interface_descriptions/__init__.py | 2 +- .../python_modules/ip_addressing/__init__.py | 2 +- .../tests/avdtestconnectivity.py | 34 +-- .../python_modules/tests/avdtesthardware.py | 10 +- .../python_modules/tests/avdtestinterfaces.py | 9 +- .../python_modules/tests/avdtestmlag.py | 4 +- .../python_modules/tests/avdtestrouting.py | 17 +- .../python_modules/tests/avdtestsecurity.py | 13 +- .../python_modules/tests/avdteststun.py | 5 +- .../python_modules/tests/avdtestsystem.py | 10 +- .../unit/action/test_verify_requirements.py | 132 ++++----- .../modules/test_configlet_build_config.py | 21 +- .../modules/test_inventory_to_container.py | 82 +++--- .../eos_validate_state_utils/test_catalog.py | 4 +- pyproject.toml | 78 +++++ .../pyavd/_cv/api/arista/alert/v1/__init__.py | 274 ++---------------- .../_cv/api/arista/configlet/v1/__init__.py | 35 +++ .../_cv/api/arista/imagestatus/v1/__init__.py | 55 +++- .../_cv/api/arista/workspace/v1/__init__.py | 3 + python-avd/pyavd/_cv/client/__init__.py | 34 ++- python-avd/pyavd/_cv/client/change_control.py | 50 ++-- python-avd/pyavd/_cv/client/configlet.py | 81 +++--- python-avd/pyavd/_cv/client/exceptions.py | 20 +- python-avd/pyavd/_cv/client/inventory.py | 23 +- python-avd/pyavd/_cv/client/studio.py | 187 ++++-------- python-avd/pyavd/_cv/client/swg.py | 17 +- python-avd/pyavd/_cv/client/tag.py | 74 ++--- python-avd/pyavd/_cv/client/utils.py | 28 +- python-avd/pyavd/_cv/client/workspace.py | 42 +-- .../_cv/workflows/create_workspace_on_cv.py | 16 +- .../_cv/workflows/deploy_configs_to_cv.py | 33 ++- .../deploy_cv_pathfinder_metadata_to_cv.py | 50 ++-- .../workflows/deploy_studio_inputs_to_cv.py | 45 +-- .../pyavd/_cv/workflows/deploy_tags_to_cv.py | 18 +- .../pyavd/_cv/workflows/deploy_to_cv.py | 7 +- .../finalize_change_control_on_cv.py | 23 +- .../_cv/workflows/finalize_workspace_on_cv.py | 23 +- python-avd/pyavd/_cv/workflows/models.py | 4 +- .../_cv/workflows/verify_devices_on_cv.py | 56 ++-- python-avd/pyavd/_eos_designs/avdfacts.py | 30 +- .../eos_designs_facts/__init__.py | 114 +++----- .../_eos_designs/eos_designs_facts/mlag.py | 21 +- .../_eos_designs/eos_designs_facts/overlay.py | 34 +-- .../eos_designs_facts/short_esi.py | 17 +- .../_eos_designs/eos_designs_facts/uplinks.py | 117 ++++---- .../_eos_designs/eos_designs_facts/vlans.py | 40 +-- .../_eos_designs/eos_designs_facts/wan.py | 16 +- .../interface_descriptions/__init__.py | 66 +++-- .../interface_descriptions/models.py | 7 +- .../interface_descriptions/utils.py | 3 +- .../_eos_designs/ip_addressing/__init__.py | 78 +++-- .../pyavd/_eos_designs/ip_addressing/utils.py | 46 +-- .../_eos_designs/shared_utils/__init__.py | 5 +- .../shared_utils/bgp_peer_groups.py | 21 +- .../shared_utils/connected_endpoints_keys.py | 14 +- .../_eos_designs/shared_utils/cv_topology.py | 26 +- .../shared_utils/filtered_tenants.py | 74 ++--- .../shared_utils/flow_tracking.py | 20 +- .../shared_utils/inband_management.py | 37 +-- .../shared_utils/interface_descriptions.py | 18 +- .../shared_utils/ip_addressing.py | 24 +- .../shared_utils/l3_interfaces.py | 33 +-- .../shared_utils/link_tracking_groups.py | 7 +- .../pyavd/_eos_designs/shared_utils/mgmt.py | 19 +- .../pyavd/_eos_designs/shared_utils/misc.py | 78 ++--- .../pyavd/_eos_designs/shared_utils/mlag.py | 55 ++-- .../_eos_designs/shared_utils/node_type.py | 82 +++--- .../shared_utils/node_type_keys.py | 26 +- .../_eos_designs/shared_utils/overlay.py | 46 ++- .../_eos_designs/shared_utils/platform.py | 11 +- .../pyavd/_eos_designs/shared_utils/ptp.py | 11 +- .../_eos_designs/shared_utils/routing.py | 46 ++- .../_eos_designs/shared_utils/switch_data.py | 22 +- .../_eos_designs/shared_utils/underlay.py | 11 +- .../pyavd/_eos_designs/shared_utils/utils.py | 28 +- .../pyavd/_eos_designs/shared_utils/wan.py | 162 +++++------ .../structured_config/__init__.py | 14 +- .../structured_config/base/__init__.py | 266 +++++++---------- .../structured_config/base/ntp.py | 20 +- .../structured_config/base/snmp_server.py | 52 ++-- .../structured_config/base/utils.py | 14 +- .../connected_endpoints/__init__.py | 7 +- .../ethernet_interfaces.py | 47 +-- .../connected_endpoints/monitor_sessions.py | 24 +- .../port_channel_interfaces.py | 46 +-- .../connected_endpoints/utils.py | 78 +++-- .../core_interfaces_and_l3_edge/__init__.py | 7 +- .../ethernet_interfaces.py | 10 +- .../port_channel_interfaces.py | 7 +- .../core_interfaces_and_l3_edge/router_bgp.py | 17 +- .../router_ospf.py | 12 +- .../core_interfaces_and_l3_edge/utils.py | 43 ++- .../__init__.py | 25 +- .../structured_config/flows/__init__.py | 46 +-- .../inband_management/__init__.py | 30 +- .../structured_config/metadata/__init__.py | 10 +- .../metadata/cv_pathfinder.py | 35 ++- .../structured_config/metadata/cv_tags.py | 66 +++-- .../structured_config/mlag/__init__.py | 83 +++--- .../network_services/__init__.py | 8 +- .../application_traffic_recognition.py | 34 ++- .../network_services/dps_interfaces.py | 7 +- .../network_services/eos_cli.py | 16 +- .../network_services/ethernet_interfaces.py | 72 +++-- .../network_services/ip_access_lists.py | 24 +- .../network_services/ip_igmp_snooping.py | 15 +- .../network_services/ip_nat.py | 7 +- .../network_services/ip_security.py | 17 +- .../ip_virtual_router_mac_address.py | 7 +- .../network_services/ipv6_static_routes.py | 6 +- .../network_services/loopback_interfaces.py | 9 +- .../network_services/metadata.py | 22 +- .../network_services/monitor_connectivity.py | 15 +- .../network_services/patch_panel.py | 19 +- .../port_channel_interfaces.py | 57 ++-- .../network_services/prefix_lists.py | 17 +- .../network_services/route_maps.py | 54 ++-- .../router_adaptive_virtual_topology.py | 33 +-- .../network_services/router_bgp.py | 203 ++++++------- .../network_services/router_internet_exit.py | 8 +- .../network_services/router_isis.py | 6 +- .../network_services/router_multicast.py | 15 +- .../network_services/router_ospf.py | 20 +- .../network_services/router_path_selection.py | 23 +- .../router_pim_sparse_mode.py | 15 +- .../router_service_insertion.py | 5 +- .../network_services/spanning_tree.py | 14 +- .../network_services/standard_access_lists.py | 9 +- .../network_services/static_routes.py | 6 +- .../network_services/struct_cfgs.py | 11 +- .../network_services/tunnel_interfaces.py | 8 +- .../network_services/utils.py | 217 +++++++------- .../network_services/utils_zscaler.py | 38 ++- .../virtual_source_nat_vrfs.py | 8 +- .../network_services/vlan_interfaces.py | 34 ++- .../network_services/vlans.py | 20 +- .../network_services/vrfs.py | 9 +- .../network_services/vxlan_interface.py | 36 ++- .../structured_config/overlay/__init__.py | 10 +- .../structured_config/overlay/cvx.py | 10 +- .../overlay/ip_extcommunity_lists.py | 9 +- .../structured_config/overlay/ip_security.py | 35 +-- .../overlay/management_cvx.py | 6 +- .../overlay/management_security.py | 7 +- .../structured_config/overlay/route_maps.py | 23 +- .../router_adaptive_virtual_topology.py | 7 +- .../structured_config/overlay/router_bfd.py | 10 +- .../structured_config/overlay/router_bgp.py | 145 ++++----- .../overlay/router_path_selection.py | 79 ++--- .../overlay/router_traffic_engineering.py | 8 +- .../structured_config/overlay/stun.py | 16 +- .../structured_config/overlay/utils.py | 36 +-- .../structured_config/underlay/__init__.py | 3 +- .../structured_config/underlay/agents.py | 12 +- .../structured_config/underlay/as_path.py | 9 +- .../underlay/ethernet_interfaces.py | 41 +-- .../underlay/ip_access_lists.py | 8 +- .../underlay/loopback_interfaces.py | 26 +- .../structured_config/underlay/mpls.py | 7 +- .../underlay/port_channel_interfaces.py | 14 +- .../underlay/prefix_lists.py | 38 ++- .../structured_config/underlay/route_maps.py | 18 +- .../structured_config/underlay/router_bgp.py | 25 +- .../structured_config/underlay/router_isis.py | 21 +- .../structured_config/underlay/router_msdp.py | 11 +- .../structured_config/underlay/router_ospf.py | 10 +- .../underlay/router_pim_sparse_mode.py | 13 +- .../underlay/standard_access_lists.py | 8 +- .../underlay/static_routes.py | 9 +- .../structured_config/underlay/utils.py | 86 +++--- .../structured_config/underlay/vlans.py | 31 +- python-avd/pyavd/_errors/__init__.py | 29 +- python-avd/pyavd/_schema/avddataconverter.py | 66 +++-- python-avd/pyavd/_schema/avdschema.py | 54 ++-- python-avd/pyavd/_schema/avdvalidator.py | 48 +-- python-avd/pyavd/_schema/store.py | 12 +- .../pyavd/_utils/append_if_not_duplicate.py | 5 +- python-avd/pyavd/_utils/batch.py | 6 +- python-avd/pyavd/_utils/compare_dicts.py | 2 +- python-avd/pyavd/_utils/default.py | 10 +- python-avd/pyavd/_utils/get.py | 25 +- python-avd/pyavd/_utils/get_all.py | 18 +- .../_utils/get_indices_of_duplicate_items.py | 7 +- python-avd/pyavd/_utils/get_ip_from_pool.py | 14 +- python-avd/pyavd/_utils/get_item.py | 28 +- python-avd/pyavd/_utils/groupby.py | 10 +- python-avd/pyavd/_utils/load_python_class.py | 18 +- python-avd/pyavd/_utils/merge/__init__.py | 42 ++- .../pyavd/_utils/merge/mergeonschema.py | 26 +- .../pyavd/_utils/password_utils/password.py | 68 +++-- .../_utils/password_utils/password_utils.py | 15 +- .../pyavd/_utils/replace_or_append_item.py | 9 +- python-avd/pyavd/_utils/strip_empties.py | 41 ++- python-avd/pyavd/_utils/template.py | 11 +- python-avd/pyavd/_utils/template_var.py | 8 +- python-avd/pyavd/_utils/unique.py | 7 +- python-avd/pyavd/avd_schema_tools.py | 26 +- python-avd/pyavd/get_avd_facts.py | 14 +- python-avd/pyavd/get_device_config.py | 0 python-avd/pyavd/get_device_doc.py | 0 .../pyavd/get_device_structured_config.py | 3 +- python-avd/pyavd/j2filters/add_md_toc.py | 27 +- python-avd/pyavd/j2filters/convert_dicts.py | 41 +-- python-avd/pyavd/j2filters/decrypt.py | 10 +- python-avd/pyavd/j2filters/default.py | 6 +- python-avd/pyavd/j2filters/encrypt.py | 10 +- python-avd/pyavd/j2filters/hide_passwords.py | 3 +- python-avd/pyavd/j2filters/is_in_filter.py | 2 +- python-avd/pyavd/j2filters/list_compress.py | 11 +- python-avd/pyavd/j2filters/natural_sort.py | 17 +- python-avd/pyavd/j2filters/range_expand.py | 178 ++++++------ python-avd/pyavd/j2filters/snmp_hash.py | 24 +- python-avd/pyavd/j2filters/status_render.py | 18 +- python-avd/pyavd/j2tests/__init__.py | 3 + python-avd/pyavd/j2tests/contains.py | 4 +- python-avd/pyavd/j2tests/defined.py | 41 ++- python-avd/pyavd/templater.py | 30 +- python-avd/pyavd/validate_inputs.py | 2 +- .../pyavd/validate_structured_config.py | 2 +- python-avd/pyavd/validation_result.py | 4 +- python-avd/pyproject.toml | 13 +- python-avd/schema_tools/avdschemaresolver.py | 48 +-- python-avd/schema_tools/constants.py | 1 + .../schema_tools/generate_docs/mdtabsgen.py | 11 +- .../schema_tools/generate_docs/tablegen.py | 9 +- .../schema_tools/generate_docs/tablerowgen.py | 61 ++-- .../schema_tools/generate_docs/utils.py | 2 +- .../schema_tools/generate_docs/yamlgen.py | 6 +- .../schema_tools/generate_docs/yamllinegen.py | 75 ++--- .../schema_tools/key_to_display_name.py | 7 +- .../metaschema/meta_schema_model.py | 57 ++-- .../schema_tools/metaschema/resolvemodel.py | 17 +- python-avd/schema_tools/store.py | 36 +-- python-avd/scripts/build-schemas.py | 23 +- python-avd/scripts/custom_build_backend.py | 33 +-- .../tests/pyavd/eos_designs/conftest.py | 1 + .../tests/pyavd/j2filters/test_add_md_toc.py | 15 +- .../pyavd/j2filters/test_convert_dict.py | 1 + .../tests/pyavd/j2filters/test_decrypt.py | 5 +- .../tests/pyavd/j2filters/test_default.py | 1 + .../tests/pyavd/j2filters/test_encrypt.py | 1 + .../pyavd/j2filters/test_hide_passwords.py | 1 + .../pyavd/j2filters/test_is_in_filter.py | 1 + .../pyavd/j2filters/test_list_compress.py | 1 + .../pyavd/j2filters/test_natural_sort.py | 1 + .../pyavd/j2filters/test_range_expand.py | 1 + .../tests/pyavd/j2filters/test_snmp_hash.py | 1 + .../pyavd/j2filters/test_status_render.py | 1 + .../tests/pyavd/j2tests/test_contains.py | 1 + .../pyavd/j2tests/test_defined_plugin.py | 1 + .../tests/pyavd/schema/test_avdschema.py | 1 + .../tests/pyavd/utils/merge/test_merge.py | 1 + .../pyavd/utils/password/test_password.py | 1 + .../utils/password/test_password_utils.py | 1 + python-avd/tests/pyavd/utils/test_get.py | 1 + .../pyavd/utils/test_get_ip_from_pool.py | 1 + .../utils/test_short_esi_to_route_target.py | 1 + .../tests/pyavd/utils/test_strip_empties.py | 1 + 305 files changed, 4259 insertions(+), 4334 deletions(-) mode change 100644 => 100755 .github/generate_release.py mode change 100755 => 100644 python-avd/pyavd/get_device_config.py mode change 100755 => 100644 python-avd/pyavd/get_device_doc.py mode change 100755 => 100644 python-avd/pyavd/get_device_structured_config.py create mode 100644 python-avd/pyavd/j2tests/__init__.py diff --git a/.github/generate_release.py b/.github/generate_release.py old mode 100644 new mode 100755 index 5b18d1bf1ad..f9bdcee42f4 --- a/.github/generate_release.py +++ b/.github/generate_release.py @@ -3,12 +3,15 @@ # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. """ -generate_release.py +generate_release.py. This script is used to generate the release.yml file as per https://docs.github.com/en/repositories/releasing-projects-on-github/automatically-generated-release-notes """ +from pathlib import Path +from typing import Any + import yaml SCOPES = [ @@ -37,9 +40,9 @@ "Fix": "Bug Fixes", "Cut": "Cut", "Doc": "Documentation", - # "CI": "CI", + # Excluding "CI": "CI", "Bump": "Bump", - # "Test": "Test", + # Excluding "Test": "Test", "Revert": "Revert", "Refactor": "Refactoring", } @@ -47,13 +50,14 @@ class SafeDumper(yaml.SafeDumper): """ - Make yamllint happy + Make yamllint happy. + https://github.com/yaml/pyyaml/issues/234#issuecomment-765894586 """ # pylint: disable=R0901,W0613,W1113 - def increase_indent(self, flow=False, *args, **kwargs): + def increase_indent(self, flow: bool = False, *_args: Any, **_kwargs: Any) -> None: return super().increase_indent(flow=flow, indentless=False) @@ -89,14 +93,14 @@ def increase_indent(self, flow=False, *args, **kwargs): { "title": "Breaking Changes", "labels": breaking_labels, - } + }, ) # Add fixes in eos_cli_config_gen categories_list.append( { "title": "Fixed issues in eos_cli_config_gen", "labels": ["rn: Fix(eos_cli_config_gen)"], - } + }, ) # Add fixes in eos_designs @@ -104,7 +108,7 @@ def increase_indent(self, flow=False, *args, **kwargs): { "title": "Fixed issues in eos_designs", "labels": ["rn: Fix(eos_designs)"], - } + }, ) # Add fixes in eos_cli_config_gen|eos_designs or eos_designs|eos_cli_config_gen categories_list.append( @@ -114,7 +118,7 @@ def increase_indent(self, flow=False, *args, **kwargs): "rn: Fix(eos_cli_config_gen|eos_designs)", "rn: Fix(eos_designs|eos_cli_config_gen)", ], - } + }, ) # Add other fixes @@ -125,7 +129,7 @@ def increase_indent(self, flow=False, *args, **kwargs): { "title": "Other Fixed issues", "labels": other_fixes_labels, - } + }, ) # Add Documentation - except for PyAVD @@ -136,7 +140,7 @@ def increase_indent(self, flow=False, *args, **kwargs): { "title": "Documentation", "labels": doc_labels, - } + }, ) # Add new features in eos_cli_config_gen @@ -144,7 +148,7 @@ def increase_indent(self, flow=False, *args, **kwargs): { "title": "New features and enhancements in eos_cli_config_gen", "labels": ["rn: Feat(eos_cli_config_gen)"], - } + }, ) # Add new features in eos_designs @@ -152,7 +156,7 @@ def increase_indent(self, flow=False, *args, **kwargs): { "title": "New features and enhancements in eos_designs", "labels": ["rn: Feat(eos_designs)"], - } + }, ) # Add new features in both eos_cli_config_gen|eos_designs or eos_designs|eos_cli_config_gen @@ -163,7 +167,7 @@ def increase_indent(self, flow=False, *args, **kwargs): "rn: Feat(eos_cli_config_gen|eos_designs)", "rn: Feat(eos_designs|eos_cli_config_gen)", ], - } + }, ) # Add other new features @@ -174,7 +178,7 @@ def increase_indent(self, flow=False, *args, **kwargs): { "title": "Other new features and enhancements", "labels": other_feat_labels, - } + }, ) # Add all PyAVD changes @@ -183,7 +187,7 @@ def increase_indent(self, flow=False, *args, **kwargs): { "title": "PyAVD Changes", "labels": pyavd_labels, - } + }, ) # Add the catch all @@ -191,15 +195,15 @@ def increase_indent(self, flow=False, *args, **kwargs): { "title": "Other Changes", "labels": ["*"], - } + }, ) - with open(r"release.yml", "w", encoding="utf-8") as release_file: + with Path("release.yml").open("w", encoding="utf-8") as release_file: yaml.dump( { "changelog": { "exclude": {"labels": exclude_list}, "categories": categories_list, - } + }, }, release_file, Dumper=SafeDumper, diff --git a/ansible_collections/arista/avd/molecule/eos_designs_unit_tests/custom_modules/custom_interface_descriptions.py b/ansible_collections/arista/avd/molecule/eos_designs_unit_tests/custom_modules/custom_interface_descriptions.py index d096a85fd88..6feba93258f 100644 --- a/ansible_collections/arista/avd/molecule/eos_designs_unit_tests/custom_modules/custom_interface_descriptions.py +++ b/ansible_collections/arista/avd/molecule/eos_designs_unit_tests/custom_modules/custom_interface_descriptions.py @@ -12,18 +12,15 @@ class CustomAvdInterfaceDescriptions(AvdInterfaceDescriptions): @cached_property - def _custom_description_prefix(self): + def _custom_description_prefix(self) -> str: return get(self._hostvars, "description_prefix", "") @cached_property - def _switch_type(self): + def _switch_type(self) -> str: return get(self._hostvars, "switch.type", "") def underlay_ethernet_interfaces(self, link_type: str, link_peer: str, link_peer_interface: str) -> str: - """ - Implementation of custom code similar to jinja in - custom_templates/interface_descriptions/underlay/ethernet-interfaces.j2 - """ + """Implementation of custom code similar to jinja in custom_templates/interface_descriptions/underlay/ethernet-interfaces.j2.""" link_peer = str(link_peer).upper() if link_type == "underlay_p2p": return f"{self._custom_description_prefix}_P2P_LINK_TO_{link_peer}_{link_peer_interface}" @@ -33,10 +30,11 @@ def underlay_ethernet_interfaces(self, link_type: str, link_peer: str, link_peer return "" - def underlay_port_channel_interfaces(self, link_peer: str, link_peer_channel_group_id: int, link_channel_description: str) -> str: + def underlay_port_channel_interfaces(self, link_peer: str, link_peer_channel_group_id: int, link_channel_description: str | None) -> str: """ - Implementation of custom code similar to jinja: - custom_templates/interface_descriptions/underlay/port-channel-interfaces.j2 + Implementation of custom code similar to jinja. + + custom_templates/interface_descriptions/underlay/port-channel-interfaces.j2. """ if link_channel_description is not None: link_channel_description = str(link_channel_description).upper() @@ -47,45 +45,48 @@ def underlay_port_channel_interfaces(self, link_peer: str, link_peer_channel_gro def mlag_ethernet_interfaces(self, mlag_interface: str) -> str: """ - Implementation of custom code similar to jinja: - custom_templates/interface_descriptions/mlag/ethernet-interfaces.j2 + Implementation of custom code similar to jinja. + + custom_templates/interface_descriptions/mlag/ethernet-interfaces.j2. """ return f"{self._custom_description_prefix}_MLAG_PEER_{self._mlag_peer}_{mlag_interface}" def mlag_port_channel_interfaces(self) -> str: """ - Implementation of custom code similar to jinja: - custom_templates/interface_descriptions/mlag/port-channel-interfaces.j2 + Implementation of custom code similar to jinja. + + custom_templates/interface_descriptions/mlag/port-channel-interfaces.j2. """ return f"{self._custom_description_prefix}_MLAG_PEER_{self._mlag_peer}_Po{self._mlag_port_channel_id}" - def connected_endpoints_ethernet_interfaces(self, peer: str = None, peer_interface: str = None) -> str: + def connected_endpoints_ethernet_interfaces(self, peer: str | None = None, peer_interface: str | None = None) -> str: """ - Implementation of custom code similar to jinja: - custom_templates/interface_descriptions/connected-endpoints/ethernet-interfaces.j2 + Implementation of custom code similar to jinja. + + custom_templates/interface_descriptions/connected-endpoints/ethernet-interfaces.j2. """ elements = [peer, peer_interface] return "_".join([str(element) for element in elements if element is not None]) - def connected_endpoints_port_channel_interfaces(self, peer: str = None, adapter_port_channel_description: str = None) -> str: + def connected_endpoints_port_channel_interfaces(self, peer: str | None = None, adapter_port_channel_description: str | None = None) -> str: """ - Implementation of custom code similar to jinja: - custom_templates/interface_descriptions/connected-endpoints/port-channel-interfaces.j2 + Implementation of custom code similar to jinja. + + custom_templates/interface_descriptions/connected-endpoints/port-channel-interfaces.j2. """ elements = [self._custom_description_prefix, peer, adapter_port_channel_description] return "_".join([str(element) for element in elements if element is not None]) - def overlay_loopback_interface(self, overlay_loopback_description: str = None) -> str: + def overlay_loopback_interface(self, overlay_loopback_description: str | None = None) -> str: # noqa: ARG002 """ - Implementation of custom code similar to jinja: - custom_templates/interface_descriptions/loopbacks/overlay-loopback.j2 + Implementation of custom code similar to jinja. + + custom_templates/interface_descriptions/loopbacks/overlay-loopback.j2. """ switch_type = str(self._switch_type).upper() return f"{self._custom_description_prefix}_EVPN_Overlay_Peering_{switch_type}" def vtep_loopback_interface(self) -> str: - """ - Implementation of custom code similar to jinja: - """ + """Implementation of custom code similar to jinja.""" switch_type = str(self._switch_type).upper() return f"{self._custom_description_prefix}_VTEP_VXLAN_Tunnel_Source_{switch_type}" diff --git a/ansible_collections/arista/avd/molecule/eos_designs_unit_tests/custom_modules/custom_interface_descriptions_with_data.py b/ansible_collections/arista/avd/molecule/eos_designs_unit_tests/custom_modules/custom_interface_descriptions_with_data.py index a0195b64eaa..15b2e89424b 100644 --- a/ansible_collections/arista/avd/molecule/eos_designs_unit_tests/custom_modules/custom_interface_descriptions_with_data.py +++ b/ansible_collections/arista/avd/molecule/eos_designs_unit_tests/custom_modules/custom_interface_descriptions_with_data.py @@ -12,7 +12,7 @@ class CustomAvdInterfaceDescriptions(AvdInterfaceDescriptions): @cached_property - def _custom_description_prefix(self): + def _custom_description_prefix(self) -> str: return get(self._hostvars, "description_prefix", "") def underlay_ethernet_interface(self, data: InterfaceDescriptionData) -> str: @@ -120,12 +120,14 @@ def connected_endpoints_port_channel_interface(self, data: InterfaceDescriptionD def router_id_loopback_interface(self, data: InterfaceDescriptionData) -> str: """ + Called per device. + Available data: - - description - - mpls_overlay_role - - mpls_lsr - - overlay_routing_protocol - - type + - description + - mpls_overlay_role + - mpls_lsr + - overlay_routing_protocol + - type """ switch_type = str(data.type).upper() return f"{self._custom_description_prefix}_EVPN_Overlay_Peering_{switch_type}" @@ -133,7 +135,8 @@ def router_id_loopback_interface(self, data: InterfaceDescriptionData) -> str: def vtep_loopback_interface(self) -> str: """ Implementation of custom code similar to jinja. - TODO: AVD5.0.0 Update to use InterfaceDescriptionData + + TODO: AVD5.0.0 Update to use InterfaceDescriptionData. """ switch_type = str(self.shared_utils.type).upper() return f"{self._custom_description_prefix}_VTEP_VXLAN_Tunnel_Source_{switch_type}" diff --git a/ansible_collections/arista/avd/molecule/eos_designs_unit_tests/custom_modules/custom_ip_addressing.py b/ansible_collections/arista/avd/molecule/eos_designs_unit_tests/custom_modules/custom_ip_addressing.py index 63ef4b3062b..08f202021fd 100644 --- a/ansible_collections/arista/avd/molecule/eos_designs_unit_tests/custom_modules/custom_ip_addressing.py +++ b/ansible_collections/arista/avd/molecule/eos_designs_unit_tests/custom_modules/custom_ip_addressing.py @@ -8,53 +8,56 @@ class CustomAvdIpAddressing(AvdIpAddressing): - pass - @cached_property - def _custom_ip_offset_10(self): - return self._hostvars.get("ip_offset_10", 0) + def _custom_ip_offset_10(self) -> int: + return int(self._hostvars.get("ip_offset_10", 0)) @cached_property - def _custom_ip_offset_10_subnets(self): + def _custom_ip_offset_10_subnets(self) -> int: """ The jinja code did a blind addition of 10 to the resulting IP even for subnets. - Here we divide the offset with two, since we are calculating /31 subnets more intelligently + + Here we divide the offset with two, since we are calculating /31 subnets more intelligently. """ return int(self._custom_ip_offset_10 / 2) @cached_property - def _custom_ip_offset_20(self): - return self._hostvars.get("ip_offset_20", 0) + def _custom_ip_offset_20(self) -> int: + return int(self._hostvars.get("ip_offset_20", 0)) @cached_property - def _custom_ip_offset_20_subnets(self): + def _custom_ip_offset_20_subnets(self) -> int: """ The jinja code did a blind addition of 20 to the resulting IP even for subnets. - Here we divide the offset with two, since we are calculating /31 subnets more intelligently + + Here we divide the offset with two, since we are calculating /31 subnets more intelligently. """ return int(self._custom_ip_offset_20 / 2) def mlag_ibgp_peering_ip_primary(self, mlag_ibgp_peering_ipv4_pool: str) -> str: """ - Implementation of custom code similar to jinja: + Implementation of custom code similar to jinja. + {{ vrf.mlag_ibgp_peering_ipv4_pool | ansible.utils.ipaddr('subnet') | - ansible.utils.ipmath((switch.mlag_switch_ids.primary - 1) * 2 + ip_offset_10) }} + ansible.utils.ipmath((switch.mlag_switch_ids.primary - 1) * 2 + ip_offset_10) }}. """ offset = self._mlag_primary_id - 1 + self._custom_ip_offset_10_subnets return self._ip(mlag_ibgp_peering_ipv4_pool, 31, offset, 0) def mlag_ibgp_peering_ip_secondary(self, mlag_ibgp_peering_ipv4_pool: str) -> str: """ - Implementation of custom code similar to jinja: + Implementation of custom code similar to jinja. + {{ vrf.mlag_ibgp_peering_ipv4_pool | ansible.utils.ipaddr('subnet') | - ansible.utils.ipmath(((switch.mlag_switch_ids.primary - 1) * 2) + 1 + ip_offset_10) }} + ansible.utils.ipmath(((switch.mlag_switch_ids.primary - 1) * 2) + 1 + ip_offset_10) }}. """ offset = self._mlag_primary_id - 1 + self._custom_ip_offset_10_subnets return self._ip(mlag_ibgp_peering_ipv4_pool, 31, offset, 1) def mlag_ip_primary(self) -> str: """ - Implementation of custom code similar to jinja: + Implementation of custom code similar to jinja. + {{ switch_data.combined.mlag_peer_ipv4_pool | ansible.utils.ipaddr('network') | ansible.utils.ipmath((mlag_primary_id - 1) * 2 + ip_offset_10) }} """ @@ -63,7 +66,8 @@ def mlag_ip_primary(self) -> str: def mlag_ip_secondary(self) -> str: """ - Implementation of custom code similar to jinja: + Implementation of custom code similar to jinja. + {{ switch_data.combined.mlag_peer_ipv4_pool | ansible.utils.ipaddr('network') | ansible.utils.ipmath(((mlag_primary_id - 1) * 2) + 1 + ip_offset_10) }} """ @@ -72,7 +76,8 @@ def mlag_ip_secondary(self) -> str: def mlag_l3_ip_primary(self) -> str: """ - Implementation of custom code similar to jinja: + Implementation of custom code similar to jinja. + {{ switch_data.combined.mlag_peer_l3_ipv4_pool | ansible.utils.ipaddr('network') | ansible.utils.ipmath((mlag_primary_id - 1) * 2 + ip_offset_10) }} """ @@ -81,7 +86,8 @@ def mlag_l3_ip_primary(self) -> str: def mlag_l3_ip_secondary(self) -> str: """ - Implementation of custom code similar to jinja: + Implementation of custom code similar to jinja. + {{ switch_data.combined.mlag_peer_l3_ipv4_pool | ansible.utils.ipaddr('network') | ansible.utils.ipmath(((mlag_primary_id - 1) * 2) + 1 + ip_offset_10) }} """ @@ -90,7 +96,8 @@ def mlag_l3_ip_secondary(self) -> str: def p2p_uplinks_ip(self, uplink_switch_index: int) -> str: """ - Implementation of custom code similar to jinja: + Implementation of custom code similar to jinja. + {{ switch.uplink_ipv4_pool | ansible.utils.ipaddr('network') | ansible.utils.ipmath(((switch.id -1) * 2 * switch.max_uplink_switches * switch.max_parallel_uplinks) + (uplink_switch_index) * 2 + 1 + ip_offset_20) }} """ @@ -99,7 +106,8 @@ def p2p_uplinks_ip(self, uplink_switch_index: int) -> str: def p2p_uplinks_peer_ip(self, uplink_switch_index: int) -> str: """ - Implementation of custom code similar to jinja: + Implementation of custom code similar to jinja. + {{ switch.uplink_ipv4_pool | ansible.utils.ipaddr('network') | ansible.utils.ipmath(((switch.id -1) * 2 * switch.max_uplink_switches * switch.max_parallel_uplinks) + (uplink_switch_index) * 2 + ip_offset_20) }} """ @@ -109,6 +117,7 @@ def p2p_uplinks_peer_ip(self, uplink_switch_index: int) -> str: def p2p_vrfs_uplinks_ip(self, uplink_switch_index: int, vrf: str) -> str: """ Implementation of custom code to override IP addressing on VRF uplinks. + We read the uplink pool from a custom dict `custom_ip_pools_for_vrfs.` - Note no error handling in this example. """ offset = ((self._id - 1) * self._max_uplink_switches * self._max_parallel_uplinks) + uplink_switch_index + self._custom_ip_offset_20_subnets @@ -118,7 +127,8 @@ def p2p_vrfs_uplinks_ip(self, uplink_switch_index: int, vrf: str) -> str: def p2p_vrfs_uplinks_peer_ip(self, uplink_switch_index: int, vrf: str) -> str: """ - Implementation of custom code to override IP addressing on VRF downlinks + Implementation of custom code to override IP addressing on VRF downlinks. + We read the uplink pool from a custom dict `custom_ip_pools_for_vrfs.` - Note no error handling in this example. """ offset = ((self._id - 1) * self._max_uplink_switches * self._max_parallel_uplinks) + uplink_switch_index + self._custom_ip_offset_20_subnets @@ -127,7 +137,8 @@ def p2p_vrfs_uplinks_peer_ip(self, uplink_switch_index: int, vrf: str) -> str: def router_id(self) -> str: """ - Implementation of custom code similar to jinja: + Implementation of custom code similar to jinja. + {{ loopback_ipv4_pool | ansible.utils.ipaddr("network") | ansible.utils.ipmath(switch_id + loopback_ipv4_offset + ip_offset_20) }} """ @@ -136,7 +147,8 @@ def router_id(self) -> str: def ipv6_router_id(self) -> str: """ - Implementation of custom code similar to jinja: + Implementation of custom code similar to jinja. + {{ loopback_ipv6_pool | ansible.utils.ipaddr("network") | ansible.utils.ipmath(switch_id + loopback_ipv6_offset + ip_offset_20) }} """ @@ -145,7 +157,8 @@ def ipv6_router_id(self) -> str: def vtep_ip_mlag(self) -> str: """ - Implementation of custom code similar to jinja: + Implementation of custom code similar to jinja. + {{ switch_vtep_loopback_ipv4_pool | ansible.utils.ipaddr('network') | ansible.utils.ipmath(mlag_primary_id + loopback_ipv4_offset + ip_offset_20) }} """ @@ -154,7 +167,8 @@ def vtep_ip_mlag(self) -> str: def vtep_ip(self) -> str: """ - Implementation of custom code similar to jinja: + Implementation of custom code similar to jinja. + {{ switch_vtep_loopback_ipv4_pool | ansible.utils.ipaddr('network') | ansible.utils.ipmath(switch_id + loopback_ipv4_offset + ip_offset_20) }} """ diff --git a/ansible_collections/arista/avd/plugins/action/cv_workflow.py b/ansible_collections/arista/avd/plugins/action/cv_workflow.py index cca38e61e1f..ba4fd049fc6 100644 --- a/ansible_collections/arista/avd/plugins/action/cv_workflow.py +++ b/ansible_collections/arista/avd/plugins/action/cv_workflow.py @@ -1,9 +1,7 @@ # Copyright (c) 2023-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -from __future__ import absolute_import, annotations, division, print_function - -__metaclass__ = type +from __future__ import annotations import json import logging @@ -11,6 +9,7 @@ from dataclasses import asdict from pathlib import Path from string import Template +from typing import Any from ansible.errors import AnsibleActionFail from ansible.plugins.action import ActionBase, display @@ -85,7 +84,7 @@ class ActionModule(ActionBase): - def run(self, tmp=None, task_vars=None): + def run(self, tmp: Any = None, task_vars: dict | None = None) -> dict: self._supports_check_mode = False if task_vars is None: @@ -95,7 +94,8 @@ def run(self, tmp=None, task_vars=None): del tmp # tmp no longer has any effect if not HAS_PYAVD: - raise AnsibleActionFail("The arista.avd.cv_workflow' plugin requires the 'pyavd' Python library. Got import error") + msg = "The arista.avd.cv_workflow' plugin requires the 'pyavd' Python library. Got import error" + raise AnsibleActionFail(msg) # Setup module logging setup_module_logging(result) @@ -111,10 +111,8 @@ def run(self, tmp=None, task_vars=None): # Running asyncio coroutine to deploy everything. return run(self.deploy(validated_args, result)) - async def deploy(self, validated_args: dict, result: dict): - """ - Prepare data, perform deployment and convert result data. - """ + async def deploy(self, validated_args: dict, result: dict) -> dict: + """Prepare data, perform deployment and convert result data.""" LOGGER.info("deploy: %s", {**validated_args, "cv_token": ""}) try: # Create CloudVision object @@ -135,7 +133,7 @@ async def deploy(self, validated_args: dict, result: dict): if validated_args["return_details"]: # Objects are converted to JSON compatible dicts. result.update( - cloudvision=dict(asdict(cloudvision), token=""), + cloudvision=dict(asdict(cloudvision), token=""), # noqa: S106 configs=[asdict(config) for config in eos_config_objects], device_tags=[asdict(device_tag) for device_tag in device_tag_objects], interface_tags=[asdict(interface_tag) for interface_tag in interface_tag_objects], @@ -171,7 +169,7 @@ async def deploy(self, validated_args: dict, result: dict): "warnings": result_object.warnings, "errors": result_object.errors, "failed": result_object.failed, - } + }, ) # Set changed if we did anything. TODO: Improve this logic to only set changed if something actually changed. @@ -186,9 +184,16 @@ async def deploy(self, validated_args: dict, result: dict): return result async def build_objects( - self, device_list: list[str], structured_config_dir: str, structured_config_suffix: str, configuration_dir: str, configlet_name_template: str + self, + device_list: list[str], + structured_config_dir: str, + structured_config_suffix: str, + configuration_dir: str, + configlet_name_template: str, ) -> tuple[list[CVEosConfig], list[CVDeviceTag], list[CVInterfaceTag], list[CVPathfinderMetadata]]: """ + Build objects. + Parameters: device_list: List of device hostnames. structured_config_dir: Path to structured config files. @@ -196,7 +201,7 @@ async def build_objects( configuration_dir: Path to EOS config files. configlet_name_template: Python string template used for naming configlets. Ex. "AVD-${hostname}" Return: - Tuple containing (, , , ) + Tuple containing (, , , ). Workflow: Per device: @@ -223,9 +228,16 @@ async def build_objects( return eos_config_objects, device_tag_objects, interface_tag_objects, cv_pathfinder_metadata_objects async def build_object_for_device( - self, hostname: str, structured_config_dir: str, structured_config_suffix: str, configuration_dir: str, configlet_name_template: str + self, + hostname: str, + structured_config_dir: str, + structured_config_suffix: str, + configuration_dir: str, + configlet_name_template: str, ) -> tuple[list[CVEosConfig], list[CVDeviceTag], list[CVInterfaceTag], list[CVPathfinderMetadata]]: """ + Build objects for one device. + Parameters: device_list: List of device hostnames. structured_config_dir: Path to structured config files. @@ -233,7 +245,7 @@ async def build_object_for_device( configuration_dir: Path to EOS config files. configlet_name_template: Python string template used for naming configlets. Ex. "AVD-${hostname}" Return: - Tuple containing (, , , ) + Tuple containing (, , , ). Workflow: Per device: @@ -245,7 +257,9 @@ async def build_object_for_device( TODO: Refactor into smaller functions. """ LOGGER.info("build_object_for_device: %s", hostname) - with Path(structured_config_dir, f"{hostname}.{structured_config_suffix}").open(mode="r", encoding="UTF-8") as structured_config_stream: + with Path(structured_config_dir, f"{hostname}.{structured_config_suffix}").open( # noqa: ASYNC101 + mode="r", encoding="UTF-8" + ) as structured_config_stream: if structured_config_suffix in ["yml", "yaml"]: interesting_keys = ("is_deployed", "serial_number", "metadata") in_interesting_context = False @@ -257,7 +271,7 @@ async def build_object_for_device( else: in_interesting_context = False - structured_config = load("".join(structured_config_lines), Loader=YamlLoader) + structured_config = load("".join(structured_config_lines), Loader=YamlLoader) # noqa: S506 TODO: Consider safeload else: # Load as JSON structured_config = json.load(structured_config_stream) @@ -277,11 +291,11 @@ async def build_object_for_device( eos_config_objects = [CVEosConfig(file=config_file_path, device=device_object, configlet_name=configlet_name)] # Build device tag objects for this device. - # metadata: - # cv_tags: - # device_tags: - # - name: topology_hint_datacenter - # value: DC1 + # ! metadata: + # ! cv_tags: + # ! device_tags: + # ! - name: topology_hint_datacenter + # ! value: DC1 device_tags = get(structured_config, "metadata.cv_tags.device_tags", default=[]) device_tag_objects = [ CVDeviceTag(label=device_tag["name"], value=device_tag["value"], device=device_object) @@ -290,13 +304,13 @@ async def build_object_for_device( ] # Build interface tag objects for this device. - # metadata: - # cv_tags: - # interface_tags: - # - interface: Ethernet3 - # tags: - # - name: peer_device_interface - # value: Ethernet3 + # ! metadata: + # ! cv_tags: + # ! interface_tags: + # ! - interface: Ethernet3 + # ! tags: + # ! - name: peer_device_interface + # ! value: Ethernet3 all_interface_tags = get(structured_config, "metadata.cv_tags.interface_tags", default=[]) interface_tag_objects = [ CVInterfaceTag( @@ -322,13 +336,13 @@ async def build_object_for_device( def setup_module_logging(result: dict) -> None: """ - Add a Handler to copy the logs from the plugin into Ansible output based on their level + Add a Handler to copy the logs from the plugin into Ansible output based on their level. Parameters: result: The dictionary used for the ansible module results """ python_to_ansible_handler = PythonToAnsibleHandler(result, display) LOGGER.addHandler(python_to_ansible_handler) - # TODO mechanism to manipulate the logger globally for pyavd + # TODO: mechanism to manipulate the logger globally for pyavd # Keep debug to be able to see logs with `-v` and `-vvv` LOGGER.setLevel(logging.DEBUG) diff --git a/ansible_collections/arista/avd/plugins/action/eos_cli_config_gen.py b/ansible_collections/arista/avd/plugins/action/eos_cli_config_gen.py index 1c74b4775dd..71ec66bab4b 100644 --- a/ansible_collections/arista/avd/plugins/action/eos_cli_config_gen.py +++ b/ansible_collections/arista/avd/plugins/action/eos_cli_config_gen.py @@ -1,13 +1,12 @@ # Copyright (c) 2023-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -from __future__ import absolute_import, annotations, division, print_function - -__metaclass__ = type +from __future__ import annotations import json import logging from pathlib import Path +from typing import Any import yaml from ansible.errors import AnsibleActionFail @@ -51,13 +50,14 @@ class ActionModule(ActionBase): """Action Module for eos_cli_config_gen.""" @cprofile() - def run(self, tmp=None, task_vars=None): + def run(self, tmp: Any = None, task_vars: dict | None = None) -> None: """Ansible Action entry point.""" if task_vars is None: task_vars = {} if not HAS_PYAVD: - raise AnsibleActionFail("The arista.avd.eos_cli_config_gen' plugin requires the 'pyavd' Python library. Got import error") + msg = "The arista.avd.eos_cli_config_gen' plugin requires the 'pyavd' Python library. Got import error" + raise AnsibleActionFail(msg) result = super().run(tmp, task_vars) del tmp # tmp no longer has any effect @@ -66,9 +66,7 @@ def run(self, tmp=None, task_vars=None): hostname = task_vars["inventory_hostname"] setup_module_logging(hostname, result) - result = self.main(task_vars, result) - - return result + return self.main(task_vars, result) def main(self, task_vars: dict, result: dict) -> dict: """Main function in charge of validating the input variables and generating the device configuration and documentation.""" @@ -80,7 +78,9 @@ def main(self, task_vars: dict, result: dict) -> dict: # Read structured config from file or task_vars and run templating to handle inline jinja. LOGGER.debug("Preparing task vars...") task_vars = self.prepare_task_vars( - task_vars, validated_args.get("structured_config_filename"), read_structured_config_from_file=validated_args["read_structured_config_from_file"] + task_vars, + validated_args.get("structured_config_filename"), + read_structured_config_from_file=validated_args["read_structured_config_from_file"], ) LOGGER.debug("Preparing task vars [done].") @@ -95,7 +95,7 @@ def main(self, task_vars: dict, result: dict) -> dict: ) LOGGER.debug("Validating structured configuration [done].") except Exception as e: - LOGGER.error(e) + LOGGER.exception(e) # noqa: TRY401 TODO: Improve code return result if result.get("failed"): @@ -168,22 +168,21 @@ def prepare_task_vars(self, task_vars: dict, structured_config_filename: str, *, structured_config_filename: The filename where the structured_config for the device is stored. read_structured_config_from_file: Flag to indicate whether or not the structured_config_filname should be read. - Returns + Returns: ------- dict: Task vars updated with the structured_config content if read and all inline Jinja rendered. - Raises + Raises: ------ AnsibleActionFail: If templating fails. """ - if read_structured_config_from_file: task_vars.update(read_vars(structured_config_filename)) # Read ansible variables and perform templating to support inline jinja2 for var in task_vars: - # TODO - reevaluate these variables + # TODO: - reevaluate these variables if str(var).startswith(("ansible", "molecule", "hostvars", "vars", "avd_switch_facts")): continue if self._templar.is_template(task_vars[var]): @@ -191,7 +190,8 @@ def prepare_task_vars(self, task_vars: dict, structured_config_filename: str, *, try: task_vars[var] = self._templar.template(task_vars[var], fail_on_undefined=False) except Exception as e: - raise AnsibleActionFail(f"Exception during templating of task_var '{var}': '{e}'") from e + msg = f"Exception during templating of task_var '{var}': '{e}'" + raise AnsibleActionFail(msg) from e if not isinstance(task_vars, dict): # Corner case for ansible-test where the passed task_vars is a nested chain-map @@ -213,7 +213,6 @@ def validate_task_vars(self, hostname: str, conversion_mode: str, validation_mod def render_template_with_ansible_templar(self, task_vars: dict, templatefile: str) -> str: """Render a template with the Ansible Templar.""" - # Get updated templar instance to be passed along to our simplified "templater" if not hasattr(self, "ansible_templar"): self.ansible_templar = get_templar(self, task_vars) @@ -229,7 +228,7 @@ def write_file(self, content: str, filename: str) -> bool: content: The content to write filename: Target filename - Returns + Returns: ------- bool: Indicate if the content of filename has changed. """ @@ -250,7 +249,7 @@ def write_file(self, content: str, filename: str) -> bool: def setup_module_logging(hostname: str, result: dict) -> None: """ - Add a Handler to copy the logs from the plugin into Ansible output based on their level + Add a Handler to copy the logs from the plugin into Ansible output based on their level. Parameters ---------- @@ -261,7 +260,7 @@ def setup_module_logging(hostname: str, result: dict) -> None: python_to_ansible_handler = PythonToAnsibleHandler(result, display) python_to_ansible_handler.addFilter(python_to_ansible_filter) LOGGER.addHandler(python_to_ansible_handler) - # TODO mechanism to manipulate the logger globally for pyavd + # TODO: mechanism to manipulate the logger globally for pyavd LOGGER.setLevel(logging.DEBUG) @@ -274,11 +273,11 @@ def read_vars(filename: Path | str) -> dict: ---------- filename: The path to the file to read as a string or a Path. - Returns + Returns: ------- dict: The content of the file as dict or an empty dict if the file does not exist. - Raises + Raises: ------ NotImplementedError: If the file extension is not json, yml or yaml. """ @@ -291,8 +290,9 @@ def read_vars(filename: Path | str) -> dict: with filename.open(mode="r", encoding="UTF-8") as stream: if filename.suffix in [".yml", ".yaml"]: - return yaml.load(stream, Loader=YamlLoader) - elif filename.suffix == ".json": + return yaml.load(stream, Loader=YamlLoader) # noqa: S506 TODO: Figure out if we can move to safeloader everywhere + if filename.suffix == ".json": return json.load(stream) - else: - raise NotImplementedError(f"Unsupported file suffix for file '{filename}'") + + msg = f"Unsupported file suffix for file '{filename}'" + raise NotImplementedError(msg) diff --git a/ansible_collections/arista/avd/plugins/action/eos_designs_facts.py b/ansible_collections/arista/avd/plugins/action/eos_designs_facts.py index 2198e64002b..9cf842713d9 100644 --- a/ansible_collections/arista/avd/plugins/action/eos_designs_facts.py +++ b/ansible_collections/arista/avd/plugins/action/eos_designs_facts.py @@ -1,12 +1,11 @@ # Copyright (c) 2023-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -from __future__ import absolute_import, division, print_function -__metaclass__ = type import cProfile import pstats +from typing import Any from ansible.errors import AnsibleActionFail from ansible.plugins.action import ActionBase, display @@ -26,12 +25,12 @@ AnsibleActionFail( f"The '{PLUGIN_NAME}' plugin requires the 'pyavd' Python library. Got import error", orig_exc=e, - ) + ), ) class ActionModule(ActionBase): - def run(self, tmp=None, task_vars=None): + def run(self, tmp: Any = None, task_vars: dict | None = None) -> None: if task_vars is None: task_vars = {} @@ -54,12 +53,13 @@ def run(self, tmp=None, task_vars=None): # Check if fabric_name is set and that all play hosts are part Ansible group set in "fabric_name" if fabric_name is None or not set(ansible_play_hosts_all).issubset(fabric_hosts): - raise AnsibleActionFail( + msg = ( "Invalid/missing 'fabric_name' variable. " "All hosts in the play must have the same 'fabric_name' value " "which must point to an Ansible Group containing the hosts." f"play_hosts: {ansible_play_hosts_all}" ) + raise AnsibleActionFail(msg) # This is not all the hostvars, but just the Ansible Hostvars Manager object where we can retrieve hostvars for each host on-demand. hostvars = task_vars["hostvars"] @@ -106,6 +106,7 @@ def run(self, tmp=None, task_vars=None): def create_avd_switch_facts_instances(self, fabric_hosts: list, hostvars: object, result: dict) -> dict: """ Fetch hostvars for all hosts and perform data conversion & validation. + Initialize all instances of EosDesignsFacts and insert various references into the variable space. Returns dict with avd_switch_facts_instances. @@ -120,7 +121,7 @@ def create_avd_switch_facts_instances(self, fabric_hosts: list, hostvars: object failure : bool msg : str - Returns + Returns: ------- dict hostname1 : dict @@ -183,15 +184,15 @@ def create_avd_switch_facts_instances(self, fabric_hosts: list, hostvars: object return avd_switch_facts - def render_avd_switch_facts(self, avd_switch_facts_instances: dict): + def render_avd_switch_facts(self, avd_switch_facts_instances: dict) -> dict: """ - Run the render method on each EosDesignsFacts object + Run the render method on each EosDesignsFacts object. Parameters ---------- avd_switch_facts_instances : dict of EosDesignsFacts - Returns + Returns: ------- dict hostname1 : dict @@ -204,7 +205,8 @@ def render_avd_switch_facts(self, avd_switch_facts_instances: dict): try: rendered_facts[host] = {"switch": avd_switch_facts_instances[host]["switch"].render()} except AristaAvdMissingVariableError as e: - raise AnsibleActionFail(f"{e} is required but was not found for host '{host}'") from e + msg = f"{e} is required but was not found for host '{host}'" + raise AnsibleActionFail(msg) from e # If the argument 'template_output' is set, run the output data through jinja2 rendering. # This is to resolve any input values with inline jinja using variables/facts set by eos_designs_facts. diff --git a/ansible_collections/arista/avd/plugins/action/eos_designs_structured_config.py b/ansible_collections/arista/avd/plugins/action/eos_designs_structured_config.py index fd59fa1c3f9..a9caf16ee61 100644 --- a/ansible_collections/arista/avd/plugins/action/eos_designs_structured_config.py +++ b/ansible_collections/arista/avd/plugins/action/eos_designs_structured_config.py @@ -1,13 +1,12 @@ # Copyright (c) 2023-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -from __future__ import absolute_import, division, print_function -__metaclass__ = type import cProfile import pstats from collections import ChainMap +from typing import Any import yaml from ansible.errors import AnsibleActionFail @@ -28,12 +27,12 @@ AnsibleActionFail( f"The '{PLUGIN_NAME}' plugin requires the 'pyavd' Python library. Got import error", orig_exc=e, - ) + ), ) class ActionModule(ActionBase): - def run(self, tmp=None, task_vars=None): + def run(self, tmp: Any = None, task_vars: dict | None = None) -> dict: if task_vars is None: task_vars = {} @@ -64,7 +63,8 @@ def run(self, tmp=None, task_vars=None): try: task_vars[var] = self._templar.template(task_vars[var], fail_on_undefined=False) except Exception as e: - raise AnsibleActionFail(f"Exception during templating of task_var '{var}'") from e + msg = f"Exception during templating of task_var '{var}'" + raise AnsibleActionFail(msg) from e # Get updated templar instance to be passed along to our simplified "templater" self.templar = get_templar(self, task_vars) @@ -171,10 +171,11 @@ def run(self, tmp=None, task_vars=None): return result - def write_file(self, content, task_vars): + def write_file(self, content: str, task_vars: dict) -> dict: """ This function implements the Ansible 'copy' action_module, to benefit from Ansible builtin functionality like 'changed'. - Reuse task data + + Reuse task data. """ new_task = self._task.copy() new_task.args = { diff --git a/ansible_collections/arista/avd/plugins/action/eos_validate_state_reports.py b/ansible_collections/arista/avd/plugins/action/eos_validate_state_reports.py index f4b6e10386e..02af2db6380 100644 --- a/ansible_collections/arista/avd/plugins/action/eos_validate_state_reports.py +++ b/ansible_collections/arista/avd/plugins/action/eos_validate_state_reports.py @@ -1,12 +1,10 @@ # Copyright (c) 2023-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -from __future__ import absolute_import, annotations, division, print_function - -__metaclass__ = type +from __future__ import annotations from json import JSONDecodeError, load -from typing import TYPE_CHECKING, Generator +from typing import TYPE_CHECKING, Any from ansible.errors import AnsibleActionFail from ansible.plugins.action import ActionBase, display @@ -16,6 +14,7 @@ from ansible_collections.arista.avd.plugins.plugin_utils.utils import get_validated_path, get_validated_value if TYPE_CHECKING: + from collections.abc import Generator from pathlib import Path PLUGIN_NAME = "arista.avd.eos_validate_state_reports" @@ -27,7 +26,7 @@ AnsibleActionFail( f"The '{PLUGIN_NAME}' plugin requires the 'pyavd' Python library. Got import error", orig_exc=e, - ) + ), ) @@ -51,7 +50,7 @@ def _test_results_gen(input_path: Path) -> Generator[dict, None, None]: class ActionModule(ActionBase): # @cprofile() - def run(self, tmp=None, task_vars=None): + def run(self, tmp: Any = None, task_vars: dict | None = None) -> dict: if task_vars is None: task_vars = {} @@ -88,11 +87,11 @@ def run(self, tmp=None, task_vars=None): # Getting the host results JSON file saved by eos_validate_state_runner action plugin result_path = get_validated_path(path_input=test_results_dir / f"{host}-results.json", parent=False) # Process the host test results - for test_result in _test_results_gen(input_path=result_path): - try: + try: + for test_result in _test_results_gen(input_path=result_path): test_results.update_results(test_result) - except TypeError as error: - display.warning(f"Failed to update the test results of host {host}: {error}") + except TypeError as error: + display.warning(f"Failed to update the test results of host {host}: {error}") except (JSONDecodeError, OSError, TypeError, FileNotFoundError) as error: display.warning(f"Failed to load the test results of host {host}: {error}") diff --git a/ansible_collections/arista/avd/plugins/action/eos_validate_state_runner.py b/ansible_collections/arista/avd/plugins/action/eos_validate_state_runner.py index d4c24e96575..2b1e84270fa 100644 --- a/ansible_collections/arista/avd/plugins/action/eos_validate_state_runner.py +++ b/ansible_collections/arista/avd/plugins/action/eos_validate_state_runner.py @@ -1,13 +1,11 @@ # Copyright (c) 2023-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -from __future__ import absolute_import, annotations, division, print_function - -__metaclass__ = type +from __future__ import annotations import logging from json import dump -from typing import TYPE_CHECKING, Mapping +from typing import TYPE_CHECKING, Any from ansible.errors import AnsibleActionFail from ansible.parsing.yaml.dumper import AnsibleDumper @@ -22,6 +20,7 @@ ) if TYPE_CHECKING: + from collections.abc import Mapping from pathlib import Path LOGGER = logging.getLogger("ansible_collections.arista.avd") @@ -31,13 +30,13 @@ class AnsibleNoAliasDumper(AnsibleDumper): - def ignore_aliases(self, data): + def ignore_aliases(self, _data: Any) -> bool: return True class ActionModule(ActionBase): # @cprofile() - def run(self, tmp=None, task_vars=None): + def run(self, tmp: Any = None, task_vars: dict | None = None) -> dict: self._supports_check_mode = True if task_vars is None: @@ -66,7 +65,11 @@ def run(self, tmp=None, task_vars=None): # Get task arguments and validate them try: logging_level = get_validated_value( - data=self._task.args, key="logging_level", expected_type=str, default_value="WARNING", allowed_values=LOGGING_LEVELS + data=self._task.args, + key="logging_level", + expected_type=str, + default_value="WARNING", + allowed_values=LOGGING_LEVELS, ) skip_tests = get_validated_value(data=self._task.args, key="skip_tests", expected_type=list, default_value=[]) save_catalog = get_validated_value(data=self._task.args, key="save_catalog", expected_type=bool, default_value=False) @@ -163,6 +166,6 @@ def setup_module_logging(hostname: str, result: dict) -> None: python_to_ansible_handler = PythonToAnsibleHandler(result, display) python_to_ansible_handler.addFilter(python_to_ansible_filter) LOGGER.addHandler(python_to_ansible_handler) - # TODO mechanism to manipulate the logger globally for pyavd + # TODO: mechanism to manipulate the logger globally for pyavd # Keep debug to be able to see logs with `-v` and `-vvv` LOGGER.setLevel(logging.DEBUG) diff --git a/ansible_collections/arista/avd/plugins/action/inventory_to_container.py b/ansible_collections/arista/avd/plugins/action/inventory_to_container.py index 41ca3f12d8c..6df3753d174 100644 --- a/ansible_collections/arista/avd/plugins/action/inventory_to_container.py +++ b/ansible_collections/arista/avd/plugins/action/inventory_to_container.py @@ -1,19 +1,20 @@ # Copyright (c) 2023-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -from __future__ import absolute_import, division, print_function - -__metaclass__ = type +from pathlib import Path +from typing import TYPE_CHECKING, Any import yaml from ansible.errors import AnsibleActionFail from ansible.inventory.group import Group -from ansible.inventory.host import Host -from ansible.inventory.manager import InventoryManager from ansible.parsing.yaml.dumper import AnsibleDumper from ansible.plugins.action import ActionBase from ansible.utils.display import Display +if TYPE_CHECKING: + from ansible.inventory.host import Host + from ansible.inventory.manager import InventoryManager + # Root container on CloudVision. # Shall not be changed unless CloudVision changes it in the core. CVP_ROOT_CONTAINER = "Tenant" @@ -22,14 +23,14 @@ class ActionModule(ActionBase): - def _maybe_convert_device_filter(self): + def _maybe_convert_device_filter(self) -> None: # Converting string device filter to list device_filter = self._task.args.get("device_filter") if device_filter is not None and not isinstance(device_filter, list): display.debug(f"device_filter must be of type list, got '{device_filter}' of type {type(device_filter)} instead. Converting...") self._task.args["device_filter"] = [device_filter] - def run(self, tmp=None, task_vars=None): + def run(self, tmp: Any = None, task_vars: dict | None = None) -> dict: if task_vars is None: task_vars = {} @@ -50,12 +51,12 @@ def run(self, tmp=None, task_vars=None): file_data_keys = ["cvp_configlets", "cvp_topology"] file_data = {key: result[key] for key in file_data_keys if key in result} - with open(destination, "w", encoding="utf8") as file: + with Path(destination).open("w", encoding="utf8") as file: yaml.dump(file_data, file, Dumper=AnsibleDumper) return result - def build_cvp_topology_from_inventory(self, task_vars, module_args: dict) -> dict: + def build_cvp_topology_from_inventory(self, task_vars: dict, module_args: dict) -> dict: # Inventory Manager is the Ansible Class handling everything about hosts and groups inventory_manager: InventoryManager = task_vars["hostvars"]._inventory @@ -67,9 +68,8 @@ def build_cvp_topology_from_inventory(self, task_vars, module_args: dict) -> dic # Verify that the group referenced in 'container_root' is valid if container_root not in inventory_manager.groups: - raise AnsibleActionFail( - f"Group '{container_root}' given as 'container_root' argument on 'arista.avd.inventory_to_container' cannot be found in Ansible inventory" - ) + msg = f"Group '{container_root}' given as 'container_root' argument on 'arista.avd.inventory_to_container' cannot be found in Ansible inventory" + raise AnsibleActionFail(msg) # cvp_topology holds the final output data cvp_topology = {} @@ -93,17 +93,18 @@ def build_cvp_topology_from_inventory(self, task_vars, module_args: dict) -> dic return cvp_topology - def get_group_data(self, group: Group, device_filter: list, all_groups_from_root: set = None, parent_container: str = None) -> dict: + def get_group_data(self, group: Group, device_filter: list, all_groups_from_root: set | None = None, parent_container: str | None = None) -> dict: # Find parent container if not set if parent_container is None: # Only evaluate parent_groups which are part of all_groups_from_root. A group can have multiple parents. parent_groups: set = set(group.parent_groups).intersection(all_groups_from_root) # Ensure that we have only one parent group. Otherwise we cannot build a tree. if len(parent_groups) > 1: - raise AnsibleActionFail( + msg = ( f"arista.avd.inventory_to_container: Group '{group}' has more than one parent group ({parent_groups}) below the 'container_root'." " Unable to build CloudVision container hierarchy." ) + raise AnsibleActionFail(msg) parent_container = parent_groups.pop().name # Build list of devices under the group diff --git a/ansible_collections/arista/avd/plugins/action/set_vars.py b/ansible_collections/arista/avd/plugins/action/set_vars.py index abbdea826bd..dd0654fe7ec 100644 --- a/ansible_collections/arista/avd/plugins/action/set_vars.py +++ b/ansible_collections/arista/avd/plugins/action/set_vars.py @@ -1,15 +1,13 @@ # Copyright (c) 2023-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -from __future__ import absolute_import, division, print_function - -__metaclass__ = type +from typing import Any from ansible.plugins.action import ActionBase class ActionModule(ActionBase): - def run(self, tmp=None, task_vars=None): + def run(self, tmp: Any = None, task_vars: dict | None = None) -> dict: # noqa: ARG002 if task_vars is None: task_vars = {} diff --git a/ansible_collections/arista/avd/plugins/action/verify_requirements.py b/ansible_collections/arista/avd/plugins/action/verify_requirements.py index 08de325de62..764901da45d 100644 --- a/ansible_collections/arista/avd/plugins/action/verify_requirements.py +++ b/ansible_collections/arista/avd/plugins/action/verify_requirements.py @@ -1,16 +1,16 @@ # Copyright (c) 2023-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -__metaclass__ = type import json -import os import sys from importlib.metadata import Distribution, PackageNotFoundError, version +from pathlib import Path from subprocess import PIPE, Popen +from typing import Any import yaml -from ansible import constants as C +from ansible import constants as C # noqa: N812 from ansible.errors import AnsibleActionFail from ansible.module_utils.compat.importlib import import_module from ansible.plugins.action import ActionBase, display @@ -31,7 +31,7 @@ def _validate_python_version(info: dict, result: dict) -> bool: """ - TODO - avoid hardcoding the min supported version + TODO: - avoid hardcoding the min supported version. Args: info (dict): Dictionary to store information to present in ansible logs @@ -51,10 +51,10 @@ def _validate_python_version(info: dict, result: dict) -> bool: running_version = ".".join(str(v) for v in sys.version_info[:3]) min_version = ".".join(str(v) for v in MIN_PYTHON_SUPPORTED_VERSION) if sys.version_info < MIN_PYTHON_SUPPORTED_VERSION: - display.error(f"Python Version running {running_version} - Minimum Version required is {min_version}", False) + display.error(f"Python Version running {running_version} - Minimum Version required is {min_version}", wrap_text=False) return False # Keeping this for next deprecation adjust the message as required - elif DEPRECATE_MIN_PYTHON_SUPPORTED_VERSION and sys.version_info[:2] == MIN_PYTHON_SUPPORTED_VERSION: + if DEPRECATE_MIN_PYTHON_SUPPORTED_VERSION and sys.version_info[:2] == MIN_PYTHON_SUPPORTED_VERSION: result.setdefault("deprecations", []).append( { "msg": ( @@ -65,8 +65,8 @@ def _validate_python_version(info: dict, result: dict) -> bool: "version 2.15 End-Of-Life is scheduled for November 2024 and it will be the " "last `ansible-core` version supporting Python version 3.9 as documented here: " "https://docs.ansible.com/ansible/latest/reference_appendices/release_and_maintenance.html#ansible-core-support-matrix." - ) - } + ), + }, ) return True @@ -74,7 +74,7 @@ def _validate_python_version(info: dict, result: dict) -> bool: def _validate_python_requirements(requirements: list, info: dict) -> bool: """ - Validate python lib versions + Validate python lib versions. Args: requirements (list): List of requirements for pythom modules @@ -97,7 +97,8 @@ def _validate_python_requirements(requirements: list, info: dict) -> bool: try: req = Requirement(raw_req) except InvalidRequirement as exc: - raise AnsibleActionFail(f"Wrong format for requirement {raw_req}") from exc + msg = f"Wrong format for requirement {raw_req}" + raise AnsibleActionFail(msg) from exc try: installed_version = version(req.name) @@ -115,7 +116,7 @@ def _validate_python_requirements(requirements: list, info: dict) -> bool: "installed": None, "required_version": str(req.specifier) if len(req.specifier) > 0 else None, } - display.error(f"Python library '{req.name}' required but not found - requirement is {str(req)}", False) + display.error(f"Python library '{req.name}' required but not found - requirement is {req!s}", wrap_text=False) valid = False continue @@ -134,18 +135,19 @@ def _validate_python_requirements(requirements: list, info: dict) -> bool: } display.warning( f"Found {req.name} valid versions {valid_versions} among {detected_versions} from metadata - assuming a valid version is running - more" - " information available with -v" + " information available with -v", ) display.v( "The Arista AVD collection relies on Python built-in library `importlib.metadata` to detect running versions. In some cases where legacy" " dist-info folders are leftovers in the site-packages folder, there can be misdetection of the version. This module assumes that if any" " version matches the required one, then the requirement is met. This could led to false positive results. Please make sure to clean the" - " leftovers dist-info folders." + " leftovers dist-info folders.", ) elif len(detected_versions) > 1: # More than one dist found and none matching the requirements display.error( - f"Python library '{req.name}' detected versions {detected_versions} - requirement is {str(req)} - more information available with -v", False + f"Python library '{req.name}' detected versions {detected_versions} - requirement is {req!s} - more information available with -v", + wrap_text=False, ) requirements_dict["mismatched"][req.name] = { "installed": installed_version, @@ -154,7 +156,7 @@ def _validate_python_requirements(requirements: list, info: dict) -> bool: "required_version": str(req.specifier) if len(req.specifier) > 0 else None, } else: - display.error(f"Python library '{req.name}' version running {installed_version} - requirement is {str(req)}", False) + display.error(f"Python library '{req.name}' version running {installed_version} - requirement is {req!s}", wrap_text=False) requirements_dict["mismatched"][req.name] = { "installed": installed_version, "required_version": str(req.specifier) if len(req.specifier) > 0 else None, @@ -167,7 +169,7 @@ def _validate_python_requirements(requirements: list, info: dict) -> bool: def _validate_ansible_version(collection_name: str, running_version: str, info: dict, result: dict) -> bool: """ - Validate ansible version in use, running_version, based on the collection requirements + Validate ansible version in use, running_version, based on the collection requirements. Args: collection_name (str): The collection name @@ -186,20 +188,20 @@ def _validate_ansible_version(collection_name: str, running_version: str, info: info["requires_ansible"] = str(specifiers_set) if not specifiers_set.contains(running_version): display.error( - f"Ansible Version running {running_version} - Requirement is {str(specifiers_set)}", - False, + f"Ansible Version running {running_version} - Requirement is {specifiers_set!s}", + wrap_text=False, ) return False # Keeping this for next deprecation - set the value of deprecation_specifiers_set when needed and adjust message - elif not deprecation_specifiers_set.contains(running_version): + if not deprecation_specifiers_set.contains(running_version): result.setdefault("deprecations", []).append( { "msg": ( f"You are currently running ansible-core {running_version}. The next minor release of AVD after November 6th 2023 will drop support for" " ansible-core<2.14. Python 3.8 support will be dropped at the same time as ansible-core>=2.14 does not support it. See the following link" " for more details: https://docs.ansible.com/ansible/latest/reference_appendices/release_and_maintenance.html#ansible-core-support-matrix" - ) - } + ), + }, ) return True @@ -207,10 +209,10 @@ def _validate_ansible_version(collection_name: str, running_version: str, info: def _validate_ansible_collections(running_collection_name: str, info: dict) -> bool: """ - Verify the version of required ansible collections running based on the collection requirements + Verify the version of required ansible collections running based on the collection requirements. Args: - collection_name (str): The collection name + running_collection_name (str): The collection name info (dict): Dictionary to store information to present in ansible logs Return True if all collection requirements are valid, False otherwise @@ -218,8 +220,7 @@ def _validate_ansible_collections(running_collection_name: str, info: dict) -> b valid = True collection_path = _get_collection_path(running_collection_name) - collections_file = os.path.join(collection_path, "collections.yml") - with open(collections_file, "rb") as fd: + with Path(collection_path, "collections.yml").open("rb") as fd: metadata = yaml.safe_load(fd) if "collections" not in metadata: # no requirements @@ -234,7 +235,7 @@ def _validate_ansible_collections(running_collection_name: str, info: dict) -> b for collection_dict in metadata["collections"]: if "name" not in collection_dict: - display.error("key `name` required but not found in collections requirement - please raise an issue on Github", False) + display.error("key `name` required but not found in collections requirement - please raise an issue on Github", wrap_text=False) continue collection_name = collection_dict["name"] @@ -249,9 +250,9 @@ def _validate_ansible_collections(running_collection_name: str, info: dict) -> b "required_version": str(specifiers_set) if len(specifiers_set) > 0 else None, } if specifiers_set: - display.error(f"{collection_name} required but not found - required version is {str(specifiers_set)}", False) + display.error(f"{collection_name} required but not found - required version is {specifiers_set!s}", wrap_text=False) else: - display.error(f"{collection_name} required but not found", False) + display.error(f"{collection_name} required but not found", wrap_text=False) valid = False continue @@ -263,7 +264,7 @@ def _validate_ansible_collections(running_collection_name: str, info: dict) -> b "required_version": str(specifiers_set) if len(specifiers_set) > 0 else None, } else: - display.error(f"{collection_name} version running {installed_version} - required version is {str(specifiers_set)}", False) + display.error(f"{collection_name} version running {installed_version} - required version is {specifiers_set!s}", wrap_text=False) requirements_dict["mismatched"][collection_name] = { "installed": installed_version, "required_version": str(specifiers_set) if len(specifiers_set) > 0 else None, @@ -275,41 +276,35 @@ def _validate_ansible_collections(running_collection_name: str, info: dict) -> b def _get_collection_path(collection_name: str) -> str: - """ - Retrieve the collection path based on the collection_name - """ + """Retrieve the collection path based on the collection_name.""" collection = import_module(f"ansible_collections.{collection_name}") - return os.path.dirname(collection.__file__) + return str(Path(collection.__file__).parent) -def _get_collection_version(collection_path) -> str: - """ - Returns the collection version based on the collection path - """ +def _get_collection_version(collection_path: str) -> str: + """Returns the collection version based on the collection path.""" # Trying to find the version based on either galaxy.yml or MANIFEST.json try: - galaxy_file = os.path.join(collection_path, "galaxy.yml") - with open(galaxy_file, "rb") as fd: + galaxy_file = Path(collection_path, "galaxy.yml") + with galaxy_file.open("rb") as fd: metadata = yaml.safe_load(fd) except FileNotFoundError: - manifest_file = os.path.join(collection_path, "MANIFEST.json") - with open(manifest_file, "rb") as fd: + manifest_file = Path(collection_path, "MANIFEST.json") + with manifest_file.open("rb") as fd: metadata = json.load(fd)["collection_info"] return metadata["version"] def _get_running_collection_version(running_collection_name: str, result: dict) -> None: - """ - Stores the version collection in result - """ + """Stores the version collection in result.""" collection_path = _get_collection_path(running_collection_name) version = _get_collection_version(collection_path) try: # Try to detect a git tag # Using subprocess for now - with Popen(["git", "describe", "--tags"], stdout=PIPE, stderr=PIPE, cwd=collection_path) as process: + with Popen(["git", "describe", "--tags"], stdout=PIPE, stderr=PIPE, cwd=collection_path) as process: # noqa: S603, S607 output, err = process.communicate() if err: # Not that when molecule runs, it runs in a copy of the directory that is not a git repo @@ -324,13 +319,13 @@ def _get_running_collection_version(running_collection_name: str, result: dict) result["collection"] = { "name": running_collection_name, - "path": os.path.dirname(os.path.dirname(collection_path)), + "path": str(Path(collection_path).parents[1]), "version": version, } class ActionModule(ActionBase): - def run(self, tmp=None, task_vars=None): + def run(self, tmp: Any = None, task_vars: dict | None = None) -> dict: if task_vars is None: task_vars = {} @@ -338,10 +333,12 @@ def run(self, tmp=None, task_vars=None): del tmp # tmp no longer has any effect if not HAS_PACKAGING: - raise AnsibleActionFail("packaging is required to run this plugin") + msg = "packaging is required to run this plugin" + raise AnsibleActionFail(msg) if not (self._task.args and "requirements" in self._task.args): - raise AnsibleActionFail("The argument 'requirements' must be set") + msg = "The argument 'requirements' must be set" + raise AnsibleActionFail(msg) py_requirements = self._task.args.get("requirements") avd_ignore_requirements = self._task.args.get("avd_ignore_requirements", False) @@ -349,7 +346,8 @@ def run(self, tmp=None, task_vars=None): avd_ignore_requirements = True if not isinstance(py_requirements, list): - raise AnsibleActionFail("The argument 'requirements' is not a list") + msg = "The argument 'requirements' is not a list" + raise AnsibleActionFail(msg) running_ansible_version = task_vars["ansible_version"]["string"] running_collection_name = task_vars["ansible_collection_name"] diff --git a/ansible_collections/arista/avd/plugins/filter/add_md_toc.py b/ansible_collections/arista/avd/plugins/filter/add_md_toc.py index 8f53db7a871..89aefb83a90 100644 --- a/ansible_collections/arista/avd/plugins/filter/add_md_toc.py +++ b/ansible_collections/arista/avd/plugins/filter/add_md_toc.py @@ -4,9 +4,7 @@ # # def arista.avd.add_md_toc # -from __future__ import absolute_import, division, print_function -__metaclass__ = type from ansible.errors import AnsibleFilterError @@ -21,7 +19,7 @@ AnsibleFilterError( f"The '{PLUGIN_NAME}' plugin requires the 'pyavd' Python library. Got import error", orig_exc=e, - ) + ), ) DOCUMENTATION = r""" @@ -78,6 +76,6 @@ """ -class FilterModule(object): - def filters(self): +class FilterModule: + def filters(self) -> dict: return {"add_md_toc": wrap_filter(PLUGIN_NAME)(add_md_toc)} diff --git a/ansible_collections/arista/avd/plugins/filter/convert_dicts.py b/ansible_collections/arista/avd/plugins/filter/convert_dicts.py index 1e5ef833c67..f87614d3ce5 100644 --- a/ansible_collections/arista/avd/plugins/filter/convert_dicts.py +++ b/ansible_collections/arista/avd/plugins/filter/convert_dicts.py @@ -4,9 +4,7 @@ # # def arista.avd.convert_dicts # -from __future__ import absolute_import, division, print_function -__metaclass__ = type from ansible.errors import AnsibleFilterError @@ -21,7 +19,7 @@ AnsibleFilterError( f"The '{PLUGIN_NAME}' plugin requires the 'pyavd' Python library. Got import error", orig_exc=e, - ) + ), ) DOCUMENTATION = r""" @@ -100,8 +98,8 @@ """ -class FilterModule(object): - def filters(self): +class FilterModule: + def filters(self) -> dict: return { "convert_dicts": wrap_filter(PLUGIN_NAME)(convert_dicts), } diff --git a/ansible_collections/arista/avd/plugins/filter/decrypt.py b/ansible_collections/arista/avd/plugins/filter/decrypt.py index 9e0abafd483..fe8c6a60899 100644 --- a/ansible_collections/arista/avd/plugins/filter/decrypt.py +++ b/ansible_collections/arista/avd/plugins/filter/decrypt.py @@ -14,7 +14,7 @@ AnsibleFilterError( f"The '{PLUGIN_NAME}' plugin requires the 'pyavd' Python library. Got import error", orig_exc=e, - ) + ), ) DOCUMENTATION = r""" @@ -85,8 +85,8 @@ """ -class FilterModule(object): - def filters(self): +class FilterModule: + def filters(self) -> dict: return { "decrypt": wrap_filter(PLUGIN_NAME)(decrypt), } diff --git a/ansible_collections/arista/avd/plugins/filter/default.py b/ansible_collections/arista/avd/plugins/filter/default.py index c092b6fd468..2be3f576a93 100644 --- a/ansible_collections/arista/avd/plugins/filter/default.py +++ b/ansible_collections/arista/avd/plugins/filter/default.py @@ -4,9 +4,7 @@ # # def arista.avd.default # -from __future__ import absolute_import, division, print_function -__metaclass__ = type from ansible.errors import AnsibleFilterError @@ -21,7 +19,7 @@ AnsibleFilterError( f"The '{PLUGIN_NAME}' plugin requires the 'pyavd' Python library. Got import error", orig_exc=e, - ) + ), ) DOCUMENTATION = r""" @@ -61,8 +59,8 @@ """ -class FilterModule(object): - def filters(self): +class FilterModule: + def filters(self) -> dict: return { "default": wrap_filter(PLUGIN_NAME)(default), } diff --git a/ansible_collections/arista/avd/plugins/filter/encrypt.py b/ansible_collections/arista/avd/plugins/filter/encrypt.py index 70692d677cc..6e7f48b0f5e 100644 --- a/ansible_collections/arista/avd/plugins/filter/encrypt.py +++ b/ansible_collections/arista/avd/plugins/filter/encrypt.py @@ -14,7 +14,7 @@ AnsibleFilterError( f"The '{PLUGIN_NAME}' plugin requires the 'pyavd' Python library. Got import error", orig_exc=e, - ) + ), ) DOCUMENTATION = r""" @@ -100,8 +100,8 @@ """ -class FilterModule(object): - def filters(self): +class FilterModule: + def filters(self) -> dict: return { "encrypt": wrap_filter(PLUGIN_NAME)(encrypt), } diff --git a/ansible_collections/arista/avd/plugins/filter/hide_passwords.py b/ansible_collections/arista/avd/plugins/filter/hide_passwords.py index 540e507bd68..8a7cdc0f5f5 100644 --- a/ansible_collections/arista/avd/plugins/filter/hide_passwords.py +++ b/ansible_collections/arista/avd/plugins/filter/hide_passwords.py @@ -17,7 +17,7 @@ AnsibleFilterError( f"The '{PLUGIN_NAME}' plugin requires the 'pyavd' Python library. Got import error", orig_exc=e, - ) + ), ) DOCUMENTATION = r""" @@ -52,8 +52,8 @@ """ -class FilterModule(object): - def filters(self): +class FilterModule: + def filters(self) -> dict: return { "hide_passwords": wrap_filter(PLUGIN_NAME)(hide_passwords), } diff --git a/ansible_collections/arista/avd/plugins/filter/is_in_filter.py b/ansible_collections/arista/avd/plugins/filter/is_in_filter.py index ceb8a2db299..ffda75dba57 100644 --- a/ansible_collections/arista/avd/plugins/filter/is_in_filter.py +++ b/ansible_collections/arista/avd/plugins/filter/is_in_filter.py @@ -4,9 +4,7 @@ # # device-filter filter # -from __future__ import absolute_import, division, print_function -__metaclass__ = type from ansible.errors import AnsibleFilterError @@ -21,7 +19,7 @@ AnsibleFilterError( f"The '{PLUGIN_NAME}' plugin requires the 'pyavd' Python library. Got import error", orig_exc=e, - ) + ), ) DOCUMENTATION = r""" @@ -64,8 +62,8 @@ """ -class FilterModule(object): - def filters(self): +class FilterModule: + def filters(self) -> dict: return { "is_in_filter": wrap_filter(PLUGIN_NAME)(is_in_filter), } diff --git a/ansible_collections/arista/avd/plugins/filter/list_compress.py b/ansible_collections/arista/avd/plugins/filter/list_compress.py index a6122379828..654568e3156 100644 --- a/ansible_collections/arista/avd/plugins/filter/list_compress.py +++ b/ansible_collections/arista/avd/plugins/filter/list_compress.py @@ -4,9 +4,7 @@ # # list_compress filter # -from __future__ import absolute_import, division, print_function -__metaclass__ = type from ansible.errors import AnsibleFilterError @@ -21,7 +19,7 @@ AnsibleFilterError( f"The '{PLUGIN_NAME}' plugin requires the 'pyavd' Python library. Got import error", orig_exc=e, - ) + ), ) DOCUMENTATION = r""" @@ -56,8 +54,8 @@ """ -class FilterModule(object): - def filters(self): +class FilterModule: + def filters(self) -> dict: return { "list_compress": wrap_filter(PLUGIN_NAME)(list_compress), } diff --git a/ansible_collections/arista/avd/plugins/filter/natural_sort.py b/ansible_collections/arista/avd/plugins/filter/natural_sort.py index d707bad44f6..7dda9400a34 100644 --- a/ansible_collections/arista/avd/plugins/filter/natural_sort.py +++ b/ansible_collections/arista/avd/plugins/filter/natural_sort.py @@ -4,9 +4,7 @@ # # natural_sort filter # -from __future__ import absolute_import, division, print_function -__metaclass__ = type from ansible.errors import AnsibleFilterError @@ -21,7 +19,7 @@ AnsibleFilterError( f"The '{PLUGIN_NAME}' plugin requires the 'pyavd' Python library. Got import error", orig_exc=e, - ) + ), ) DOCUMENTATION = r""" @@ -76,8 +74,8 @@ """ -class FilterModule(object): - def filters(self): +class FilterModule: + def filters(self) -> dict: return { "natural_sort": wrap_filter(PLUGIN_NAME)(natural_sort), } diff --git a/ansible_collections/arista/avd/plugins/filter/range_expand.py b/ansible_collections/arista/avd/plugins/filter/range_expand.py index 2150f5b3cfd..eb2a94e57e7 100644 --- a/ansible_collections/arista/avd/plugins/filter/range_expand.py +++ b/ansible_collections/arista/avd/plugins/filter/range_expand.py @@ -4,9 +4,7 @@ # # range_expand filter # -from __future__ import absolute_import, division, print_function -__metaclass__ = type from ansible.errors import AnsibleFilterError @@ -21,7 +19,7 @@ AnsibleFilterError( f"The '{PLUGIN_NAME}' plugin requires the 'pyavd' Python library. Got import error", orig_exc=e, - ) + ), ) @@ -95,8 +93,8 @@ """ -class FilterModule(object): - def filters(self): +class FilterModule: + def filters(self) -> dict: return { "range_expand": wrap_filter(PLUGIN_NAME)(range_expand), } diff --git a/ansible_collections/arista/avd/plugins/filter/snmp_hash.py b/ansible_collections/arista/avd/plugins/filter/snmp_hash.py index 1ab440e8d35..5d5e7f7ae6d 100644 --- a/ansible_collections/arista/avd/plugins/filter/snmp_hash.py +++ b/ansible_collections/arista/avd/plugins/filter/snmp_hash.py @@ -5,7 +5,6 @@ # snmp_hash filter # -__metaclass__ = type from ansible.errors import AnsibleFilterError @@ -20,7 +19,7 @@ AnsibleFilterError( f"The '{PLUGIN_NAME}' plugin requires the 'pyavd' Python library. Got import error", orig_exc=e, - ) + ), ) DOCUMENTATION = r""" @@ -72,6 +71,6 @@ """ -class FilterModule(object): - def filters(self): +class FilterModule: + def filters(self) -> dict: return {"snmp_hash": wrap_filter(PLUGIN_NAME)(snmp_hash)} diff --git a/ansible_collections/arista/avd/plugins/filter/status_render.py b/ansible_collections/arista/avd/plugins/filter/status_render.py index f5f29232b35..07848c425de 100644 --- a/ansible_collections/arista/avd/plugins/filter/status_render.py +++ b/ansible_collections/arista/avd/plugins/filter/status_render.py @@ -1,9 +1,7 @@ # Copyright (c) 2023-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -from __future__ import absolute_import, division, print_function -__metaclass__ = type from ansible.errors import AnsibleFilterError @@ -18,7 +16,7 @@ AnsibleFilterError( f"The '{PLUGIN_NAME}' plugin requires the 'pyavd' Python library. Got import error", orig_exc=e, - ) + ), ) DOCUMENTATION = r""" @@ -52,7 +50,6 @@ """ -class FilterModule(object): - - def filters(self): +class FilterModule: + def filters(self) -> dict: return {"status_render": wrap_filter(PLUGIN_NAME)(status_render)} diff --git a/ansible_collections/arista/avd/plugins/modules/configlet_build_config.py b/ansible_collections/arista/avd/plugins/modules/configlet_build_config.py index 706ee33ff74..e21ea40de80 100644 --- a/ansible_collections/arista/avd/plugins/modules/configlet_build_config.py +++ b/ansible_collections/arista/avd/plugins/modules/configlet_build_config.py @@ -1,13 +1,7 @@ -#!/usr/bin/python -# -*- coding: utf-8 -*- -# # Copyright (c) 2019-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -from __future__ import absolute_import, division, print_function - -__metaclass__ = type DOCUMENTATION = r""" --- @@ -47,9 +41,8 @@ configlet_extension: 'cfg' """ -import glob -import os import traceback +from pathlib import Path from ansible.module_utils.basic import AnsibleModule @@ -63,7 +56,7 @@ YAML_IMP_ERR = traceback.format_exc() -def get_configlet(src_folder="", prefix="AVD", extension="cfg"): +def get_configlet(src_folder: str = "", prefix: str = "AVD", extension: str = "cfg") -> dict: """ Get available configlets to deploy to CVP. @@ -76,38 +69,32 @@ def get_configlet(src_folder="", prefix="AVD", extension="cfg"): extension : str, optional File extension to lookup configlet file, by default 'cfg' - Returns + Returns: ------- dict Dictionary of configlets found in source folder. """ - src_configlets = glob.glob(f"{src_folder}/*.{extension}") + src_configlets = Path(src_folder).glob(f"*.{extension}") configlets = {} for file in src_configlets: - if prefix != "none": - name = prefix + "_" + os.path.splitext(os.path.basename(file))[0] - else: - name = os.path.splitext(os.path.basename(file))[0] - with open(file, "r", encoding="utf8") as file: - data = file.read() + name = prefix + "_" + file.stem if prefix != "none" else file.stem + with file.open(encoding="utf8") as stream: + data = stream.read() configlets[name] = data return configlets -def main(): +def main() -> None: """Main entry point for module execution.""" - # TODO - ansible module prefers constructor over literal - # for dict - # pylint: disable=use-dict-literal - argument_spec = dict( - configlet_dir=dict(type="str", required=True), - configlet_prefix=dict(type="str", required=True), - configlet_extension=dict(type="str", required=False, default="conf"), - destination=dict(type="str", required=False, default=None), - ) + argument_spec = { + "configlet_dir": {"type": "str", "required": True}, + "configlet_prefix": {"type": "str", "required": True}, + "configlet_extension": {"type": "str", "required": False, "default": "conf"}, + "destination": {"type": "str", "required": False, "default": None}, + } module = AnsibleModule(argument_spec=argument_spec, supports_check_mode=False) - result = dict(changed=False) + result = {"changed": False} if not HAS_YAML: module.fail_json(msg="yaml lib is required for this module") @@ -115,12 +102,14 @@ def main(): # If set, build configlet topology if module.params["configlet_dir"] is not None: result["cvp_configlets"] = get_configlet( - src_folder=module.params["configlet_dir"], prefix=module.params["configlet_prefix"], extension=module.params["configlet_extension"] + src_folder=module.params["configlet_dir"], + prefix=module.params["configlet_prefix"], + extension=module.params["configlet_extension"], ) # Write vars to file if set by user if module.params["destination"] is not None: - with open(module.params["destination"], "w", encoding="utf8") as file: + with Path(module.params["destination"]).open("w", encoding="utf8") as file: yaml.dump(result, file) module.exit_json(**result) diff --git a/ansible_collections/arista/avd/plugins/modules/inventory_to_container.py b/ansible_collections/arista/avd/plugins/modules/inventory_to_container.py index fbfb9b727fc..55c8c3cb005 100644 --- a/ansible_collections/arista/avd/plugins/modules/inventory_to_container.py +++ b/ansible_collections/arista/avd/plugins/modules/inventory_to_container.py @@ -1,13 +1,7 @@ -#!/usr/bin/python -# -*- coding: utf-8 -*- -# # Copyright (c) 2019-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -from __future__ import absolute_import, division, print_function - -__metaclass__ = type DOCUMENTATION = r""" --- @@ -78,21 +72,24 @@ save_topology: true """ -import glob -import os import traceback +from pathlib import Path +from typing import Any from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.errors import AnsibleValidationError TREELIB_IMP_ERR = None + try: - from treelib import Tree + from treelib import Node, Tree HAS_TREELIB = True except ImportError: HAS_TREELIB = False TREELIB_IMP_ERR = traceback.format_exc() + Tree = Node = object + YAML_IMP_ERR = None try: import yaml @@ -108,7 +105,7 @@ CVP_ROOT_CONTAINER = "Tenant" -def is_in_filter(hostname_filter=None, hostname="eos"): +def is_in_filter(hostname_filter: list | None = None, hostname: str = "eos") -> bool: """ Check if device is part of the filter or not. @@ -119,24 +116,19 @@ def is_in_filter(hostname_filter=None, hostname="eos"): hostname : str Device hostname to compare against filter. - Returns + Returns: ------- boolean True if device hostname is part of filter. False if not. """ - # W102 Workaround to avoid list as default value. if hostname_filter is None: hostname_filter = ["all"] - if "all" in hostname_filter: - return True - elif any(element in hostname for element in hostname_filter): - return True - return False + return "all" in hostname_filter or any(element in hostname for element in hostname_filter) -def isIterable(testing_object=None): +def is_iterable(testing_object: Any = None) -> bool | None: """ Test if an object is iterable or not. @@ -148,35 +140,33 @@ def isIterable(testing_object=None): Object to test if it is iterable or not, by default None """ try: - iter(testing_object) # noqa - return True + iter(testing_object) except TypeError: return False + return True + -def isLeaf(tree, nid): +def is_leaf(tree: Tree, nid: Node) -> bool: """ Test if NodeID is a leaf with no nid attached to it. Parameters ---------- - tree : treelib.Tree + tree : Tree Tree where NID is defined. - nid : treelib.Node + nid : Node NodeID to test. - Returns + Returns: ------- boolean True if node is a leaf, false in other situation """ - if nid and len(tree.is_branch(nid)) == 0: - return True - else: - return False + return bool(nid and len(tree.is_branch(nid)) == 0) -def get_configlet(src_folder="", prefix="AVD", extension="cfg", device_filter=None): +def get_configlet(src_folder: str = "", prefix: str = "AVD", extension: str = "cfg", device_filter: list | None = None) -> dict: """ Get available configlets to deploy to CVP. @@ -191,7 +181,7 @@ def get_configlet(src_folder="", prefix="AVD", extension="cfg", device_filter=No device_filter: list, optional List of filter to compare device configlet and to select only a subset of configlet. - Returns + Returns: ------- dict Dictionary of configlets found in source folder. @@ -200,22 +190,19 @@ def get_configlet(src_folder="", prefix="AVD", extension="cfg", device_filter=No if device_filter is None: device_filter = ["all"] - src_configlets = glob.glob(src_folder + "/*." + extension) + src_configlets = Path(src_folder).glob(f"*.{extension}") configlets = {} for file in src_configlets: # Build structure only if configlet match device_filter. - if is_in_filter(hostname=os.path.splitext(os.path.basename(file))[0], hostname_filter=device_filter): - if prefix != "none": - name = prefix + "_" + os.path.splitext(os.path.basename(file))[0] - else: - name = os.path.splitext(os.path.basename(file))[0] - with open(file, "r", encoding="utf8") as file: - data = file.read() + if is_in_filter(hostname=file.stem, hostname_filter=device_filter): + name = prefix + "_" + file.stem if prefix != "none" else file.stem + with file.open(encoding="utf8") as stream: + data = stream.read() configlets[name] = data return configlets -def get_device_option_value(device_data_dict, option_name): +def get_device_option_value(device_data_dict: dict, option_name: str) -> str | None: """ get_device_option_value Extract value of a host_var defined in inventory file. @@ -225,23 +212,24 @@ def get_device_option_value(device_data_dict, option_name): Parameters ---------- device_data_dict : dict - List of options defined under device. + Dict of options defined under device. option_name : string Name of option searched by function. - Returns + Returns: ------- string Value set for variable, else None """ - if isIterable(device_data_dict): + if is_iterable(device_data_dict): for option in device_data_dict: if option_name == option: return device_data_dict[option] return None + return None -def serialize_yaml_inventory_data(dict_inventory, parent_container=None, tree_topology=None): +def serialize_yaml_inventory_data(dict_inventory: dict, parent_container: str | None = None, tree_topology: Tree | None = None) -> Tree: """ Build a tree topology from YAML inventory file content. @@ -251,15 +239,15 @@ def serialize_yaml_inventory_data(dict_inventory, parent_container=None, tree_to Inventory YAML content. parent_container : str, optional Registration of container N-1 for recursive function, by default None - tree_topology : treelib.Tree, optional + tree_topology : Tree, optional Tree topology built over iteration, by default None - Returns + Returns: ------- - treelib.Tree + Tree complete container tree topology. """ - if isIterable(dict_inventory): + if is_iterable(dict_inventory): # Working with ROOT container for Fabric if tree_topology is None: # initiate tree topology and add ROOT under Tenant @@ -270,22 +258,23 @@ def serialize_yaml_inventory_data(dict_inventory, parent_container=None, tree_to # Recursive Inventory read for k1, v1 in dict_inventory.items(): # Read a leaf - if isIterable(v1) and "children" not in v1: + if is_iterable(v1) and "children" not in v1: tree_topology.create_node(k1, k1, parent=parent_container) # If subgroup has kids - if isIterable(v1) and "children" in v1: + if is_iterable(v1) and "children" in v1: tree_topology.create_node(k1, k1, parent=parent_container) serialize_yaml_inventory_data(dict_inventory=v1["children"], parent_container=k1, tree_topology=tree_topology) - elif k1 == "children" and isIterable(v1): + elif k1 == "children" and is_iterable(v1): # Extract sub-group information for k2, v2 in v1.items(): # Add subgroup to tree tree_topology.create_node(k2, k2, parent=parent_container) serialize_yaml_inventory_data(dict_inventory=v2, parent_container=k2, tree_topology=tree_topology) return tree_topology + return None -def get_devices(dict_inventory, search_container=None, devices=None, device_filter=None): +def get_devices(dict_inventory: dict | None, search_container: str | None = None, devices: list[str] | None = None, device_filter: list | None = None) -> list: """ Get devices attached to a container. @@ -300,7 +289,7 @@ def get_devices(dict_inventory, search_container=None, devices=None, device_filt device_filter: list, optional List of filter to compare device name and to select only a subset of devices. - Returns + Returns: ------- list List of found devices. @@ -313,7 +302,7 @@ def get_devices(dict_inventory, search_container=None, devices=None, device_filt for k1, v1 in dict_inventory.items(): # Read a leaf - if k1 == search_container and isIterable(v1) and "hosts" in v1: + if k1 == search_container and is_iterable(v1) and "hosts" in v1: for dev, data in v1["hosts"].items(): if ( is_in_filter(hostname_filter=device_filter, hostname=dev) @@ -321,16 +310,16 @@ def get_devices(dict_inventory, search_container=None, devices=None, device_filt ): devices.append(dev) # If subgroup has kids - if isIterable(v1) and "children" in v1: + if is_iterable(v1) and "children" in v1: get_devices(dict_inventory=v1["children"], search_container=search_container, devices=devices, device_filter=device_filter) - elif k1 == "children" and isIterable(v1): + elif k1 == "children" and is_iterable(v1): # Extract sub-group information - for k2, v2 in v1.items(): + for v2 in v1.values(): get_devices(dict_inventory=v2, search_container=search_container, devices=devices, device_filter=device_filter) return devices -def get_containers(inventory_content, parent_container, device_filter): +def get_containers(inventory_content: dict, parent_container: str, device_filter: list | None) -> dict: """ get_containers - Build Container topology to build on CoudVision. @@ -343,7 +332,7 @@ def get_containers(inventory_content, parent_container, device_filter): device_filter : list, optional List of filter to compare device name and to select only a subset of devices. - Returns + Returns: ------- JSON CVP Container structure to use with cv_container. @@ -359,7 +348,7 @@ def get_containers(inventory_content, parent_container, device_filter): if container == parent_container: data["parent_container"] = CVP_ROOT_CONTAINER elif parent.tag != CVP_ROOT_CONTAINER: - if isLeaf(tree=tree_dc, nid=container): + if is_leaf(tree=tree_dc, nid=container): devices = get_devices(dict_inventory=inventory_content, search_container=container, devices=[], device_filter=device_filter) data["devices"] = devices data["parent_container"] = parent.tag @@ -367,22 +356,19 @@ def get_containers(inventory_content, parent_container, device_filter): return container_json -def main(): +def main() -> None: """Main entry point for module execution.""" - # TODO - ansible module prefers constructor over literal - # for dict - # pylint: disable=use-dict-literal - argument_spec = dict( - inventory=dict(type="str", required=False), - container_root=dict(type="str", required=True), - configlet_dir=dict(type="str", required=False), - configlet_prefix=dict(type="str", required=False, default="AVD"), - destination=dict(type="str", required=False), - device_filter=dict(type="list", elements="str", default="all"), - ) + argument_spec = { + "inventory": {"type": "str", "required": False}, + "container_root": {"type": "str", "required": True}, + "configlet_dir": {"type": "str", "required": False}, + "configlet_prefix": {"type": "str", "required": False, "default": "AVD"}, + "destination": {"type": "str", "required": False}, + "device_filter": {"type": "list", "elements": "str", "default": "all"}, + } module = AnsibleModule(argument_spec=argument_spec, supports_check_mode=False) - result = dict(changed=False) + result = {"changed": False} if not HAS_YAML: module.fail_json(msg="yaml lib is required for this module") @@ -398,22 +384,27 @@ def main(): parent_container = module.params["container_root"] # Build containers & devices topology inventory_content = "" - with open(inventory_file, "r", encoding="utf8") as stream: + with Path(inventory_file).open(encoding="utf8") as stream: try: # add a constructor to return "!VAULT" for inline vault variables # to avoid the parse yaml.SafeLoader.add_constructor("!vault", lambda _, __: "!VAULT") inventory_content = yaml.safe_load(stream) except yaml.YAMLError as exc: - raise AnsibleValidationError("Failed to parse inventory file") from exc + msg = "Failed to parse inventory file" + raise AnsibleValidationError(msg) from exc result["cvp_topology"] = get_containers( - inventory_content=inventory_content, parent_container=parent_container, device_filter=module.params["device_filter"] + inventory_content=inventory_content, + parent_container=parent_container, + device_filter=module.params["device_filter"], ) # If set, build configlet topology if module.params["configlet_dir"] is not None: result["cvp_configlets"] = get_configlet( - src_folder=module.params["configlet_dir"], prefix=module.params["configlet_prefix"], device_filter=module.params["device_filter"] + src_folder=module.params["configlet_dir"], + prefix=module.params["configlet_prefix"], + device_filter=module.params["device_filter"], ) module.exit_json(**result) diff --git a/ansible_collections/arista/avd/plugins/plugin_utils/eos_validate_state_utils/ansible_eos_device.py b/ansible_collections/arista/avd/plugins/plugin_utils/eos_validate_state_utils/ansible_eos_device.py index 132372dae7f..0441572b97d 100644 --- a/ansible_collections/arista/avd/plugins/plugin_utils/eos_validate_state_utils/ansible_eos_device.py +++ b/ansible_collections/arista/avd/plugins/plugin_utils/eos_validate_state_utils/ansible_eos_device.py @@ -7,7 +7,7 @@ from functools import partial from json import JSONDecodeError, loads from logging import getLogger -from typing import TYPE_CHECKING, Generator +from typing import TYPE_CHECKING from urllib.error import HTTPError from ansible.errors import AnsibleActionFail, AnsibleConnectionFailure @@ -24,7 +24,7 @@ AnsibleActionFail( f"The '{PLUGIN_NAME}' plugin requires the 'pyavd' Python library. Got import error", orig_exc=e, - ) + ), ) logger = getLogger(__name__) @@ -41,6 +41,8 @@ AntaDevice = object if TYPE_CHECKING: + from collections.abc import Generator + from ansible.plugins.connection import ConnectionBase from anta.models import AntaCommand @@ -79,7 +81,7 @@ def __init__(self, name: str, connection: ConnectionBase, tags: list | None = No if not self.check_mode and not hasattr(connection, "_sub_plugin"): raise AristaAvdError( message="AVD could not determine the Ansible connection plugin used. " - "Please ensure that the 'ansible_network_os' and 'ansible_connection' variables are set to 'eos' and 'httpapi' respectively for this host." + "Please ensure that the 'ansible_network_os' and 'ansible_connection' variables are set to 'eos' and 'httpapi' respectively for this host.", ) # In check_mode we don't care that we cannot connect to the device if self.check_mode or (plugin_name := connection._sub_plugin.get("name")) == ANSIBLE_EOS_PLUGIN_NAME: @@ -87,7 +89,7 @@ def __init__(self, name: str, connection: ConnectionBase, tags: list | None = No else: raise AristaAvdError( message=f"The provided Ansible connection does not use EOS HttpApi plugin: {plugin_name}. " - "Please ensure that the 'ansible_network_os' and 'ansible_connection' variables are set to 'eos' and 'httpapi' respectively for this host." + "Please ensure that the 'ansible_network_os' and 'ansible_connection' variables are set to 'eos' and 'httpapi' respectively for this host.", ) @property @@ -104,7 +106,7 @@ def __rich_repr__(self) -> Generator: if __DEBUG__: yield "_connection", connection_vars - async def _collect(self, command: AntaCommand, *, collection_id: str | None = None) -> None: + async def _collect(self, command: AntaCommand, *, collection_id: str | None = None) -> None: # noqa: ARG002 """Collect device command result using Ansible HttpApi connection plugin. Supports outformat 'json' and 'text' as output structure. @@ -112,6 +114,9 @@ async def _collect(self, command: AntaCommand, *, collection_id: str | None = No Args: ---- command (AntaCommand): The command to collect. + + Keyword Args: + ------------- collection_id (str, optional): This parameter is not used in this implementation. Defaults to None. If there is an exception while collecting the command, the exception will be propagated diff --git a/ansible_collections/arista/avd/plugins/plugin_utils/eos_validate_state_utils/avdtestbase.py b/ansible_collections/arista/avd/plugins/plugin_utils/eos_validate_state_utils/avdtestbase.py index 0b53a1dc3ea..c6102fbfcee 100644 --- a/ansible_collections/arista/avd/plugins/plugin_utils/eos_validate_state_utils/avdtestbase.py +++ b/ansible_collections/arista/avd/plugins/plugin_utils/eos_validate_state_utils/avdtestbase.py @@ -4,11 +4,13 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Mapping +from typing import TYPE_CHECKING from .mixins import DeviceUtilsMixin, ValidationMixin if TYPE_CHECKING: + from collections.abc import Mapping + from .config_manager import ConfigManager LOGGER = logging.getLogger(__name__) @@ -58,7 +60,7 @@ def render(self) -> dict: If `test_definition` is not set or returns a falsy value (e.g., None), an empty dictionary will be returned instead. - Returns + Returns: ------- dict: The test definition if available and valid; otherwise, an empty dictionary. """ diff --git a/ansible_collections/arista/avd/plugins/plugin_utils/eos_validate_state_utils/config_manager.py b/ansible_collections/arista/avd/plugins/plugin_utils/eos_validate_state_utils/config_manager.py index 45eb3ccd6ed..4f255a786ff 100644 --- a/ansible_collections/arista/avd/plugins/plugin_utils/eos_validate_state_utils/config_manager.py +++ b/ansible_collections/arista/avd/plugins/plugin_utils/eos_validate_state_utils/config_manager.py @@ -4,9 +4,9 @@ from __future__ import annotations import logging +from collections.abc import Mapping from functools import cached_property from ipaddress import ip_interface -from typing import Mapping from ansible.errors import AnsibleActionFail @@ -22,7 +22,7 @@ AnsibleActionFail( f"The '{PLUGIN_NAME}' plugin requires the 'pyavd' Python library. Got import error", orig_exc=e, - ) + ), ) LOGGER = logging.getLogger(__name__) @@ -88,7 +88,7 @@ def get_vtep_mapping(self) -> list[tuple[str, str]]: def _get_loopback_mappings(self) -> dict: """Generate the loopback mappings for the eos_validate_state tests, which are used in AvdTestBase subclasses. - Returns + Returns: ------- dict: A dictionary containing: - "loopback0_mapping": A list of tuples where each tuple contains a hostname and its Loopback0 IP address. diff --git a/ansible_collections/arista/avd/plugins/plugin_utils/eos_validate_state_utils/csv_report.py b/ansible_collections/arista/avd/plugins/plugin_utils/eos_validate_state_utils/csv_report.py index 3f572037c0f..3f0bc1a5d1b 100644 --- a/ansible_collections/arista/avd/plugins/plugin_utils/eos_validate_state_utils/csv_report.py +++ b/ansible_collections/arista/avd/plugins/plugin_utils/eos_validate_state_utils/csv_report.py @@ -4,9 +4,10 @@ from __future__ import annotations import csv -from typing import TYPE_CHECKING, Generator +from typing import TYPE_CHECKING if TYPE_CHECKING: + from collections.abc import Generator from io import TextIOWrapper from .results_manager import ResultsManager @@ -31,7 +32,7 @@ def generate_rows(self) -> Generator[dict, None, None]: Results are sourced from `failed_tests` or `all_tests`, based on whether the report includes only failed tests or all results. - Yields + Yields: ------ Generator[dict, None, None]: A generator of test result dictionaries representing CSV rows. """ diff --git a/ansible_collections/arista/avd/plugins/plugin_utils/eos_validate_state_utils/get_anta_results.py b/ansible_collections/arista/avd/plugins/plugin_utils/eos_validate_state_utils/get_anta_results.py index 4fccf2e4e1d..1acc74b0f4d 100644 --- a/ansible_collections/arista/avd/plugins/plugin_utils/eos_validate_state_utils/get_anta_results.py +++ b/ansible_collections/arista/avd/plugins/plugin_utils/eos_validate_state_utils/get_anta_results.py @@ -25,7 +25,7 @@ AnsibleActionFail( f"The '{PLUGIN_NAME}' plugin requires the 'pyavd' Python library. Got import error", orig_exc=e, - ) + ), ) if TYPE_CHECKING: @@ -147,7 +147,7 @@ def load_custom_catalogs(catalog_files: list[Path]) -> dict: with file.open("r", encoding="UTF-8") as fd: catalog = load(fd, Loader=CSafeLoader) catalog_list.append(catalog) - except (OSError, YAMLError) as error: + except (OSError, YAMLError) as error: # noqa: PERF203 TODO: Investigate and improve code to avoid try/except inside loop msg = f"Failed to load the custom ANTA catalog from {file}: {error!s}" raise AristaAvdError(msg) from error @@ -241,5 +241,5 @@ def create_dry_run_report(device_name: str, catalog: AntaCatalog, manager: Resul categories=categories, description=description, custom_field=custom_field, - ) + ), ) diff --git a/ansible_collections/arista/avd/plugins/plugin_utils/eos_validate_state_utils/md_report.py b/ansible_collections/arista/avd/plugins/plugin_utils/eos_validate_state_utils/md_report.py index 3d378123854..19ef3f895e8 100644 --- a/ansible_collections/arista/avd/plugins/plugin_utils/eos_validate_state_utils/md_report.py +++ b/ansible_collections/arista/avd/plugins/plugin_utils/eos_validate_state_utils/md_report.py @@ -5,9 +5,10 @@ import re from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, ClassVar, Generator +from typing import TYPE_CHECKING, ClassVar if TYPE_CHECKING: + from collections.abc import Generator from io import TextIOWrapper from .results_manager import ResultsManager @@ -137,7 +138,7 @@ def safe_markdown(self, text: str | None) -> str: ---------- text (str): The text to escape markdown characters from. - Returns + Returns: ------- str: The text with escaped markdown characters. """ diff --git a/ansible_collections/arista/avd/plugins/plugin_utils/eos_validate_state_utils/mixins.py b/ansible_collections/arista/avd/plugins/plugin_utils/eos_validate_state_utils/mixins.py index f6fe6aa8963..3eaa0edd821 100644 --- a/ansible_collections/arista/avd/plugins/plugin_utils/eos_validate_state_utils/mixins.py +++ b/ansible_collections/arista/avd/plugins/plugin_utils/eos_validate_state_utils/mixins.py @@ -19,7 +19,7 @@ AnsibleActionFail( f"The '{PLUGIN_NAME}' plugin requires the 'pyavd' Python library. Got import error", orig_exc=e, - ) + ), ) LOGGER = logging.getLogger(__name__) @@ -49,7 +49,7 @@ def update_interface_shutdown(self, interface: dict, host: str | None = None) -> """ host_struct_cfg = self.config_manager.get_host_structured_config(host=host) if host else self.structured_config if "Ethernet" in get(interface, "name", ""): - interface["shutdown"] = default(get(interface, "shutdown"), get(host_struct_cfg, "interface_defaults.ethernet.shutdown"), False) + interface["shutdown"] = default(get(interface, "shutdown"), get(host_struct_cfg, "interface_defaults.ethernet.shutdown"), False) # noqa: FBT003 else: interface["shutdown"] = get(interface, "shutdown", default=False) @@ -120,7 +120,7 @@ class ValidationMixin: It should be used as a mixin class in the AvdTestBase classes. """ - # TODO @carl-baillargeon: Split the validate_data method into two methods: one for expected key-value pairs and one for required keys. + # TODO: @carl-baillargeon: Split the validate_data method into two methods: one for expected key-value pairs and one for required keys. def validate_data( self, data: dict | None = None, diff --git a/ansible_collections/arista/avd/plugins/plugin_utils/eos_validate_state_utils/results_manager.py b/ansible_collections/arista/avd/plugins/plugin_utils/eos_validate_state_utils/results_manager.py index 212e8104799..396a27cef99 100644 --- a/ansible_collections/arista/avd/plugins/plugin_utils/eos_validate_state_utils/results_manager.py +++ b/ansible_collections/arista/avd/plugins/plugin_utils/eos_validate_state_utils/results_manager.py @@ -139,7 +139,7 @@ def update_results(self, result: dict) -> None: def total_tests(self) -> int: """Calculates the total number of tests processed. - Returns + Returns: ------- int: The total number of tests. """ @@ -149,7 +149,7 @@ def total_tests(self) -> int: def sorted_category_stats(self) -> dict: """A property that returns the category_stats dictionary sorted by key name. - Returns + Returns: ------- dict: The sorted category_stats dictionary. """ diff --git a/ansible_collections/arista/avd/plugins/plugin_utils/merge/mergecatalog.py b/ansible_collections/arista/avd/plugins/plugin_utils/merge/mergecatalog.py index 9439ea43176..cd274554774 100644 --- a/ansible_collections/arista/avd/plugins/plugin_utils/merge/mergecatalog.py +++ b/ansible_collections/arista/avd/plugins/plugin_utils/merge/mergecatalog.py @@ -16,7 +16,7 @@ AnsibleActionFail( f"The '{PLUGIN_NAME}' plugin requires the 'pyavd' Python library. Got import error", orig_exc=e, - ) + ), ) try: diff --git a/ansible_collections/arista/avd/plugins/plugin_utils/pyavd_wrappers.py b/ansible_collections/arista/avd/plugins/plugin_utils/pyavd_wrappers.py index 0a9c5dc9554..58f98a98aad 100644 --- a/ansible_collections/arista/avd/plugins/plugin_utils/pyavd_wrappers.py +++ b/ansible_collections/arista/avd/plugins/plugin_utils/pyavd_wrappers.py @@ -1,19 +1,21 @@ # Copyright (c) 2019-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -from __future__ import absolute_import, annotations, division, print_function - -__metaclass__ = type +from __future__ import annotations import warnings from functools import partial, wraps -from typing import Callable, Literal +from typing import TYPE_CHECKING, Any, Literal from ansible.errors import AnsibleFilterError, AnsibleInternalError, AnsibleUndefinedVariable from ansible.module_utils.basic import to_native from ansible.utils.display import Display from jinja2.exceptions import UndefinedError +if TYPE_CHECKING: + from collections.abc import Callable + from typing import NoReturn + display = Display() @@ -24,10 +26,10 @@ class RaiseOnUse: Used with Ansible try/except import logic to not fail on import of plugins, but instead fail on first use. """ - def __init__(self, exception: Exception): + def __init__(self, exception: Exception) -> None: self.exception = exception - def __call__(self, *args, **kwargs): + def __call__(self, *_args: Any, **_kwargs: Any) -> NoReturn: raise self.exception @@ -38,11 +40,12 @@ def wrap_plugin(plugin_type: Literal["filter", "test"], name: str) -> Callable: } if plugin_type not in plugin_map: - raise AnsibleInternalError(f"Wrong plugin type {plugin_type} passed to wrap_plugin.") + msg = f"Wrong plugin type {plugin_type} passed to wrap_plugin." + raise AnsibleInternalError(msg) def wrap_plugin_decorator(func: Callable) -> Callable: @wraps(func) - def plugin_wrapper(*args, **kwargs): + def plugin_wrapper(*args: Any, **kwargs: Any) -> Any: """Wrapper function for plugins. NOTE: if the same warning is raised multiple times, Ansible Display() will print only one @@ -53,11 +56,14 @@ def plugin_wrapper(*args, **kwargs): if w: for warning in w: display.warning(str(warning.message)) - return result except UndefinedError as e: - raise AnsibleUndefinedVariable(f"{plugin_type.capitalize()} '{name}' failed: {to_native(e)}", orig_exc=e) from e + msg = f"{plugin_type.capitalize()} '{name}' failed: {to_native(e)}" + raise AnsibleUndefinedVariable(msg, orig_exc=e) from e except Exception as e: - raise plugin_map[plugin_type](f"{plugin_type.capitalize()} '{name}' failed: {to_native(e)}", orig_exc=e) from e + msg = f"{plugin_type.capitalize()} '{name}' failed: {to_native(e)}" + raise plugin_map[plugin_type](msg, orig_exc=e) from e + + return result return plugin_wrapper diff --git a/ansible_collections/arista/avd/plugins/plugin_utils/schema/avdschematools.py b/ansible_collections/arista/avd/plugins/plugin_utils/schema/avdschematools.py index bbdcb73c8d7..9618540e2a2 100644 --- a/ansible_collections/arista/avd/plugins/plugin_utils/schema/avdschematools.py +++ b/ansible_collections/arista/avd/plugins/plugin_utils/schema/avdschematools.py @@ -3,13 +3,17 @@ # that can be found in the LICENSE file. from __future__ import annotations -from typing import Generator +from typing import TYPE_CHECKING from ansible.errors import AnsibleActionFail -from ansible.utils.display import Display from ansible_collections.arista.avd.plugins.plugin_utils.pyavd_wrappers import RaiseOnUse +if TYPE_CHECKING: + from collections.abc import Generator + + from ansible.utils.display import Display + try: from pyavd._errors import AristaAvdError, AvdDeprecationWarning from pyavd._schema.avdschema import AvdSchema @@ -25,19 +29,17 @@ class AvdSchemaTools: - """ - Tools that wrap the various schema components for easy reuse in Ansible plugins - """ + """Tools that wrap the various schema components for easy reuse in Ansible plugins.""" def __init__( self, hostname: str, ansible_display: Display, - schema: dict = None, - schema_id: str = None, - conversion_mode: str = None, - validation_mode: str = None, - plugin_name: str = None, + schema: dict | None = None, + schema_id: str | None = None, + conversion_mode: str | None = None, + validation_mode: str | None = None, + plugin_name: str | None = None, ) -> None: self._set_schema(schema, schema_id) self.hostname = hostname @@ -48,22 +50,26 @@ def __init__( def _set_schema(self, schema: dict | None, schema_id: str | None) -> None: if schema is None and schema_id is None: - raise AnsibleActionFail("Either argument 'schema' or 'schema_id' must be set") + msg = "Either argument 'schema' or 'schema_id' must be set" + raise AnsibleActionFail(msg) try: self.avdschema = AvdSchema(schema=schema, schema_id=schema_id) except AristaAvdError as e: - raise AnsibleActionFail("Invalid Schema!") from e + msg = "Invalid Schema!" + raise AnsibleActionFail(msg) from e def _set_conversion_mode(self, conversion_mode: str | None) -> None: if conversion_mode is None: conversion_mode = DEFAULT_CONVERSION_MODE if not isinstance(conversion_mode, str): - raise AnsibleActionFail("The argument 'conversion_mode' must be a string") + msg = "The argument 'conversion_mode' must be a string" + raise AnsibleActionFail(msg) if conversion_mode not in VALID_CONVERSION_MODES: - raise AnsibleActionFail(f"Invalid value '{conversion_mode}' for the argument 'conversion_mode'. Must be one of {VALID_CONVERSION_MODES}") + msg = f"Invalid value '{conversion_mode}' for the argument 'conversion_mode'. Must be one of {VALID_CONVERSION_MODES}" + raise AnsibleActionFail(msg) self.conversion_mode = conversion_mode @@ -72,19 +78,22 @@ def _set_validation_mode(self, validation_mode: str | None) -> None: validation_mode = DEFAULT_VALIDATION_MODE if not isinstance(validation_mode, str): - raise AnsibleActionFail("The argument 'validation_mode' must be a string") + msg = "The argument 'validation_mode' must be a string" + raise AnsibleActionFail(msg) if validation_mode not in VALID_VALIDATION_MODES: - raise AnsibleActionFail(f"Invalid value '{validation_mode}' for the argument 'validation_mode'. Must be one of {VALID_VALIDATION_MODES}") + msg = f"Invalid value '{validation_mode}' for the argument 'validation_mode'. Must be one of {VALID_VALIDATION_MODES}" + raise AnsibleActionFail(msg) self.validation_mode = validation_mode def convert_data(self, data: dict) -> int: """ - Convert data according to the schema (convert_types) + Convert data according to the schema (convert_types). + The data conversion is done in-place (updating the original "data" dict). - Returns + Returns: ------- int : number of conversions done """ @@ -108,9 +117,9 @@ def convert_data(self, data: dict) -> int: def validate_data(self, data: dict) -> int: """ - Validate data according to the schema + Validate data according to the schema. - Returns + Returns: ------- int : number of validation errors """ @@ -123,7 +132,7 @@ def validate_data(self, data: dict) -> int: def convert_and_validate_data(self, data: dict, return_counters: bool = False) -> dict: """ - Convert & Validate data according to the schema + Convert & Validate data according to the schema. Calls conversion and validation methods and gather resulting messages @@ -157,6 +166,7 @@ def convert_and_validate_data(self, data: dict, return_counters: bool = False) - def handle_validation_exceptions(self, exceptions: Generator, mode: str) -> int: """ Iterate through the Generator of exceptions. + This method is actually where the content of the generator gets executed. It displays various messages depending on the `mode` parameter @@ -183,7 +193,7 @@ def handle_validation_exceptions(self, exceptions: Generator, mode: str) -> int: date=exception.date, collection_name=self.plugin_name, removed=exception.removed, - ) + ), ) self.ansible_display.deprecated( @@ -200,28 +210,27 @@ def handle_validation_exceptions(self, exceptions: Generator, mode: str) -> int: continue message = f"[{self.hostname}]: {exception}" if mode == "error": - self.ansible_display.error(message, False) + self.ansible_display.error(message, wrap_text=False) elif mode == "info": self.ansible_display.display(message) elif mode == "debug": self.ansible_display.v(message) else: - # mode == "warning" - self.ansible_display.warning(message, False) + # when mode == "warning" + self.ansible_display.warning(message, wrap_text=False) return counter def validate_schema(self) -> int: """ - Validate the loaded schema according to the meta-schema + Validate the loaded schema according to the meta-schema. Returns int with number of validation errors """ - # avd_schema.validate_schema returns a generator, which we iterate through in handle_exceptions to perform the actual conversions. exceptions = self.avdschema.validate_schema(self.avdschema._schema) return self.handle_validation_exceptions(exceptions, "error") - def build_result_message(self, conversions: int = 0, validation_errors: int = 0, schema_validation_errors: int = 0): + def build_result_message(self, conversions: int = 0, validation_errors: int = 0, schema_validation_errors: int = 0) -> str | None: result_messages = [] if conversions: diff --git a/ansible_collections/arista/avd/plugins/plugin_utils/utils/compile_searchpath.py b/ansible_collections/arista/avd/plugins/plugin_utils/utils/compile_searchpath.py index fde6e0ea5bf..c02c811a0e7 100644 --- a/ansible_collections/arista/avd/plugins/plugin_utils/utils/compile_searchpath.py +++ b/ansible_collections/arista/avd/plugins/plugin_utils/utils/compile_searchpath.py @@ -1,17 +1,17 @@ # Copyright (c) 2023-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -from os.path import join as path_join +from pathlib import Path -def compile_searchpath(searchpath: list): +def compile_searchpath(searchpath: list) -> list[str]: """ - Create a new searchpath by inserting new items with <>/templates into the existing searchpath + Create a new searchpath by inserting new items with <>/templates into the existing searchpath. This is copying the behavior of the "ansible.builtin.template" lookup module, and is necessary to be able to load templates from all supported paths. - Example + Example: ------- compile_searchpath(["patha", "pathb", "pathc"]) -> ["patha", "patha/templates", "pathb", "pathb/templates", "pathc", "pathc/templates"] @@ -21,14 +21,13 @@ def compile_searchpath(searchpath: list): searchpath : list of str List of Paths - Returns + Returns: ------- list of str List of both original and extra paths with "/templates" added. """ - newsearchpath = [] for p in searchpath: - newsearchpath.append(path_join(p, "templates")) + newsearchpath.append(str(Path(p, "templates"))) newsearchpath.append(p) return newsearchpath diff --git a/ansible_collections/arista/avd/plugins/plugin_utils/utils/cprofile_decorator.py b/ansible_collections/arista/avd/plugins/plugin_utils/utils/cprofile_decorator.py index e86db9ed574..1bf1a91d47d 100644 --- a/ansible_collections/arista/avd/plugins/plugin_utils/utils/cprofile_decorator.py +++ b/ansible_collections/arista/avd/plugins/plugin_utils/utils/cprofile_decorator.py @@ -6,7 +6,10 @@ import cProfile import pstats from functools import wraps -from typing import Any, Callable +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import Callable def cprofile(sort_by: str = "cumtime") -> Callable: diff --git a/ansible_collections/arista/avd/plugins/plugin_utils/utils/get_templar.py b/ansible_collections/arista/avd/plugins/plugin_utils/utils/get_templar.py index 266202a9650..7f79219d89a 100644 --- a/ansible_collections/arista/avd/plugins/plugin_utils/utils/get_templar.py +++ b/ansible_collections/arista/avd/plugins/plugin_utils/utils/get_templar.py @@ -14,8 +14,8 @@ def get_templar(action_plugin_instance: ActionBase, task_vars: dict) -> Templar: """ - Return a new instance of Ansible Templar Class based on the - "._templar" from the given action_plugin_instance. + Return a new instance of Ansible Templar Class based on the "._templar" from the given action_plugin_instance. + The new instance is loaded with new searchpath based on ".ansible_search_path" from the given task_vars. """ diff --git a/ansible_collections/arista/avd/plugins/plugin_utils/utils/log_message.py b/ansible_collections/arista/avd/plugins/plugin_utils/utils/log_message.py index 8f08113e395..38bcfd8a789 100644 --- a/ansible_collections/arista/avd/plugins/plugin_utils/utils/log_message.py +++ b/ansible_collections/arista/avd/plugins/plugin_utils/utils/log_message.py @@ -34,7 +34,8 @@ def log_message( """ # Validate logging level if log_level.upper() not in LOGGING_LEVELS: - raise ValueError("Invalid logging level. Please choose from DEBUG, INFO, WARNING, ERROR, CRITICAL.") + msg = "Invalid logging level. Please choose from DEBUG, INFO, WARNING, ERROR, CRITICAL." + raise ValueError(msg) dot_notation = f"{key_path}.{key}" if key_path else f"{key}" msg_type = "is missing" if not value else f"!= '{value}'" diff --git a/ansible_collections/arista/avd/plugins/plugin_utils/utils/python_to_ansible_logging_handler.py b/ansible_collections/arista/avd/plugins/plugin_utils/utils/python_to_ansible_logging_handler.py index c4f975f64ac..ea4d68f9365 100644 --- a/ansible_collections/arista/avd/plugins/plugin_utils/utils/python_to_ansible_logging_handler.py +++ b/ansible_collections/arista/avd/plugins/plugin_utils/utils/python_to_ansible_logging_handler.py @@ -15,7 +15,7 @@ class PythonToAnsibleHandler(Handler): """ - Logging Handler that makes a bridge between Ansible display and plugin Result objects + Logging Handler that makes a bridge between Ansible display and plugin Result objects. It is used to: * send ERROR or CRITICAL logs to result[stderr] and failed the plugins @@ -31,15 +31,13 @@ def __init__(self, result: dict, display: Display) -> None: self.result = result def emit(self, record: LogRecord) -> None: - """ - Custom emit function that reads the message level - """ + """Custom emit function that reads the message level.""" message = self._format_msg(record) if record.levelno in [logging.CRITICAL, logging.ERROR]: self.result.setdefault("stderr_lines", []).append(message) - self.result["stderr"] = self.result.setdefault("stderr", "") + f"{str(message)}\n" + self.result["stderr"] = self.result.setdefault("stderr", "") + f"{message!s}\n" self.result["failed"] = True - elif record.levelno in [logging.WARN, logging.WARNING]: + elif record.levelno in [logging.WARNING, logging.WARNING]: self.result.setdefault("warnings", []).append(message) elif record.levelno == logging.INFO: self.display.v(str(message)) @@ -47,27 +45,22 @@ def emit(self, record: LogRecord) -> None: self.display.vvv(str(message)) def _format_msg(self, record: LogRecord) -> str: - """ - Used to format an augmented LogRecord that contains the 'hostname' attribute - """ + """Used to format an augmented LogRecord that contains the 'hostname' attribute.""" return f"<{record.hostname}> {self.format(record)}" if hasattr(record, "hostname") else self.format(record) class PythonToAnsibleContextFilter(Filter): """ - Logging Filter to extend the LogRecord that goes through it with an - extra attribute 'hostname'. For this, it needs to be initialized with a hostname. + Logging Filter to extend the LogRecord that goes through it with an extra attribute 'hostname'. For this, it needs to be initialized with a hostname. This extra attribute can then be used in the PythonToAnsibleHandler to format the messages """ - def __init__(self, hostname: str): + def __init__(self, hostname: str) -> None: super().__init__() self.hostname = hostname def filter(self, record: LogRecord) -> bool: - """ - Add self.hostname as an attribute to the LogRecord - """ + """Add self.hostname as an attribute to the LogRecord.""" record.hostname = self.hostname return True diff --git a/ansible_collections/arista/avd/plugins/plugin_utils/utils/yaml_dumper.py b/ansible_collections/arista/avd/plugins/plugin_utils/utils/yaml_dumper.py index f5889de8590..9814f5ed1cf 100644 --- a/ansible_collections/arista/avd/plugins/plugin_utils/utils/yaml_dumper.py +++ b/ansible_collections/arista/avd/plugins/plugin_utils/utils/yaml_dumper.py @@ -3,6 +3,8 @@ # that can be found in the LICENSE file. from __future__ import annotations +from typing import Any + try: from yaml import CSafeDumper as YamlDumper except ImportError: @@ -11,7 +13,7 @@ # https://ttl255.com/yaml-anchors-and-aliases-and-how-to-disable-them/ class NoAliasDumper(YamlDumper): - def ignore_aliases(self, data): + def ignore_aliases(self, _data: Any) -> bool: return True diff --git a/ansible_collections/arista/avd/plugins/test/contains.py b/ansible_collections/arista/avd/plugins/test/contains.py index 175e992a890..9fa310b9c1d 100644 --- a/ansible_collections/arista/avd/plugins/test/contains.py +++ b/ansible_collections/arista/avd/plugins/test/contains.py @@ -1,28 +1,25 @@ # Copyright (c) 2023-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -# -# arista.avd.contains -# -# Example: -# A is [1, 2] -# B is [3, 4] -# C is [2, 3] -# -# Jinja test examples: -# {% if A is arista.avd.contains(B) %} => false -# {% if B is arista.avd.contains(C) %} => true -# {% if C is arista.avd.contains(A) %} => true -# {% if C is arista.avd.contains(B) %} => true -# -# {% if A is arista.avd.contains(0) %} => false -# {% if B is arista.avd.contains(1) %} => false -# {% if C is arista.avd.contains(2) %} => true -# {% if D is arista.avd.contains(3) %} => false <- Protecting against undefined gracefully. -from __future__ import absolute_import, division, print_function +""" +arista.avd.contains test plugin. + +Example: +A = [1, 2] +B = [3, 4] +C = [2, 3] -__metaclass__ = type +Jinja test examples: +{% if A is arista.avd.contains(B) %} => false +{% if B is arista.avd.contains(C) %} => true +{% if C is arista.avd.contains(A) %} => true +{% if C is arista.avd.contains(B) %} => true +{% if A is arista.avd.contains(0) %} => false +{% if B is arista.avd.contains(1) %} => false +{% if C is arista.avd.contains(2) %} => true +{% if D is arista.avd.contains(3) %} => false <- Protecting against undefined gracefully. +""" from ansible.errors import AnsibleFilterError @@ -37,7 +34,7 @@ AnsibleFilterError( f"The '{PLUGIN_NAME}' plugin requires the 'pyavd' Python library. Got import error", orig_exc=e, - ) + ), ) @@ -81,6 +78,6 @@ """ -class TestModule(object): - def tests(self): +class TestModule: + def tests(self) -> dict: return {"contains": wrap_test(PLUGIN_NAME)(contains)} diff --git a/ansible_collections/arista/avd/plugins/test/defined.py b/ansible_collections/arista/avd/plugins/test/defined.py index 7067247eec6..375a6793876 100644 --- a/ansible_collections/arista/avd/plugins/test/defined.py +++ b/ansible_collections/arista/avd/plugins/test/defined.py @@ -1,29 +1,26 @@ # Copyright (c) 2023-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -# -# arista.avd.defined -# -# Example: -# A is undefined -# B is none -# C is "c" -# D is "d" -# -# Jinja test examples: -# {% if A is arista.avd.defined %} => false -# {% if B is arista.avd.defined %} => false -# {% if C is arista.avd.defined %} => true -# {% if D is arista.avd.defined %} => true -# -# {% if A is arista.avd.defined("c") %} => false -# {% if B is arista.avd.defined("c") %} => false -# {% if C is arista.avd.defined("c") %} => true -# {% if D is arista.avd.defined("c") %} => false -from __future__ import absolute_import, division, print_function +""" +arista.avd.defined test plugin. + +Example: +A is undefined +B is none +C is "c" +D is "d" -__metaclass__ = type +Jinja test examples: +{% if A is arista.avd.defined %} => false +{% if B is arista.avd.defined %} => false +{% if C is arista.avd.defined %} => true +{% if D is arista.avd.defined %} => true +{% if A is arista.avd.defined("c") %} => false +{% if B is arista.avd.defined("c") %} => false +{% if C is arista.avd.defined("c") %} => true +{% if D is arista.avd.defined("c") %} => false +""" from ansible.errors import AnsibleTemplateError @@ -38,7 +35,7 @@ AnsibleTemplateError( f"The '{PLUGIN_NAME}' plugin requires the 'pyavd' Python library. Got import error", orig_exc=e, - ) + ), ) @@ -97,6 +94,6 @@ """ -class TestModule(object): - def tests(self): +class TestModule: + def tests(self) -> dict: return {"defined": wrap_test(PLUGIN_NAME)(defined)} diff --git a/ansible_collections/arista/avd/plugins/vars/global_vars.py b/ansible_collections/arista/avd/plugins/vars/global_vars.py index 931ac5da3e1..e86d5aa6d10 100644 --- a/ansible_collections/arista/avd/plugins/vars/global_vars.py +++ b/ansible_collections/arista/avd/plugins/vars/global_vars.py @@ -1,9 +1,7 @@ # Copyright (c) 2023-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -from __future__ import absolute_import, division, print_function -__metaclass__ = type DOCUMENTATION = r""" --- @@ -85,12 +83,13 @@ """ -import os +from pathlib import Path +from typing import Any from ansible.errors import AnsibleParserError from ansible.inventory.group import Group from ansible.inventory.host import Host -from ansible.module_utils._text import to_bytes, to_native, to_text +from ansible.module_utils._text import to_native from ansible.plugins.vars import BaseVarsPlugin from ansible.utils.vars import combine_vars @@ -98,42 +97,36 @@ class VarsModule(BaseVarsPlugin): - def find_variable_source(self, path, loader): - """ - Return the source files from which to load data, - if the path is a directory - lookup vars file inside - """ + def find_variable_source(self, path: str, loader: object) -> list: + """Return the source files from which to load data, if the path is a directory - lookup vars file inside.""" global_vars_paths = self.get_option("paths") extensions = self.get_option("_valid_extensions") found_files = [] for g_path in global_vars_paths: - b_opath = os.path.realpath(to_bytes(os.path.join(path, g_path))) - opath = to_text(b_opath) try: - if not os.path.exists(b_opath): + opath = Path(path, g_path) + if not opath.exists(): # file does not exist, skip it self._display.vvv(f"Path: {opath} does not exist - skipping") continue self._display.vvv(f"Adding Path: {opath} to global variables") - if os.path.isdir(b_opath): + if opath.is_dir(): self._display.debug(f"\tProcessing dir {opath}") - res = loader._get_dir_vars_files(opath, extensions) - self._display.debug(f"Found variable files {str(res)}") + res = loader._get_dir_vars_files(str(opath), extensions) + self._display.debug(f"Found variable files {res!s}") found_files.extend(res) else: - found_files.append(b_opath) + found_files.append(str(opath)) except Exception as e: raise AnsibleParserError(to_native(e)) from e return found_files - def get_vars(self, loader, path, entities, cache=True): - """ - Return global variables for the `all` group in the inventory file - """ - global FOUND + def get_vars(self, loader: object, path: str, entities: Any, _cache: bool = True) -> dict: + """Return global variables for the `all` group in the inventory file.""" + global FOUND # noqa: PLW0603 TODO: improve to avoid using global if not isinstance(entities, list): entities = [entities] @@ -145,12 +138,11 @@ def get_vars(self, loader, path, entities, cache=True): if not isinstance(entity, (Host, Group)): # Changed the error message because the TYPE_REGEX of ansible was triggering # unidiomatic-typecheck because of the `or` word before the type call... - raise AnsibleParserError(f"Supplied entity is of type {type(entity)} but must be of type Host or Group instead") + msg = f"Supplied entity is of type {type(entity)} but must be of type Host or Group instead" + raise AnsibleParserError(msg) if entity.name != "all": continue - print(entity.name, path) - for path in FOUND: new_data = loader.load_from_file(path, cache=True, unsafe=True) if new_data: diff --git a/ansible_collections/arista/avd/roles/eos_designs/python_modules/interface_descriptions/__init__.py b/ansible_collections/arista/avd/roles/eos_designs/python_modules/interface_descriptions/__init__.py index 360d52c3978..9b9ce4f92ba 100644 --- a/ansible_collections/arista/avd/roles/eos_designs/python_modules/interface_descriptions/__init__.py +++ b/ansible_collections/arista/avd/roles/eos_designs/python_modules/interface_descriptions/__init__.py @@ -13,7 +13,7 @@ AnsibleActionFail( "The 'arista.avd.eos_designs' collection requires the 'pyavd' Python library. Got import error", orig_exc=e, - ) + ), ) diff --git a/ansible_collections/arista/avd/roles/eos_designs/python_modules/ip_addressing/__init__.py b/ansible_collections/arista/avd/roles/eos_designs/python_modules/ip_addressing/__init__.py index 36dee1d7bcd..4a62f2fe9a4 100644 --- a/ansible_collections/arista/avd/roles/eos_designs/python_modules/ip_addressing/__init__.py +++ b/ansible_collections/arista/avd/roles/eos_designs/python_modules/ip_addressing/__init__.py @@ -12,7 +12,7 @@ AnsibleActionFail( "The 'arista.avd.eos_designs' collection requires the 'pyavd' Python library. Got import error", orig_exc=e, - ) + ), ) diff --git a/ansible_collections/arista/avd/roles/eos_validate_state/python_modules/tests/avdtestconnectivity.py b/ansible_collections/arista/avd/roles/eos_validate_state/python_modules/tests/avdtestconnectivity.py index 467d813734e..6de1d965c9d 100644 --- a/ansible_collections/arista/avd/roles/eos_validate_state/python_modules/tests/avdtestconnectivity.py +++ b/ansible_collections/arista/avd/roles/eos_validate_state/python_modules/tests/avdtestconnectivity.py @@ -14,9 +14,7 @@ class AvdTestP2PIPReachability(AvdTestBase): - """ - AvdTestP2PIPReachability class for P2P IP reachability tests. - """ + """AvdTestP2PIPReachability class for P2P IP reachability tests.""" anta_module = "anta.tests.connectivity" @@ -57,17 +55,15 @@ def test_definition(self) -> dict | None: "VerifyReachability": { "hosts": [{"source": src_ip, "destination": dst_ip, "vrf": "default", "repeat": 1}], "result_overwrite": {"custom_field": custom_field}, - } - } + }, + }, ) return {self.anta_module: anta_tests} if anta_tests else None class AvdTestInbandReachability(AvdTestBase): - """ - AvdTestInbandReachability class for inband management reachability tests. - """ + """AvdTestInbandReachability class for inband management reachability tests.""" anta_module = "anta.tests.connectivity" @@ -106,17 +102,15 @@ def test_definition(self) -> dict | None: "VerifyReachability": { "hosts": [{"source": src_ip, "destination": dst_ip, "vrf": vrf, "repeat": 1}], "result_overwrite": {"custom_field": custom_field}, - } - } + }, + }, ) return {self.anta_module: anta_tests} if anta_tests else None class AvdTestLoopback0Reachability(AvdTestBase): - """ - AvdTestLoopback0Reachability class for Loopback0 reachability tests. - """ + """AvdTestLoopback0Reachability class for Loopback0 reachability tests.""" anta_module = "anta.tests.connectivity" @@ -156,17 +150,15 @@ def test_definition(self) -> dict | None: "VerifyReachability": { "hosts": [{"source": src_ip, "destination": dst_ip, "vrf": "default", "repeat": 1}], "result_overwrite": {"custom_field": custom_field}, - } - } + }, + }, ) return {self.anta_module: anta_tests} if anta_tests else None class AvdTestLLDPTopology(AvdTestBase): - """ - AvdTestLLDPTopology class for the LLDP topology tests. - """ + """AvdTestLLDPTopology class for the LLDP topology tests.""" anta_module = "anta.tests.connectivity" @@ -212,11 +204,11 @@ def test_definition(self) -> dict | None: "port": str(interface["name"]), "neighbor_device": str(peer), "neighbor_port": str(interface["peer_interface"]), - } + }, ], "result_overwrite": {"custom_field": custom_field}, - } - } + }, + }, ) return {self.anta_module: anta_tests} if anta_tests else None diff --git a/ansible_collections/arista/avd/roles/eos_validate_state/python_modules/tests/avdtesthardware.py b/ansible_collections/arista/avd/roles/eos_validate_state/python_modules/tests/avdtesthardware.py index 9b324407c7c..e9f09f6abaf 100644 --- a/ansible_collections/arista/avd/roles/eos_validate_state/python_modules/tests/avdtesthardware.py +++ b/ansible_collections/arista/avd/roles/eos_validate_state/python_modules/tests/avdtesthardware.py @@ -10,9 +10,7 @@ class AvdTestHardware(AvdTestBase): - """ - AvdTestHardware class for hardware tests. - """ + """AvdTestHardware class for hardware tests.""" anta_module = "anta.tests.hardware" @@ -39,20 +37,20 @@ def test_definition(self) -> dict: "VerifyEnvironmentPower": { "states": pwr_supply_states, "result_overwrite": {"custom_field": f"Accepted States: {self.format_list(pwr_supply_states)}"}, - } + }, }, { "VerifyEnvironmentCooling": { "states": fan_states, "result_overwrite": {"custom_field": f"Accepted States: {self.format_list(fan_states)}"}, - } + }, }, {"VerifyTemperature": None}, { "VerifyTransceiversManufacturers": { "manufacturers": xcvr_manufacturers, "result_overwrite": {"custom_field": f"Accepted Manufacturers: {self.format_list(xcvr_manufacturers)}"}, - } + }, }, ] diff --git a/ansible_collections/arista/avd/roles/eos_validate_state/python_modules/tests/avdtestinterfaces.py b/ansible_collections/arista/avd/roles/eos_validate_state/python_modules/tests/avdtestinterfaces.py index f2798f5c050..8e632caf6b9 100644 --- a/ansible_collections/arista/avd/roles/eos_validate_state/python_modules/tests/avdtestinterfaces.py +++ b/ansible_collections/arista/avd/roles/eos_validate_state/python_modules/tests/avdtestinterfaces.py @@ -13,18 +13,16 @@ class AvdTestInterfacesState(AvdTestBase): - """ - AvdTestInterfacesState class for interfaces state tests. - """ + """AvdTestInterfacesState class for interfaces state tests.""" anta_module = "anta.tests.interfaces" - interfaces_to_test = [ + interfaces_to_test = ( "ethernet_interfaces", "port_channel_interfaces", "vlan_interfaces", "loopback_interfaces", "dps_interfaces", - ] + ) @cached_property def test_definition(self) -> dict | None: @@ -34,7 +32,6 @@ def test_definition(self) -> dict | None: Returns: test_definition (dict): ANTA test definition. """ - anta_tests = [] required_keys = ["name", "shutdown"] diff --git a/ansible_collections/arista/avd/roles/eos_validate_state/python_modules/tests/avdtestmlag.py b/ansible_collections/arista/avd/roles/eos_validate_state/python_modules/tests/avdtestmlag.py index 64cfb5ace20..0234862bef7 100644 --- a/ansible_collections/arista/avd/roles/eos_validate_state/python_modules/tests/avdtestmlag.py +++ b/ansible_collections/arista/avd/roles/eos_validate_state/python_modules/tests/avdtestmlag.py @@ -12,9 +12,7 @@ class AvdTestMLAG(AvdTestBase): - """ - AvdTestMLAG class for MLAG tests. - """ + """AvdTestMLAG class for MLAG tests.""" anta_module = "anta.tests.mlag" diff --git a/ansible_collections/arista/avd/roles/eos_validate_state/python_modules/tests/avdtestrouting.py b/ansible_collections/arista/avd/roles/eos_validate_state/python_modules/tests/avdtestrouting.py index 52bd17817e7..a551026e3ee 100644 --- a/ansible_collections/arista/avd/roles/eos_validate_state/python_modules/tests/avdtestrouting.py +++ b/ansible_collections/arista/avd/roles/eos_validate_state/python_modules/tests/avdtestrouting.py @@ -9,15 +9,13 @@ from ansible_collections.arista.avd.plugins.plugin_utils.eos_validate_state_utils.avdtestbase import AvdTestBase from ansible_collections.arista.avd.plugins.plugin_utils.utils import get -from ..bgp_constants import BGP_ADDRESS_FAMILIES +from ..bgp_constants import BGP_ADDRESS_FAMILIES # noqa: TID252 Will be fixed when moving to pyavd LOGGER = logging.getLogger(__name__) class AvdTestRoutingTable(AvdTestBase): - """ - AvdTestRoutingTable class for routing table entry verification tests. - """ + """AvdTestRoutingTable class for routing table entry verification tests.""" anta_module = "anta.tests.routing.generic" @@ -37,7 +35,6 @@ def add_test(mapping: list) -> None: Avoids duplicate tests for the same IP address (e.g. MLAG VTEPs). """ - processed_ips = set() for peer, ip in mapping: @@ -50,8 +47,8 @@ def add_test(mapping: list) -> None: "VerifyRoutingTableEntry": { "routes": [ip], "result_overwrite": {"custom_field": f"Route: {ip} - Peer: {peer}"}, - } - } + }, + }, ) processed_ips.add(ip) @@ -82,11 +79,11 @@ class AvdTestBGP(AvdTestBase): """ anta_module = "anta.tests.routing" - anta_tests = {} + anta_tests = {} # noqa: RUF012 def add_test(self, afi: str, bgp_neighbor_ip: str, bgp_peer: str, description: str, safi: str | None = None) -> dict: """Add a BGP test definition with the proper input parameters.""" - custom_field = f"BGP {description} Peer: {''.join([bgp_peer, ' (IP: ', bgp_neighbor_ip, ')']) if bgp_peer is not None else bgp_neighbor_ip}" + custom_field = f"BGP {description} Peer: {f'{bgp_peer} (IP: {bgp_neighbor_ip})' if bgp_peer is not None else bgp_neighbor_ip}" address_family = {"afi": afi, "peers": [bgp_neighbor_ip]} if safi: @@ -137,7 +134,7 @@ def create_tests( def test_definition(self) -> dict | None: """Generates the proper ANTA test definition for all BGP tests. - Returns + Returns: ------- test_definition (dict): ANTA test definition. diff --git a/ansible_collections/arista/avd/roles/eos_validate_state/python_modules/tests/avdtestsecurity.py b/ansible_collections/arista/avd/roles/eos_validate_state/python_modules/tests/avdtestsecurity.py index 37b2a0a5195..2c6e7dc54a8 100644 --- a/ansible_collections/arista/avd/roles/eos_validate_state/python_modules/tests/avdtestsecurity.py +++ b/ansible_collections/arista/avd/roles/eos_validate_state/python_modules/tests/avdtestsecurity.py @@ -13,9 +13,7 @@ class AvdTestAPIHttpsSSL(AvdTestBase): - """ - AvdTestAPIHttpsSSL class for eAPI HTTPS SSL tests. - """ + """AvdTestAPIHttpsSSL class for eAPI HTTPS SSL tests.""" anta_module = "anta.tests.security" @@ -41,6 +39,7 @@ def test_definition(self) -> dict | None: class AvdTestIPSecurity(AvdTestBase): """ AvdTestIPSecurity class for IP security connection tests. + It validates the state of IPv4 security connections for a specified peer, ensuring they are established. It specifically focuses on IPv4 security connections within the default VRF. In its current state, the test validates only IPsec connections defined as static peers under the `router path-selection` section of the configuration. @@ -71,7 +70,9 @@ def test_definition(self) -> dict | None: for peer_idx, peer in enumerate(path_group["static_peers"]): if self.validate_data( - data=peer, data_path=f"router_path_selection.path_groups.[{group_idx}].static_peers.[{peer_idx}]", required_keys="router_ip" + data=peer, + data_path=f"router_path_selection.path_groups.[{group_idx}].static_peers.[{peer_idx}]", + required_keys="router_ip", ): peer_address = peer["router_ip"] vrf = "default" # TODO: Keeping the vrf name static for now. We may need to change later on. @@ -81,8 +82,8 @@ def test_definition(self) -> dict | None: "VerifySpecificIPSecConn": { "ip_security_connections": [{"peer": peer_address, "vrf": vrf}], "result_overwrite": {"custom_field": f"IPv4 Peer: {peer_address} VRF: {vrf}"}, - } - } + }, + }, ) added_peers.add((peer_address, vrf)) diff --git a/ansible_collections/arista/avd/roles/eos_validate_state/python_modules/tests/avdteststun.py b/ansible_collections/arista/avd/roles/eos_validate_state/python_modules/tests/avdteststun.py index fcb1b99a6c2..6ecb47a984c 100644 --- a/ansible_collections/arista/avd/roles/eos_validate_state/python_modules/tests/avdteststun.py +++ b/ansible_collections/arista/avd/roles/eos_validate_state/python_modules/tests/avdteststun.py @@ -16,6 +16,7 @@ class AvdTestStun(AvdTestBase): """ AvdTestStun class for STUN tests. + Validates the presence of a STUN client translation for a given source IPv4 address and port for WAN scenarios. The list of expected translations for each device is built by searching through router_path_selection.path_groups.local_interfaces. """ @@ -64,8 +65,8 @@ def test_definition(self) -> dict | None: "VerifyStunClient": { "stun_clients": [{"source_address": source_address, "source_port": source_port}], "result_overwrite": {"custom_field": f"Source IPv4 Address: {source_address} Source Port: {source_port}"}, - } - } + }, + }, ) # Return the ANTA tests as a dictionary diff --git a/ansible_collections/arista/avd/roles/eos_validate_state/python_modules/tests/avdtestsystem.py b/ansible_collections/arista/avd/roles/eos_validate_state/python_modules/tests/avdtestsystem.py index 194bb9c8cad..3fb2ee13f34 100644 --- a/ansible_collections/arista/avd/roles/eos_validate_state/python_modules/tests/avdtestsystem.py +++ b/ansible_collections/arista/avd/roles/eos_validate_state/python_modules/tests/avdtestsystem.py @@ -9,9 +9,7 @@ class AvdTestNTP(AvdTestBase): - """ - AvdTestNTP class for NTP tests. - """ + """AvdTestNTP class for NTP tests.""" anta_module = "anta.tests.system" @@ -23,7 +21,6 @@ def test_definition(self) -> dict: Returns: test_definition (dict): ANTA test definition. """ - anta_tests = [ {"VerifyNTP": None}, ] @@ -32,9 +29,7 @@ def test_definition(self) -> dict: class AvdTestReloadCause(AvdTestBase): - """ - AvdTestReloadCause class for the reload cause of the device. - """ + """AvdTestReloadCause class for the reload cause of the device.""" anta_module = "anta.tests.system" @@ -46,7 +41,6 @@ def test_definition(self) -> dict: Returns: test_definition (dict): ANTA test definition. """ - anta_tests = [ {"VerifyReloadCause": None}, ] diff --git a/ansible_collections/arista/avd/tests/unit/action/test_verify_requirements.py b/ansible_collections/arista/avd/tests/unit/action/test_verify_requirements.py index 6ae40de6036..cc9bf752059 100644 --- a/ansible_collections/arista/avd/tests/unit/action/test_verify_requirements.py +++ b/ansible_collections/arista/avd/tests/unit/action/test_verify_requirements.py @@ -1,11 +1,11 @@ # Copyright (c) 2023-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -__metaclass__ = type import os from collections import namedtuple from importlib.metadata import PackageNotFoundError +from pathlib import Path from unittest.mock import patch import pytest @@ -21,26 +21,15 @@ @pytest.mark.parametrize( - "mocked_version, expected_return", + ("mocked_version", "expected_return"), [ - ( - (2, 2, 2, "final", 0), - False, - ), - ( - (MIN_PYTHON_SUPPORTED_VERSION[0], MIN_PYTHON_SUPPORTED_VERSION[1], 42, "final", 0), - True, - ), - ( - (MIN_PYTHON_SUPPORTED_VERSION[0], MIN_PYTHON_SUPPORTED_VERSION[1] + 1, 42, "final", 0), - True, - ), + ((2, 2, 2, "final", 0), False), + ((MIN_PYTHON_SUPPORTED_VERSION[0], MIN_PYTHON_SUPPORTED_VERSION[1], 42, "final", 0), True), + ((MIN_PYTHON_SUPPORTED_VERSION[0], MIN_PYTHON_SUPPORTED_VERSION[1] + 1, 42, "final", 0), True), ], ) -def test__validate_python_version(mocked_version, expected_return): - """ - TODO - could add the expected stderr - """ +def test__validate_python_version(mocked_version, expected_return) -> None: + """TODO: - could add the expected stderr.""" info = {} result = {} # As in ansible module result version_info = namedtuple("version_info", "major minor micro releaselevel serial") @@ -62,7 +51,7 @@ def test__validate_python_version(mocked_version, expected_return): @pytest.mark.parametrize( - "n_reqs, mocked_version, requirement_version, expected_return", + ("n_reqs", "mocked_version", "requirement_version", "expected_return"), [ pytest.param( 1, @@ -101,11 +90,11 @@ def test__validate_python_version(mocked_version, expected_return): ), ], ) -def test__validate_python_requirements(n_reqs, mocked_version, requirement_version, expected_return): +def test__validate_python_requirements(n_reqs, mocked_version, requirement_version, expected_return) -> None: """ - Running with n_reqs requirements + Running with n_reqs requirements. - TODO - check the results + TODO: - check the results - not testing for wrongly formatted requirements """ result = {} @@ -119,7 +108,7 @@ def test__validate_python_requirements(n_reqs, mocked_version, requirement_versi @pytest.mark.parametrize( - "mocked_running_version, deprecated_version, expected_return", + ("mocked_running_version", "deprecated_version", "expected_return"), [ pytest.param( "2.16", @@ -141,10 +130,8 @@ def test__validate_python_requirements(n_reqs, mocked_version, requirement_versi # ), ], ) -def test__validate_ansible_version(mocked_running_version, deprecated_version, expected_return): - """ - TODO - check that the requires_ansible is picked up from the correct place - """ +def test__validate_ansible_version(mocked_running_version, deprecated_version, expected_return) -> None: + """TODO: - check that the requires_ansible is picked up from the correct place.""" info = {} result = {} # As in ansible module result ret = _validate_ansible_version("arista.avd", mocked_running_version, info, result) @@ -155,50 +142,20 @@ def test__validate_ansible_version(mocked_running_version, deprecated_version, e @pytest.mark.parametrize( - "n_reqs, mocked_version, requirement_version, expected_return", + ("n_reqs", "mocked_version", "requirement_version", "expected_return"), [ - pytest.param( - 1, - "4.3", - ">=4.2", - True, - id="valid version", - ), - pytest.param( - 1, - "4.3", - None, - True, - id="no required version", - ), - pytest.param( - 2, - "4.0", - ">=4.2", - False, - id="invalid version", - ), - pytest.param( - 1, - None, - ">=4.2", - False, - id="missing requirement", - ), - pytest.param( - 0, - None, - None, - True, - id="no requirement", - ), + pytest.param(1, "4.3", ">=4.2", True, id="valid version"), + pytest.param(1, "4.3", None, True, id="no required version"), + pytest.param(2, "4.0", ">=4.2", False, id="invalid version"), + pytest.param(1, None, ">=4.2", False, id="missing requirement"), + pytest.param(0, None, None, True, id="no requirement"), ], ) -def test__validate_ansible_collections(n_reqs, mocked_version, requirement_version, expected_return): +def test__validate_ansible_collections(n_reqs, mocked_version, requirement_version, expected_return) -> None: """ - Running with n_reqs requirements + Running with n_reqs requirements in the collection file. - TODO - check the results + TODO: - check the results - not testing for wrongly formatted collection.yml file """ result = {} @@ -211,39 +168,50 @@ def test__validate_ansible_collections(n_reqs, mocked_version, requirement_versi for collection in metadata["collections"]: collection["version"] = requirement_version - with patch("ansible_collections.arista.avd.plugins.action.verify_requirements.yaml.safe_load") as patched_safe_load, patch( - "ansible_collections.arista.avd.plugins.action.verify_requirements._get_collection_path" - ) as patched__get_collection_path, patch( - "ansible_collections.arista.avd.plugins.action.verify_requirements._get_collection_version" - ) as patched__get_collection_version, patch( - "ansible_collections.arista.avd.plugins.action.verify_requirements.open" + with ( + patch("ansible_collections.arista.avd.plugins.action.verify_requirements.Path.open"), + patch("ansible_collections.arista.avd.plugins.action.verify_requirements.yaml.safe_load") as patched_safe_load, + patch( + "ansible_collections.arista.avd.plugins.action.verify_requirements._get_collection_path", + ) as patched__get_collection_path, + patch( + "ansible_collections.arista.avd.plugins.action.verify_requirements._get_collection_version", + ) as patched__get_collection_version, + patch( + "ansible_collections.arista.avd.plugins.action.verify_requirements.open", + ), ): patched_safe_load.return_value = metadata - patched__get_collection_path.return_value = "dummy" + patched__get_collection_path.return_value = "/collections/foo/bar" if mocked_version is None and n_reqs > 0: # First call is for arista.avd - patched__get_collection_path.side_effect = ["dummy", ModuleNotFoundError()] + patched__get_collection_path.side_effect = ["/collections/foo/bar", ModuleNotFoundError()] patched__get_collection_version.return_value = mocked_version ret = _validate_ansible_collections("arista.avd", result) assert ret == expected_return -def test__get_running_collection_version_git_not_installed(): - """ - Verify that when git is not found in PATH the function returns properly - """ +def test__get_running_collection_version_git_not_installed() -> None: + """Verify that when git is not found in PATH the function returns properly.""" # setting PATH to empty string to make sure git is not present os.environ["PATH"] = "" # setting ANSIBLE_VERBOSITY to trigger the log message when raising the exception os.environ["ANSIBLE_VERBOSITY"] = "3" result = {} - with patch("ansible_collections.arista.avd.plugins.action.verify_requirements._get_collection_path") as patched__get_collection_path, patch( - "ansible_collections.arista.avd.plugins.action.verify_requirements._get_collection_version" - ) as patched__get_collection_version, patch("ansible_collections.arista.avd.plugins.action.verify_requirements.display") as patched_display: + with ( + patch("ansible_collections.arista.avd.plugins.action.verify_requirements.Path") as patched_Path, + patch("ansible_collections.arista.avd.plugins.action.verify_requirements._get_collection_path") as patched__get_collection_path, + patch( + "ansible_collections.arista.avd.plugins.action.verify_requirements._get_collection_version", + ) as patched__get_collection_version, + patch("ansible_collections.arista.avd.plugins.action.verify_requirements.display") as patched_display, + ): patched__get_collection_path.return_value = "." patched__get_collection_version.return_value = "42.0.0" + # TODO: Path is less kind than os.path was + patched_Path.return_value = Path("/collections/foo/bar/__synthetic__/blah") _get_running_collection_version("dummy", result) patched_display.vvv.assert_called_once_with("Could not find 'git' executable, returning collection version") - assert result == {"collection": {"name": "dummy", "path": "", "version": "42.0.0"}} + assert result == {"collection": {"name": "dummy", "path": "/collections/foo/bar", "version": "42.0.0"}} diff --git a/ansible_collections/arista/avd/tests/unit/modules/test_configlet_build_config.py b/ansible_collections/arista/avd/tests/unit/modules/test_configlet_build_config.py index f465a63ecef..b8998214f98 100644 --- a/ansible_collections/arista/avd/tests/unit/modules/test_configlet_build_config.py +++ b/ansible_collections/arista/avd/tests/unit/modules/test_configlet_build_config.py @@ -1,9 +1,7 @@ # Copyright (c) 2023-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -from __future__ import absolute_import, division, print_function -__metaclass__ = type import os @@ -21,24 +19,21 @@ class TestConfigletBuildConfig: - def verify_configlets(self, src_folder, prefix, extension, output): + def verify_configlets(self, src_folder, prefix, extension, output) -> None: suffixes = [".cfg"] - for dirpath, dirnames, filenames in os.walk(src_folder): + for dirpath, _dirnames, filenames in os.walk(src_folder): for filename in filenames: filesplit = os.path.splitext(filename) - if not prefix: - key = filesplit[0] - else: - key = prefix + "_" + filesplit[0] + key = filesplit[0] if not prefix else prefix + "_" + filesplit[0] if filesplit[1] in suffixes: - assert key in output.keys() + assert key in output # Compare contents of each file - with open(os.path.join(dirpath, filename), "r", encoding="utf8") as f: + with open(os.path.join(dirpath, filename), encoding="utf8") as f: assert f.read() == output[key] @pytest.mark.parametrize("DATA", CONFIGLETS_DATA.values(), ids=CONFIGLETS_DATA.keys()) - def test_get_configlet(self, DATA): + def test_get_configlet(self, DATA) -> None: prefix = DATA.get("prefix", None) extension = DATA.get("extension", "cfg") src_folder = DATA["src_folder"] @@ -51,11 +46,11 @@ def test_get_configlet(self, DATA): assert isinstance(output, dict) self.verify_configlets(src_folder, prefix, extension, output) - def test_get_configlet_invalid_source(self): + def test_get_configlet_invalid_source(self) -> None: output = get_configlet() assert output == {} - def test_get_configlet_none_prefix(self): + def test_get_configlet_none_prefix(self) -> None: extension = "cfg" output = get_configlet(src_folder=CONFIGLETS_DIR, prefix="none") assert isinstance(output, dict) diff --git a/ansible_collections/arista/avd/tests/unit/modules/test_inventory_to_container.py b/ansible_collections/arista/avd/tests/unit/modules/test_inventory_to_container.py index 18f62fcfe79..a5b995d32b8 100644 --- a/ansible_collections/arista/avd/tests/unit/modules/test_inventory_to_container.py +++ b/ansible_collections/arista/avd/tests/unit/modules/test_inventory_to_container.py @@ -1,9 +1,7 @@ # Copyright (c) 2023-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -from __future__ import absolute_import, division, print_function -__metaclass__ = type import json import logging @@ -18,8 +16,8 @@ get_device_option_value, get_devices, is_in_filter, - isIterable, - isLeaf, + is_iterable, + is_leaf, ) from ansible_collections.arista.avd.plugins.modules.inventory_to_container import serialize_yaml_inventory_data as serialize @@ -80,133 +78,133 @@ @pytest.fixture(scope="session") def inventory(): yaml.SafeLoader.add_constructor("!vault", lambda _, __: "!VAULT") - with open(INVENTORY_FILE, "r", encoding="utf8") as stream: + with open(INVENTORY_FILE, encoding="utf8") as stream: try: inventory_content = yaml.safe_load(stream) except yaml.YAMLError as e: - logging.error(e) + logging.exception(e) return None return inventory_content class TestInventoryToContainer: - def test_is_in_filter_default_filter(self): + def test_is_in_filter_default_filter(self) -> None: output = is_in_filter(hostname=HOSTNAME_VALID) assert output - def test_is_in_filter_valid_hostname(self): + def test_is_in_filter_valid_hostname(self) -> None: output = is_in_filter(hostname_filter=HOSTNAME_FILTER_VALID, hostname=HOSTNAME_VALID) assert output - def test_is_in_filter_invalid_hostname(self): + def test_is_in_filter_invalid_hostname(self) -> None: output = is_in_filter(hostname_filter=HOSTNAME_FILTER_VALID, hostname=HOSTNAME_INVALID) assert output is False # TODO: Check if this is a valid testcase. Add a type check? - def test_is_in_filter_invalid_filter(self): + def test_is_in_filter_invalid_filter(self) -> None: output = is_in_filter(hostname_filter=HOSTNAME_FILTER_INVALID, hostname=HOSTNAME_VALID) assert output - def test_isIterable_default_iterable(self): - output = isIterable() + def test_is_iterable_default_iterable(self) -> None: + output = is_iterable() assert output is False @pytest.mark.parametrize("DATA", IS_ITERABLE_VALID) - def test_isIterable_valid_iterable(self, DATA): - output = isIterable(DATA) + def test_is_iterable_valid_iterable(self, DATA) -> None: + output = is_iterable(DATA) assert output @pytest.mark.parametrize("DATA", IS_ITERABLE_INVALID) - def test_isIterable_invalid_iterable(self, DATA): - output = isIterable(DATA) + def test_is_iterable_invalid_iterable(self, DATA) -> None: + output = is_iterable(DATA) assert output is False - def test_isLeaf_valid_leaf(self): - output = isLeaf(TREELIB, TREELIB_VALID_LEAF) + def test_is_leaf_valid_leaf(self) -> None: + output = is_leaf(TREELIB, TREELIB_VALID_LEAF) assert output - def test_isLeaf_invalid_leaf(self): - output = isLeaf(TREELIB, TREELIB_INVALID_LEAF) + def test_is_leaf_invalid_leaf(self) -> None: + output = is_leaf(TREELIB, TREELIB_INVALID_LEAF) assert output is False - def test_isLeaf_none_leaf(self): - output = isLeaf(TREELIB, None) + def test_is_leaf_none_leaf(self) -> None: + output = is_leaf(TREELIB, None) assert output is False - def test_get_device_option_value_valid(self, inventory): + def test_get_device_option_value_valid(self, inventory) -> None: data = inventory["all"]["children"]["CVP"]["hosts"] output = get_device_option_value(device_data_dict=data, option_name="cv_server") assert output assert isinstance(output, dict) - def test_get_device_option_value_invalid(self, inventory): + def test_get_device_option_value_invalid(self, inventory) -> None: data = inventory["all"]["children"]["CVP"]["hosts"] output = get_device_option_value(device_data_dict=data, option_name="is_deployed") assert output is None - def test_get_device_option_value_none(self, inventory): + def test_get_device_option_value_none(self, inventory) -> None: data = inventory["all"]["children"]["CVP"]["hosts"] output = get_device_option_value(device_data_dict=data, option_name=None) assert output is None - def test_get_device_option_value_empty_data(self, inventory): + def test_get_device_option_value_empty_data(self, inventory) -> None: output = get_device_option_value(device_data_dict=None, option_name="cv_server") assert output is None - def test_get_devices_empty_inventory(self): + def test_get_devices_empty_inventory(self) -> None: output = get_devices(None) assert output is None - def test_get_devices_default_search_container(self, inventory): + def test_get_devices_default_search_container(self, inventory) -> None: output = get_devices(inventory) assert output is None - def test_get_devices_non_default_search_container(self, inventory): + def test_get_devices_non_default_search_container(self, inventory) -> None: output = get_devices(inventory, search_container=SEARCH_CONTAINER, devices=[]) assert output == GET_DEVICES - def test_get_devices_preexisting_devices(self, inventory): + def test_get_devices_preexisting_devices(self, inventory) -> None: devices = ["TEST_DEVICE"] output = get_devices(inventory, search_container=SEARCH_CONTAINER, devices=devices) - assert output == ["TEST_DEVICE"] + GET_DEVICES + assert output == ["TEST_DEVICE", *GET_DEVICES] - def test_get_devices_preexisting_devices_with_device_filter(self, inventory): + def test_get_devices_preexisting_devices_with_device_filter(self, inventory) -> None: output = get_devices(inventory, search_container=SEARCH_CONTAINER, devices=[], device_filter=[GET_DEVICE_FILTER]) assert [GET_DEVICE_FILTER in item for item in output] @pytest.mark.parametrize("DATA", [None]) - def test_serialize_empty_inventory(self, DATA): + def test_serialize_empty_inventory(self, DATA) -> None: output = serialize(DATA) assert output is None - def test_serialize_valid_inventory(self, inventory): + def test_serialize_valid_inventory(self, inventory) -> None: output = serialize(inventory) assert isinstance(output, treelib.tree.Tree) tree_dict = json.loads(output.to_json()) - assert (list(tree_dict.keys()))[0] == ROOT_CONTAINER + assert next(iter(tree_dict.keys())) == ROOT_CONTAINER @pytest.mark.parametrize("DATA", PARENT_CONTAINER.values(), ids=PARENT_CONTAINER.keys()) - def test_serialize_parent_container(self, DATA, inventory): + def test_serialize_parent_container(self, DATA, inventory) -> None: output = serialize(inventory, parent_container=DATA["parent"]) assert isinstance(output, treelib.tree.Tree) tree_dict = json.loads(output.to_json()) - assert (list(tree_dict.keys()))[0] == ROOT_CONTAINER + assert next(iter(tree_dict.keys())) == ROOT_CONTAINER - def test_serialize_none_parent_container_with_tree_topology(self, inventory): + def test_serialize_none_parent_container_with_tree_topology(self, inventory) -> None: tree = treelib.tree.Tree() output = serialize(inventory, tree_topology=tree) assert isinstance(output, treelib.tree.Tree) tree_dict = json.loads(output.to_json()) - assert (list(tree_dict.keys()))[0] == "all" + assert next(iter(tree_dict.keys())) == "all" - def test_serialize_non_default_parent_container_with_tree_topology(self, inventory): + def test_serialize_non_default_parent_container_with_tree_topology(self, inventory) -> None: tree = treelib.tree.Tree() tree.create_node(NON_DEFAULT_PARENT_CONTAINER, NON_DEFAULT_PARENT_CONTAINER) output = serialize(inventory, parent_container=NON_DEFAULT_PARENT_CONTAINER, tree_topology=tree) tree_dict = json.loads(output.to_json()) - assert (list(tree_dict.keys()))[0] == NON_DEFAULT_PARENT_CONTAINER + assert next(iter(tree_dict.keys())) == NON_DEFAULT_PARENT_CONTAINER @pytest.mark.parametrize("DATA", PARENT_CONTAINER.values(), ids=PARENT_CONTAINER.keys()) - def test_get_containers(self, DATA, inventory): + def test_get_containers(self, DATA, inventory) -> None: output = get_containers(inventory, parent_container=DATA["parent"], device_filter=["all"]) assert output == DATA["expected_output"] diff --git a/ansible_collections/arista/avd/tests/unit/plugins/plugin_utils/eos_validate_state_utils/test_catalog.py b/ansible_collections/arista/avd/tests/unit/plugins/plugin_utils/eos_validate_state_utils/test_catalog.py index 4b3f28813c5..ec70e3e7506 100644 --- a/ansible_collections/arista/avd/tests/unit/plugins/plugin_utils/eos_validate_state_utils/test_catalog.py +++ b/ansible_collections/arista/avd/tests/unit/plugins/plugin_utils/eos_validate_state_utils/test_catalog.py @@ -1,14 +1,12 @@ # Copyright (c) 2023-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -from __future__ import absolute_import, division, print_function -__metaclass__ = type # import pytest # TODO class TestCatalog: - def test(self): + def test(self) -> None: pass diff --git a/pyproject.toml b/pyproject.toml index eb07b591e09..1b3b22bf15e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,3 +12,81 @@ extend_skip_glob = [ profile = "black" skip_gitignore = true line_length = 160 + +known_first_party = ["pyavd", "schema_tools"] + +[tool.ruff] +line-length = 160 +extend-exclude = [ + "python-avd/pyavd/_cv/api/**/*", + "python-avd/tests/**/*", + "ansible_collections/arista/avd/tests/**/*", +] +target-version = "py310" + +[tool.ruff.lint] +extend-select = ["ALL"] +ignore = [ + "ANN101", # Missing type annotation for `self` in method - we know what self is.. + "ANN102", # Missing type annotation for `cls` in classmethod - we know what cls is.. + "D203", # Ignoring conflicting D* warnings - one-blank-line-before-class + "D212", # Ignoring conflicting D* warnings - multi-line-summary-first-line + "COM812", # Ignoring conflicting rules that may cause conflicts when used with the formatter + "ISC001", # Ignoring conflicting rules that may cause conflicts when used with the formatter + "TD002", # We don't have require authors in TODO + "TD003", # We don't have an issue link for all TODOs today + "FIX002", # Line contains TODO - ignoring for ruff for now + "F821", # Disable undefined-name until resolution of #10451 + "SLF001", # Accessing private members - TODO: Improve code + "D100", # Missing docstring in public module - TODO: Improve code + "D101", # Missing docstring in public class - TODO: Improve code + "D102", # Missing docstring in public method - TODO: Improve code + "D103", # Missing docstring in public function - TODO: Improve code + "D104", # Missing docstring in public package + "D105", # Missing docstring in magic method - TODO: Improve code + "D106", # Missing docstring in public nested class - TODO: Improve code + "D107", # Missing docstring in `__init__` - TODO: Improve code + "ANN401", # Dynamically typed expressions (typing.Any) are disallowed - TODO: Improve code + "C901", # complex-structure - TODO: Improve code + "FBT001", # Boolean-typed positional argument in function definition - TODO: Improve code + "FBT002", # Boolean default positional argument in function definition - TODO: Improve code + "PD011", # Use numpy instead of .values - False positive + "BLE001", # Do not catch blind exception: `Exception - TODO: Improve code + "PLR2004", # Magic value used in comparison - TODO: Evaluate + "DTZ005", # `datetime.datetime.now()` called without a `tz` argument - TODO: Improve code + "UP038", # UP038 Use `X | Y` in `isinstance` call instead of `(X, Y)` - Why would I? It impacts performance. +] + +[tool.ruff.lint.pydocstyle] +convention = "google" + +[tool.ruff.lint.per-file-ignores] +"ansible_collections/arista/avd/plugins/*/*.py" = [ + "E402", # Ansible plugins require a layout with imports below the docs + "INP001", # implicit namespace package. Add an `__init__.py` - Ansible plugins are not in packages +] +"ansible_collections/arista/avd/roles/eos_validate_state/python_modules/**/*.py" = [ + "INP001", # implicit namespace package. Add an `__init__.py` - Will be fixed once moved to PyAVD +] +"ansible_collections/arista/avd/molecule/eos_designs_unit_tests/custom_modules/*.py" = [ + "INP001", # implicit namespace package. Add an `__init__.py` - TODO: Evaluate or see if it is a false positive +] +"python-avd/scripts/**/*.py" = [ + "T201", # Interactive scripts can have print + "INP001", # implicit namespace package. Add an `__init__.py` - TODO: Evaluate or see if it is a false positive +] +"python-avd/pyavd/_cv/client/*.py" = [ + "B904", # Within an `except` clause, raise exceptions with `raise - TODO: Improve code +] + +[tool.ruff.lint.pylint] +max-args = 12 +max-branches = 54 +max-returns = 10 +max-statements = 148 + +[tool.ruff.lint.isort] +known-first-party = ["pyavd", "schema_tools"] + +[tool.ruff.format] +docstring-code-format = true diff --git a/python-avd/pyavd/_cv/api/arista/alert/v1/__init__.py b/python-avd/pyavd/_cv/api/arista/alert/v1/__init__.py index 155ddae5b4f..b05d9391ecc 100644 --- a/python-avd/pyavd/_cv/api/arista/alert/v1/__init__.py +++ b/python-avd/pyavd/_cv/api/arista/alert/v1/__init__.py @@ -575,6 +575,9 @@ class Settings(aristaproto.Message): zoom: "ZoomSettings" = aristaproto.message_field(19) """zoom is the global default settings for zoom""" + webhook: "WebhookSettings" = aristaproto.message_field(20) + """webhook is the auth settings for webhook""" + @dataclass(eq=False, repr=False) class EmailSettings(aristaproto.Message): @@ -617,7 +620,7 @@ class EmailSettings(aristaproto.Message): azure_o_auth: "AzureOAuth" = aristaproto.message_field(7) """ azure_o_auth used for auth when using an Azure smtp server - uses auth_username + uses auth_username, scopes is not required as we use https://outlook.office365.com/.default """ @@ -645,7 +648,14 @@ class AzureOAuth(aristaproto.Message): auth_uri: Optional[str] = aristaproto.message_field( 4, wraps=aristaproto.TYPE_STRING ) - """auth_uri is the URI used for OAuth""" + """ + auth_uri is the URI used for OAuth + this should always be https://login.microsoftonline.com/ unless using a very custom + set up, where the Azure enviroment is not running on microsoft servers + """ + + scopes: "___fmp__.RepeatedString" = aristaproto.message_field(5) + """scopes are the scopes that auth is granted for""" @dataclass(eq=False, repr=False) @@ -756,6 +766,17 @@ class MsTeamsSettings(aristaproto.Message): """url is the url of the webhook to send alerts to""" +@dataclass(eq=False, repr=False) +class WebhookSettings(aristaproto.Message): + """WebhookSettings contain the settings for sending alerts to a Webhook""" + + azure_o_auth: "AzureOAuth" = aristaproto.message_field(1) + """ + azure_o_auth used for auth when using an Azure smtp server + uses auth_username + """ + + @dataclass(eq=False, repr=False) class SyslogSettings(aristaproto.Message): """SyslogSettings contain the settings for sending alerts with syslog""" @@ -1468,6 +1489,11 @@ class WebhookEndpoint(aristaproto.Message): alert when true. """ + settings_override: "WebhookSettings" = aristaproto.message_field(6) + """ + settings_override is the override for the webhook global endpoint settings + """ + @dataclass(eq=False, repr=False) class SlackEndpoint(aristaproto.Message): @@ -1954,48 +1980,6 @@ class AlertStreamResponse(aristaproto.Message): """ -@dataclass(eq=False, repr=False) -class AlertBatchedStreamRequest(aristaproto.Message): - time: "__time__.TimeBounds" = aristaproto.message_field(3) - """ - TimeRange allows limiting response data to within a specified time window. - If this field is populated, at least one of the two time fields are required. - - For GetAll, the fields start and end can be used as follows: - - * end: Returns the state of each Alert at end. - * Each Alert response is fully-specified (all fields set). - * start: Returns the state of each Alert at start, followed by updates until now. - * Each Alert response at start is fully-specified, but updates may be partial. - * start and end: Returns the state of each Alert at start, followed by updates - until end. - * Each Alert response at start is fully-specified, but updates until end may - be partial. - - This field is not allowed in the Subscribe RPC. - """ - - max_messages: Optional[int] = aristaproto.message_field( - 4, wraps=aristaproto.TYPE_UINT32 - ) - """ - MaxMessages limits the maximum number of messages that can be contained in one batch. - MaxMessages is required to be at least 1. - The maximum number of messages in a batch is min(max_messages, INTERNAL_BATCH_LIMIT) - INTERNAL_BATCH_LIMIT is set based on the maximum message size. - """ - - -@dataclass(eq=False, repr=False) -class AlertBatchedStreamResponse(aristaproto.Message): - responses: List["AlertStreamResponse"] = aristaproto.message_field(1) - """ - Values are the values deemed relevant to the initiating request. - The length of this structure is guaranteed to be between (inclusive) 1 and - min(req.max_messages, INTERNAL_BATCH_LIMIT). - """ - - @dataclass(eq=False, repr=False) class AlertConfigRequest(aristaproto.Message): time: datetime = aristaproto.message_field(2) @@ -2065,48 +2049,6 @@ class AlertConfigStreamResponse(aristaproto.Message): """ -@dataclass(eq=False, repr=False) -class AlertConfigBatchedStreamRequest(aristaproto.Message): - time: "__time__.TimeBounds" = aristaproto.message_field(3) - """ - TimeRange allows limiting response data to within a specified time window. - If this field is populated, at least one of the two time fields are required. - - For GetAll, the fields start and end can be used as follows: - - * end: Returns the state of each AlertConfig at end. - * Each AlertConfig response is fully-specified (all fields set). - * start: Returns the state of each AlertConfig at start, followed by updates until now. - * Each AlertConfig response at start is fully-specified, but updates may be partial. - * start and end: Returns the state of each AlertConfig at start, followed by updates - until end. - * Each AlertConfig response at start is fully-specified, but updates until end may - be partial. - - This field is not allowed in the Subscribe RPC. - """ - - max_messages: Optional[int] = aristaproto.message_field( - 4, wraps=aristaproto.TYPE_UINT32 - ) - """ - MaxMessages limits the maximum number of messages that can be contained in one batch. - MaxMessages is required to be at least 1. - The maximum number of messages in a batch is min(max_messages, INTERNAL_BATCH_LIMIT) - INTERNAL_BATCH_LIMIT is set based on the maximum message size. - """ - - -@dataclass(eq=False, repr=False) -class AlertConfigBatchedStreamResponse(aristaproto.Message): - responses: List["AlertConfigStreamResponse"] = aristaproto.message_field(1) - """ - Values are the values deemed relevant to the initiating request. - The length of this structure is guaranteed to be between (inclusive) 1 and - min(req.max_messages, INTERNAL_BATCH_LIMIT). - """ - - @dataclass(eq=False, repr=False) class AlertConfigSetRequest(aristaproto.Message): value: "AlertConfig" = aristaproto.message_field(1) @@ -2611,42 +2553,6 @@ async def subscribe_meta( ): yield response - async def get_all_batched( - self, - alert_batched_stream_request: "AlertBatchedStreamRequest", - *, - timeout: Optional[float] = None, - deadline: Optional["Deadline"] = None, - metadata: Optional["MetadataLike"] = None - ) -> AsyncIterator["AlertBatchedStreamResponse"]: - async for response in self._unary_stream( - "/arista.alert.v1.AlertService/GetAllBatched", - alert_batched_stream_request, - AlertBatchedStreamResponse, - timeout=timeout, - deadline=deadline, - metadata=metadata, - ): - yield response - - async def subscribe_batched( - self, - alert_batched_stream_request: "AlertBatchedStreamRequest", - *, - timeout: Optional[float] = None, - deadline: Optional["Deadline"] = None, - metadata: Optional["MetadataLike"] = None - ) -> AsyncIterator["AlertBatchedStreamResponse"]: - async for response in self._unary_stream( - "/arista.alert.v1.AlertService/SubscribeBatched", - alert_batched_stream_request, - AlertBatchedStreamResponse, - timeout=timeout, - deadline=deadline, - metadata=metadata, - ): - yield response - class AlertConfigServiceStub(aristaproto.ServiceStub): async def get_one( @@ -2737,42 +2643,6 @@ async def set( metadata=metadata, ) - async def get_all_batched( - self, - alert_config_batched_stream_request: "AlertConfigBatchedStreamRequest", - *, - timeout: Optional[float] = None, - deadline: Optional["Deadline"] = None, - metadata: Optional["MetadataLike"] = None - ) -> AsyncIterator["AlertConfigBatchedStreamResponse"]: - async for response in self._unary_stream( - "/arista.alert.v1.AlertConfigService/GetAllBatched", - alert_config_batched_stream_request, - AlertConfigBatchedStreamResponse, - timeout=timeout, - deadline=deadline, - metadata=metadata, - ): - yield response - - async def subscribe_batched( - self, - alert_config_batched_stream_request: "AlertConfigBatchedStreamRequest", - *, - timeout: Optional[float] = None, - deadline: Optional["Deadline"] = None, - metadata: Optional["MetadataLike"] = None - ) -> AsyncIterator["AlertConfigBatchedStreamResponse"]: - async for response in self._unary_stream( - "/arista.alert.v1.AlertConfigService/SubscribeBatched", - alert_config_batched_stream_request, - AlertConfigBatchedStreamResponse, - timeout=timeout, - deadline=deadline, - metadata=metadata, - ): - yield response - class DefaultTemplateServiceStub(aristaproto.ServiceStub): async def get_one( @@ -3170,16 +3040,6 @@ async def subscribe_meta( ) -> AsyncIterator["MetaResponse"]: raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED) - async def get_all_batched( - self, alert_batched_stream_request: "AlertBatchedStreamRequest" - ) -> AsyncIterator["AlertBatchedStreamResponse"]: - raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED) - - async def subscribe_batched( - self, alert_batched_stream_request: "AlertBatchedStreamRequest" - ) -> AsyncIterator["AlertBatchedStreamResponse"]: - raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED) - async def __rpc_get_one( self, stream: "grpclib.server.Stream[AlertRequest, AlertResponse]" ) -> None: @@ -3217,28 +3077,6 @@ async def __rpc_subscribe_meta( request, ) - async def __rpc_get_all_batched( - self, - stream: "grpclib.server.Stream[AlertBatchedStreamRequest, AlertBatchedStreamResponse]", - ) -> None: - request = await stream.recv_message() - await self._call_rpc_handler_server_stream( - self.get_all_batched, - stream, - request, - ) - - async def __rpc_subscribe_batched( - self, - stream: "grpclib.server.Stream[AlertBatchedStreamRequest, AlertBatchedStreamResponse]", - ) -> None: - request = await stream.recv_message() - await self._call_rpc_handler_server_stream( - self.subscribe_batched, - stream, - request, - ) - def __mapping__(self) -> Dict[str, grpclib.const.Handler]: return { "/arista.alert.v1.AlertService/GetOne": grpclib.const.Handler( @@ -3265,18 +3103,6 @@ def __mapping__(self) -> Dict[str, grpclib.const.Handler]: AlertStreamRequest, MetaResponse, ), - "/arista.alert.v1.AlertService/GetAllBatched": grpclib.const.Handler( - self.__rpc_get_all_batched, - grpclib.const.Cardinality.UNARY_STREAM, - AlertBatchedStreamRequest, - AlertBatchedStreamResponse, - ), - "/arista.alert.v1.AlertService/SubscribeBatched": grpclib.const.Handler( - self.__rpc_subscribe_batched, - grpclib.const.Cardinality.UNARY_STREAM, - AlertBatchedStreamRequest, - AlertBatchedStreamResponse, - ), } @@ -3307,16 +3133,6 @@ async def set( ) -> "AlertConfigSetResponse": raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED) - async def get_all_batched( - self, alert_config_batched_stream_request: "AlertConfigBatchedStreamRequest" - ) -> AsyncIterator["AlertConfigBatchedStreamResponse"]: - raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED) - - async def subscribe_batched( - self, alert_config_batched_stream_request: "AlertConfigBatchedStreamRequest" - ) -> AsyncIterator["AlertConfigBatchedStreamResponse"]: - raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED) - async def __rpc_get_one( self, stream: "grpclib.server.Stream[AlertConfigRequest, AlertConfigResponse]" ) -> None: @@ -3364,28 +3180,6 @@ async def __rpc_set( response = await self.set(request) await stream.send_message(response) - async def __rpc_get_all_batched( - self, - stream: "grpclib.server.Stream[AlertConfigBatchedStreamRequest, AlertConfigBatchedStreamResponse]", - ) -> None: - request = await stream.recv_message() - await self._call_rpc_handler_server_stream( - self.get_all_batched, - stream, - request, - ) - - async def __rpc_subscribe_batched( - self, - stream: "grpclib.server.Stream[AlertConfigBatchedStreamRequest, AlertConfigBatchedStreamResponse]", - ) -> None: - request = await stream.recv_message() - await self._call_rpc_handler_server_stream( - self.subscribe_batched, - stream, - request, - ) - def __mapping__(self) -> Dict[str, grpclib.const.Handler]: return { "/arista.alert.v1.AlertConfigService/GetOne": grpclib.const.Handler( @@ -3418,18 +3212,6 @@ def __mapping__(self) -> Dict[str, grpclib.const.Handler]: AlertConfigSetRequest, AlertConfigSetResponse, ), - "/arista.alert.v1.AlertConfigService/GetAllBatched": grpclib.const.Handler( - self.__rpc_get_all_batched, - grpclib.const.Cardinality.UNARY_STREAM, - AlertConfigBatchedStreamRequest, - AlertConfigBatchedStreamResponse, - ), - "/arista.alert.v1.AlertConfigService/SubscribeBatched": grpclib.const.Handler( - self.__rpc_subscribe_batched, - grpclib.const.Cardinality.UNARY_STREAM, - AlertConfigBatchedStreamRequest, - AlertConfigBatchedStreamResponse, - ), } diff --git a/python-avd/pyavd/_cv/api/arista/configlet/v1/__init__.py b/python-avd/pyavd/_cv/api/arista/configlet/v1/__init__.py index 5023b4e4425..c9aa624bc94 100644 --- a/python-avd/pyavd/_cv/api/arista/configlet/v1/__init__.py +++ b/python-avd/pyavd/_cv/api/arista/configlet/v1/__init__.py @@ -375,6 +375,13 @@ class ConfigletStreamRequest(aristaproto.Message): subscriptions if filter(s) are sufficiently specific. """ + filter: "Filter" = aristaproto.message_field(2) + """ + For each Configlet in the list, all populated fields are considered ANDed together + as a filtering operation. Similarly, the list itself is ORed such that any individual + filter that matches a given Configlet is streamed to the user. + """ + time: "__time__.TimeBounds" = aristaproto.message_field(3) """ TimeRange allows limiting response data to within a specified time window. @@ -427,6 +434,13 @@ class ConfigletBatchedStreamRequest(aristaproto.Message): subscriptions if filter(s) are sufficiently specific. """ + filter: "Filter" = aristaproto.message_field(2) + """ + For each Configlet in the list, all populated fields are considered ANDed together + as a filtering operation. Similarly, the list itself is ORed such that any individual + filter that matches a given Configlet is streamed to the user. + """ + time: "__time__.TimeBounds" = aristaproto.message_field(3) """ TimeRange allows limiting response data to within a specified time window. @@ -991,6 +1005,13 @@ class ConfigletConfigStreamRequest(aristaproto.Message): subscriptions if filter(s) are sufficiently specific. """ + filter: "Filter" = aristaproto.message_field(2) + """ + For each ConfigletConfig in the list, all populated fields are considered ANDed together + as a filtering operation. Similarly, the list itself is ORed such that any individual + filter that matches a given ConfigletConfig is streamed to the user. + """ + time: "__time__.TimeBounds" = aristaproto.message_field(3) """ TimeRange allows limiting response data to within a specified time window. @@ -1045,6 +1066,13 @@ class ConfigletConfigBatchedStreamRequest(aristaproto.Message): subscriptions if filter(s) are sufficiently specific. """ + filter: "Filter" = aristaproto.message_field(2) + """ + For each ConfigletConfig in the list, all populated fields are considered ANDed together + as a filtering operation. Similarly, the list itself is ORed such that any individual + filter that matches a given ConfigletConfig is streamed to the user. + """ + time: "__time__.TimeBounds" = aristaproto.message_field(3) """ TimeRange allows limiting response data to within a specified time window. @@ -1180,6 +1208,13 @@ class ConfigletConfigDeleteAllRequest(aristaproto.Message): A filtered DeleteAll will use GetAll with filter to find things to delete. """ + filter: "Filter" = aristaproto.message_field(2) + """ + For each ConfigletConfig in the list, all populated fields are considered ANDed together + as a filtering operation. Similarly, the list itself is ORed such that any individual + filter that matches a given ConfigletConfig will be deleted. + """ + @dataclass(eq=False, repr=False) class ConfigletConfigDeleteAllResponse(aristaproto.Message): diff --git a/python-avd/pyavd/_cv/api/arista/imagestatus/v1/__init__.py b/python-avd/pyavd/_cv/api/arista/imagestatus/v1/__init__.py index 9c0ba368577..212fa4c7478 100644 --- a/python-avd/pyavd/_cv/api/arista/imagestatus/v1/__init__.py +++ b/python-avd/pyavd/_cv/api/arista/imagestatus/v1/__init__.py @@ -201,6 +201,12 @@ class ErrorCode(aristaproto.Enum): with a non-2GB EOS or a non-2GB device is incompatible with a 2GB-EOS. """ + EOS_EXTENSION_VERSION_INCOMPATIBLE = 16 + """ + ERROR_CODE_EOS_EXTENSION_VERSION_INCOMPATIBLE represents the case where the given extension + version doesn't support the given EOS version. + """ + class WarningCode(aristaproto.Enum): """WarningCode indicates warnings produced during image validations.""" @@ -305,6 +311,19 @@ class WarningCode(aristaproto.Enum): """ +class InfoCode(aristaproto.Enum): + """InfoCode indicates info messages produced during image validations.""" + + UNSPECIFIED = 0 + """INFO_CODE_UNSPECIFIED indicates info code is unspecified.""" + + NEWER_VERSION_AVAILABLE = 1 + """ + INFO_CODE_NEWER_VERSION_AVAILABLE represents cases where a newer EOS maintainance + release is available for download. + """ + + @dataclass(eq=False, repr=False) class SoftwareImage(aristaproto.Message): """ @@ -639,13 +658,19 @@ class Summary(aristaproto.Message): errors: "ImageErrors" = aristaproto.message_field(3) """ errors are the image errors encountered while validating the image. These are - displayed on the change control review page (for changes made outside the workspace). + displayed on the workspace build results page. """ warnings: "ImageWarnings" = aristaproto.message_field(4) """ warnings are the image warnings encountered while validating the image. These are - displayed on the change control review page (for changes made outside the workspace). + displayed on the workspace build results page. + """ + + infos: "ImageInfos" = aristaproto.message_field(5) + """ + infos are the image infos encountered while validating the image. These are + displayed on the workspace build results page. """ @@ -703,6 +728,32 @@ class ImageWarnings(aristaproto.Message): """values is a list of image warnings.""" +@dataclass(eq=False, repr=False) +class ImageInfo(aristaproto.Message): + """ImageInfo wraps `InfoCode` enum with a reason string.""" + + sku: Optional[str] = aristaproto.message_field(1, wraps=aristaproto.TYPE_STRING) + """sku represents the name of the sku.""" + + info_code: "InfoCode" = aristaproto.enum_field(2) + """info_code is the info code.""" + + info_msg: Optional[str] = aristaproto.message_field( + 3, wraps=aristaproto.TYPE_STRING + ) + """info_msg provides a description of the info.""" + + +@dataclass(eq=False, repr=False) +class ImageInfos(aristaproto.Message): + """ + ImageInfos is the list of info messages reported by CVP when handling image validations. + """ + + values: List["ImageInfo"] = aristaproto.message_field(1) + """values is a list of image infos.""" + + @dataclass(eq=False, repr=False) class MetaResponse(aristaproto.Message): time: datetime = aristaproto.message_field(1) diff --git a/python-avd/pyavd/_cv/api/arista/workspace/v1/__init__.py b/python-avd/pyavd/_cv/api/arista/workspace/v1/__init__.py index 5cb57b2a4ed..aa03319c8b3 100644 --- a/python-avd/pyavd/_cv/api/arista/workspace/v1/__init__.py +++ b/python-avd/pyavd/_cv/api/arista/workspace/v1/__init__.py @@ -666,6 +666,9 @@ class ImageValidationResult(aristaproto.Message): ) """image_input_error indicates any errors in image inputs.""" + infos: "__imagestatus_v1__.ImageInfos" = aristaproto.message_field(5) + """infos are any info messages about the generated image.""" + @dataclass(eq=False, repr=False) class BuildStageState(aristaproto.Message): diff --git a/python-avd/pyavd/_cv/client/__init__.py b/python-avd/pyavd/_cv/client/__init__.py index 303861e4198..947826a136e 100644 --- a/python-avd/pyavd/_cv/client/__init__.py +++ b/python-avd/pyavd/_cv/client/__init__.py @@ -4,6 +4,7 @@ from __future__ import annotations import ssl +from typing import TYPE_CHECKING from grpclib.client import Channel from requests import JSONDecodeError, post @@ -18,6 +19,10 @@ from .utils import UtilsMixin from .workspace import WorkspaceMixin +if TYPE_CHECKING: + from types import TracebackType + from typing import Self + class CVClient( ChangeControlMixin, @@ -72,22 +77,19 @@ def __init__( self._password = password self._verify_certs = verify_certs - async def __aenter__(self) -> CVClient: - """ - Using asynchronous context manager since grpclib must be initialized inside an asyncio loop. - """ + async def __aenter__(self) -> Self: + """Using asynchronous context manager since grpclib must be initialized inside an asyncio loop.""" self._connect() return self - async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + async def __aexit__(self, _exc_type: type[BaseException] | None, _exc_val: BaseException | None, _exc_tb: TracebackType | None) -> None: self._channel.close() self._channel = None def _connect(self) -> None: - # TODO: - # - Verify connection - # - Handle multinode clusters - # - Detect supported API versions and set instance properties accordingly. + # TODO: Verify connection + # TODO: Handle multinode clusters + # TODO: Detect supported API versions and set instance properties accordingly. if not self._token: self._set_token() @@ -108,6 +110,7 @@ def _connect(self) -> None: def _set_token(self) -> None: """ Uses username/password for authenticating via REST. + Sets the session token into self._token to be used for gRPC channel. TODO: Handle multinode clusters @@ -116,7 +119,8 @@ def _set_token(self) -> None: return if not self._username or not self._password: - raise CVClientException("Unable to authenticate. Missing token or username/password.") + msg = "Unable to authenticate. Missing token or username/password." + raise CVClientException(msg) if not self._verify_certs: # Accepting SonarLint issue: We are purposely implementing no verification of certs. @@ -128,10 +132,14 @@ def _set_token(self) -> None: context = None try: - response = post( - "https://" + self._servers[0] + "/cvpservice/login/authenticate.do", auth=(self._username, self._password), verify=self._verify_certs, json={} + response = post( # noqa: S113 TODO: Add configurable timeout + "https://" + self._servers[0] + "/cvpservice/login/authenticate.do", + auth=(self._username, self._password), + verify=self._verify_certs, + json={}, ) self._token = response.json()["sessionId"] except (KeyError, JSONDecodeError) as e: - raise CVClientException("Unable to get token from CloudVision server. Please supply service account token instead of username/password.") from e + msg = "Unable to get token from CloudVision server. Please supply service account token instead of username/password." + raise CVClientException(msg) from e diff --git a/python-avd/pyavd/_cv/client/change_control.py b/python-avd/pyavd/_cv/client/change_control.py index 2395afe4b01..bff0152b7fe 100644 --- a/python-avd/pyavd/_cv/client/change_control.py +++ b/python-avd/pyavd/_cv/client/change_control.py @@ -3,11 +3,10 @@ # that can be found in the LICENSE file. from __future__ import annotations -from datetime import datetime from logging import getLogger from typing import TYPE_CHECKING, Literal -from ..api.arista.changecontrol.v1 import ( +from pyavd._cv.api.arista.changecontrol.v1 import ( ApproveConfig, ApproveConfigServiceStub, ApproveConfigSetRequest, @@ -24,9 +23,12 @@ ChangeControlStreamRequest, FlagConfig, ) + from .exceptions import get_cv_client_exception if TYPE_CHECKING: + from datetime import datetime + from aristaproto import _DateTime from . import CVClient @@ -42,9 +44,7 @@ class ChangeControlMixin: - """ - Only to be used as mixin on CVClient class. - """ + """Only to be used as mixin on CVClient class.""" workspace_api_version: Literal["v1"] = "v1" @@ -55,7 +55,7 @@ async def get_change_control( timeout: float = 10.0, ) -> ChangeControl: """ - Get Change Control using arista.changecontrol.v1.ChangeControlService.GetOne API + Get Change Control using arista.changecontrol.v1.ChangeControlService.GetOne API. Parameters: change_control_id: Unique identifier of the Change Control. @@ -73,11 +73,11 @@ async def get_change_control( try: response = await client.get_one(request, metadata=self._metadata, timeout=timeout) - return response.value - except Exception as e: raise get_cv_client_exception(e, f"Change Control ID '{change_control_id}'") or e + return response.value + async def set_change_control( self: CVClient, change_control_id: str, @@ -86,7 +86,7 @@ async def set_change_control( timeout: float = 10.0, ) -> ChangeControlConfigSetResponse: """ - Set Change Control details using arista.changecontrol.v1.ChangeControlConfigService.Set API + Set Change Control details using arista.changecontrol.v1.ChangeControlConfigService.Set API. Parameters: change_control_id: Unique identifier of the Change Control. @@ -102,17 +102,17 @@ async def set_change_control( value=ChangeControlConfig( key=ChangeControlKey(id=change_control_id), change=ChangeConfig(name=name, notes=description), - ) + ), ) client = ChangeControlConfigServiceStub(self._channel) try: response = await client.set(request, metadata=self._metadata, timeout=timeout) - return response.value - except Exception as e: raise get_cv_client_exception(e, f"Change Control ID '{change_control_id}'") or e + return response.value + async def approve_change_control( self: CVClient, change_control_id: str, @@ -121,7 +121,7 @@ async def approve_change_control( timeout: float = 10.0, ) -> ApproveConfig: """ - Get Change Control using arista.changecontrol.v1.ChangeControlService.GetOne API + Get Change Control using arista.changecontrol.v1.ChangeControlService.GetOne API. Parameters: change_control_id: Unique identifier of the Change Control. @@ -139,17 +139,17 @@ async def approve_change_control( key=ChangeControlKey(id=change_control_id), approve=FlagConfig(value=True, notes=description), version=timestamp, - ) + ), ) client = ApproveConfigServiceStub(self._channel) try: response = await client.set(request, metadata=self._metadata, timeout=timeout) - return response.value - except Exception as e: raise get_cv_client_exception(e, f"Approving Change Control ID '{change_control_id}' for timestamp '{timestamp}'") or e + return response.value + async def start_change_control( self: CVClient, change_control_id: str, @@ -157,7 +157,7 @@ async def start_change_control( timeout: float = 10.0, ) -> ChangeControlConfig: """ - Set Change Control details using arista.changecontrol.v1.ChangeControlConfigService.Set API + Set Change Control details using arista.changecontrol.v1.ChangeControlConfigService.Set API. Parameters: change_control_id: Unique identifier of the Change Control. @@ -171,17 +171,17 @@ async def start_change_control( value=ChangeControlConfig( key=ChangeControlKey(id=change_control_id), start=FlagConfig(value=True, notes=description), - ) + ), ) client = ChangeControlConfigServiceStub(self._channel) try: response = await client.set(request, metadata=self._metadata, timeout=timeout) - return response.value - except Exception as e: raise get_cv_client_exception(e, f"Change Control ID '{change_control_id}'") or e + return response.value + async def wait_for_change_control_state( self: CVClient, cc_id: str, @@ -190,6 +190,7 @@ async def wait_for_change_control_state( ) -> ChangeControl: """ Monitor a Change control using arista.changecontrol.v1.ChangeControlService.Subscribe API for a response to the given cc_id. + Blocks until a response is returned or timed out. Parameters: @@ -204,7 +205,7 @@ async def wait_for_change_control_state( partial_eq_filter=[ ChangeControl( key=ChangeControlKey(id=cc_id), - ) + ), ], ) client = ChangeControlServiceStub(self._channel) @@ -212,10 +213,9 @@ async def wait_for_change_control_state( responses = client.subscribe(request, metadata=self._metadata, timeout=timeout) async for response in responses: LOGGER.debug("wait_for_change_control_complete: Response is '%s.'", response) - if hasattr(response, "value"): - if response.value.status == CHANGE_CONTROL_STATUS_MAP[state]: - LOGGER.info("wait_for_change_control_complete: Got response for request '%s': %s", cc_id, response.value.status) - return response.value + if hasattr(response, "value") and response.value.status == CHANGE_CONTROL_STATUS_MAP[state]: + LOGGER.info("wait_for_change_control_complete: Got response for request '%s': %s", cc_id, response.value.status) + return response.value LOGGER.debug("wait_for_change_control_complete: Status of change control is '%s.'", response) except Exception as e: diff --git a/python-avd/pyavd/_cv/client/configlet.py b/python-avd/pyavd/_cv/client/configlet.py index 097879e92d4..641fa8d37e3 100644 --- a/python-avd/pyavd/_cv/client/configlet.py +++ b/python-avd/pyavd/_cv/client/configlet.py @@ -3,11 +3,10 @@ # that can be found in the LICENSE file. from __future__ import annotations -from datetime import datetime from pathlib import Path from typing import TYPE_CHECKING, Literal -from ..api.arista.configlet.v1 import ( +from pyavd._cv.api.arista.configlet.v1 import ( Configlet, ConfigletAssignment, ConfigletAssignmentConfig, @@ -26,12 +25,15 @@ ConfigletStreamRequest, MatchPolicy, ) -from ..api.arista.time import TimeBounds -from ..api.fmp import RepeatedString +from pyavd._cv.api.arista.time import TimeBounds +from pyavd._cv.api.fmp import RepeatedString + from .constants import DEFAULT_API_TIMEOUT from .exceptions import get_cv_client_exception if TYPE_CHECKING: + from datetime import datetime + from . import CVClient @@ -43,9 +45,7 @@ class ConfigletMixin: - """ - Only to be used as mixin on CVClient class. - """ + """Only to be used as mixin on CVClient class.""" configlet_api_version: Literal["v1"] = "v1" @@ -72,22 +72,20 @@ async def get_configlet_containers( if container_ids: for container_id in container_ids: request.partial_eq_filter.append( - ConfigletAssignment(key=ConfigletAssignmentKey(workspace_id=workspace_id, configlet_assignment_id=container_id)) + ConfigletAssignment(key=ConfigletAssignmentKey(workspace_id=workspace_id, configlet_assignment_id=container_id)), ) else: request.partial_eq_filter.append(ConfigletAssignment(key=ConfigletAssignmentKey(workspace_id=workspace_id))) client = ConfigletAssignmentServiceStub(self._channel) - configlet_assignments = [] try: responses = client.get_all(request, metadata=self._metadata, timeout=timeout) - async for response in responses: - configlet_assignments.append(response.value) - return configlet_assignments - + configlet_assignments = [response.value async for response in responses] except Exception as e: raise get_cv_client_exception(e, f"Workspace ID '{workspace_id}', ConfigletAssignment ID '{container_ids}'") or e + return configlet_assignments + async def set_configlet_container( self: CVClient, workspace_id: str, @@ -122,16 +120,16 @@ async def set_configlet_container( query=query, child_assignment_ids=RepeatedString(values=child_assignment_ids), match_policy=ASSIGNMENT_MATCH_POLICY_MAP.get(match_policy), - ) + ), ) client = ConfigletAssignmentConfigServiceStub(self._channel) try: response = await client.set(request, metadata=self._metadata, timeout=timeout) - return response.value - except Exception as e: raise get_cv_client_exception(e, f"Workspace ID '{workspace_id}', ConfigletAssignment ID '{container_id}'") or e + return response.value + async def set_configlet_containers( self: CVClient, workspace_id: str, @@ -150,7 +148,6 @@ async def set_configlet_containers( Returns: ConfigletAssignmentKey objects after being set including any server-generated values. """ - request = ConfigletAssignmentConfigSetSomeRequest( values=[ ConfigletAssignmentConfig( @@ -163,20 +160,17 @@ async def set_configlet_containers( match_policy=ASSIGNMENT_MATCH_POLICY_MAP.get(match_policy), ) for container_id, display_name, description, configlet_ids, query, child_assignment_ids, match_policy in containers - ] + ], ) client = ConfigletAssignmentConfigServiceStub(self._channel) - assignment_keys = [] try: responses = client.set_some(request, metadata=self._metadata, timeout=timeout + len(request.values) * 0.5) - async for response in responses: - assignment_keys.append(response.key) - - return assignment_keys - + assignment_keys = [response.key async for response in responses] except Exception as e: raise get_cv_client_exception(e, f"Workspace ID '{workspace_id}', Containers '{containers}'") or e + return assignment_keys + async def delete_configlet_container( self: CVClient, workspace_id: str, @@ -197,16 +191,16 @@ async def delete_configlet_container( value=ConfigletAssignmentConfig( key=ConfigletAssignmentKey(workspace_id=workspace_id, configlet_assignment_id=assignment_id), remove=True, - ) + ), ) client = ConfigletAssignmentConfigServiceStub(self._channel) try: response = await client.set(request, metadata=self._metadata, timeout=timeout) - return response.value - except Exception as e: raise get_cv_client_exception(e, f"Workspace ID '{workspace_id}', ConfigletAssignment ID '{assignment_id}'") or e + return response.value + async def get_configlets( self: CVClient, workspace_id: str, @@ -216,6 +210,7 @@ async def get_configlets( ) -> list[Configlet]: """ Get Configlets using arista.configlet.v1.ConfigletServiceStub.GetAll API. + Missing objects will not produce an error. Parameters: @@ -235,17 +230,15 @@ async def get_configlets( request.partial_eq_filter.append(Configlet(key=ConfigletKey(workspace_id=workspace_id))) client = ConfigletServiceStub(self._channel) - configlets = [] + try: responses = client.get_all(request, metadata=self._metadata, timeout=timeout) - async for response in responses: - configlets.append(response.value) - - return configlets - + configlets = [response.value async for response in responses] except Exception as e: raise get_cv_client_exception(e, f"Workspace ID '{workspace_id}', Configlet IDs '{configlet_ids}'") or e + return configlets + async def set_configlet( self: CVClient, workspace_id: str, @@ -275,16 +268,16 @@ async def set_configlet( display_name=display_name, description=description, body=body, - ) + ), ) client = ConfigletConfigServiceStub(self._channel) try: response = await client.set(request, metadata=self._metadata, timeout=timeout) - return response.value - except Exception as e: raise get_cv_client_exception(e, f"Workspace ID '{workspace_id}', Configlet ID '{configlet_id}'") or e + return response.value + async def set_configlet_from_file( self: CVClient, workspace_id: str, @@ -314,16 +307,16 @@ async def set_configlet_from_file( display_name=display_name, description=description, body=Path(file).read_text(encoding="UTF-8"), - ) + ), ) client = ConfigletConfigServiceStub(self._channel) try: response = await client.set(request, metadata=self._metadata, timeout=timeout) - return response.value - except Exception as e: raise get_cv_client_exception(e, f"Workspace ID '{workspace_id}', Configlet ID '{configlet_id}', File '{file}'") or e + return response.value + async def delete_configlets( self: CVClient, workspace_id: str, @@ -347,16 +340,14 @@ async def delete_configlets( ConfigletConfig( key=ConfigletKey(workspace_id=workspace_id, configlet_id=configlet_id), remove=True, - ) + ), ) client = ConfigletConfigServiceStub(self._channel) - configlet_configs = [] try: responses = client.set_some(request, metadata=self._metadata, timeout=timeout) - async for response in responses: - configlet_configs.append(response.key) - - return configlet_configs + configlet_configs = [response.key async for response in responses] except Exception as e: raise get_cv_client_exception(e, f"Workspace ID '{workspace_id}', Configlet IDs '{configlet_ids}'") or e + + return configlet_configs diff --git a/python-avd/pyavd/_cv/client/exceptions.py b/python-avd/pyavd/_cv/client/exceptions.py index a1f02d86ac4..b7b5bd6ad3e 100644 --- a/python-avd/pyavd/_cv/client/exceptions.py +++ b/python-avd/pyavd/_cv/client/exceptions.py @@ -32,37 +32,37 @@ def get_cv_client_exception(exception: Exception, cv_client_details: str | None return None -class CVClientException(Exception): - """Base exception""" +class CVClientException(Exception): # noqa: N818 + """Base exception.""" class CVTimeoutError(CVClientException): - """API call timed out""" + """API call timed out.""" class CVResourceNotFound(CVClientException): - """CloudVision Resource not found""" + """CloudVision Resource not found.""" class CVResourceInvalidState(CVClientException): - """Invalid state for CloudVision Resource""" + """Invalid state for CloudVision Resource.""" class CVWorkspaceBuildTimeout(CVClientException): - """Build of CloudVision Workspace timed out""" + """Build of CloudVision Workspace timed out.""" class CVWorkspaceBuildFailed(CVClientException): - """Build of CloudVision Workspace failed""" + """Build of CloudVision Workspace failed.""" class CVWorkspaceSubmitFailed(CVClientException): - """Build of CloudVision Workspace failed""" + """Build of CloudVision Workspace failed.""" class CVWorkspaceStateTimeout(CVClientException): - """Timed out waiting for Workspace to get to the expected state""" + """Timed out waiting for Workspace to get to the expected state.""" class CVChangeControlFailed(CVClientException): - """CloudVision ChangeControl failed during execution""" + """CloudVision ChangeControl failed during execution.""" diff --git a/python-avd/pyavd/_cv/client/inventory.py b/python-avd/pyavd/_cv/client/inventory.py index ee543278f20..831e7967f9d 100644 --- a/python-avd/pyavd/_cv/client/inventory.py +++ b/python-avd/pyavd/_cv/client/inventory.py @@ -3,21 +3,21 @@ # that can be found in the LICENSE file. from __future__ import annotations -from datetime import datetime from typing import TYPE_CHECKING, Literal -from ..api.arista.inventory.v1 import Device, DeviceKey, DeviceServiceStub, DeviceStreamRequest -from ..api.arista.time import TimeBounds +from pyavd._cv.api.arista.inventory.v1 import Device, DeviceKey, DeviceServiceStub, DeviceStreamRequest +from pyavd._cv.api.arista.time import TimeBounds + from .exceptions import get_cv_client_exception if TYPE_CHECKING: + from datetime import datetime + from . import CVClient class InventoryMixin: - """ - Only to be used as mixin on CVClient class. - """ + """Only to be used as mixin on CVClient class.""" inventory_api_version: Literal["v1"] = "v1" @@ -48,16 +48,13 @@ async def get_inventory_devices( key=DeviceKey(device_id=serial_number), system_mac_address=system_mac_address, hostname=hostname, - ) + ), ) client = DeviceServiceStub(self._channel) - inventory_devices = [] try: responses = client.get_all(request, metadata=self._metadata, timeout=timeout) - async for response in responses: - inventory_devices.append(response.value) - - return inventory_devices - + inventory_devices = [response.value async for response in responses] except Exception as e: raise get_cv_client_exception(e, f"devices '{devices}'") or e + + return inventory_devices diff --git a/python-avd/pyavd/_cv/client/studio.py b/python-avd/pyavd/_cv/client/studio.py index 860405589ad..f4cbf57bd88 100644 --- a/python-avd/pyavd/_cv/client/studio.py +++ b/python-avd/pyavd/_cv/client/studio.py @@ -4,11 +4,10 @@ from __future__ import annotations import json -from datetime import datetime from logging import getLogger from typing import TYPE_CHECKING, Any, Literal -from ..api.arista.studio.v1 import ( +from pyavd._cv.api.arista.studio.v1 import ( Inputs, InputsConfig, InputsConfigServiceStub, @@ -27,12 +26,15 @@ StudioRequest, StudioServiceStub, ) -from ..api.arista.time import TimeBounds -from ..api.fmp import RepeatedString +from pyavd._cv.api.arista.time import TimeBounds +from pyavd._cv.api.fmp import RepeatedString + from .constants import DEFAULT_API_TIMEOUT from .exceptions import CVResourceNotFound, get_cv_client_exception if TYPE_CHECKING: + from datetime import datetime + from . import CVClient LOGGER = getLogger(__name__) @@ -41,9 +43,7 @@ class StudioMixin: - """ - Only to be used as mixin on CVClient class. - """ + """Only to be used as mixin on CVClient class.""" studio_api_version: Literal["v1"] = "v1" @@ -78,7 +78,6 @@ async def get_studio( client = StudioServiceStub(self._channel) try: response = await client.get_one(request, metadata=self._metadata, timeout=timeout) - return response.value except Exception as e: # pylint: disable=broad-exception-caught e = get_cv_client_exception(e, f"Studio ID '{studio_id}, Workspace ID '{workspace_id}'") or e if isinstance(e, CVResourceNotFound): @@ -86,7 +85,10 @@ async def get_studio( # This simply means the studio itself was not changed in this workspace. pass else: - raise e + raise + else: + # We get here if no exception was raised, meaining the studio was changed in the workspace. + return response.value # If we get here, it means no studio was returned by the workspace call. # So now we fetch the studio config from the workspace to see if the studio was deleted in this workspace. @@ -95,17 +97,17 @@ async def get_studio( StudioConfig( key=StudioKey(studio_id=studio_id, workspace_id=workspace_id), remove=True, - ) + ), ], time=TimeBounds(start=None, end=time), ) client = StudioConfigServiceStub(self._channel) try: responses = client.get_all(request, metadata=self._metadata, timeout=timeout) - async for response in responses: + async for _response in responses: # If we get here it means we got an entry with "removed: True" so no need to look further. - raise CVResourceNotFound("The studio was deleted in the workspace.", f"Studio ID '{studio_id}, Workspace ID '{workspace_id}'") - + msg = "The studio was deleted in the workspace." + raise CVResourceNotFound(msg, f"Studio ID '{studio_id}, Workspace ID '{workspace_id}'") # noqa: TRY301 TODO: Improve error handling except Exception as e: raise get_cv_client_exception(e, f"Studio ID '{studio_id}, Workspace ID '{workspace_id}'") or e @@ -118,10 +120,11 @@ async def get_studio( client = StudioServiceStub(self._channel) try: response = await client.get_one(request, metadata=self._metadata, timeout=timeout) - return response.value except Exception as e: raise get_cv_client_exception(e, f"Studio ID '{studio_id}, Workspace ID '{workspace_id}'") or e + return response.value + async def get_studio_inputs( self: CVClient, studio_id: str, @@ -152,7 +155,7 @@ async def get_studio_inputs( partial_eq_filter=[ Inputs( key=InputsKey(studio_id=studio_id, workspace_id=workspace_id), - ) + ), ], time=time, ) @@ -187,14 +190,14 @@ async def get_studio_inputs( InputsConfig( key=InputsKey(studio_id=studio_id, workspace_id=workspace_id), remove=True, - ) + ), ], time=time, ) client = InputsConfigServiceStub(self._channel) try: responses = client.get_all(request, metadata=self._metadata, timeout=timeout) - async for response in responses: + async for _response in responses: # If we get here it means we got an entry with "removed: True" so no need to look further. return default_value @@ -207,7 +210,7 @@ async def get_studio_inputs( partial_eq_filter=[ Inputs( key=InputsKey(studio_id=studio_id, workspace_id=""), - ) + ), ], time=time, ) @@ -223,11 +226,11 @@ async def get_studio_inputs( data=studio_inputs, value=json.loads(response.value.inputs), ) - - return studio_inputs or default_value except Exception as e: raise get_cv_client_exception(e, f"Studio ID '{studio_id}, Workspace ID '{workspace_id}'") or e + return studio_inputs or default_value + async def get_studio_inputs_with_path( self: CVClient, studio_id: str, @@ -267,19 +270,19 @@ async def get_studio_inputs_with_path( client = InputsServiceStub(self._channel) try: response = await client.get_one(request, metadata=self._metadata, timeout=timeout) - - # We only get a response if the inputs are set/changed in the workspace. - if response.value.inputs is not None: - return json.loads(response.value.inputs) - return default_value - except Exception as e: # pylint: disable=broad-exception-caught e = get_cv_client_exception(e, f"Studio ID '{studio_id}, Workspace ID '{workspace_id}', Path '{input_path}'") or e if isinstance(e, CVResourceNotFound) and workspace_id != "": # Ignore this error, since it simply means we have to check if inputs got deleted in this workspace or fetch from mainline as last resort. pass else: - raise e + raise + else: + # We get here if no exception was raised. + # We only get a response if the inputs are set/changed in the workspace. + if response.value.inputs is not None: + return json.loads(response.value.inputs) + return default_value # If we get here, it means no inputs were returned by the workspace call. # So now we fetch the inputs config from the workspace to see if the inputs were deleted in this workspace. @@ -297,7 +300,7 @@ async def get_studio_inputs_with_path( client = InputsConfigServiceStub(self._channel) try: responses = client.get_all(request, metadata=self._metadata, timeout=timeout) - async for response in responses: + async for _response in responses: # If we get here it means we got an entry with "removed: True" so no need to look further. return default_value @@ -317,15 +320,16 @@ async def get_studio_inputs_with_path( client = InputsServiceStub(self._channel) try: response = await client.get_one(request, metadata=self._metadata, timeout=timeout) - if response.value.inputs is not None: - return json.loads(response.value.inputs) - return default_value except Exception as e: # pylint: disable=broad-exception-caught e = get_cv_client_exception(e, f"Studio ID '{studio_id}, Workspace ID '{workspace_id}', Path '{input_path}'") or e if isinstance(e, CVResourceNotFound): # Ignore this error, since it simply means we no inputs are in the studio so we will return the default value. return default_value - raise e + raise + + if response.value.inputs is not None: + return json.loads(response.value.inputs) + return default_value async def set_studio_inputs( self: CVClient, @@ -358,16 +362,16 @@ async def set_studio_inputs( path=RepeatedString(values=input_path), ), inputs=json.dumps(inputs), - ) + ), ) client = InputsConfigServiceStub(self._channel) try: response = await client.set(request, metadata=self._metadata, timeout=timeout) - return response.value - except Exception as e: raise get_cv_client_exception(e, f"Studio ID '{studio_id}, Workspace ID '{workspace_id}', Path '{input_path}'") or e + return response.value + async def get_topology_studio_inputs( self: CVClient, workspace_id: str, @@ -376,9 +380,6 @@ async def get_topology_studio_inputs( timeout: float = DEFAULT_API_TIMEOUT, ) -> list[dict]: """ - TODO: Once the topology studio inputs API is public, this function can be replaced by the _future variant. - It will probably need some version detection to see if the API is supported. - Get Topology Studio Inputs using arista.studio.v1.InputsService.GetAll and arista.studio.v1.InputsConfigService.GetAll APIs. Parameters: @@ -392,7 +393,11 @@ async def get_topology_studio_inputs( """ topology_inputs: list[dict] = [] studio_inputs: dict = await self.get_studio_inputs( - studio_id=TOPOLOGY_STUDIO_ID, workspace_id=workspace_id, default_value={}, time=time, timeout=timeout + studio_id=TOPOLOGY_STUDIO_ID, + workspace_id=workspace_id, + default_value={}, + time=time, + timeout=timeout, ) for device_entry in studio_inputs.get("devices", []): if not isinstance(device_entry, dict): @@ -419,7 +424,7 @@ async def get_topology_studio_inputs( } for interface in interfaces ], - } + }, ) return topology_inputs @@ -430,9 +435,6 @@ async def set_topology_studio_inputs( timeout: float = DEFAULT_API_TIMEOUT, ) -> list[InputsKey]: """ - TODO: Once the topology studio inputs API is public, this function can be replaced by the _future variant. - It will probably need some version detection to see if the API is supported. - Set Topology Studio Inputs using arista.studio.v1.InputsConfigService.Set API. Parameters: @@ -469,7 +471,7 @@ async def set_topology_studio_inputs( path=RepeatedString(values=["devices", str(device_index), "inputs", "device"]), ), inputs=json.dumps(device_info), - ) + ), ) index_offset = len(studio_inputs.get("devices", [])) @@ -489,106 +491,15 @@ async def set_topology_studio_inputs( path=RepeatedString(values=["devices", str(device_index)]), ), inputs=json.dumps(device_entry), - ) + ), ) input_keys = [] client = InputsConfigServiceStub(self._channel) try: responses = client.set_some(request, metadata=self._metadata, timeout=timeout + len(request.values) * 0.1) - async for response in responses: - input_keys.append(response.key) - - return input_keys - + input_keys = [response.key async for response in responses] except Exception as e: raise get_cv_client_exception(e, f"Studio ID '{TOPOLOGY_STUDIO_ID}, Workspace ID '{workspace_id}', Devices '{device_inputs}'") or e - # Future versions for once topology studio API is available. - # - # async def _future__get_topology_studio_inputs( - # self: CVClient, - # workspace_id: str, - # device_ids: list[str] | None = None, - # time: datetime | None = None, - # timeout: float = DEFAULT_API_TIMEOUT, - # ) -> list[TopologyInput]: - # """ - # TODO: Once the topology studio inputs API is public, this function can be put in place. - # It will probably need some version detection to see if the API is supported. - - # Get Topology Studio Inputs using arista.studio.v1.TopologyInputsService.GetAll and arista.studio.v1.TopologyInputsConfigService.GetAll APIs. - - # Parameters: - # workspace_id: Unique identifier of the Workspace for which the information is fetched. Use "" for mainline. - # device_ids: List of Device IDs / Serial numbers to get inputs for. - # time: Timestamp from which the information is fetched. `now()` if not set. - # timeout: Timeout in seconds. - - # Returns: - # Inputs object. - # """ - # request = TopologyInputStreamRequest(partial_eq_filter=[], time=time) - # if device_ids: - # for device_id in device_ids: - # request.partial_eq_filter.append( - # TopologyInput( - # key=TopologyInputKey(workspace_id=workspace_id, device_id=device_id), - # ) - # ) - # else: - # request.partial_eq_filter.append( - # TopologyInput( - # key=TopologyInputKey(workspace_id=workspace_id), - # ) - # ) - # client = TopologyInputServiceStub(self._channel) - # topology_inputs = [] - # try: - # responses = client.get_all(request, metadata=self._metadata, timeout=timeout) - # async for response in responses: - # topology_inputs.append(response.value) - # return topology_inputs - # except Exception as e: - # raise get_cv_client_exception(e, f"Workspace ID '{workspace_id}', Device IDs '{device_ids}'") or e - - # async def _future_set_topology_studio_inputs( - # self: CVClient, - # workspace_id: str, - # device_inputs: list[tuple[str, str]], - # timeout: float = DEFAULT_API_TIMEOUT, - # ) -> list[TopologyInputKey]: - # """ - # TODO: Once the topology studio inputs API is public, this function can be put in place. - # It will probably need some version detection to see if the API is supported. - - # Set Topology Studio Inputs using arista.studio.v1.TopologyInputsConfigService.Set API. - - # Parameters: - # workspace_id: Unique identifier of the Workspace for which the information is set. - # device_inputs: List of Tuples with the format (, ). - # timeout: Timeout in seconds. - - # Returns: - # TopologyInputKey objects after being set including any server-generated values. - # """ - # request = TopologyInputConfigSetSomeRequest( - # values=[ - # TopologyInputConfig( - # key=TopologyInputKey(workspace_id=workspace_id, device_id=device_id), - # device_info=DeviceInfo(device_id=device_id, hostname=hostname), - # ) - # for device_id, hostname in device_inputs - # ] - # ) - - # client = TopologyInputConfigServiceStub(self._channel) - # topology_input_keys = [] - # try: - # responses = client.set_some(request, metadata=self._metadata, timeout=timeout) - # async for response in responses: - # topology_input_keys.append(response.key) - # return topology_input_keys - - # except Exception as e: - # raise get_cv_client_exception(e, f"Workspace ID '{workspace_id}', Device IDs '{device_inputs}'") or e + return input_keys diff --git a/python-avd/pyavd/_cv/client/swg.py b/python-avd/pyavd/_cv/client/swg.py index b9568b2b2b6..cc50ae34f9e 100644 --- a/python-avd/pyavd/_cv/client/swg.py +++ b/python-avd/pyavd/_cv/client/swg.py @@ -7,7 +7,7 @@ from logging import getLogger from typing import TYPE_CHECKING, Literal -from ..api.arista.swg.v1 import ( +from pyavd._cv.api.arista.swg.v1 import ( EndpointConfig, EndpointConfigServiceStub, EndpointConfigSetRequest, @@ -17,6 +17,7 @@ ServiceName, SwgKey, ) + from .constants import DEFAULT_API_TIMEOUT from .exceptions import get_cv_client_exception @@ -32,9 +33,7 @@ class SwgMixin: - """ - Only to be used as mixin on CVClient class. - """ + """Only to be used as mixin on CVClient class.""" swg_api_version: Literal["v1"] = "v1" @@ -46,7 +45,7 @@ async def set_swg_device( timeout: float = DEFAULT_API_TIMEOUT, ) -> tuple[datetime, EndpointConfig]: """ - Set SWG Endpoints using arista.swg.v1.EndpointStatusService.Set API + Set SWG Endpoints using arista.swg.v1.EndpointStatusService.Set API. Parameters: device_id: Unique identifier of the Device - typically serial number. @@ -62,18 +61,18 @@ async def set_swg_device( value=EndpointConfig( key=SwgKey(device_id=device_id, service_name=ELEMENT_TYPE_MAP[service]), address=location, - ) + ), ) client = EndpointConfigServiceStub(self._channel) try: LOGGER.info("set_swg_device: Setting location for '%s': %s", device_id, location) response = await client.set(request, metadata=self._metadata, timeout=timeout) - return response.time, response.value - except Exception as e: raise get_cv_client_exception(e, f"set_swg_device: Device ID '{device_id}', service '{service}', location '{location}'") or e + return response.time, response.value + async def wait_for_swg_endpoint_status( self: CVClient, device_id: str, @@ -82,7 +81,7 @@ async def wait_for_swg_endpoint_status( timeout: float = DEFAULT_API_TIMEOUT, ) -> EndpointStatus: """ - Subscribe and wait for one SWG Endpoint using arista.swg.v1.EndpointStatusService.Subscribe API + Subscribe and wait for one SWG Endpoint using arista.swg.v1.EndpointStatusService.Subscribe API. Parameters: device_id: Unique identifier of the Device - typically serial number. diff --git a/python-avd/pyavd/_cv/client/tag.py b/python-avd/pyavd/_cv/client/tag.py index 1332c672772..b9c24a7353e 100644 --- a/python-avd/pyavd/_cv/client/tag.py +++ b/python-avd/pyavd/_cv/client/tag.py @@ -3,10 +3,9 @@ # that can be found in the LICENSE file. from __future__ import annotations -from datetime import datetime from typing import TYPE_CHECKING, Literal -from ..api.arista.tag.v2 import ( +from pyavd._cv.api.arista.tag.v2 import ( CreatorType, ElementType, Tag, @@ -26,10 +25,13 @@ TagServiceStub, TagStreamRequest, ) -from ..api.arista.time import TimeBounds +from pyavd._cv.api.arista.time import TimeBounds + from .exceptions import get_cv_client_exception if TYPE_CHECKING: + from datetime import datetime + from . import CVClient ELEMENT_TYPE_MAP = { @@ -47,9 +49,7 @@ class TagMixin: - """ - Only to be used as mixin on CVClient class. - """ + """Only to be used as mixin on CVClient class.""" tags_api_version: Literal["v2"] = "v2" # TODO: Ensure the to document that we only support v2 of this api - hence only the CV versions supporting that. @@ -93,9 +93,7 @@ async def get_tags( client = TagServiceStub(self._channel) try: responses = client.get_all(request, metadata=self._metadata, timeout=timeout) - async for response in responses: - tags.append(response.value) - + tags.extend(response.value async for response in responses) except Exception as e: raise get_cv_client_exception(e, f"Workspace ID '' (main), Element Type '{element_type}', Creator Type '{creator_type}'") or e @@ -123,11 +121,11 @@ async def get_tags( self._remove_item_from_list(tag, tags, self._match_tags) else: self._upsert_item_in_list(tag, tags, self._match_tags) - return tags - except Exception as e: raise get_cv_client_exception(e, f"Workspace ID '{workspace_id}', Element Type '{element_type}', Creator Type '{creator_type}'") or e + return tags + async def set_tags( self: CVClient, workspace_id: str, @@ -158,22 +156,20 @@ async def set_tags( element_type=ELEMENT_TYPE_MAP[element_type], label=label, value=value, - ) - ) + ), + ), ) - tag_keys = [] client = TagConfigServiceStub(self._channel) try: responses = client.set_some(request, metadata=self._metadata, timeout=timeout + len(request.values) * 0.1) - async for response in responses: - # Recreating a full tag object. Since we just created it, it *must* be a user created tag. - tag_keys.append(response.key) - return tag_keys - + # Recreating a full tag object. Since we just created it, it *must* be a user created tag. + tag_keys = [response.key async for response in responses] except Exception as e: raise get_cv_client_exception(e, f"Workspace ID '{workspace_id}', Element Type '{element_type}'") or e + return tag_keys + async def get_tag_assignments( self: CVClient, workspace_id: str, @@ -213,9 +209,7 @@ async def get_tag_assignments( client = TagAssignmentServiceStub(self._channel) try: responses = client.get_all(request, metadata=self._metadata, timeout=timeout) - async for response in responses: - tag_assignments.append(response.value) - + tag_assignments.extend(response.value async for response in responses) except Exception as e: raise get_cv_client_exception(e, f"Workspace ID '' (main), Element Type '{element_type}', Creator Type '{creator_type}'") or e @@ -243,11 +237,11 @@ async def get_tag_assignments( self._remove_item_from_list(tag_assignment, tag_assignments, self._match_tag_assignments) else: self._upsert_item_in_list(tag_assignment, tag_assignments, self._match_tag_assignments) - return tag_assignments - except Exception as e: raise get_cv_client_exception(e, f"Workspace ID '{workspace_id}', Element Type '{element_type}', Creator Type '{creator_type}'") or e + return tag_assignments + async def set_tag_assignments( self: CVClient, workspace_id: str, @@ -280,21 +274,19 @@ async def set_tag_assignments( value=value, device_id=device_id, interface_id=interface_id, - ) - ) + ), + ), ) - tag_assignment_keys = [] client = TagAssignmentConfigServiceStub(self._channel) try: responses = client.set_some(request, metadata=self._metadata, timeout=timeout + len(request.values) * 0.1) - async for response in responses: - tag_assignment_keys.append(response.key) - return tag_assignment_keys - + tag_assignment_keys = [response.key async for response in responses] except Exception as e: raise get_cv_client_exception(e, f"Workspace ID '{workspace_id}', Element Type '{element_type}'") or e + return tag_assignment_keys + async def delete_tag_assignments( self: CVClient, workspace_id: str, @@ -329,32 +321,26 @@ async def delete_tag_assignments( interface_id=interface_id, ), remove=True, - ) + ), ) - tag_assignment_keys = [] client = TagAssignmentConfigServiceStub(self._channel) try: responses = client.set_some(request, metadata=self._metadata, timeout=timeout + len(request.values) * 0.1) - async for response in responses: - tag_assignment_keys.append(response.key) - return tag_assignment_keys - + tag_assignment_keys = [response.key async for response in responses] except Exception as e: raise get_cv_client_exception(e, f"Workspace ID '{workspace_id}', Element Type '{element_type}'") or e + return tag_assignment_keys + @staticmethod def _match_tags(a: Tag, b: Tag) -> bool: - """ - Match up the properties of two tags without looking at the Workspace and Creator Type fields. - """ + """Match up the properties of two tags without looking at the Workspace and Creator Type fields.""" return all([a.key.element_type == b.key.element_type, a.key.label == b.key.label, a.key.value == b.key.value]) @staticmethod def _match_tag_assignments(a: TagAssignment, b: TagAssignment) -> bool: - """ - Match up the properties of two tag assignments without looking at the Workspace and Creator Type fields. - """ + """Match up the properties of two tag assignments without looking at the Workspace and Creator Type fields.""" return all( [ a.key.element_type == b.key.element_type, @@ -362,5 +348,5 @@ def _match_tag_assignments(a: TagAssignment, b: TagAssignment) -> bool: a.key.value == b.key.value, a.key.device_id == b.key.device_id, a.key.interface_id == b.key.interface_id, - ] + ], ) diff --git a/python-avd/pyavd/_cv/client/utils.py b/python-avd/pyavd/_cv/client/utils.py index 9c0bcf1b7ee..06e1f9ffd43 100644 --- a/python-avd/pyavd/_cv/client/utils.py +++ b/python-avd/pyavd/_cv/client/utils.py @@ -3,19 +3,19 @@ # that can be found in the LICENSE file. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: + from collections.abc import Callable + from . import CVClient class UtilsMixin: - """ - Only to be used as mixin on CVClient class. - """ + """Only to be used as mixin on CVClient class.""" @staticmethod - def _remove_item_from_list(itm, lst: list, matcher: Callable) -> None: + def _remove_item_from_list(itm: Any, lst: list, matcher: Callable) -> None: """ Remove one item from the given list. @@ -30,7 +30,7 @@ def _remove_item_from_list(itm, lst: list, matcher: Callable) -> None: return @staticmethod - def _upsert_item_in_list(itm, lst: list, matcher: Callable) -> None: + def _upsert_item_in_list(itm: Any, lst: list, matcher: Callable) -> None: """ Update or append one item from the given list. @@ -43,7 +43,7 @@ def _upsert_item_in_list(itm, lst: list, matcher: Callable) -> None: lst.append(itm) - def _set_value_from_path(self: CVClient, path: list[str], data: list | dict, value) -> None: + def _set_value_from_path(self: CVClient, path: list[str], data: list | dict, value: Any) -> None: """ Recursive function to walk through data to set value on path, creating any level needed. @@ -59,7 +59,8 @@ def _set_value_from_path(self: CVClient, path: list[str], data: list | dict, val if isinstance(value, dict) and isinstance(data, dict): data.update(value) return - raise RuntimeError(f"Path '{path}', value type '{type(value)}' cannot be set on data type '{type(data)}'") + msg = f"Path '{path}', value type '{type(value)}' cannot be set on data type '{type(data)}'" + raise RuntimeError(msg) # Convert '0' to 0. path = [int(element) if str(element).isnumeric() else element for element in path] if len(path) == 1: @@ -69,7 +70,8 @@ def _set_value_from_path(self: CVClient, path: list[str], data: list | dict, val # We ignore the actual integer value and just append the item to the list. data.append(value) else: - raise RuntimeError(f"Path '{path}' cannot be set on data of type '{type(data)}'") + msg = f"Path '{path}' cannot be set on data of type '{type(data)}'" + raise RuntimeError(msg) return # Two or more elements in path. @@ -98,9 +100,8 @@ def _set_value_from_path(self: CVClient, path: list[str], data: list | dict, val self._set_value_from_path(path[1:], data[index], value) else: - raise RuntimeError(f"Path '{path}', value type '{type(value)}' cannot be set on data of type '{type(data)}'") - - return None + msg = f"Path '{path}', value type '{type(value)}' cannot be set on data of type '{type(data)}'" + raise TypeError(msg) def _get_value_from_path(self: CVClient, path: list[str], data: list | dict, default_value: Any = None) -> Any: """ @@ -123,7 +124,8 @@ def _get_value_from_path(self: CVClient, path: list[str], data: list | dict, def # Convert '0' to 0. path = [int(element) if str(element).isnumeric() else element for element in path] if isinstance(path[0], int) and not isinstance(data, list): - raise TypeError(f"Path element is '{path[0]}' but data is not a list (got '{type(data)}').") + msg = f"Path element is '{path[0]}' but data is not a list (got '{type(data)}')." + raise TypeError(msg) try: return self._get_value_from_path(path[1:], data[path[0]]) diff --git a/python-avd/pyavd/_cv/client/workspace.py b/python-avd/pyavd/_cv/client/workspace.py index d14341d1898..4564d974502 100644 --- a/python-avd/pyavd/_cv/client/workspace.py +++ b/python-avd/pyavd/_cv/client/workspace.py @@ -3,12 +3,11 @@ # that can be found in the LICENSE file. from __future__ import annotations -from datetime import datetime from logging import getLogger from typing import TYPE_CHECKING, Literal from uuid import uuid4 -from ..api.arista.workspace.v1 import ( +from pyavd._cv.api.arista.workspace.v1 import ( Request, RequestParams, Response, @@ -22,9 +21,12 @@ WorkspaceServiceStub, WorkspaceStreamRequest, ) + from .exceptions import get_cv_client_exception if TYPE_CHECKING: + from datetime import datetime + from . import CVClient LOGGER = getLogger(__name__) @@ -40,9 +42,7 @@ class WorkspaceMixin: - """ - Only to be used as mixin on CVClient class. - """ + """Only to be used as mixin on CVClient class.""" workspace_api_version: Literal["v1"] = "v1" @@ -53,7 +53,7 @@ async def get_workspace( timeout: float = 10.0, ) -> Workspace: """ - Get Workspace using arista.workspace.v1.WorkspaceService.GetOne API + Get Workspace using arista.workspace.v1.WorkspaceService.GetOne API. Parameters: workspace_id: Unique identifier the workspace. @@ -73,11 +73,11 @@ async def get_workspace( try: response = await client.get_one(request, metadata=self._metadata, timeout=timeout) - return response.value - except Exception as e: raise get_cv_client_exception(e, f"Workspace ID '{workspace_id}'") or e + return response.value + async def create_workspace( self: CVClient, workspace_id: str, @@ -86,7 +86,7 @@ async def create_workspace( timeout: float = 10.0, ) -> WorkspaceConfig: """ - Create Workspace using arista.workspace.v1.WorkspaceConfigService.Set API + Create Workspace using arista.workspace.v1.WorkspaceConfigService.Set API. Parameters: workspace_id: Unique identifier the workspace. @@ -102,7 +102,7 @@ async def create_workspace( key=WorkspaceKey(workspace_id=workspace_id), display_name=display_name, description=description, - ) + ), ) client = WorkspaceConfigServiceStub(self._channel) response = await client.set(request, metadata=self._metadata, timeout=timeout) @@ -114,7 +114,7 @@ async def abandon_workspace( timeout: float = 10.0, ) -> WorkspaceConfig: """ - Abandon Workspace using arista.workspace.v1.WorkspaceConfigService.Set API + Abandon Workspace using arista.workspace.v1.WorkspaceConfigService.Set API. Parameters: workspace_id: Unique identifier the workspace. @@ -130,7 +130,7 @@ async def abandon_workspace( request_params=RequestParams( request_id=f"req-{uuid4()}", ), - ) + ), ) client = WorkspaceConfigServiceStub(self._channel) response = await client.set(request, metadata=self._metadata, timeout=timeout) @@ -142,7 +142,7 @@ async def build_workspace( timeout: float = 10.0, ) -> WorkspaceConfig: """ - Request a build of the Workspace using arista.workspace.v1.WorkspaceConfigService.Set API + Request a build of the Workspace using arista.workspace.v1.WorkspaceConfigService.Set API. Parameters: workspace_id: Unique identifier the workspace. @@ -158,7 +158,7 @@ async def build_workspace( request_params=RequestParams( request_id=f"req-{uuid4()}", ), - ) + ), ) client = WorkspaceConfigServiceStub(self._channel) response = await client.set(request, metadata=self._metadata, timeout=timeout) @@ -170,7 +170,7 @@ async def delete_workspace( timeout: float = 10.0, ) -> WorkspaceKey: """ - Delete Workspace using arista.workspace.v1.WorkspaceConfigService.Delete API + Delete Workspace using arista.workspace.v1.WorkspaceConfigService.Delete API. Parameters: workspace_id: Unique identifier the workspace. @@ -191,7 +191,7 @@ async def submit_workspace( timeout: float = 10.0, ) -> WorkspaceConfig: """ - Request submission of the Workspace using arista.workspace.v1.WorkspaceConfigService.Set API + Request submission of the Workspace using arista.workspace.v1.WorkspaceConfigService.Set API. Parameters: workspace_id: Unique identifier the Workspace. @@ -201,13 +201,12 @@ async def submit_workspace( Returns: WorkspaceConfig object after being set including any server-generated values. """ - request = WorkspaceConfigSetRequest( WorkspaceConfig( key=WorkspaceKey(workspace_id=workspace_id), request=Request.SUBMIT_FORCE if force else Request.SUBMIT, request_params=RequestParams(request_id=f"req-{uuid4()}"), - ) + ), ) client = WorkspaceConfigServiceStub(self._channel) response = await client.set(request, metadata=self._metadata, timeout=timeout) @@ -222,6 +221,7 @@ async def wait_for_workspace_response( ) -> tuple[Response, Workspace]: """ Monitor a Workspace using arista.workspace.v1.WorkspaceService.Subscribe API for a response to the given request_id. + Blocks until a response is returned or timed out. Parameters: @@ -236,7 +236,7 @@ async def wait_for_workspace_response( partial_eq_filter=[ Workspace( key=WorkspaceKey(workspace_id=workspace_id), - ) + ), ], ) client = WorkspaceServiceStub(self._channel) @@ -247,7 +247,9 @@ async def wait_for_workspace_response( LOGGER.info("wait_for_workspace_response: Got response for request '%s': %s", request_id, response.value.responses.values[request_id]) return response.value.responses.values[request_id], response.value LOGGER.debug( - "wait_for_workspace_response: Got workspace update but not for request_id '%s'. Workspace State: %s", request_id, response.value.state + "wait_for_workspace_response: Got workspace update but not for request_id '%s'. Workspace State: %s", + request_id, + response.value.state, ) except Exception as e: raise get_cv_client_exception(e, f"Workspace ID '{workspace_id}', Request ID '{request_id}") or e diff --git a/python-avd/pyavd/_cv/workflows/create_workspace_on_cv.py b/python-avd/pyavd/_cv/workflows/create_workspace_on_cv.py index 003ea5ed46d..8de425c13ad 100644 --- a/python-avd/pyavd/_cv/workflows/create_workspace_on_cv.py +++ b/python-avd/pyavd/_cv/workflows/create_workspace_on_cv.py @@ -4,11 +4,15 @@ from __future__ import annotations from logging import getLogger +from typing import TYPE_CHECKING -from ..api.arista.workspace.v1 import WorkspaceState -from ..client import CVClient -from ..client.exceptions import CVResourceInvalidState, CVResourceNotFound -from .models import CVWorkspace +from pyavd._cv.api.arista.workspace.v1 import WorkspaceState +from pyavd._cv.client.exceptions import CVResourceInvalidState, CVResourceNotFound + +if TYPE_CHECKING: + from pyavd._cv.client import CVClient + + from .models import CVWorkspace LOGGER = getLogger(__name__) @@ -16,6 +20,7 @@ async def create_workspace_on_cv(workspace: CVWorkspace, cv_client: CVClient) -> None: """ Create or update a Workspace from the given workspace object. + In-place update the workspace state. """ LOGGER.info("create_workspace_on_cv: %s", workspace) @@ -24,7 +29,8 @@ async def create_workspace_on_cv(workspace: CVWorkspace, cv_client: CVClient) -> if existing_workspace.state == WorkspaceState.PENDING: workspace.state = "pending" else: - raise CVResourceInvalidState("The requested workspace is not in state 'pending'") + msg = "The requested workspace is not in state 'pending'" + raise CVResourceInvalidState(msg) except CVResourceNotFound: await cv_client.create_workspace(workspace_id=workspace.id, display_name=workspace.name, description=workspace.description) workspace.state = "pending" diff --git a/python-avd/pyavd/_cv/workflows/deploy_configs_to_cv.py b/python-avd/pyavd/_cv/workflows/deploy_configs_to_cv.py index 3c92e3bd48e..c7b423d3ea7 100644 --- a/python-avd/pyavd/_cv/workflows/deploy_configs_to_cv.py +++ b/python-avd/pyavd/_cv/workflows/deploy_configs_to_cv.py @@ -5,10 +5,14 @@ from asyncio import gather from logging import getLogger +from typing import TYPE_CHECKING -from ..._utils import batch -from ..client import CVClient -from .models import CVEosConfig, DeployToCvResult +from pyavd._utils import batch + +if TYPE_CHECKING: + from pyavd._cv.client import CVClient + + from .models import CVEosConfig, DeployToCvResult LOGGER = getLogger(__name__) @@ -21,7 +25,7 @@ async def deploy_configs_to_cv(configs: list[CVEosConfig], result: DeployToCvResult, cv_client: CVClient) -> None: """ - Deploy given configs using "Static Configlet Studio" + Deploy given configs using "Static Configlet Studio". - Create/verify a single configuration container named "AVD Configurations". - Upload Configlets and assign to devices. @@ -29,17 +33,16 @@ async def deploy_configs_to_cv(configs: list[CVEosConfig], result: DeployToCvRes TODO: See if this can be optimized to check if the configlets are already in place and correct. A hash would have been nice. TODO: Split long configs into multiple configlets for 990KB chunks. Need to figure out how to batch it. """ - LOGGER.info("deploy_configs_to_cv: %s", len(configs)) if not configs: return - # Build Todo with CVEosConfig objects that exist on CloudVision. Add the rest to skipped. + # Build TODO: with CVEosConfig objects that exist on CloudVision. Add the rest to skipped. result.skipped_configs.extend(config for config in configs if not config.device._exists_on_cv) LOGGER.info("deploy_configs_to_cv: %s skipped configs because the devices are missing from CloudVision.", len(result.skipped_configs)) todo_configs = [config for config in configs if config.device._exists_on_cv] - LOGGER.info("deploy_configs_to_cv: %s todo configs.", len(todo_configs)) + LOGGER.info("deploy_configs_to_cv: %s TODO: configs.", len(todo_configs)) # No need to continue if we have nothing to do. if not todo_configs: @@ -55,7 +58,8 @@ async def deploy_configs_to_cv(configs: list[CVEosConfig], result: DeployToCvRes async def deploy_configlets_to_cv(configs: list[CVEosConfig], workspace_id: str, cv_client: CVClient) -> None: """ - Bluntly setting configs like nothing was there. Only create missing containers + Bluntly setting configs like nothing was there. Only create missing containers. + TODO: Fetch config checksums for existing configs and only upload what is needed. """ configlet_coroutines = [] @@ -68,7 +72,7 @@ async def deploy_configlets_to_cv(configs: list[CVEosConfig], workspace_id: str, file=config.file, display_name=config.configlet_name or f"{CONFIGLET_NAME_PREFIX}{config.device.hostname}", description=f"Configuration created and uploaded by AVD for {config.device.hostname}", - ) + ), ) LOGGER.info("deploy_configs_to_cv: Deploying %s configlets in batches of %s.", len(configlet_coroutines), PARALLEL_COROUTINES) @@ -80,7 +84,8 @@ async def deploy_configlets_to_cv(configs: list[CVEosConfig], workspace_id: str, async def get_existing_device_container_ids_from_root_container(workspace_id: str, cv_client: CVClient) -> list[str]: """ Get or create root level container for AVD configurations. Using the hardcoded id from CONFIGLET_CONTAINER_ID. - Then return the list of existing device container ids. (Always empty if we just created the root container) + + Then return the list of existing device container ids. (Always empty if we just created the root container). """ root_cv_containers = await cv_client.get_configlet_containers(workspace_id=workspace_id, container_ids=[CONFIGLET_CONTAINER_ID]) LOGGER.info("get_or_create_configlet_root_container: Got AVD root container? %s", bool(root_cv_containers)) @@ -98,7 +103,10 @@ async def get_existing_device_container_ids_from_root_container(workspace_id: st ) # Add the root level container to the list of root level containers using the studio inputs API (!?!) root_containers: list = await cv_client.get_studio_inputs_with_path( - studio_id=STATIC_CONFIGLET_STUDIO_ID, workspace_id=workspace_id, input_path=["configletAssignmentRoots"], default_value=[] + studio_id=STATIC_CONFIGLET_STUDIO_ID, + workspace_id=workspace_id, + input_path=["configletAssignmentRoots"], + default_value=[], ) LOGGER.info("deploy_configs_to_cv: Found %s root containers.", len(root_containers)) if CONFIGLET_CONTAINER_ID not in root_containers: @@ -116,6 +124,7 @@ async def get_existing_device_container_ids_from_root_container(workspace_id: st async def deploy_configlet_containers_to_cv(configs: list[CVEosConfig], workspace_id: str, cv_client: CVClient) -> None: """ Identify existing containers and ensure they have the correct configuration. + Then update/create as needed. TODO: Refactor to set_some on supported CV versions @@ -159,7 +168,7 @@ async def deploy_configlet_containers_to_cv(configs: list[CVEosConfig], workspac description=description, query=query, configlet_ids=configlet_ids, - ) + ), ) LOGGER.info("deploy_configs_to_cv: Deploying %s configlet assignments / containers in batches of %s.", len(container_coroutines), PARALLEL_COROUTINES) diff --git a/python-avd/pyavd/_cv/workflows/deploy_cv_pathfinder_metadata_to_cv.py b/python-avd/pyavd/_cv/workflows/deploy_cv_pathfinder_metadata_to_cv.py index a9ebd94d2d9..d95264db98d 100644 --- a/python-avd/pyavd/_cv/workflows/deploy_cv_pathfinder_metadata_to_cv.py +++ b/python-avd/pyavd/_cv/workflows/deploy_cv_pathfinder_metadata_to_cv.py @@ -5,13 +5,17 @@ from copy import deepcopy from logging import getLogger +from typing import TYPE_CHECKING -from ..._utils import get, get_v2 -from ..._utils.password_utils.password import simple_7_decrypt -from ..api.arista.studio.v1 import InputSchema -from ..client import CVClient -from ..client.exceptions import CVResourceNotFound -from .models import CVDevice, CVPathfinderMetadata, DeployToCvResult +from pyavd._cv.client.exceptions import CVResourceNotFound +from pyavd._utils import get, get_v2 +from pyavd._utils.password_utils.password import simple_7_decrypt + +if TYPE_CHECKING: + from pyavd._cv.api.arista.studio.v1 import InputSchema + from pyavd._cv.client import CVClient + + from .models import CVDevice, CVPathfinderMetadata, DeployToCvResult LOGGER = getLogger(__name__) @@ -26,7 +30,7 @@ def is_pathfinder_location_supported(studio_schema: InputSchema) -> bool: - """Detect if pathfinder location is supported by the metadata studio""" + """Detect if pathfinder location is supported by the metadata studio.""" pathfinder_group_fields = get_v2(studio_schema, "fields.values.pathfinderGroup.group_props.members.values") if pathfinder_group_fields is None: return False @@ -35,23 +39,24 @@ def is_pathfinder_location_supported(studio_schema: InputSchema) -> bool: def is_avt_hop_count_supported(studio_schema: InputSchema) -> bool: - """Detect if AVT hop count is supported by the metadata studio""" + """Detect if AVT hop count is supported by the metadata studio.""" return bool(get_v2(studio_schema, "fields.values.avtHopCount")) def is_internet_exit_zscaler_supported(studio_schema: InputSchema) -> bool: - """Detect if zscaler internet exit is supported by the metadata studio""" + """Detect if zscaler internet exit is supported by the metadata studio.""" return bool(get_v2(studio_schema, "fields.values.zscaler")) def is_applications_supported(studio_schema: InputSchema) -> bool: - """Detect if applications is supported by the metadata studio""" + """Detect if applications is supported by the metadata studio.""" return bool(get_v2(studio_schema, "fields.values.applications")) async def get_metadata_studio_schema(result: DeployToCvResult, cv_client: CVClient) -> InputSchema | None: """ Download and return the input schema for the cv pathfinder metadata studio. + Returns None if the metadata studio is not found. """ try: @@ -69,9 +74,7 @@ async def get_metadata_studio_schema(result: DeployToCvResult, cv_client: CVClie def update_general_metadata(metadata: dict, studio_inputs: dict, studio_schema: InputSchema) -> list[str]: - """ - In-place update general metadata in studio_inputs. - """ + """In-place update general metadata in studio_inputs.""" warnings = [] # Temporary fix for default values in metadata studio @@ -109,7 +112,7 @@ def update_general_metadata(metadata: dict, studio_inputs: dict, studio_schema: ], "regions": get(metadata, "regions", default=[]), "vrfs": get(metadata, "vrfs", default=[]), - } + }, ) if applications := generate_applications_metadata(metadata): @@ -146,7 +149,7 @@ def upsert_pathfinder(metadata: dict, device: CVDevice, studio_inputs: dict, stu "circuitId": interface.get("circuit_id", ""), "pathgroup": interface.get("pathgroup", ""), "publicIp": interface.get("public_ip", ""), - } + }, }, "tags": {"query": f"interface:{interface.get('name', '')}@{device.serial_number}"}, } @@ -207,7 +210,7 @@ def upsert_edge(metadata: dict, device: CVDevice, studio_inputs: dict, studio_sc "carrier": interface.get("carrier", ""), "circuitId": interface.get("circuit_id", ""), "pathgroup": interface.get("pathgroup", ""), - } + }, }, "tags": {"query": f"interface:{interface.get('name', '')}@{device.serial_number}"}, } @@ -311,7 +314,6 @@ async def deploy_cv_pathfinder_metadata_to_cv(cv_pathfinder_metadata: list[CVPat vni: 100 ``` """ - LOGGER.info("deploy_cv_pathfinder_metadata_to_cv: Got cv_pathfinder_metadata for %s devices", len(cv_pathfinder_metadata)) if not cv_pathfinder_metadata: @@ -322,7 +324,9 @@ async def deploy_cv_pathfinder_metadata_to_cv(cv_pathfinder_metadata: list[CVPat # Get existing studio inputs existing_studio_inputs = await cv_client.get_studio_inputs( - studio_id=CV_PATHFINDER_METADATA_STUDIO_ID, workspace_id=result.workspace.id, default_value=CV_PATHFINDER_DEFAULT_STUDIO_INPUTS + studio_id=CV_PATHFINDER_METADATA_STUDIO_ID, + workspace_id=result.workspace.id, + default_value=CV_PATHFINDER_DEFAULT_STUDIO_INPUTS, ) studio_inputs = deepcopy(existing_studio_inputs) @@ -370,7 +374,7 @@ async def deploy_cv_pathfinder_metadata_to_cv(cv_pathfinder_metadata: list[CVPat for pathfinder in pathfinders: result.warnings.extend( - upsert_pathfinder(metadata=pathfinder.metadata, device=pathfinder.device, studio_inputs=studio_inputs, studio_schema=studio_schema) + upsert_pathfinder(metadata=pathfinder.metadata, device=pathfinder.device, studio_inputs=studio_inputs, studio_schema=studio_schema), ) for edge in edges: @@ -385,7 +389,8 @@ async def deploy_cv_pathfinder_metadata_to_cv(cv_pathfinder_metadata: list[CVPat def generate_internet_exit_metadata(metadata: dict, device: CVDevice, studio_schema: InputSchema) -> tuple[dict, list[str]]: """ Generate internet-exit related metadata for one device. - To be inserted into edge router metadata under "services" + + To be inserted into edge router metadata under "services". Returns metadata dict and list of any warnings raised. """ @@ -441,7 +446,7 @@ def generate_internet_exit_metadata(metadata: dict, device: CVDevice, studio_sch } for vpn_credential in internet_exit_policy["vpn_credentials"] ], - } + }, ) services_dict["zscaler"]["tunnels"].extend( { @@ -457,7 +462,8 @@ def generate_internet_exit_metadata(metadata: dict, device: CVDevice, studio_sch def generate_applications_metadata(metadata: dict) -> dict: """ Generate application traffic recognition related metadata for one patfinder. - To be inserted into the common metadata under "applications" + + To be inserted into the common metadata under "applications". """ if (applications := get(metadata, "applications")) is None: return {} diff --git a/python-avd/pyavd/_cv/workflows/deploy_studio_inputs_to_cv.py b/python-avd/pyavd/_cv/workflows/deploy_studio_inputs_to_cv.py index c89e922e1cf..275d10ba173 100644 --- a/python-avd/pyavd/_cv/workflows/deploy_studio_inputs_to_cv.py +++ b/python-avd/pyavd/_cv/workflows/deploy_studio_inputs_to_cv.py @@ -5,10 +5,14 @@ from asyncio import gather from logging import getLogger +from typing import TYPE_CHECKING -from ..._utils import batch -from ..client import CVClient -from .models import CVStudioInputs, DeployToCvResult +from pyavd._utils import batch + +if TYPE_CHECKING: + from pyavd._cv.client import CVClient + + from .models import CVStudioInputs, DeployToCvResult LOGGER = getLogger(__name__) @@ -22,45 +26,26 @@ def get_studio_id(studio_input: CVStudioInputs) -> str: return studio_input.studio_id -# TODO: Cleanup if we decide to live with the limitation of overlapping inputs in favor of performance. -# async def set_multiple_studio_inputs_on_one_studio(workspace_id: str, studio_inputs: list[CVStudioInputs], cv_client: CVClient) -> None: -# """ -# Setting multiple CVStudioInputs on the same studio one by one to avoid race conditions when updating overlapping -# """ -# for studio_input in studio_inputs: -# await cv_client.set_studio_inputs( -# studio_id=studio_input.studio_id, workspace_id=workspace_id, inputs=studio_input.inputs, input_path=studio_input.input_path -# ) - - async def deploy_studio_inputs_to_cv(studio_inputs: list[CVStudioInputs], result: DeployToCvResult, cv_client: CVClient) -> None: """ Deploy given Studio Inputs. It is not supported to deploy overlapping studio inputs for the same studio. """ - LOGGER.info("deploy_studio_inputs_to_cv: %s", len(studio_inputs)) if not studio_inputs: return - # TODO: Cleanup if we decide to live with the limitation of overlapping inputs in favor of performance. - # grouped_by_studio_id = groupby(sorted(studio_inputs, get_studio_id), get_studio_id) - - studio_inputs_coroutines = [] - # TODO: Cleanup if we decide to live with the limitation of overlapping inputs in favor of performance. - # for studio_id, studio_inputs_for_one_studio in grouped_by_studio_id: - # studio_inputs_coroutines.append( - # set_multiple_studio_inputs_on_one_studio(workspace_id=result.workspace.id, studio_inputs=studio_inputs_for_one_studio, cv_client=cv_client) - # ) - - for studio_input in studio_inputs: - studio_inputs_coroutines.append( - cv_client.set_studio_inputs( - studio_id=studio_input.studio_id, workspace_id=result.workspace.id, inputs=studio_input.inputs, input_path=studio_input.input_path - ) + studio_inputs_coroutines = [ + cv_client.set_studio_inputs( + studio_id=studio_input.studio_id, + workspace_id=result.workspace.id, + inputs=studio_input.inputs, + input_path=studio_input.input_path, ) + for studio_input in studio_inputs + ] # Deploy studio inputs in parallel in batches of 20. LOGGER.info("deploy_studio_inputs_to_cv: Deploying %s Studio Inputs in batches of 20.", len(studio_inputs_coroutines)) diff --git a/python-avd/pyavd/_cv/workflows/deploy_tags_to_cv.py b/python-avd/pyavd/_cv/workflows/deploy_tags_to_cv.py index 1e66153dd0e..24ee8f5ab36 100644 --- a/python-avd/pyavd/_cv/workflows/deploy_tags_to_cv.py +++ b/python-avd/pyavd/_cv/workflows/deploy_tags_to_cv.py @@ -4,10 +4,13 @@ from __future__ import annotations from logging import getLogger +from typing import TYPE_CHECKING -from ..client import CVClient from .models import CVDeviceTag, CVInterfaceTag, CVWorkspace +if TYPE_CHECKING: + from pyavd._cv.client import CVClient + LOGGER = getLogger(__name__) @@ -44,12 +47,9 @@ async def deploy_tags_to_cv( if not tags: return - if isinstance(tags[0], CVInterfaceTag): - tag_type = "interface" - else: - tag_type = "device" + tag_type = "interface" if isinstance(tags[0], CVInterfaceTag) else "device" - # Build Todo with CVDevice/CVInterfaceTag objects that exist on CloudVision. Add the rest to skipped. + # Build TODO: with CVDevice/CVInterfaceTag objects that exist on CloudVision. Add the rest to skipped. skipped_tags.extend(tag for tag in tags if tag.device is not None and not tag.device._exists_on_cv) todo_tags = [tag for tag in tags if tag.device is None or tag.device._exists_on_cv] @@ -67,7 +67,7 @@ async def deploy_tags_to_cv( if tags_to_add: await cv_client.set_tags(workspace_id=workspace.id, tags=[(tag.label, tag.value) for tag in tags_to_add], element_type=tag_type) - # Remove entries with no assignment from todo and add to deployed. + # Remove entries with no assignment from TODO: and add to deployed. deployed_tags.extend(tag for tag in todo_tags if tag.device is None) todo_tags = [tag for tag in todo_tags if tag.device is not None] @@ -86,7 +86,7 @@ async def deploy_tags_to_cv( ] LOGGER.info("deploy_tags_to_cv: Got %s tag assignments", len(existing_assignments)) - # Move all existing assignments from todo to deployed. + # Move all existing assignments from TODO: to deployed. deployed_tags.extend(tag for tag in todo_tags if (tag.label, tag.value, tag.device.serial_number, getattr(tag, "interface", None)) in existing_assignments) todo_tags = [tag for tag in todo_tags if (tag.label, tag.value, tag.device.serial_number, getattr(tag, "interface", None)) not in existing_assignments] @@ -98,7 +98,7 @@ async def deploy_tags_to_cv( element_type=tag_type, ) - # Move all todo to deployed. + # Move all TODO: to deployed. deployed_tags.extend(todo_tags) # Now we start removing assignments depending on strict_tags or not. diff --git a/python-avd/pyavd/_cv/workflows/deploy_to_cv.py b/python-avd/pyavd/_cv/workflows/deploy_to_cv.py index bb2c7d9f904..9a972bb5d06 100644 --- a/python-avd/pyavd/_cv/workflows/deploy_to_cv.py +++ b/python-avd/pyavd/_cv/workflows/deploy_to_cv.py @@ -5,8 +5,9 @@ from logging import getLogger -from ..client import CVClient -from ..client.exceptions import CVClientException +from pyavd._cv.client import CVClient +from pyavd._cv.client.exceptions import CVClientException + from .create_workspace_on_cv import create_workspace_on_cv from .deploy_configs_to_cv import deploy_configs_to_cv from .deploy_cv_pathfinder_metadata_to_cv import deploy_cv_pathfinder_metadata_to_cv @@ -42,7 +43,7 @@ async def deploy_to_cv( cv_pathfinder_metadata: list[CVPathfinderMetadata] | None = None, skip_missing_devices: bool = False, strict_tags: bool = True, - timeouts: CVTimeOuts | None = None, # pylint: disable=unused-argument + timeouts: CVTimeOuts | None = None, # pylint: disable=unused-argument # noqa: ARG001 ) -> DeployToCvResult: """ Deploy various objects to CloudVision. diff --git a/python-avd/pyavd/_cv/workflows/finalize_change_control_on_cv.py b/python-avd/pyavd/_cv/workflows/finalize_change_control_on_cv.py index f04a650391e..b6510c6be4a 100644 --- a/python-avd/pyavd/_cv/workflows/finalize_change_control_on_cv.py +++ b/python-avd/pyavd/_cv/workflows/finalize_change_control_on_cv.py @@ -4,11 +4,15 @@ from __future__ import annotations from logging import getLogger +from typing import TYPE_CHECKING -from ..api.arista.changecontrol.v1 import ChangeControl, ChangeControlStatus -from ..client import CVClient -from ..client.exceptions import CVChangeControlFailed -from .models import CVChangeControl +from pyavd._cv.api.arista.changecontrol.v1 import ChangeControl, ChangeControlStatus +from pyavd._cv.client.exceptions import CVChangeControlFailed + +if TYPE_CHECKING: + from pyavd._cv.client import CVClient + + from .models import CVChangeControl LOGGER = getLogger(__name__) @@ -35,10 +39,10 @@ def get_change_control_state(cv_change_control: ChangeControl) -> str: async def finalize_change_control_on_cv(change_control: CVChangeControl, cv_client: CVClient) -> None: """ Update and finalize a Change Control on CloudVision from the given result.CVChangeControl object. + Depending on the requested state the Change Control will be left in pending approval, approved, started, completed or canceled. In-place update the CVChangeControl object. """ - LOGGER.info("finalize_change_control_on_cv: %s", change_control) cv_change_control = await cv_client.get_change_control(change_control_id=change_control.id) @@ -67,9 +71,11 @@ async def finalize_change_control_on_cv(change_control: CVChangeControl, cv_clie # TODO: Add cancel/delete # For all other requested states we first need to approve. - if not change_control.state == "approved": + if change_control.state != "approved": await cv_client.approve_change_control( - change_control_id=change_control.id, timestamp=cv_change_control.change.time, description="Automatic approval by AVD" + change_control_id=change_control.id, + timestamp=cv_change_control.change.time, + description="Automatic approval by AVD", ) change_control.state = "approved" LOGGER.info("finalize_change_control_on_cv: %s", change_control) @@ -90,7 +96,8 @@ async def finalize_change_control_on_cv(change_control: CVChangeControl, cv_clie if cv_change_control.error is not None: change_control.state = "failed" LOGGER.info("finalize_change_control_on_cv: %s", change_control) - raise CVChangeControlFailed(f"Change control failed during execution {change_control.id}: {cv_change_control.error}") + msg = f"Change control failed during execution {change_control.id}: {cv_change_control.error}" + raise CVChangeControlFailed(msg) change_control.state = "completed" LOGGER.info("finalize_change_control_on_cv: %s", change_control) diff --git a/python-avd/pyavd/_cv/workflows/finalize_workspace_on_cv.py b/python-avd/pyavd/_cv/workflows/finalize_workspace_on_cv.py index d9ed0a3fa86..a9d92623367 100644 --- a/python-avd/pyavd/_cv/workflows/finalize_workspace_on_cv.py +++ b/python-avd/pyavd/_cv/workflows/finalize_workspace_on_cv.py @@ -4,11 +4,15 @@ from __future__ import annotations from logging import getLogger +from typing import TYPE_CHECKING -from ..api.arista.workspace.v1 import ResponseStatus, WorkspaceState -from ..client import CVClient -from ..client.exceptions import CVWorkspaceBuildFailed, CVWorkspaceSubmitFailed -from .models import CVWorkspace +from pyavd._cv.api.arista.workspace.v1 import ResponseStatus, WorkspaceState +from pyavd._cv.client.exceptions import CVWorkspaceBuildFailed, CVWorkspaceSubmitFailed + +if TYPE_CHECKING: + from pyavd._cv.client import CVClient + + from .models import CVWorkspace LOGGER = getLogger(__name__) @@ -25,10 +29,10 @@ async def finalize_workspace_on_cv(workspace: CVWorkspace, cv_client: CVClient) -> None: """ Finalize a Workspace from the given result.CVWorkspace object. + Depending on the requested state the Workspace will be left in pending, built, submitted, abandoned or deleted. In-place update the workspace state and creates/updates a ChangeControl object on the result object if applicable. """ - LOGGER.info("finalize_workspace_on_cv: %s", workspace) if workspace.requested_state in (workspace.state, "pending"): @@ -39,10 +43,11 @@ async def finalize_workspace_on_cv(workspace: CVWorkspace, cv_client: CVClient) if build_result.status != ResponseStatus.SUCCESS: workspace.state = "build failed" LOGGER.info("finalize_workspace_on_cv: %s", workspace) - raise CVWorkspaceBuildFailed( + msg = ( f"Failed to build workspace {workspace.id}: {build_result}. " f"See details: https://{cv_client._servers[0]}/cv/provisioning/workspaces?ws={workspace.id}" ) + raise CVWorkspaceBuildFailed(msg) workspace.state = "built" LOGGER.info("finalize_workspace_on_cv: %s", workspace) @@ -53,12 +58,14 @@ async def finalize_workspace_on_cv(workspace: CVWorkspace, cv_client: CVClient) if workspace.requested_state == "submitted" and workspace.state == "built": workspace_config = await cv_client.submit_workspace(workspace_id=workspace.id, force=workspace.force) submit_result, cv_workspace = await cv_client.wait_for_workspace_response( - workspace_id=workspace.id, request_id=workspace_config.request_params.request_id + workspace_id=workspace.id, + request_id=workspace_config.request_params.request_id, ) if submit_result.status != ResponseStatus.SUCCESS: workspace.state = "submit failed" LOGGER.info("finalize_workspace_on_cv: %s", workspace) - raise CVWorkspaceSubmitFailed(f"Failed to submit workspace {workspace.id}: {submit_result}") + msg = f"Failed to submit workspace {workspace.id}: {submit_result}" + raise CVWorkspaceSubmitFailed(msg) workspace.state = "submitted" if cv_workspace.cc_ids.values: diff --git a/python-avd/pyavd/_cv/workflows/models.py b/python-avd/pyavd/_cv/workflows/models.py index 1d4d9558aa6..5e4057de6e1 100644 --- a/python-avd/pyavd/_cv/workflows/models.py +++ b/python-avd/pyavd/_cv/workflows/models.py @@ -115,13 +115,11 @@ class DeployToCvResult: warnings: list = field(default_factory=list) workspace: CVWorkspace | None = field(default_factory=CVWorkspace) change_control: CVChangeControl | None = None - # deployed_devices: list[CVDevice] = field(default_factory=list) deployed_configs: list[CVEosConfig] = field(default_factory=list) deployed_device_tags: list[CVDeviceTag] = field(default_factory=list) deployed_interface_tags: list[CVInterfaceTag] = field(default_factory=list) deployed_studio_inputs: list[CVStudioInputs] = field(default_factory=list) deployed_cv_pathfinder_metadata: list[CVPathfinderMetadata] = field(default_factory=list) - # skipped_devices: list[CVDevice] = field(default_factory=list) skipped_configs: list[CVEosConfig] = field(default_factory=list) skipped_device_tags: list[CVDeviceTag] = field(default_factory=list) skipped_interface_tags: list[CVInterfaceTag] = field(default_factory=list) @@ -156,7 +154,7 @@ class CVEosConfig: @dataclass class CVTimeOuts: - """Timeouts in seconds""" + """Timeouts in seconds.""" workspace_build_timeout: float = 300.0 change_control_creation_timeout: float = 300.0 diff --git a/python-avd/pyavd/_cv/workflows/verify_devices_on_cv.py b/python-avd/pyavd/_cv/workflows/verify_devices_on_cv.py index b9806e29206..4131d642384 100644 --- a/python-avd/pyavd/_cv/workflows/verify_devices_on_cv.py +++ b/python-avd/pyavd/_cv/workflows/verify_devices_on_cv.py @@ -4,35 +4,46 @@ from __future__ import annotations from logging import getLogger +from typing import TYPE_CHECKING + +from pyavd._cv.client.exceptions import CVResourceNotFound -from ..client import CVClient -from ..client.exceptions import CVResourceNotFound from .models import CVDevice +if TYPE_CHECKING: + from pyavd._cv.client import CVClient + LOGGER = getLogger(__name__) -async def verify_devices_on_cv(devices: list[CVDevice], workspace_id: str, skip_missing_devices: bool, warnings: list[Exception], cv_client: CVClient) -> None: - """ - Verify that the given Devices are already present in the CloudVision Inventory & I&T Studio. - """ +async def verify_devices_on_cv( + *, devices: list[CVDevice], workspace_id: str, skip_missing_devices: bool, warnings: list[Exception], cv_client: CVClient +) -> None: + """Verify that the given Devices are already present in the CloudVision Inventory & I&T Studio.""" LOGGER.info("verify_devices_on_cv: %s", len(devices)) # Return if we have nothing to do. if not devices: return - existing_devices = await verify_devices_in_cloudvision_inventory(devices, skip_missing_devices, warnings, cv_client) + existing_devices = await verify_devices_in_cloudvision_inventory( + devices=devices, skip_missing_devices=skip_missing_devices, warnings=warnings, cv_client=cv_client + ) await verify_devices_in_topology_studio(existing_devices, workspace_id, cv_client) return async def verify_devices_in_cloudvision_inventory( - devices: list[CVDevice], skip_missing_devices: bool, warnings: list[Exception], cv_client: CVClient + *, + devices: list[CVDevice], + skip_missing_devices: bool, + warnings: list[Exception], + cv_client: CVClient, ) -> list[CVDevice]: """ - Verify that the given Devices are already present in the CloudVision Inventory - and in-place update the given objects with missing information like + Verify that the given Devices are already present in the CloudVision Inventory. + + Then in-place update the given objects with missing information like system MAC address and serial number. Hostname is always set for a device, but to support initial rollout, the hostname will not @@ -42,13 +53,12 @@ async def verify_devices_in_cloudvision_inventory( Returns a list of CVDevice objects found to exist on CloudVision. """ - # Using set to only include a device once. - device_tuples = set( + device_tuples = { (device.serial_number, device.system_mac_address, device.hostname if not any([device.serial_number, device.system_mac_address]) else None) for device in devices if device._exists_on_cv is None - ) + } LOGGER.info("verify_devices_in_cloudvision_inventory: %s unique devices.", len(device_tuples)) found_devices = await cv_client.get_inventory_devices(devices=device_tuples) @@ -94,7 +104,7 @@ async def verify_devices_in_cloudvision_inventory( # If a device is not found, we will set _exist_on_cv back to False. existing_devices = [device for device in devices if device._exists_on_cv] # Using set to only include a device once. - existing_device_tuples = set((device.serial_number, device.system_mac_address, device.hostname) for device in existing_devices) + existing_device_tuples = {(device.serial_number, device.system_mac_address, device.hostname) for device in existing_devices} LOGGER.info( "verify_devices_in_cloudvision_inventory: %s existing device objects for %s unique devices in inventory", @@ -103,7 +113,9 @@ async def verify_devices_in_cloudvision_inventory( ) if missing_devices := [device for device in devices if not device._exists_on_cv]: - warnings.append(missing_devices_handler(missing_devices, skip_missing_devices, "CloudVision Device Inventory")) + warnings.append( + missing_devices_handler(missing_devices=missing_devices, skip_missing_devices=skip_missing_devices, context="CloudVision Device Inventory") + ) return existing_devices @@ -117,8 +129,7 @@ async def verify_devices_in_topology_studio(existing_devices: list[CVDevice], wo Existing devices are updated with hostname and system mac address. Missing devices are added with device id, hostname, system mac address. """ - - existing_device_tuples = set((device.serial_number, device.hostname, device.system_mac_address) for device in existing_devices) + existing_device_tuples = {(device.serial_number, device.hostname, device.system_mac_address) for device in existing_devices} cv_topology_inputs = await cv_client.get_topology_studio_inputs( workspace_id=workspace_id, @@ -133,9 +144,7 @@ async def verify_devices_in_topology_studio(existing_devices: list[CVDevice], wo update_topology_inputs = [] for serial_number, hostname, system_mac_address in existing_device_tuples: - if serial_number not in topology_inputs_dict_by_serial: - update_topology_inputs.append((serial_number, hostname, system_mac_address)) - elif ( + if serial_number not in topology_inputs_dict_by_serial or ( hostname != topology_inputs_dict_by_serial[serial_number]["hostname"] or system_mac_address != topology_inputs_dict_by_serial[serial_number]["mac_address"] ): @@ -146,14 +155,15 @@ async def verify_devices_in_topology_studio(existing_devices: list[CVDevice], wo await cv_client.set_topology_studio_inputs(workspace_id=workspace_id, device_inputs=update_topology_inputs) -def missing_devices_handler(missing_devices: list[CVDevice], skip_missing_devices: bool, context: str) -> Exception: +def missing_devices_handler(*, missing_devices: list[CVDevice], skip_missing_devices: bool, context: str) -> Exception: """ - Handle missing devices: + Handle missing devices. + - Raises if skip_missing_devices is False. - Return Exception if skip_missing_devices is True. """ # Using set to only include a device once. - missing_device_tuples = set((device.serial_number, device.system_mac_address, device.hostname) for device in missing_devices) + missing_device_tuples = {(device.serial_number, device.system_mac_address, device.hostname) for device in missing_devices} # Notice these are new objects only used for the exception. unique_missing_devices = [CVDevice(hostname, serial_number, system_mac_address) for serial_number, system_mac_address, hostname in missing_device_tuples] LOGGER.warning( diff --git a/python-avd/pyavd/_eos_designs/avdfacts.py b/python-avd/pyavd/_eos_designs/avdfacts.py index d6b82b11271..aedbf0fc6e0 100644 --- a/python-avd/pyavd/_eos_designs/avdfacts.py +++ b/python-avd/pyavd/_eos_designs/avdfacts.py @@ -4,21 +4,22 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from .shared_utils import SharedUtils class AvdFacts: - def __init__(self, hostvars: dict, shared_utils: SharedUtils): + def __init__(self, hostvars: dict, shared_utils: SharedUtils) -> None: self._hostvars = hostvars self.shared_utils = shared_utils @classmethod - def __keys(cls): # pylint: disable=bad-option-value, unused-private-member # CH Sep-22: Some pylint bug. + def __keys(cls) -> list[str]: # pylint: disable=bad-option-value, unused-private-member # CH Sep-22: Some pylint bug. """ Get all class attributes including those of base Classes and Mixins. + Using MRO, which is the same way Python resolves attributes. """ keys = [] @@ -29,34 +30,27 @@ def __keys(cls): # pylint: disable=bad-option-value, unused-private-member # CH return keys @classmethod - def keys(cls): + def keys(cls) -> list[str]: """ - Return the list of "keys" + Return the list of "keys". Actually the returned list are the names of attributes not starting with "_" and using cached_property class. The "_" check is added to allow support for "internal" cached_properties storing temporary values. """ - return [key for key in cls.__keys() if not key.startswith("_") and isinstance(getattr(cls, key), cached_property)] @classmethod - def internal_keys(cls): - """ - Return a list containing the names of attributes starting with "_" and using cached_property class. - """ - + def internal_keys(cls) -> list[str]: + """Return a list containing the names of attributes starting with "_" and using cached_property class.""" return [key for key in cls.__keys() if key.startswith("_") and isinstance(getattr(cls, key), cached_property)] - def get(self, key, default_value=None): - """ - Emulate the builtin dict .get method - """ - + def get(self, key: str, default_value: Any = None) -> Any: + """Emulate the builtin dict .get method.""" if key in self.keys(): return getattr(self, key) return default_value - def render(self): + def render(self) -> dict: """ Return a dictionary of all @cached_property values. @@ -66,6 +60,6 @@ def render(self): """ return {key: getattr(self, key) for key in self.keys() if getattr(self, key) is not None} - def clear_cache(self): + def clear_cache(self) -> None: for key in self.keys() + self.internal_keys(): self.__dict__.pop(key, None) diff --git a/python-avd/pyavd/_eos_designs/eos_designs_facts/__init__.py b/python-avd/pyavd/_eos_designs/eos_designs_facts/__init__.py index 0f81294ca8b..86621adb72f 100644 --- a/python-avd/pyavd/_eos_designs/eos_designs_facts/__init__.py +++ b/python-avd/pyavd/_eos_designs/eos_designs_facts/__init__.py @@ -5,9 +5,10 @@ from functools import cached_property -from ..._errors import AristaAvdError -from ..._utils import get -from ..avdfacts import AvdFacts +from pyavd._eos_designs.avdfacts import AvdFacts +from pyavd._errors import AristaAvdError +from pyavd._utils import get + from .mlag import MlagMixin from .overlay import OverlayMixin from .short_esi import ShortEsiMixin @@ -26,15 +27,13 @@ class EosDesignsFacts(AvdFacts, MlagMixin, ShortEsiMixin, OverlayMixin, WanMixin @cached_property def id(self) -> int | None: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" return self.shared_utils.id @cached_property def type(self) -> str: """ - Exposed in avd_switch_facts + Exposed in avd_switch_facts. switch.type fact set based on type variable """ @@ -42,50 +41,38 @@ def type(self) -> str: @cached_property def platform(self) -> str | None: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" return self.shared_utils.platform @cached_property def is_deployed(self) -> bool: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" return self.shared_utils.is_deployed @cached_property def serial_number(self) -> str | None: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" return self.shared_utils.serial_number @cached_property def mgmt_interface(self) -> str | None: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" return self.shared_utils.mgmt_interface @cached_property def mgmt_ip(self) -> str | None: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" return self.shared_utils.mgmt_ip @cached_property def mpls_lsr(self) -> bool: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" return self.shared_utils.mpls_lsr @cached_property def evpn_multicast(self) -> bool | None: """ - Exposed in avd_switch_facts + Exposed in avd_switch_facts. This method _must_ be in EosDesignsFacts and not in SharedUtils, since it reads the SharedUtils instance on the peer. This is only possible when running from EosDesignsFacts, since this is the only time where we can access the actual @@ -95,133 +82,106 @@ def evpn_multicast(self) -> bool | None: return None if get(self._hostvars, "evpn_multicast") is True and self.shared_utils.vtep is True: if not (self.shared_utils.underlay_multicast is True and self.shared_utils.igmp_snooping_enabled is not False): - raise AristaAvdError( - "'evpn_multicast: True' is only supported in combination with 'underlay_multicast: True' and 'igmp_snooping_enabled : True'" - ) + msg = "'evpn_multicast: True' is only supported in combination with 'underlay_multicast: True' and 'igmp_snooping_enabled : True'" + raise AristaAvdError(msg) if self.shared_utils.mlag is True: peer_eos_designs_facts: EosDesignsFacts = self.shared_utils.mlag_peer_facts if self.shared_utils.overlay_rd_type_admin_subfield == peer_eos_designs_facts.shared_utils.overlay_rd_type_admin_subfield: - raise AristaAvdError( - "For MLAG devices Route Distinguisher must be unique when 'evpn_multicast: True' since it will create a multi-vtep configuration." - ) + msg = "For MLAG devices Route Distinguisher must be unique when 'evpn_multicast: True' since it will create a multi-vtep configuration." + raise AristaAvdError(msg) return True return None @cached_property def loopback_ipv4_pool(self) -> str | None: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" if self.shared_utils.underlay_router is True: return self.shared_utils.loopback_ipv4_pool return None @cached_property def uplink_ipv4_pool(self) -> str | None: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" if self.shared_utils.underlay_router: return self.shared_utils.uplink_ipv4_pool return None @cached_property def downlink_pools(self) -> dict | None: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" if self.shared_utils.underlay_router: return self.shared_utils.downlink_pools return None @cached_property def bgp_as(self) -> str | None: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" if self.shared_utils.underlay_router is True: return self.shared_utils.bgp_as + return None @cached_property def underlay_routing_protocol(self) -> str: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" return self.shared_utils.underlay_routing_protocol @cached_property def vtep_loopback_ipv4_pool(self) -> str | None: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" if self.shared_utils.vtep is True: return self.shared_utils.vtep_loopback_ipv4_pool return None @cached_property def inband_mgmt_subnet(self) -> str | None: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" if self.shared_utils.configure_parent_for_inband_mgmt: return self.shared_utils.inband_mgmt_subnet + return None @cached_property def inband_mgmt_ipv6_subnet(self) -> str | None: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" if self.shared_utils.configure_parent_for_inband_mgmt_ipv6: return self.shared_utils.inband_mgmt_ipv6_subnet + return None @cached_property def inband_mgmt_vlan(self) -> int | None: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" if self.shared_utils.configure_parent_for_inband_mgmt or self.shared_utils.configure_parent_for_inband_mgmt_ipv6: return self.shared_utils.inband_mgmt_vlan + return None @cached_property def inband_ztp(self) -> bool | None: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" return self.shared_utils.inband_ztp @cached_property def inband_ztp_vlan(self) -> int | None: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" if self.shared_utils.inband_ztp: return self.shared_utils.inband_mgmt_vlan + return None @cached_property def inband_ztp_lacp_fallback_delay(self) -> int | None: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" return self.shared_utils.inband_ztp_lacp_fallback_delay @cached_property def dc_name(self) -> str | None: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" return self.shared_utils.dc_name @cached_property def group(self) -> str | None: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" return self.shared_utils.group @cached_property def router_id(self) -> str | None: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" return self.shared_utils.router_id diff --git a/python-avd/pyavd/_eos_designs/eos_designs_facts/mlag.py b/python-avd/pyavd/_eos_designs/eos_designs_facts/mlag.py index c6abf1af7aa..93541b9aac3 100644 --- a/python-avd/pyavd/_eos_designs/eos_designs_facts/mlag.py +++ b/python-avd/pyavd/_eos_designs/eos_designs_facts/mlag.py @@ -13,42 +13,35 @@ class MlagMixin: """ Mixin Class used to generate some of the EosDesignsFacts. + Class should only be used as Mixin to the EosDesignsFacts class Using type-hint on self to get proper type-hints on attributes across all Mixins. """ @cached_property def mlag_peer(self: EosDesignsFacts) -> str | None: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" if self.shared_utils.mlag: return self.shared_utils.mlag_peer return None @cached_property def mlag_port_channel_id(self: EosDesignsFacts) -> int | None: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" if self.shared_utils.mlag: return self.shared_utils.mlag_port_channel_id return None @cached_property def mlag_interfaces(self: EosDesignsFacts) -> list | None: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" if self.shared_utils.mlag: return self.shared_utils.mlag_interfaces return None @cached_property def mlag_ip(self: EosDesignsFacts) -> str | None: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" if self.shared_utils.mlag: return self.shared_utils.mlag_ip return None @@ -56,7 +49,7 @@ def mlag_ip(self: EosDesignsFacts) -> str | None: @cached_property def mlag_l3_ip(self: EosDesignsFacts) -> str | None: """ - Exposed in avd_switch_facts + Exposed in avd_switch_facts. Only if L3 and not running rfc5549 for both underlay and overlay """ @@ -71,7 +64,7 @@ def mlag_l3_ip(self: EosDesignsFacts) -> str | None: @cached_property def mlag_switch_ids(self: EosDesignsFacts) -> dict | None: """ - Exposed in avd_switch_facts + Exposed in avd_switch_facts. Returns the switch id's of both primary and secondary switches for a given node group {"primary": int, "secondary": int} diff --git a/python-avd/pyavd/_eos_designs/eos_designs_facts/overlay.py b/python-avd/pyavd/_eos_designs/eos_designs_facts/overlay.py index 4e7df3cd619..e3f25e0f698 100644 --- a/python-avd/pyavd/_eos_designs/eos_designs_facts/overlay.py +++ b/python-avd/pyavd/_eos_designs/eos_designs_facts/overlay.py @@ -6,7 +6,7 @@ from functools import cached_property from typing import TYPE_CHECKING -from ..._utils import get +from pyavd._utils import get if TYPE_CHECKING: from . import EosDesignsFacts @@ -15,28 +15,25 @@ class OverlayMixin: """ Mixin Class used to generate some of the EosDesignsFacts. - Class should only be used as Mixin to the EosDesignsFacts class + + Class should only be used as Mixin to the EosDesignsFacts class. Using type-hint on self to get proper type-hints on attributes across all Mixins. """ @cached_property def evpn_role(self: EosDesignsFacts) -> str | None: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" return self.shared_utils.evpn_role @cached_property def mpls_overlay_role(self: EosDesignsFacts) -> str | None: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" return self.shared_utils.mpls_overlay_role @cached_property def evpn_route_servers(self: EosDesignsFacts) -> list: """ - Exposed in avd_switch_facts + Exposed in avd_switch_facts. For evpn clients the default value for EVPN Route Servers is the content of the uplink_switches variable set elsewhere. For all other evpn roles there is no default. @@ -49,19 +46,16 @@ def evpn_route_servers(self: EosDesignsFacts) -> list: @cached_property def mpls_route_reflectors(self: EosDesignsFacts) -> list | None: - """ - Exposed in avd_switch_facts - """ - if self.shared_utils.underlay_router is True: - if self.mpls_overlay_role in ["client", "server"] or (self.evpn_role in ["client", "server"] and self.overlay["evpn_mpls"]): - return get(self.shared_utils.switch_data_combined, "mpls_route_reflectors") + """Exposed in avd_switch_facts.""" + if self.shared_utils.underlay_router is True and ( + self.mpls_overlay_role in ["client", "server"] or (self.evpn_role in ["client", "server"] and self.overlay["evpn_mpls"]) + ): + return get(self.shared_utils.switch_data_combined, "mpls_route_reflectors") return None @cached_property def overlay(self: EosDesignsFacts) -> dict | None: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" if self.shared_utils.underlay_router is True: return { "peering_address": self.shared_utils.overlay_peering_address, @@ -71,9 +65,7 @@ def overlay(self: EosDesignsFacts) -> dict | None: @cached_property def vtep_ip(self: EosDesignsFacts) -> str | None: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" if self.shared_utils.vtep: return self.shared_utils.vtep_ip return None diff --git a/python-avd/pyavd/_eos_designs/eos_designs_facts/short_esi.py b/python-avd/pyavd/_eos_designs/eos_designs_facts/short_esi.py index 85826fcab74..1ce8b8f2e0a 100644 --- a/python-avd/pyavd/_eos_designs/eos_designs_facts/short_esi.py +++ b/python-avd/pyavd/_eos_designs/eos_designs_facts/short_esi.py @@ -8,7 +8,7 @@ from hashlib import sha256 from typing import TYPE_CHECKING -from ..._utils import default, get +from pyavd._utils import default, get if TYPE_CHECKING: from . import EosDesignsFacts @@ -17,29 +17,28 @@ class ShortEsiMixin: """ Mixin Class used to generate some of the EosDesignsFacts. - Class should only be used as Mixin to the EosDesignsFacts class + + Class should only be used as Mixin to the EosDesignsFacts class. Using type-hint on self to get proper type-hints on attributes across all Mixins. """ @cached_property def _short_esi(self: EosDesignsFacts) -> str: """ - If short_esi is set to "auto" we will use sha256 to create a - unique short_esi value based on various uplink information. + If short_esi is set to "auto" we will use sha256 to create a unique short_esi value based on various uplink information. Note: Secondary MLAG switch should have the same short-esi value as primary MLAG switch. """ - if self.shared_utils.mlag_role == "secondary": - # On the MLAG Secondary use short-esi from MLAG primary - if (peer_short_esi := self.shared_utils.mlag_peer_facts._short_esi) is not None: - return peer_short_esi + # On the MLAG Secondary use short-esi from MLAG primary + if self.shared_utils.mlag_role == "secondary" and (peer_short_esi := self.shared_utils.mlag_peer_facts._short_esi) is not None: + return peer_short_esi short_esi = get(self.shared_utils.switch_data_combined, "short_esi") if short_esi == "auto": esi_seed_1 = "".join(self.shared_utils.uplink_switches[:2]) esi_seed_2 = "".join(self.shared_utils.uplink_switch_interfaces[:2]) esi_seed_3 = "".join(default(self.shared_utils.uplink_interfaces, [])[:2]) esi_seed_4 = default(self.shared_utils.group, "") - esi_hash = sha256(f"{esi_seed_1}{esi_seed_2}{esi_seed_3}{esi_seed_4}".encode("UTF-8")).hexdigest() + esi_hash = sha256(f"{esi_seed_1}{esi_seed_2}{esi_seed_3}{esi_seed_4}".encode()).hexdigest() short_esi = re.sub(r"([0-9a-f]{4})", r"\1:", esi_hash)[:14] return short_esi diff --git a/python-avd/pyavd/_eos_designs/eos_designs_facts/uplinks.py b/python-avd/pyavd/_eos_designs/eos_designs_facts/uplinks.py index 6bf0d7d671c..c18ed3abcd2 100644 --- a/python-avd/pyavd/_eos_designs/eos_designs_facts/uplinks.py +++ b/python-avd/pyavd/_eos_designs/eos_designs_facts/uplinks.py @@ -7,9 +7,9 @@ from functools import cached_property from typing import TYPE_CHECKING -from ..._errors import AristaAvdError -from ..._utils import append_if_not_duplicate, get, unique -from ...j2filters import list_compress, range_expand +from pyavd._errors import AristaAvdError +from pyavd._utils import append_if_not_duplicate, get, unique +from pyavd.j2filters import list_compress, range_expand if TYPE_CHECKING: from . import EosDesignsFacts @@ -18,28 +18,26 @@ class UplinksMixin: """ Mixin Class used to generate some of the EosDesignsFacts. - Class should only be used as Mixin to the EosDesignsFacts class + + Class should only be used as Mixin to the EosDesignsFacts class. Using type-hint on self to get proper type-hints on attributes across all Mixins. """ @cached_property def max_parallel_uplinks(self: EosDesignsFacts) -> int: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" return self.shared_utils.max_parallel_uplinks @cached_property def max_uplink_switches(self: EosDesignsFacts) -> int: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" return self.shared_utils.max_uplink_switches @cached_property def _uplink_port_channel_id(self: EosDesignsFacts) -> int: """ For MLAG secondary get the uplink_port_channel_id from the peer's facts. + We don't need to validate it (1-2000), since it will be validated on the peer. For MLAG primary or none MLAG, take the value of 'uplink_port_channel_id' if set, @@ -53,13 +51,13 @@ def _uplink_port_channel_id(self: EosDesignsFacts) -> int: # MLAG Secondary peer_uplink_port_channel_id = self.shared_utils.mlag_peer_facts._uplink_port_channel_id # check that port-channel IDs are the same as on primary - if uplink_port_channel_id is not None: - if uplink_port_channel_id != peer_uplink_port_channel_id: - raise AristaAvdError( - f"'uplink_port_channel_id' on '{self.shared_utils.hostname}' is set to {uplink_port_channel_id} and is not matching" - f" {peer_uplink_port_channel_id} set on MLAG peer." - " The 'uplink_port_channel_id' must be matching on MLAG peers." - ) + if uplink_port_channel_id is not None and uplink_port_channel_id != peer_uplink_port_channel_id: + msg = ( + f"'uplink_port_channel_id' on '{self.shared_utils.hostname}' is set to {uplink_port_channel_id} and is not matching" + f" {peer_uplink_port_channel_id} set on MLAG peer." + " The 'uplink_port_channel_id' must be matching on MLAG peers." + ) + raise AristaAvdError(msg) return peer_uplink_port_channel_id # MLAG Primary or not MLAG. @@ -68,11 +66,9 @@ def _uplink_port_channel_id(self: EosDesignsFacts) -> int: uplink_port_channel_id = int("".join(re.findall(r"\d", self.shared_utils.uplink_interfaces[0]))) # produce an error if the switch is MLAG and port-channel ID is above 2000 - if self.shared_utils.mlag: - if not 1 <= uplink_port_channel_id <= 2000: - raise AristaAvdError( - f"'uplink_port_channel_id' must be between 1 and 2000 for MLAG switches. Got '{uplink_port_channel_id}' on '{self.shared_utils.hostname}'." - ) + if self.shared_utils.mlag and not 1 <= uplink_port_channel_id <= 2000: + msg = f"'uplink_port_channel_id' must be between 1 and 2000 for MLAG switches. Got '{uplink_port_channel_id}' on '{self.shared_utils.hostname}'." + raise AristaAvdError(msg) return uplink_port_channel_id @@ -80,6 +76,7 @@ def _uplink_port_channel_id(self: EosDesignsFacts) -> int: def _uplink_switch_port_channel_id(self: EosDesignsFacts) -> int: """ For MLAG secondary get the uplink_switch_port_channel_id from the peer's facts. + We don't need to validate it (1-2000), since it will be validated on the peer. For MLAG primary or none MLAG, take the value of 'uplink_switch_port_channel_id' if set, @@ -94,13 +91,13 @@ def _uplink_switch_port_channel_id(self: EosDesignsFacts) -> int: # MLAG Secondary peer_uplink_switch_port_channel_id = self.shared_utils.mlag_peer_facts._uplink_switch_port_channel_id # check that port-channel IDs are the same as on primary - if uplink_switch_port_channel_id is not None: - if uplink_switch_port_channel_id != peer_uplink_switch_port_channel_id: - raise AristaAvdError( - f"'uplink_switch_port_channel_id'expected_error_message on '{self.shared_utils.hostname}' is set to {uplink_switch_port_channel_id} and" - f" is not matching {peer_uplink_switch_port_channel_id} set on MLAG peer. The 'uplink_switch_port_channel_id' must be matching on MLAG" - " peers." - ) + if uplink_switch_port_channel_id is not None and uplink_switch_port_channel_id != peer_uplink_switch_port_channel_id: + msg = ( + f"'uplink_switch_port_channel_id'expected_error_message on '{self.shared_utils.hostname}' is set to {uplink_switch_port_channel_id} and" + f" is not matching {peer_uplink_switch_port_channel_id} set on MLAG peer. The 'uplink_switch_port_channel_id' must be matching on MLAG" + " peers." + ) + raise AristaAvdError(msg) return peer_uplink_switch_port_channel_id # MLAG Primary or not MLAG. @@ -111,19 +108,19 @@ def _uplink_switch_port_channel_id(self: EosDesignsFacts) -> int: # produce an error if the uplink switch is MLAG and port-channel ID is above 2000 uplink_switch_facts: EosDesignsFacts = self.shared_utils.get_peer_facts(self.shared_utils.uplink_switches[0], required=True) - if uplink_switch_facts.shared_utils.mlag: - if not 1 <= uplink_switch_port_channel_id <= 2000: - raise AristaAvdError( - f"'uplink_switch_port_channel_id' must be between 1 and 2000 for MLAG switches. Got '{uplink_switch_port_channel_id}' on" - f" '{self.shared_utils.hostname}'." - ) + if uplink_switch_facts.shared_utils.mlag and not 1 <= uplink_switch_port_channel_id <= 2000: + msg = ( + f"'uplink_switch_port_channel_id' must be between 1 and 2000 for MLAG switches. Got '{uplink_switch_port_channel_id}' on" + f" '{self.shared_utils.hostname}'." + ) + raise AristaAvdError(msg) return uplink_switch_port_channel_id @cached_property def uplinks(self: EosDesignsFacts) -> list: """ - Exposed in avd_switch_facts + Exposed in avd_switch_facts. List of uplinks with all parameters @@ -136,21 +133,23 @@ def uplinks(self: EosDesignsFacts) -> list: get_uplink = self._get_port_channel_uplink elif self.shared_utils.uplink_type == "p2p-vrfs": if self.shared_utils.network_services_l3 is False or self.shared_utils.underlay_router is False: - raise AristaAvdError("'underlay_router' and 'network_services.l3' must be 'true' for the node_type_key when using 'p2p-vrfs' as 'uplink_type'.") + msg = "'underlay_router' and 'network_services.l3' must be 'true' for the node_type_key when using 'p2p-vrfs' as 'uplink_type'." + raise AristaAvdError(msg) get_uplink = self._get_p2p_vrfs_uplink elif self.shared_utils.uplink_type == "lan": if self.shared_utils.network_services_l3 is False or self.shared_utils.underlay_router is False: - raise AristaAvdError("'underlay_router' and 'network_services.l3' must be 'true' for the node_type_key when using 'lan' as 'uplink_type'.") + msg = "'underlay_router' and 'network_services.l3' must be 'true' for the node_type_key when using 'lan' as 'uplink_type'." + raise AristaAvdError(msg) if len(self.shared_utils.uplink_interfaces) > 1: - raise AristaAvdError(f"'uplink_type: lan' only supports a single uplink interface. Got {self.shared_utils.uplink_interfaces}.") + msg = f"'uplink_type: lan' only supports a single uplink interface. Got {self.shared_utils.uplink_interfaces}." + raise AristaAvdError(msg) # TODO: Adjust error message when we add lan-port-channel support. - # raise AristaAvdError( - # "'uplink_type: lan' only supports a single uplink interface. " - # f"Got {self._uplink_interfaces}. Consider 'uplink_type: lan-port-channel' if applicable." - # ) + # uplink_type: lan' only supports a single uplink interface. + # Got {self._uplink_interfaces}. Consider 'uplink_type: lan-port-channel' if applicable. get_uplink = self._get_l2_uplink else: - raise AristaAvdError(f"Invalid uplink_type '{self.shared_utils.uplink_type}'.") + msg = f"Invalid uplink_type '{self.shared_utils.uplink_type}'." + raise AristaAvdError(msg) uplinks = [] uplink_switches = self.shared_utils.uplink_switches @@ -172,9 +171,7 @@ def uplinks(self: EosDesignsFacts) -> list: return uplinks def _get_p2p_uplink(self: EosDesignsFacts, uplink_index: int, uplink_interface: str, uplink_switch: str, uplink_switch_interface: str) -> dict: - """ - Return a single uplink dictionary for uplink_type p2p - """ + """Return a single uplink dictionary for uplink_type p2p.""" uplink_switch_facts: EosDesignsFacts = self.shared_utils.get_peer_facts(uplink_switch, required=True) uplink = { "interface": uplink_interface, @@ -220,9 +217,7 @@ def _get_p2p_uplink(self: EosDesignsFacts, uplink_index: int, uplink_interface: return uplink def _get_port_channel_uplink(self: EosDesignsFacts, uplink_index: int, uplink_interface: str, uplink_switch: str, uplink_switch_interface: str) -> dict: - """ - Return a single uplink dictionary for uplink_type port-channel - """ + """Return a single uplink dictionary for uplink_type port-channel.""" uplink_switch_facts: EosDesignsFacts = self.shared_utils.get_peer_facts(uplink_switch, required=True) # Reusing get_l2_uplink @@ -251,14 +246,12 @@ def _get_port_channel_uplink(self: EosDesignsFacts, uplink_index: int, uplink_in def _get_l2_uplink( self: EosDesignsFacts, - uplink_index: int, # pylint: disable=unused-argument + uplink_index: int, # pylint: disable=unused-argument # noqa: ARG002 uplink_interface: str, uplink_switch: str, uplink_switch_interface: str, ) -> dict: - """ - Return a single uplink dictionary for an L2 uplink. Reused for both uplink_type port-channel, lan and TODO lan-port-channel. - """ + """Return a single uplink dictionary for an L2 uplink. Reused for both uplink_type port-channel, lan and TODO: lan-port-channel.""" uplink_switch_facts: EosDesignsFacts = self.shared_utils.get_peer_facts(uplink_switch, required=True) uplink = { "interface": uplink_interface, @@ -315,9 +308,7 @@ def _get_l2_uplink( return uplink def _get_p2p_vrfs_uplink(self: EosDesignsFacts, uplink_index: int, uplink_interface: str, uplink_switch: str, uplink_switch_interface: str) -> dict: - """ - Return a single uplink dictionary for uplink_type p2p-vrfs - """ + """Return a single uplink dictionary for uplink_type p2p-vrfs.""" uplink_switch_facts: EosDesignsFacts = self.shared_utils.get_peer_facts(uplink_switch, required=True) # Reusing regular p2p logic for main interface. @@ -351,7 +342,12 @@ def _get_p2p_vrfs_uplink(self: EosDesignsFacts, uplink_index: int, uplink_interf subinterface["structured_config"] = self.shared_utils.uplink_structured_config append_if_not_duplicate( - uplink["subinterfaces"], "vrf", subinterface, context="Uplink subinterfaces", context_keys=["interface", "vrf"], ignore_same_dict=True + uplink["subinterfaces"], + "vrf", + subinterface, + context="Uplink subinterfaces", + context_keys=["interface", "vrf"], + ignore_same_dict=True, ) return uplink @@ -359,7 +355,7 @@ def _get_p2p_vrfs_uplink(self: EosDesignsFacts, uplink_index: int, uplink_interf @cached_property def uplink_peers(self: EosDesignsFacts) -> list: """ - Exposed in avd_switch_facts + Exposed in avd_switch_facts. List of all **unique** uplink peers @@ -373,6 +369,7 @@ def uplink_peers(self: EosDesignsFacts) -> list: def _default_downlink_interfaces(self: EosDesignsFacts) -> list: """ internal _default_downlink_interfaces set based on default_interfaces. - Parsed by downstream switches during eos_designs_facts phase + + Parsed by downstream switches during eos_designs_facts phase. """ return range_expand(get(self.shared_utils.default_interfaces, "downlink_interfaces", default=[])) diff --git a/python-avd/pyavd/_eos_designs/eos_designs_facts/vlans.py b/python-avd/pyavd/_eos_designs/eos_designs_facts/vlans.py index e3e8fdccf1b..b322bd2627b 100644 --- a/python-avd/pyavd/_eos_designs/eos_designs_facts/vlans.py +++ b/python-avd/pyavd/_eos_designs/eos_designs_facts/vlans.py @@ -7,8 +7,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ..._utils import get -from ...j2filters import convert_dicts, list_compress, range_expand +from pyavd._utils import get +from pyavd.j2filters import convert_dicts, list_compress, range_expand if TYPE_CHECKING: from . import EosDesignsFacts @@ -17,14 +17,15 @@ class VlansMixin: """ Mixin Class used to generate some of the EosDesignsFacts. - Class should only be used as Mixin to the EosDesignsFacts class + + Class should only be used as Mixin to the EosDesignsFacts class. Using type-hint on self to get proper type-hints on attributes across all Mixins. """ @cached_property def vlans(self: EosDesignsFacts) -> str: """ - Exposed in avd_switch_facts + Exposed in avd_switch_facts. Return the compressed list of vlans to be defined on this switch @@ -36,9 +37,7 @@ def vlans(self: EosDesignsFacts) -> str: return list_compress(self._vlans) def _parse_adapter_settings(self: EosDesignsFacts, adapter_settings: dict) -> tuple[set, set]: - """ - Parse the given adapter_settings and return relevant vlans and trunk_groups - """ + """Parse the given adapter_settings and return relevant vlans and trunk_groups.""" vlans = set() trunk_groups = set(adapter_settings.get("trunk_groups", [])) if "vlans" in adapter_settings and adapter_settings["vlans"] not in ["all", "", None]: @@ -71,7 +70,7 @@ def _parse_adapter_settings(self: EosDesignsFacts, adapter_settings: dict) -> tu @cached_property def _local_endpoint_vlans_and_trunk_groups(self: EosDesignsFacts) -> tuple[set, set]: """ - Return list of vlans and list of trunk groups used by connected_endpoints on this switch + Return list of vlans and list of trunk groups used by connected_endpoints on this switch. Also includes the inband_management_vlan """ @@ -107,8 +106,8 @@ def _local_endpoint_vlans_and_trunk_groups(self: EosDesignsFacts) -> tuple[set, for switch_regex in network_port_item.get("switches", []): # The match test is built on Python re.match which tests from the beginning of the string #} # Since the user would not expect "DC1-LEAF1" to also match "DC-LEAF11" we will force ^ and $ around the regex - switch_regex = rf"^{switch_regex}$" - if not re.match(switch_regex, self.shared_utils.hostname): + raw_switch_regex = rf"^{switch_regex}$" + if not re.match(raw_switch_regex, self.shared_utils.hostname): # Skip entry if no match continue @@ -128,6 +127,7 @@ def _local_endpoint_vlans_and_trunk_groups(self: EosDesignsFacts) -> tuple[set, def _downstream_switch_endpoint_vlans_and_trunk_groups(self: EosDesignsFacts) -> tuple[set, set]: """ Return set of vlans and set of trunk groups used by downstream switches. + Traverse any downstream L2 switches so ensure we can provide connectivity to any vlans / trunk groups used by them. """ if not self.shared_utils.any_network_services: @@ -148,6 +148,7 @@ def _downstream_switch_endpoint_vlans_and_trunk_groups(self: EosDesignsFacts) -> def _mlag_peer_endpoint_vlans_and_trunk_groups(self: EosDesignsFacts) -> tuple[set, set]: """ Return set of vlans and set of trunk groups used by connected_endpoints on the MLAG peer. + This could differ from local vlans and trunk groups if a connected endpoint is only connected to one leaf. """ if not self.shared_utils.mlag: @@ -160,7 +161,9 @@ def _mlag_peer_endpoint_vlans_and_trunk_groups(self: EosDesignsFacts) -> tuple[s @cached_property def _endpoint_vlans_and_trunk_groups(self: EosDesignsFacts) -> tuple[set, set]: """ - Return set of vlans and set of trunk groups used by connected_endpoints on this switch, + Return set of vlans and set of trunk groups. + + The trunk groups are those used by connected_endpoints on this switch, downstream switches but NOT mlag peer (since we would have circular references then). """ local_endpoint_vlans, local_endpoint_trunk_groups = self._local_endpoint_vlans_and_trunk_groups @@ -171,7 +174,8 @@ def _endpoint_vlans_and_trunk_groups(self: EosDesignsFacts) -> tuple[set, set]: def _endpoint_vlans(self: EosDesignsFacts) -> set[int]: """ Return set of vlans in use by endpoints connected to this switch, downstream switches or MLAG peer. - Ex: {1, 20, 21, 22, 23} or set() + + Ex: {1, 20, 21, 22, 23} or set(). """ if not self.shared_utils.filter_only_vlans_in_use: return set() @@ -188,7 +192,8 @@ def _endpoint_vlans(self: EosDesignsFacts) -> set[int]: def endpoint_vlans(self: EosDesignsFacts) -> str | None: """ Return compressed list of vlans in use by endpoints connected to this switch or MLAG peer. - Ex: "1,20-30" or "" + + Ex: "1,20-30" or "". """ if self.shared_utils.filter_only_vlans_in_use: return list_compress(list(self._endpoint_vlans)) @@ -197,9 +202,7 @@ def endpoint_vlans(self: EosDesignsFacts) -> str | None: @cached_property def _endpoint_trunk_groups(self: EosDesignsFacts) -> set[str]: - """ - Return set of trunk_groups in use by endpoints connected to this switch, downstream switches or MLAG peer. - """ + """Return set of trunk_groups in use by endpoints connected to this switch, downstream switches or MLAG peer.""" if not self.shared_utils.filter_only_vlans_in_use: return set() @@ -214,6 +217,7 @@ def _endpoint_trunk_groups(self: EosDesignsFacts) -> set[str]: def local_endpoint_trunk_groups(self: EosDesignsFacts) -> list[str]: """ Return list of trunk_groups in use by endpoints connected to this switch only. + Used for only applying the trunk groups in config that are relevant on this device This is a subset of endpoint_trunk_groups which is used for filtering. """ @@ -227,6 +231,7 @@ def local_endpoint_trunk_groups(self: EosDesignsFacts) -> list[str]: def endpoint_trunk_groups(self: EosDesignsFacts) -> list[str]: """ Return list of trunk_groups in use by endpoints connected to this switch, downstream switches or MLAG peer. + Used for filtering which vlans we configure on the device. This is a superset of local_endpoint_trunk_groups. """ return list(self._endpoint_trunk_groups) @@ -235,7 +240,8 @@ def endpoint_trunk_groups(self: EosDesignsFacts) -> list[str]: def _vlans(self: EosDesignsFacts) -> list[int]: """ Return list of vlans after filtering network services. - The filter is based on filter.tenants, filter.tags and filter.only_vlans_in_use + + The filter is based on filter.tenants, filter.tags and filter.only_vlans_in_use. Ex. [1, 2, 3 ,4 ,201, 3021] """ diff --git a/python-avd/pyavd/_eos_designs/eos_designs_facts/wan.py b/python-avd/pyavd/_eos_designs/eos_designs_facts/wan.py index 103901bff16..de73d8d3196 100644 --- a/python-avd/pyavd/_eos_designs/eos_designs_facts/wan.py +++ b/python-avd/pyavd/_eos_designs/eos_designs_facts/wan.py @@ -6,7 +6,7 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...j2filters import natural_sort +from pyavd.j2filters import natural_sort if TYPE_CHECKING: from . import EosDesignsFacts @@ -14,7 +14,8 @@ class WanMixin: """ - Mixin Class providing a subset of EosDesignsFacts + Mixin Class providing a subset of EosDesignsFacts. + Class should only be used as Mixin to the EosDesignsFacts class Using type-hint on self to get proper type-hints on attributes across all Mixins. """ @@ -22,11 +23,12 @@ class WanMixin: @cached_property def wan_path_groups(self: EosDesignsFacts) -> list | None: """ - TODO: Also add the path_groups importing any of our connected path groups. - Need to find out if we need to resolve recursive imports. + Return the list of WAN path_groups directly connected to this router. - Return the list of WAN path_groups directly connected to this router, with a list of dictionaries - containing the (interface, ip_address) in the path_group. + Each with a list of dictionaries containing the (interface, ip_address) in the path_group. + + TODO: Also add the path_groups importing any of our connected path groups. + Need to find out if we need to resolve recursive imports. """ if not self.shared_utils.is_wan_server: return None @@ -36,7 +38,7 @@ def wan_path_groups(self: EosDesignsFacts) -> list | None: @cached_property def wan_router_uplink_vrfs(self: EosDesignsFacts) -> list[str] | None: """ - Exposed in avd_switch_facts + Exposed in avd_switch_facts. Return the list of VRF names present on uplink switches. These VRFs will be attracted (configured) on WAN "clients" (edge/transit) unless filtered. diff --git a/python-avd/pyavd/_eos_designs/interface_descriptions/__init__.py b/python-avd/pyavd/_eos_designs/interface_descriptions/__init__.py index 047cc7a178f..ade5a32d245 100644 --- a/python-avd/pyavd/_eos_designs/interface_descriptions/__init__.py +++ b/python-avd/pyavd/_eos_designs/interface_descriptions/__init__.py @@ -4,15 +4,19 @@ from __future__ import annotations from collections import ChainMap +from typing import TYPE_CHECKING, Any + +from pyavd._eos_designs.avdfacts import AvdFacts -from ..avdfacts import AvdFacts -from .models import InterfaceDescriptionData from .utils import UtilsMixin +if TYPE_CHECKING: + from .models import InterfaceDescriptionData + class AvdInterfaceDescriptions(AvdFacts, UtilsMixin): """ - Class used to render Interface Descriptions either from custom Jinja2 templates or using default Python Logic + Class used to render Interface Descriptions either from custom Jinja2 templates or using default Python Logic. Since some templates might contain certain legacy variables (switch_*), those are mapped from the switch.* model @@ -29,12 +33,14 @@ class AvdInterfaceDescriptions(AvdFacts, UtilsMixin): - Breaking changes may happen between major releases or in rare cases for bug fixes. """ - def _template(self, template_path, **kwargs): + def _template(self, template_path: str, **kwargs: Any) -> str: template_vars = ChainMap(kwargs, self._hostvars) return self.shared_utils.template_var(template_path, template_vars) def underlay_ethernet_interface(self, data: InterfaceDescriptionData) -> str: """ + Called for each underlay ethernet interface. + Available data: - link_type - peer @@ -45,7 +51,7 @@ def underlay_ethernet_interface(self, data: InterfaceDescriptionData) -> str: - type - vrf - wan_carrier - - wan_circuit_id + - wan_circuit_id. """ desc = self.underlay_ethernet_interfaces( link_type=data.link_type, @@ -83,6 +89,8 @@ def underlay_ethernet_interfaces(self, link_type: str, link_peer: str, link_peer def underlay_port_channel_interface(self, data: InterfaceDescriptionData) -> str: """ + Called for each underlay port-channel interface. + Available data: - peer - peer_channel_group_id @@ -90,10 +98,12 @@ def underlay_port_channel_interface(self, data: InterfaceDescriptionData) -> str - mpls_overlay_role - mpls_lsr - overlay_routing_protocol - - type + - type. """ return self.underlay_port_channel_interfaces( - link_peer=data.peer, link_peer_channel_group_id=data.peer_channel_group_id, link_channel_description=data.port_channel_description + link_peer=data.peer, + link_peer_channel_group_id=data.peer_channel_group_id, + link_channel_description=data.port_channel_description, ) def underlay_port_channel_interfaces(self, link_peer: str, link_peer_channel_group_id: int, link_channel_description: str | None) -> str: @@ -117,12 +127,14 @@ def underlay_port_channel_interfaces(self, link_peer: str, link_peer_channel_gro def mlag_ethernet_interface(self, data: InterfaceDescriptionData) -> str: """ + Called for each mlag ethernet interface. + Available data: - peer_interface - mpls_overlay_role - mpls_lsr - overlay_routing_protocol - - type + - type. """ return self.mlag_ethernet_interfaces(mlag_interface=data.peer_interface) @@ -135,16 +147,18 @@ def mlag_ethernet_interfaces(self, mlag_interface: str) -> str: def mlag_port_channel_interface( self, - data: InterfaceDescriptionData, # pylint: disable=unused-argument # NOSONAR + data: InterfaceDescriptionData, # pylint: disable=unused-argument # NOSONAR # noqa: ARG002 ) -> str: """ + Called for each mlag port-channel interface. + Available data: - mlag_peer - peer_channel_group_id - mpls_overlay_role - mpls_lsr - overlay_routing_protocol - - type + - type. """ return self.mlag_port_channel_interfaces() @@ -157,6 +171,8 @@ def mlag_port_channel_interfaces(self) -> str: def connected_endpoints_ethernet_interface(self, data: InterfaceDescriptionData) -> str: """ + Called for each connected endpoint ethernet interface. + Available data: - peer - peer_interface @@ -164,17 +180,19 @@ def connected_endpoints_ethernet_interface(self, data: InterfaceDescriptionData) - mpls_overlay_role - mpls_lsr - overlay_routing_protocol - - type + - type. """ return self.connected_endpoints_ethernet_interfaces(peer=data.peer, peer_interface=data.peer_interface, adapter_description=data.description) - def connected_endpoints_ethernet_interfaces(self, peer: str = None, peer_interface: str = None, adapter_description: str = None) -> str: + def connected_endpoints_ethernet_interfaces( + self, peer: str | None = None, peer_interface: str | None = None, adapter_description: str | None = None + ) -> str: """ If a jinja template is configured, use it. + If not, use the adapter.description or default to _ TODO: AVD5.0.0 move this to the new function. """ - if template_path := self.shared_utils.interface_descriptions_templates.get("connected_endpoints_ethernet_interfaces"): return self._template(template_path, peer=peer, peer_interface=peer_interface, adapter_description=adapter_description) @@ -186,6 +204,8 @@ def connected_endpoints_ethernet_interfaces(self, peer: str = None, peer_interfa def connected_endpoints_port_channel_interface(self, data: InterfaceDescriptionData) -> str: """ + Called for each connected endpoint port-channel interface. + Available data: - peer - description @@ -193,21 +213,27 @@ def connected_endpoints_port_channel_interface(self, data: InterfaceDescriptionD - mpls_overlay_role - mpls_lsr - overlay_routing_protocol - - type + - type. """ return self.connected_endpoints_port_channel_interfaces( - peer=data.peer, adapter_description=data.description, adapter_port_channel_description=data.port_channel_description + peer=data.peer, + adapter_description=data.description, + adapter_port_channel_description=data.port_channel_description, ) def connected_endpoints_port_channel_interfaces( - self, peer: str = None, adapter_description: str = None, adapter_port_channel_description: str = None + self, + peer: str | None = None, + adapter_description: str | None = None, + adapter_port_channel_description: str | None = None, ) -> str: - """If a jinja template is configured, use it. + """ + If a jinja template is configured, use it. + If not, return the _ or default to _ TODO: AVD5.0.0 move this to the new function. """ - if template_path := self.shared_utils.interface_descriptions_templates.get("connected_endpoints_port_channel_interfaces"): return self._template( template_path, @@ -221,12 +247,14 @@ def connected_endpoints_port_channel_interfaces( def router_id_loopback_interface(self, data: InterfaceDescriptionData) -> str: """ + Called for device. + Available data: - description - mpls_overlay_role - mpls_lsr - overlay_routing_protocol - - type + - type. """ return self.overlay_loopback_interface(overlay_loopback_description=data.description) diff --git a/python-avd/pyavd/_eos_designs/interface_descriptions/models.py b/python-avd/pyavd/_eos_designs/interface_descriptions/models.py index e714c99daa1..421fd71204a 100644 --- a/python-avd/pyavd/_eos_designs/interface_descriptions/models.py +++ b/python-avd/pyavd/_eos_designs/interface_descriptions/models.py @@ -7,13 +7,12 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from ..shared_utils import SharedUtils + from pyavd._eos_designs.shared_utils import SharedUtils class InterfaceDescriptionData: """ - This class is used as transport of data between AVD code and - instances of AvdInterfaceDescriptions class or subclasses hereof. + This class is used as transport of data between AVD code and instances of AvdInterfaceDescriptions class or subclasses hereof. Attributes starting with _ are internal and may change at any time. @@ -59,7 +58,7 @@ def __init__( vrf: str | None = None, wan_carrier: str | None = None, wan_circuit_id: str | None = None, - ): + ) -> None: self._shared_utils = shared_utils self.description = description self.interface = interface diff --git a/python-avd/pyavd/_eos_designs/interface_descriptions/utils.py b/python-avd/pyavd/_eos_designs/interface_descriptions/utils.py index f3c2c0f82cf..04af015e659 100644 --- a/python-avd/pyavd/_eos_designs/interface_descriptions/utils.py +++ b/python-avd/pyavd/_eos_designs/interface_descriptions/utils.py @@ -13,7 +13,8 @@ class UtilsMixin: """ Mixin Class with internal functions. - Class should only be used as Mixin to an AvdInterfaceDescriptions class + + Class should only be used as Mixin to an AvdInterfaceDescriptions class. """ @cached_property diff --git a/python-avd/pyavd/_eos_designs/ip_addressing/__init__.py b/python-avd/pyavd/_eos_designs/ip_addressing/__init__.py index 947a1c6d01d..b90127e3316 100644 --- a/python-avd/pyavd/_eos_designs/ip_addressing/__init__.py +++ b/python-avd/pyavd/_eos_designs/ip_addressing/__init__.py @@ -3,16 +3,18 @@ # that can be found in the LICENSE file. import ipaddress from collections import ChainMap +from typing import Any + +from pyavd._eos_designs.avdfacts import AvdFacts +from pyavd._errors import AristaAvdError +from pyavd._utils import get_ip_from_pool -from ..._errors import AristaAvdError -from ..._utils import get_ip_from_pool -from ..avdfacts import AvdFacts from .utils import UtilsMixin class AvdIpAddressing(AvdFacts, UtilsMixin): """ - Class used to render IP addresses either from custom Jinja2 templates or using default Python Logic + Class used to render IP addresses either from custom Jinja2 templates or using default Python Logic. Since some templates might contain certain legacy variables (switch_*), those are mapped from the switch.* model @@ -22,27 +24,23 @@ class AvdIpAddressing(AvdFacts, UtilsMixin): """ def _ip(self, pool: str, prefixlen: int, subnet_offset: int, ip_offset: int) -> str: - """ - Shortcut to get_ip_from_pool in case any custom subclasses are using this - """ + """Shortcut to get_ip_from_pool in case any custom subclasses are using this.""" return get_ip_from_pool(pool, prefixlen, subnet_offset, ip_offset) - def _template(self, template_path, **kwargs): + def _template(self, template_path: str, **kwargs: Any) -> str: template_vars = ChainMap(kwargs, self._hostvars) return self.shared_utils.template_var(template_path, template_vars) def _mlag_ip(self, pool: str, ip_offset: int, address_family: str = "ipv4") -> str: """ - Different addressing algorithms: + Different addressing algorithms. + - first_id: offset from pool is `(mlag_primary_id - 1) * 2` - odd_id: offset from pool is `(odd_id - 1) * 2`. Requires MLAG pair to have a node with odd and a node with an even ID - same_subnet: offset from pool is always 0. All MLAG pairs will be using the same subnet (default /31). Requires the pool to have the same prefix length. """ - if address_family == "ipv6": - prefixlen = self._fabric_ip_addressing_mlag_ipv6_prefix_length - else: - prefixlen = self._fabric_ip_addressing_mlag_ipv4_prefix_length + prefixlen = self._fabric_ip_addressing_mlag_ipv6_prefix_length if address_family == "ipv6" else self._fabric_ip_addressing_mlag_ipv4_prefix_length if self._fabric_ipaddress_mlag_algorithm == "odd_id": offset = self._mlag_odd_id_based_offset return get_ip_from_pool(pool, prefixlen, offset, ip_offset) @@ -50,7 +48,8 @@ def _mlag_ip(self, pool: str, ip_offset: int, address_family: str = "ipv4") -> s if self._fabric_ipaddress_mlag_algorithm == "same_subnet": pool_network = ipaddress.ip_network(pool, strict=False) if pool_network.prefixlen != prefixlen: - raise AristaAvdError(f"MLAG same_subnet addressing requires the pool to be a /{prefixlen}") + msg = f"MLAG same_subnet addressing requires the pool to be a /{prefixlen}" + raise AristaAvdError(msg) return get_ip_from_pool(pool, prefixlen, 0, ip_offset) # Use default first_id @@ -58,9 +57,7 @@ def _mlag_ip(self, pool: str, ip_offset: int, address_family: str = "ipv4") -> s return get_ip_from_pool(pool, prefixlen, offset, ip_offset) def mlag_ibgp_peering_ip_primary(self, mlag_ibgp_peering_ipv4_pool: str) -> str: - """ - Return IP for L3 Peerings in VRFs for MLAG Primary - """ + """Return IP for L3 Peerings in VRFs for MLAG Primary.""" if template_path := self.shared_utils.ip_addressing_templates.get("mlag_ibgp_peering_ip_primary"): return self._template( template_path, @@ -70,9 +67,7 @@ def mlag_ibgp_peering_ip_primary(self, mlag_ibgp_peering_ipv4_pool: str) -> str: return self._mlag_ip(mlag_ibgp_peering_ipv4_pool, 0) def mlag_ibgp_peering_ip_secondary(self, mlag_ibgp_peering_ipv4_pool: str) -> str: - """ - Return IP for L3 Peerings in VRFs for MLAG Secondary - """ + """Return IP for L3 Peerings in VRFs for MLAG Secondary.""" if template_path := self.shared_utils.ip_addressing_templates.get("mlag_ibgp_peering_ip_secondary"): return self._template( template_path, @@ -83,7 +78,7 @@ def mlag_ibgp_peering_ip_secondary(self, mlag_ibgp_peering_ipv4_pool: str) -> st def mlag_ip_primary(self) -> str: """ - Return IP for MLAG Primary + Return IP for MLAG Primary. Default pool is "mlag_peer_ipv4_pool" """ @@ -110,7 +105,7 @@ def mlag_ip_primary(self) -> str: def mlag_ip_secondary(self) -> str: """ - Return IP for MLAG Secondary + Return IP for MLAG Secondary. Default pool is "mlag_peer_ipv4_pool" """ @@ -137,7 +132,7 @@ def mlag_ip_secondary(self) -> str: def mlag_l3_ip_primary(self) -> str: """ - Return IP for L3 Peerings for MLAG Primary + Return IP for L3 Peerings for MLAG Primary. Default pool is "mlag_peer_l3_ipv4_pool" """ @@ -153,7 +148,7 @@ def mlag_l3_ip_primary(self) -> str: def mlag_l3_ip_secondary(self) -> str: """ - Return IP for L3 Peerings for MLAG Secondary + Return IP for L3 Peerings for MLAG Secondary. Default pool is "mlag_peer_l3_ipv4_pool" """ @@ -168,10 +163,7 @@ def mlag_l3_ip_secondary(self) -> str: return self._mlag_ip(self._mlag_peer_l3_ipv4_pool, 1) def p2p_uplinks_ip(self, uplink_switch_index: int) -> str: - """ - Return Child IP for P2P Uplinks - """ - + """Return Child IP for P2P Uplinks.""" uplink_switch_index = int(uplink_switch_index) if template_path := self.shared_utils.ip_addressing_templates.get("p2p_uplinks_ip"): return self._template( @@ -185,10 +177,7 @@ def p2p_uplinks_ip(self, uplink_switch_index: int) -> str: return get_ip_from_pool(p2p_ipv4_pool, prefixlen, offset, 1) def p2p_uplinks_peer_ip(self, uplink_switch_index: int) -> str: - """ - Return Parent IP for P2P Uplinks - """ - + """Return Parent IP for P2P Uplinks.""" uplink_switch_index = int(uplink_switch_index) if template_path := self.shared_utils.ip_addressing_templates.get("p2p_uplinks_peer_ip"): return self._template( @@ -204,10 +193,10 @@ def p2p_uplinks_peer_ip(self, uplink_switch_index: int) -> str: def p2p_vrfs_uplinks_ip( self, uplink_switch_index: int, - vrf: str, # pylint: disable=unused-argument # NOSONAR + vrf: str, # pylint: disable=unused-argument # NOSONAR # noqa: ARG002 ) -> str: """ - Return Child IP for P2P-VRFs Uplinks + Return Child IP for P2P-VRFs Uplinks. Unless overridden in a custom IP addressing module, this will just reuse the regular ip addressing logic. """ @@ -216,10 +205,10 @@ def p2p_vrfs_uplinks_ip( def p2p_vrfs_uplinks_peer_ip( self, uplink_switch_index: int, - vrf: str, # pylint: disable=unused-argument # NOSONAR + vrf: str, # pylint: disable=unused-argument # NOSONAR # noqa: ARG002 ) -> str: """ - Return Parent IP for P2P-VRFs Uplinks + Return Parent IP for P2P-VRFs Uplinks. Unless overridden in a custom IP addressing module, this will just reuse the regular ip addressing logic. """ @@ -227,7 +216,7 @@ def p2p_vrfs_uplinks_peer_ip( def router_id(self) -> str: """ - Return IP address for Router ID + Return IP address for Router ID. If "loopback_ipv4_address" is set, it is used. Default pool is "loopback_ipv4_pool" @@ -249,7 +238,7 @@ def router_id(self) -> str: def ipv6_router_id(self) -> str: """ - Return IPv6 address for Router ID + Return IPv6 address for Router ID. Default pool is "loopback_ipv6_pool" Default offset from pool is `id + loopback_ipv6_offset` @@ -267,7 +256,7 @@ def ipv6_router_id(self) -> str: def vtep_ip_mlag(self) -> str: """ - Return IP address for VTEP for MLAG Leaf + Return IP address for VTEP for MLAG Leaf. If "vtep_loopback_ipv4_address" is set, it is used. Default pool is "vtep_loopback_ipv4_pool" @@ -314,7 +303,8 @@ def vtep_ip(self) -> str: def vrf_loopback_ip(self, pool: str) -> str: """ Return IP address for a Loopback interface based on the given pool. - Default offset from pool is `id + loopback_ipv4_offset` + + Default offset from pool is `id + loopback_ipv4_offset`. Used for "vtep_diagnostic.loopback". """ @@ -325,11 +315,11 @@ def evpn_underlay_l3_multicast_group( self, underlay_l3_multicast_group_ipv4_pool: str, vrf_vni: int, - vrf_id: int, # pylint: disable=unused-argument + vrf_id: int, # pylint: disable=unused-argument # noqa: ARG002 evpn_underlay_l3_multicast_group_ipv4_pool_offset: int, ) -> str: """ - Return IP address to be used for EVPN underlay L3 multicast group + Return IP address to be used for EVPN underlay L3 multicast group. TODO: Change algorithm to use VRF ID instead of VRF VNI as offset. """ @@ -342,8 +332,6 @@ def evpn_underlay_l2_multicast_group( vlan_id: int, underlay_l2_multicast_group_ipv4_pool_offset: int, ) -> str: - """ - Return IP address to be used for EVPN underlay L2 multicast group - """ + """Return IP address to be used for EVPN underlay L2 multicast group.""" offset = vlan_id - 1 + underlay_l2_multicast_group_ipv4_pool_offset return get_ip_from_pool(underlay_l2_multicast_group_ipv4_pool, 32, offset, 0) diff --git a/python-avd/pyavd/_eos_designs/ip_addressing/utils.py b/python-avd/pyavd/_eos_designs/ip_addressing/utils.py index 0b7acb518ab..2b2bed7c57d 100644 --- a/python-avd/pyavd/_eos_designs/ip_addressing/utils.py +++ b/python-avd/pyavd/_eos_designs/ip_addressing/utils.py @@ -6,9 +6,9 @@ from functools import cached_property from typing import TYPE_CHECKING -from ..._errors import AristaAvdError, AristaAvdMissingVariableError -from ..._utils import get -from ...j2filters import range_expand +from pyavd._errors import AristaAvdError, AristaAvdMissingVariableError +from pyavd._utils import get +from pyavd.j2filters import range_expand if TYPE_CHECKING: from . import AvdIpAddressing @@ -17,19 +17,22 @@ class UtilsMixin: """ Mixin Class with internal functions. - Class should only be used as Mixin to an AvdIpAddressing class + + Class should only be used as Mixin to an AvdIpAddressing class. """ @cached_property def _mlag_primary_id(self: AvdIpAddressing) -> int: if self.shared_utils.mlag_switch_ids is None or self.shared_utils.mlag_switch_ids.get("primary") is None: - raise AristaAvdMissingVariableError("'mlag_switch_ids' is required to calculate MLAG IP addresses") + msg = "'mlag_switch_ids' is required to calculate MLAG IP addresses" + raise AristaAvdMissingVariableError(msg) return self.shared_utils.mlag_switch_ids["primary"] @cached_property def _mlag_secondary_id(self: AvdIpAddressing) -> int: if self.shared_utils.mlag_switch_ids is None or self.shared_utils.mlag_switch_ids.get("secondary") is None: - raise AristaAvdMissingVariableError("'mlag_switch_ids' is required to calculate MLAG IP addresses") + msg = "'mlag_switch_ids' is required to calculate MLAG IP addresses" + raise AristaAvdMissingVariableError(msg) return self.shared_utils.mlag_switch_ids["secondary"] @cached_property @@ -63,13 +66,15 @@ def _mlag_peer_l3_ipv4_pool(self: AvdIpAddressing) -> str: @cached_property def _uplink_ipv4_pool(self: AvdIpAddressing) -> str: if self.shared_utils.uplink_ipv4_pool is None: - raise AristaAvdMissingVariableError("'uplink_ipv4_pool' is required to calculate uplink IP addresses") + msg = "'uplink_ipv4_pool' is required to calculate uplink IP addresses" + raise AristaAvdMissingVariableError(msg) return self.shared_utils.uplink_ipv4_pool @cached_property def _id(self: AvdIpAddressing) -> int: if self.shared_utils.id is None: - raise AristaAvdMissingVariableError("'id' is required to calculate IP addresses") + msg = "'id' is required to calculate IP addresses" + raise AristaAvdMissingVariableError(msg) return self.shared_utils.id @cached_property @@ -111,14 +116,14 @@ def _vtep_loopback_ipv4_pool(self: AvdIpAddressing) -> str: @cached_property def _mlag_odd_id_based_offset(self: AvdIpAddressing) -> int: """ - Return the subnet offset for an MLAG pair based on odd id + Return the subnet offset for an MLAG pair based on odd id. Requires a pair of odd and even IDs """ - # Verify a mix of odd and even IDs if (self._mlag_primary_id % 2) == (self._mlag_secondary_id % 2): - raise AristaAvdError("MLAG compact addressing mode requires all MLAG pairs to have a single odd and even ID") + msg = "MLAG compact addressing mode requires all MLAG pairs to have a single odd and even ID" + raise AristaAvdError(msg) odd_id = self._mlag_primary_id if odd_id % 2 == 0: @@ -128,12 +133,11 @@ def _mlag_odd_id_based_offset(self: AvdIpAddressing) -> int: def _get_downlink_ipv4_pool_and_offset(self: AvdIpAddressing, uplink_switch_index: int) -> tuple[str, int]: """ - Returns the downlink IP pool and offset as a tuple according to the uplink_switch_index + Returns the downlink IP pool and offset as a tuple according to the uplink_switch_index. Offset is the matching interface's index in the list of downlink_interfaces (None, None) is returned if downlink_pools are not used """ - uplink_switch_interface = self.shared_utils.uplink_switch_interfaces[uplink_switch_index] uplink_switch = self.shared_utils.uplink_switches[uplink_switch_index] peer_facts = self.shared_utils.get_peer_facts(uplink_switch, required=True) @@ -150,14 +154,17 @@ def _get_downlink_ipv4_pool_and_offset(self: AvdIpAddressing, uplink_switch_inde return (get(downlink_pool_and_interfaces, "ipv4_pool"), interface_index) # If none of the interfaces match up, throw error - raise AristaAvdError( + msg = ( f"'downlink_pools' was defined at uplink_switch, but one of the 'uplink_switch_interfaces' ({uplink_switch_interface}) " "in the downlink_switch does not match any of the downlink_pools" ) + raise AristaAvdError( + msg, + ) def _get_p2p_ipv4_pool_and_offset(self: AvdIpAddressing, uplink_switch_index: int) -> tuple[str, int]: """ - Returns IP pool and offset as a tuple according to the uplink_switch_index + Returns IP pool and offset as a tuple according to the uplink_switch_index. Uplink pool or downlink pool is returned with its corresponding offset A downlink pool's offset is the matching interface's index in the list of downlink_interfaces @@ -165,7 +172,6 @@ def _get_p2p_ipv4_pool_and_offset(self: AvdIpAddressing, uplink_switch_index: in One and only one of these pools are required to be set, otherwise an error will be thrown """ - uplink_pool = self.shared_utils.uplink_ipv4_pool if uplink_pool is not None: uplink_offset = ((self._id - 1) * self._max_uplink_switches * self._max_parallel_uplinks) + uplink_switch_index @@ -173,15 +179,15 @@ def _get_p2p_ipv4_pool_and_offset(self: AvdIpAddressing, uplink_switch_index: in downlink_pool, downlink_offset = self._get_downlink_ipv4_pool_and_offset(uplink_switch_index) if uplink_pool is not None and downlink_pool is not None: - raise AristaAvdError( + msg = ( f"Unable to assign IPs for uplinks. 'uplink_ipv4_pool' ({uplink_pool}) on this switch cannot be combined " f"with 'downlink_pools' ({downlink_pool}) on any uplink switch." ) + raise AristaAvdError(msg) if uplink_pool is None and downlink_pool is None: - raise AristaAvdMissingVariableError( - "Unable to assign IPs for uplinks. Either 'uplink_ipv4_pool' on this switch or 'downlink_pools' on all the uplink switches" - ) + msg = "Unable to assign IPs for uplinks. Either 'uplink_ipv4_pool' on this switch or 'downlink_pools' on all the uplink switches" + raise AristaAvdMissingVariableError(msg) if uplink_pool is not None: return (uplink_pool, uplink_offset) diff --git a/python-avd/pyavd/_eos_designs/shared_utils/__init__.py b/python-avd/pyavd/_eos_designs/shared_utils/__init__.py index 60ce02a2049..50d5180ebaf 100644 --- a/python-avd/pyavd/_eos_designs/shared_utils/__init__.py +++ b/python-avd/pyavd/_eos_designs/shared_utils/__init__.py @@ -52,8 +52,7 @@ class SharedUtils( FlowTrackingMixin, ): """ - Class with commonly used methods / cached_properties to be shared between all the python modules - loaded in eos_designs. + Class with commonly used methods / cached_properties to be shared between all the python modules loaded in eos_designs. This class is instantiated in 'EosDesignsFacts' class and set as 'shared_utils' property. This class is also instantiated in 'eos_designs_structured_config' and the instance is given as argument to @@ -66,6 +65,6 @@ class SharedUtils( The class cannot be overridden. """ - def __init__(self, hostvars: dict, templar) -> None: + def __init__(self, hostvars: dict, templar: object) -> None: self.hostvars = hostvars self.templar = templar diff --git a/python-avd/pyavd/_eos_designs/shared_utils/bgp_peer_groups.py b/python-avd/pyavd/_eos_designs/shared_utils/bgp_peer_groups.py index a32bebe2356..9689ed0f038 100644 --- a/python-avd/pyavd/_eos_designs/shared_utils/bgp_peer_groups.py +++ b/python-avd/pyavd/_eos_designs/shared_utils/bgp_peer_groups.py @@ -6,7 +6,7 @@ from functools import cached_property from typing import TYPE_CHECKING -from ..._utils import get +from pyavd._utils import get if TYPE_CHECKING: from . import SharedUtils @@ -14,23 +14,24 @@ class BgpPeerGroupsMixin: """ - Mixin Class providing a subset of SharedUtils - Class should only be used as Mixin to the SharedUtils class + Mixin Class providing a subset of SharedUtils. + + Class should only be used as Mixin to the SharedUtils class. Using type-hint on self to get proper type-hints on attributes across all Mixins. """ @cached_property - def bgp_peer_groups(self: SharedUtils): + def bgp_peer_groups(self: SharedUtils) -> dict | None: """ - Get bgp_peer_groups configurations or fallback to defaults + Get bgp_peer_groups configurations or fallback to defaults. Supporting legacy uppercase keys as well. """ if not self.underlay_router: return None - BGP_PEER_GROUPS = [ - # (key, default_name, default_bfd) + default_bgp_peer_groups = [ + # key, default_name, default_bfd # Default BFD is set to None when not True, to avoid generating config for disabling BFD ("ipv4_underlay_peers", "IPv4-UNDERLAY-PEERS", None), ("mlag_ipv4_underlay_peer", "MLAG-IPv4-UNDERLAY-PEER", None), @@ -44,7 +45,7 @@ def bgp_peer_groups(self: SharedUtils): ] bgp_peer_groups = {} - for key, default_name, default_bfd in BGP_PEER_GROUPS: + for key, default_name, default_bfd in default_bgp_peer_groups: bgp_peer_groups[key] = { "name": get(self.hostvars, f"bgp_peer_groups.{key}.name", default=default_name), "password": get(self.hostvars, f"bgp_peer_groups.{key}.password"), @@ -61,7 +62,9 @@ def bgp_peer_groups(self: SharedUtils): if get(self.hostvars, f"bgp_peer_groups.{key}.bfd", default=default_bfd): bgp_peer_groups[key]["bfd_timers"] = get( - self.hostvars, f"bgp_peer_groups.{key}.bfd_timers", default={"interval": 1000, "min_rx": 1000, "multiplier": 10} + self.hostvars, + f"bgp_peer_groups.{key}.bfd_timers", + default={"interval": 1000, "min_rx": 1000, "multiplier": 10}, ) return bgp_peer_groups diff --git a/python-avd/pyavd/_eos_designs/shared_utils/connected_endpoints_keys.py b/python-avd/pyavd/_eos_designs/shared_utils/connected_endpoints_keys.py index c2a22f510e9..5013681f700 100644 --- a/python-avd/pyavd/_eos_designs/shared_utils/connected_endpoints_keys.py +++ b/python-avd/pyavd/_eos_designs/shared_utils/connected_endpoints_keys.py @@ -6,8 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ..._utils import get -from ...j2filters import convert_dicts +from pyavd._utils import get +from pyavd.j2filters import convert_dicts if TYPE_CHECKING: from . import SharedUtils @@ -31,20 +31,20 @@ class ConnectedEndpointsKeysMixin: """ - Mixin Class providing a subset of SharedUtils - Class should only be used as Mixin to the SharedUtils class + Mixin Class providing a subset of SharedUtils. + + Class should only be used as Mixin to the SharedUtils class. Using type-hint on self to get proper type-hints on attributes across all Mixins. """ @cached_property def connected_endpoints_keys(self: SharedUtils) -> list: """ - Return connected_endpoints_keys filtered for invalid entries and unused keys + Return connected_endpoints_keys filtered for invalid entries and unused keys. NOTE: This method is called _before_ any schema validation, since we need to resolve connected_endpoints_keys dynamically """ connected_endpoints_keys = [] # Support legacy data model by converting nested dict to list of dict connected_endpoints_keys = convert_dicts(get(self.hostvars, "connected_endpoints_keys", default=DEFAULT_CONNECTED_ENDPOINTS_KEYS), "key") - connected_endpoints_keys = [entry for entry in connected_endpoints_keys if entry.get("key") is not None and self.hostvars.get(entry["key"]) is not None] - return connected_endpoints_keys + return [entry for entry in connected_endpoints_keys if entry.get("key") is not None and self.hostvars.get(entry["key"]) is not None] diff --git a/python-avd/pyavd/_eos_designs/shared_utils/cv_topology.py b/python-avd/pyavd/_eos_designs/shared_utils/cv_topology.py index 45da1e06785..59c517d8140 100644 --- a/python-avd/pyavd/_eos_designs/shared_utils/cv_topology.py +++ b/python-avd/pyavd/_eos_designs/shared_utils/cv_topology.py @@ -6,8 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ..._utils import get, get_item -from ...j2filters import range_expand +from pyavd._utils import get, get_item +from pyavd.j2filters import range_expand if TYPE_CHECKING: from . import SharedUtils @@ -15,15 +15,17 @@ class CvTopology: """ - Mixin Class providing a subset of SharedUtils - Class should only be used as Mixin to the SharedUtils class + Mixin Class providing a subset of SharedUtils. + + Class should only be used as Mixin to the SharedUtils class. Using type-hint on self to get proper type-hints on attributes across all Mixins. """ @cached_property def cv_topology(self: SharedUtils) -> dict | None: """ - Returns the cv_topology for this device like + Returns the cv_topology for this device. + { hostname: , platform: , @@ -33,8 +35,8 @@ def cv_topology(self: SharedUtils) -> dict | None: neighbor: neighbor_interface: } - ] - + ]. + } """ if get(self.hostvars, "use_cv_topology") is not True: return None @@ -52,11 +54,13 @@ def cv_topology(self: SharedUtils) -> dict | None: def cv_topology_platform(self: SharedUtils) -> str | None: if self.cv_topology is not None: return self.cv_topology.get("platform") + return None @cached_property def cv_topology_config(self: SharedUtils) -> dict: """ - Returns dict with keys derived from cv topology (or empty dict) + Returns dict with keys derived from cv topology (or empty dict). + { uplink_interfaces: uplink_switches: @@ -76,7 +80,7 @@ def cv_topology_config(self: SharedUtils) -> dict: "uplink_interfaces", required=True, org_key="Found 'use_cv_topology:true' so 'default_interfaces.[].uplink_interfaces'", - ) + ), ) config = {} for uplink_interface in default_uplink_interfaces: @@ -94,13 +98,13 @@ def cv_topology_config(self: SharedUtils) -> dict: "mlag_interfaces", required=True, org_key="Found 'use_cv_topology:true' so 'default_interfaces.[].mlag_interfaces'", - ) + ), ) for mlag_interface in default_mlag_interfaces: if cv_interface := get_item(cv_interfaces, "name", mlag_interface): config.setdefault("mlag_interfaces", []).append(cv_interface["name"]) # TODO: Set mlag_peer once we get a user-defined var for that. - # config["mlag_peer"] = cv_interface["neighbor"] + # TODO: config["mlag_peer"] = cv_interface["neighbor"] for cv_interface in cv_interfaces: if cv_interface["name"].startswith("Management"): diff --git a/python-avd/pyavd/_eos_designs/shared_utils/filtered_tenants.py b/python-avd/pyavd/_eos_designs/shared_utils/filtered_tenants.py index 909a392cfdc..472cf7ebce6 100644 --- a/python-avd/pyavd/_eos_designs/shared_utils/filtered_tenants.py +++ b/python-avd/pyavd/_eos_designs/shared_utils/filtered_tenants.py @@ -6,9 +6,9 @@ from functools import cached_property from typing import TYPE_CHECKING -from ..._errors import AristaAvdError, AristaAvdMissingVariableError -from ..._utils import default, get, get_item, merge, unique -from ...j2filters import convert_dicts, natural_sort, range_expand +from pyavd._errors import AristaAvdError, AristaAvdMissingVariableError +from pyavd._utils import default, get, get_item, merge, unique +from pyavd.j2filters import convert_dicts, natural_sort, range_expand if TYPE_CHECKING: from . import SharedUtils @@ -16,8 +16,9 @@ class FilteredTenantsMixin: """ - Mixin Class providing a subset of SharedUtils - Class should only be used as Mixin to the SharedUtils class + Mixin Class providing a subset of SharedUtils. + + Class should only be used as Mixin to the SharedUtils class. Using type-hint on self to get proper type-hints on attributes across all Mixins. """ @@ -25,6 +26,7 @@ class FilteredTenantsMixin: def filtered_tenants(self: SharedUtils) -> list[dict]: """ Return sorted tenants list from all network_services_keys and filtered based on filter_tenants. + Keys of Tenant data model will be converted to lists. All sub data models like vrfs and l2vlans are also converted and filtered. """ @@ -35,15 +37,11 @@ def filtered_tenants(self: SharedUtils) -> list[dict]: filter_tenants = self.filter_tenants for network_services_key in self.network_services_keys: tenants = convert_dicts(get(self.hostvars, network_services_key["name"]), "name") - for tenant in tenants: - if tenant["name"] in filter_tenants or "all" in filter_tenants: - filtered_tenants.append( - { - **tenant, - "l2vlans": self.filtered_l2vlans(tenant), - "vrfs": self.filtered_vrfs(tenant), - } - ) + filtered_tenants.extend( + {**tenant, "l2vlans": self.filtered_l2vlans(tenant), "vrfs": self.filtered_vrfs(tenant)} + for tenant in tenants + if tenant["name"] in filter_tenants or "all" in filter_tenants + ) no_vrf_default = all(vrf["name"] != "default" for tenant in filtered_tenants for vrf in tenant["vrfs"]) if self.is_wan_router and no_vrf_default: @@ -64,7 +62,7 @@ def filtered_tenants(self: SharedUtils) -> list[dict]: } ], "l2vlans": [], - } + }, ) elif self.is_wan_router: # It is enough to check only the first occurrence of default VRF as some other piece of code @@ -72,11 +70,11 @@ def filtered_tenants(self: SharedUtils) -> list[dict]: for tenant in filtered_tenants: if (vrf_default := get_item(tenant["vrfs"], "name", "default")) is None: continue - if "evpn" in vrf_default.get("address_families", ["evpn"]): - if self.underlay_filter_peer_as: - raise AristaAvdError( - "WAN configuration requires EVPN to be enabled for VRF 'default'. Got 'address_families: {vrf_default['address_families']}." - ) + if "evpn" in vrf_default.get("address_families", ["evpn"]) and self.underlay_filter_peer_as: + msg = "WAN configuration requires EVPN to be enabled for VRF 'default'. Got 'address_families: {vrf_default['address_families']}." + raise AristaAvdError( + msg, + ) break return natural_sort(filtered_tenants, "name") @@ -84,6 +82,7 @@ def filtered_tenants(self: SharedUtils) -> list[dict]: def filtered_l2vlans(self: SharedUtils, tenant: dict) -> list[dict]: """ Return sorted and filtered l2vlan list from given tenant. + Filtering based on l2vlan tags. """ if not self.network_services_l2: @@ -98,19 +97,18 @@ def filtered_l2vlans(self: SharedUtils, tenant: dict) -> list[dict]: for l2vlan in l2vlans: l2vlan["evpn_vlan_bundle"] = get(l2vlan, "evpn_vlan_bundle", default=tenant_evpn_vlan_bundle) - l2vlans = [ + return [ # Copy and set tenant key on all l2vlans {**l2vlan, "tenant": tenant["name"]} for l2vlan in l2vlans if self.is_accepted_vlan(l2vlan) and ("all" in self.filter_tags or set(l2vlan.get("tags", ["all"])).intersection(self.filter_tags)) ] - return l2vlans - def is_accepted_vlan(self: SharedUtils, vlan: dict) -> bool: """ - Check if vlan is in accepted_vlans list - If filter.only_vlans_in_use is True also check if vlan id or trunk group is assigned to connected endpoint + Check if vlan is in accepted_vlans list. + + If filter.only_vlans_in_use is True also check if vlan id or trunk group is assigned to connected endpoint. """ vlan_id = int(vlan["id"]) @@ -125,7 +123,7 @@ def is_accepted_vlan(self: SharedUtils, vlan: dict) -> bool: return True # Picking this up from facts so this would fail if accessed when shared_utils is run before facts - # TODO see if this can be optimized + # TODO: see if this can be optimized endpoint_trunk_groups = set(self.get_switch_fact("endpoint_trunk_groups", required=False) or []) if self.enable_trunk_groups and vlan.get("trunk_groups") and endpoint_trunk_groups.intersection(vlan["trunk_groups"]): return True @@ -136,6 +134,7 @@ def is_accepted_vlan(self: SharedUtils, vlan: dict) -> bool: def accepted_vlans(self: SharedUtils) -> list[int]: """ The 'vlans' switch fact is a string representing a vlan range (ex. "1-200"). + For l2 switches return intersection of vlans from this switch and vlans from uplink switches. For anything else return the expanded vlans from this switch. """ @@ -160,7 +159,7 @@ def accepted_vlans(self: SharedUtils) -> list[int]: def is_accepted_vrf(self: SharedUtils, vrf: dict) -> bool: """ - Returns True if + Returns True if. - filter.allow_vrfs == ["all"] OR VRF is included in filter.allow_vrfs. @@ -193,6 +192,7 @@ def is_forced_vrf(self: SharedUtils, vrf: dict) -> bool: def filtered_vrfs(self: SharedUtils, tenant: dict) -> list[dict]: """ Return sorted and filtered vrf list from given tenant. + Filtering based on svi tags, l3interfaces, loopbacks or self.is_forced_vrf() check. Keys of VRF data model will be converted to lists. """ @@ -278,7 +278,7 @@ def filtered_vrfs(self: SharedUtils, tenant: dict) -> list[dict]: @cached_property def svi_profiles(self: SharedUtils) -> list[dict]: """ - Return list of svi_profiles + Return list of svi_profiles. The key "nodes" is filtered to only contain one item with the relevant dict from "nodes" or {} """ @@ -293,7 +293,7 @@ def svi_profiles(self: SharedUtils) -> list[dict]: def get_merged_svi_config(self: SharedUtils, svi: dict) -> list[dict]: """ - Return structured config for one svi after inheritance + Return structured config for one svi after inheritance. Handle inheritance of node config as svi_profiles in two levels: @@ -357,6 +357,7 @@ def get_merged_svi_config(self: SharedUtils, svi: dict) -> list[dict]: def filtered_svis(self: SharedUtils, vrf: dict) -> list[dict]: """ Return sorted and filtered svi list from given tenant vrf. + Filtering based on accepted vlans since eos_designs_facts already filtered that on tags and trunk_groups. """ @@ -383,14 +384,15 @@ def endpoint_vlans(self: SharedUtils) -> list: endpoint_vlans = self.get_switch_fact("endpoint_vlans", required=False) if not endpoint_vlans: return [] - return [int(id) for id in range_expand(endpoint_vlans)] + return [int(vlan_id) for vlan_id in range_expand(endpoint_vlans)] @staticmethod - def get_vrf_id(vrf, required: bool = True) -> int | None: + def get_vrf_id(vrf: dict, required: bool = True) -> int | None: vrf_id = default(vrf.get("vrf_id"), vrf.get("vrf_vni")) if vrf_id is None: if required: - raise AristaAvdMissingVariableError(f"'vrf_id' or 'vrf_vni' for VRF '{vrf['name']} must be set.") + msg = f"'vrf_id' or 'vrf_vni' for VRF '{vrf['name']} must be set." + raise AristaAvdMissingVariableError(msg) return None return int(vrf_id) @@ -398,13 +400,14 @@ def get_vrf_id(vrf, required: bool = True) -> int | None: def get_vrf_vni(vrf: dict) -> int: vrf_vni = default(vrf.get("vrf_vni"), vrf.get("vrf_id")) if vrf_vni is None: - raise AristaAvdMissingVariableError(f"'vrf_vni' or 'vrf_id' for VRF '{vrf['name']} must be set.") + msg = f"'vrf_vni' or 'vrf_id' for VRF '{vrf['name']} must be set." + raise AristaAvdMissingVariableError(msg) return int(vrf_vni) @cached_property def vrfs(self: SharedUtils) -> list: """ - Return the list of vrfs to be defined on this switch + Return the list of vrfs to be defined on this switch. Ex. ["default", "prod"] """ @@ -422,6 +425,7 @@ def vrfs(self: SharedUtils) -> list: def get_additional_svi_config(svi_config: dict, svi: dict, vrf: dict) -> None: """ Adding IP helpers and OSPF for SVIs via a common function. + Used for SVIs and for subinterfaces when uplink_type: lan. The given svi_config is updated in-place. @@ -439,7 +443,7 @@ def get_additional_svi_config(svi_config: dict, svi: dict, vrf: dict) -> None: "ospf_area": svi["ospf"].get("area", "0"), "ospf_network_point_to_point": svi["ospf"].get("point_to_point", False), "ospf_cost": svi["ospf"].get("cost"), - } + }, ) ospf_authentication = svi["ospf"].get("authentication") if ospf_authentication == "simple" and (ospf_simple_auth_key := svi["ospf"].get("simple_auth_key")) is not None: diff --git a/python-avd/pyavd/_eos_designs/shared_utils/flow_tracking.py b/python-avd/pyavd/_eos_designs/shared_utils/flow_tracking.py index 68c79979ea6..08bc391a594 100644 --- a/python-avd/pyavd/_eos_designs/shared_utils/flow_tracking.py +++ b/python-avd/pyavd/_eos_designs/shared_utils/flow_tracking.py @@ -7,7 +7,7 @@ from functools import cached_property from typing import TYPE_CHECKING, Literal -from ..._utils import get +from pyavd._utils import get if TYPE_CHECKING: from . import SharedUtils @@ -15,16 +15,16 @@ class FlowTrackingMixin: """ - Mixin Class providing a subset of SharedUtils - Class should only be used as Mixin to the SharedUtils class + Mixin Class providing a subset of SharedUtils. + + Class should only be used as Mixin to the SharedUtils class. Using type-hint on self to get proper type-hints on attributes across all Mixins. """ @cached_property def flow_tracking_type(self: SharedUtils) -> str: default_flow_tracker_type = get(self.node_type_key_data, "default_flow_tracker_type", "sampled") - flow_tracker_type = get(self.switch_data_combined, "flow_tracker_type", default=default_flow_tracker_type) - return flow_tracker_type + return get(self.switch_data_combined, "flow_tracker_type", default=default_flow_tracker_type) @cached_property def default_flow_tracker_name(self: SharedUtils) -> str: @@ -32,9 +32,7 @@ def default_flow_tracker_name(self: SharedUtils) -> str: @cached_property def fabric_flow_tracking(self: SharedUtils) -> defaultdict: - """ - Return fabric level flow tracking settings for all data models - """ + """Return fabric level flow tracking settings for all data models.""" configured_values = get(self.hostvars, "fabric_flow_tracking", default={}) # By default, flow tracker is `hardware` type named `FLOW-TRACKER` @@ -42,7 +40,7 @@ def fabric_flow_tracking(self: SharedUtils) -> defaultdict: lambda: { "enabled": None, "name": self.default_flow_tracker_name, - } + }, ) # By default, flow tracking is enabled only on DPS interfaces @@ -73,9 +71,7 @@ def get_flow_tracker( "dps_interfaces", ], ) -> dict: - """ - Return flow_tracking settings for a link, falling back to the fabric flow_tracking_settings if not defined. - """ + """Return flow_tracking settings for a link, falling back to the fabric flow_tracking_settings if not defined.""" link_tracker_enabled, link_tracker_name = None, None if link_settings is not None: link_tracker_enabled = get(link_settings, "flow_tracking.enabled") diff --git a/python-avd/pyavd/_eos_designs/shared_utils/inband_management.py b/python-avd/pyavd/_eos_designs/shared_utils/inband_management.py index 4f06e387392..44a380447fa 100644 --- a/python-avd/pyavd/_eos_designs/shared_utils/inband_management.py +++ b/python-avd/pyavd/_eos_designs/shared_utils/inband_management.py @@ -7,8 +7,8 @@ from ipaddress import ip_network from typing import TYPE_CHECKING -from ..._errors import AristaAvdMissingVariableError -from ..._utils import default, get +from pyavd._errors import AristaAvdMissingVariableError +from pyavd._utils import default, get if TYPE_CHECKING: from . import SharedUtils @@ -16,8 +16,9 @@ class InbandManagementMixin: """ - Mixin Class providing a subset of SharedUtils - Class should only be used as Mixin to the SharedUtils class + Mixin Class providing a subset of SharedUtils. + + Class should only be used as Mixin to the SharedUtils class. Using type-hint on self to get proper type-hints on attributes across all Mixins. """ @@ -74,7 +75,7 @@ def inband_mgmt_vrf(self: SharedUtils) -> str | None: @cached_property def inband_mgmt_gateway(self: SharedUtils) -> str | None: """ - Inband management gateway + Inband management gateway. If inband_mgmt_ip is set but not via inband_mgmt_subnet we return the value of inband_mgmt_gateway. @@ -89,12 +90,12 @@ def inband_mgmt_gateway(self: SharedUtils) -> str | None: return get(self.switch_data_combined, "inband_mgmt_gateway") subnet = ip_network(self.inband_mgmt_subnet, strict=False) - return f"{str(subnet[1])}" + return f"{subnet[1]!s}" @cached_property def inband_mgmt_ipv6_gateway(self: SharedUtils) -> str | None: """ - Inband management ipv6 gateway + Inband management ipv6 gateway. If inband_mgmt_ipv6_address is set but not via inband_mgmt_ipv6_subnet we return the value of inband_mgmt_ipv6_gateway. @@ -109,16 +110,17 @@ def inband_mgmt_ipv6_gateway(self: SharedUtils) -> str | None: return get(self.switch_data_combined, "inband_mgmt_ipv6_gateway") subnet = ip_network(self.inband_mgmt_ipv6_subnet, strict=False) - return f"{str(subnet[1])}" + return f"{subnet[1]!s}" @cached_property def inband_mgmt_ip(self: SharedUtils) -> str | None: """ - Inband management IP + Inband management IP. + Set to either: - Value of inband_mgmt_ip - deducted IP from inband_mgmt_subnet & id - - None + - None. """ if (inband_mgmt_ip := get(self.switch_data_combined, "inband_mgmt_ip")) is not None: return inband_mgmt_ip @@ -127,7 +129,8 @@ def inband_mgmt_ip(self: SharedUtils) -> str | None: return None if self.id is None: - raise AristaAvdMissingVariableError(f"'id' is not set on '{self.hostname}' and is required to set inband_mgmt_ip from inband_mgmt_subnet") + msg = f"'id' is not set on '{self.hostname}' and is required to set inband_mgmt_ip from inband_mgmt_subnet" + raise AristaAvdMissingVariableError(msg) subnet = ip_network(self.inband_mgmt_subnet, strict=False) inband_mgmt_ip = str(subnet[3 + self.id]) @@ -136,11 +139,12 @@ def inband_mgmt_ip(self: SharedUtils) -> str | None: @cached_property def inband_mgmt_ipv6_address(self: SharedUtils) -> str | None: """ - Inband management IPv6 Address + Inband management IPv6 Address. + Set to either: - Value of inband_mgmt_ipv6_address - deduced IP from inband_mgmt_ipv6_subnet & id - - None + - None. """ if (inband_mgmt_ipv6_address := get(self.switch_data_combined, "inband_mgmt_ipv6_address")) is not None: return inband_mgmt_ipv6_address @@ -149,9 +153,8 @@ def inband_mgmt_ipv6_address(self: SharedUtils) -> str | None: return None if self.id is None: - raise AristaAvdMissingVariableError( - f"'id' is not set on '{self.hostname}' and is required to set inband_mgmt_ipv6_address from inband_mgmt_ipv6_subnet" - ) + msg = f"'id' is not set on '{self.hostname}' and is required to set inband_mgmt_ipv6_address from inband_mgmt_ipv6_subnet" + raise AristaAvdMissingVariableError(msg) subnet = ip_network(self.inband_mgmt_ipv6_subnet, strict=False) inband_mgmt_ipv6_address = str(subnet[3 + self.id]) @@ -160,7 +163,7 @@ def inband_mgmt_ipv6_address(self: SharedUtils) -> str | None: @cached_property def inband_mgmt_interface(self: SharedUtils) -> str | None: """ - Inband management Interface used only to set as source interface on various management protocols + Inband management Interface used only to set as source interface on various management protocols. For L2 switches defaults to Vlan For all other devices set to value of inband_mgmt_interface or None diff --git a/python-avd/pyavd/_eos_designs/shared_utils/interface_descriptions.py b/python-avd/pyavd/_eos_designs/shared_utils/interface_descriptions.py index a3ac1b73126..a0d7f8c4429 100644 --- a/python-avd/pyavd/_eos_designs/shared_utils/interface_descriptions.py +++ b/python-avd/pyavd/_eos_designs/shared_utils/interface_descriptions.py @@ -6,8 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ..._utils import get, load_python_class, merge -from ..interface_descriptions import AvdInterfaceDescriptions +from pyavd._eos_designs.interface_descriptions import AvdInterfaceDescriptions +from pyavd._utils import get, load_python_class, merge if TYPE_CHECKING: from . import SharedUtils @@ -17,7 +17,8 @@ class InterfaceDescriptionsMixin: """ - Mixin Class providing a subset of SharedUtils + Mixin Class providing a subset of SharedUtils. + Class should only be used as Mixin to the SharedUtils class Using type-hint on self to get proper type-hints on attributes across all Mixins. """ @@ -25,8 +26,9 @@ class InterfaceDescriptionsMixin: @cached_property def interface_descriptions(self: SharedUtils) -> AvdInterfaceDescriptions: """ - Load the python_module defined in `templates.interface_descriptions.python_module` - Return an instance of the class defined by `templates.interface_descriptions.python_class_name` as cached_property + Load the python_module defined in `templates.interface_descriptions.python_module`. + + Return an instance of the class defined by `templates.interface_descriptions.python_class_name` as cached_property. """ module_path = self.interface_descriptions_templates.get("python_module") if module_path is None: @@ -45,9 +47,11 @@ def interface_descriptions(self: SharedUtils) -> AvdInterfaceDescriptions: @cached_property def interface_descriptions_templates(self: SharedUtils) -> dict: """ - Return dict with interface_descriptions templates set based on + Return dict with interface_descriptions templates. + + Set based on templates.interface_descriptions.* combined with (overridden by) - node_type_keys..interface_descriptions.* + node_type_keys..interface_descriptions.*. """ hostvar_templates = get(self.hostvars, "templates.interface_descriptions", default={}) node_type_templates = get(self.node_type_key_data, "interface_descriptions", default={}) diff --git a/python-avd/pyavd/_eos_designs/shared_utils/ip_addressing.py b/python-avd/pyavd/_eos_designs/shared_utils/ip_addressing.py index 3700f6c7d82..c60abe18bfa 100644 --- a/python-avd/pyavd/_eos_designs/shared_utils/ip_addressing.py +++ b/python-avd/pyavd/_eos_designs/shared_utils/ip_addressing.py @@ -6,8 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ..._utils import get, load_python_class, merge -from ..ip_addressing import AvdIpAddressing +from pyavd._eos_designs.ip_addressing import AvdIpAddressing +from pyavd._utils import get, load_python_class, merge if TYPE_CHECKING: from . import SharedUtils @@ -17,8 +17,9 @@ class IpAddressingMixin: """ - Mixin Class providing a subset of SharedUtils - Class should only be used as Mixin to the SharedUtils class + Mixin Class providing a subset of SharedUtils. + + Class should only be used as Mixin to the SharedUtils class. Using type-hint on self to get proper type-hints on attributes across all Mixins. """ @@ -62,9 +63,7 @@ def vtep_loopback_ipv4_address(self: SharedUtils) -> str: @cached_property def vtep_ip(self: SharedUtils) -> str: - """ - Render ipv4 address for vtep_ip using dynamically loaded python module. - """ + """Render ipv4 address for vtep_ip using dynamically loaded python module.""" if self.mlag is True: return self.ip_addressing.vtep_ip_mlag() @@ -77,8 +76,9 @@ def vtep_vvtep_ip(self: SharedUtils) -> str | None: @cached_property def ip_addressing(self: SharedUtils) -> AvdIpAddressing: """ - Load the python_module defined in `templates.ip_addressing.python_module` - Return an instance of the class defined by `templates.ip_addressing.python_class_name` as cached_property + Load the python_module defined in `templates.ip_addressing.python_module`. + + Return an instance of the class defined by `templates.ip_addressing.python_class_name` as cached_property. """ module_path = self.ip_addressing_templates.get("python_module") if module_path is None: @@ -97,9 +97,11 @@ def ip_addressing(self: SharedUtils) -> AvdIpAddressing: @cached_property def ip_addressing_templates(self: SharedUtils) -> dict: """ - Return dict with ip_addressing templates set based on + Return dict with ip_addressing templates. + + Set based on templates.ip_addressing.* combined with (overridden by) - node_type_keys..ip_addressing.* + node_type_keys..ip_addressing.*. """ hostvar_templates = get(self.hostvars, "templates.ip_addressing", default={}) node_type_templates = get(self.node_type_key_data, "ip_addressing", default={}) diff --git a/python-avd/pyavd/_eos_designs/shared_utils/l3_interfaces.py b/python-avd/pyavd/_eos_designs/shared_utils/l3_interfaces.py index 87be9552c1c..e879a9aaed7 100644 --- a/python-avd/pyavd/_eos_designs/shared_utils/l3_interfaces.py +++ b/python-avd/pyavd/_eos_designs/shared_utils/l3_interfaces.py @@ -6,9 +6,9 @@ from functools import cached_property from typing import TYPE_CHECKING -from ..._errors import AristaAvdMissingVariableError -from ..._utils import get, get_item, merge -from ..interface_descriptions import InterfaceDescriptionData +from pyavd._eos_designs.interface_descriptions.models import InterfaceDescriptionData +from pyavd._errors import AristaAvdMissingVariableError +from pyavd._utils import get, get_item, merge if TYPE_CHECKING: from . import SharedUtils @@ -16,8 +16,9 @@ class L3InterfacesMixin: """ - Mixin Class providing a subset of SharedUtils - Class should only be used as Mixin to the SharedUtils class + Mixin Class providing a subset of SharedUtils. + + Class should only be used as Mixin to the SharedUtils class. Using type-hint on self to get proper type-hints on attributes across all Mixins. """ @@ -31,9 +32,7 @@ def sanitize_interface_name(self: SharedUtils, interface_name: str) -> str: return interface_name.replace("/", "_") def apply_l3_interfaces_profile(self: SharedUtils, l3_interface: dict) -> dict: - """ - Apply a profile to an l3_interface - """ + """Apply a profile to an l3_interface.""" if "profile" not in l3_interface: # Nothing to do return l3_interface @@ -48,15 +47,13 @@ def l3_interface_profiles(self: SharedUtils) -> list: return get(self.hostvars, "l3_interface_profiles", default=[]) # TODO: Add sflow knob under fabric_sflow to cover l3_interfaces defined under the node_types. - # @cached_property - # def _l3_interfaces_sflow(self) -> bool | None: - # return get(self._hostvars, f"fabric_sflow.{self.data_model}") + # TODO: @cached_property + # TODO: def _l3_interfaces_sflow(self) -> bool | None: + # TODO: return get(self._hostvars, f"fabric_sflow.{self.data_model}") @cached_property def l3_interfaces(self: SharedUtils) -> list: - """ - Returns the list of l3_interfaces, where any referenced profiles are applied. - """ + """Returns the list of l3_interfaces, where any referenced profiles are applied.""" if not (l3_interfaces := get(self.switch_data_combined, "l3_interfaces")): return [] @@ -77,13 +74,15 @@ def l3_interfaces_bgp_neighbors(self: SharedUtils) -> list: peer_as = get(bgp, "peer_as") if peer_as is None: - raise AristaAvdMissingVariableError(f"'l3_interfaces[{interface['name']}].bgp.peer_as' needs to be set to enable BGP.") + msg = f"'l3_interfaces[{interface['name']}].bgp.peer_as' needs to be set to enable BGP." + raise AristaAvdMissingVariableError(msg) is_intf_wan = get(interface, "wan_carrier") is not None prefix_list_in = get(bgp, "ipv4_prefix_list_in") if prefix_list_in is None and is_intf_wan: - raise AristaAvdMissingVariableError(f"BGP is enabled but 'bgp.ipv4_prefix_list_in' is not configured for l3_interfaces[{interface['name']}]") + msg = f"BGP is enabled but 'bgp.ipv4_prefix_list_in' is not configured for l3_interfaces[{interface['name']}]" + raise AristaAvdMissingVariableError(msg) description = interface.get("description") if not description: @@ -95,7 +94,7 @@ def l3_interfaces_bgp_neighbors(self: SharedUtils) -> list: peer_interface=interface.get("peer_interface"), wan_carrier=interface.get("wan_carrier"), wan_circuit_id=interface.get("wan_circuit_id"), - ) + ), ) neighbor = { diff --git a/python-avd/pyavd/_eos_designs/shared_utils/link_tracking_groups.py b/python-avd/pyavd/_eos_designs/shared_utils/link_tracking_groups.py index e9e5b16d22a..8aa87d70486 100644 --- a/python-avd/pyavd/_eos_designs/shared_utils/link_tracking_groups.py +++ b/python-avd/pyavd/_eos_designs/shared_utils/link_tracking_groups.py @@ -6,7 +6,7 @@ from functools import cached_property from typing import TYPE_CHECKING -from ..._utils import get +from pyavd._utils import get if TYPE_CHECKING: from . import SharedUtils @@ -14,8 +14,9 @@ class LinkTrackingGroupsMixin: """ - Mixin Class providing a subset of SharedUtils - Class should only be used as Mixin to the SharedUtils class + Mixin Class providing a subset of SharedUtils. + + Class should only be used as Mixin to the SharedUtils class. Using type-hint on self to get proper type-hints on attributes across all Mixins. """ diff --git a/python-avd/pyavd/_eos_designs/shared_utils/mgmt.py b/python-avd/pyavd/_eos_designs/shared_utils/mgmt.py index 4ed5cc781fe..4248f183431 100644 --- a/python-avd/pyavd/_eos_designs/shared_utils/mgmt.py +++ b/python-avd/pyavd/_eos_designs/shared_utils/mgmt.py @@ -6,8 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ..._errors import AristaAvdMissingVariableError -from ..._utils import default, get +from pyavd._errors import AristaAvdMissingVariableError +from pyavd._utils import default, get if TYPE_CHECKING: from . import SharedUtils @@ -15,18 +15,21 @@ class MgmtMixin: """ - Mixin Class providing a subset of SharedUtils - Class should only be used as Mixin to the SharedUtils class + Mixin Class providing a subset of SharedUtils. + + Class should only be used as Mixin to the SharedUtils class. Using type-hint on self to get proper type-hints on attributes across all Mixins. """ @cached_property def mgmt_interface(self: SharedUtils) -> str | None: """ + mgmt_interface. + mgmt_interface is inherited from Global var mgmt_interface -> Platform Settings management_interface -> - Fabric Topology data model mgmt_interface + Fabric Topology data model mgmt_interface. """ return default( get(self.switch_data_combined, "mgmt_interface"), @@ -66,14 +69,16 @@ def default_mgmt_method(self: SharedUtils) -> str | None: default_mgmt_method = get(self.hostvars, "default_mgmt_method", default="oob") if default_mgmt_method == "oob": if (self.mgmt_ip is None) and (self.ipv6_mgmt_ip is None): - raise AristaAvdMissingVariableError("'default_mgmt_method: oob' requires either 'mgmt_ip' or 'ipv6_mgmt_ip' to bet set.") + msg = "'default_mgmt_method: oob' requires either 'mgmt_ip' or 'ipv6_mgmt_ip' to bet set." + raise AristaAvdMissingVariableError(msg) return default_mgmt_method if default_mgmt_method == "inband": # Check for missing interface if self.inband_mgmt_interface is None: - raise AristaAvdMissingVariableError("'default_mgmt_method: inband' requires 'inband_mgmt_interface' to be set.") + msg = "'default_mgmt_method: inband' requires 'inband_mgmt_interface' to be set." + raise AristaAvdMissingVariableError(msg) return default_mgmt_method diff --git a/python-avd/pyavd/_eos_designs/shared_utils/misc.py b/python-avd/pyavd/_eos_designs/shared_utils/misc.py index 38b1eae2fdb..c2ec172a131 100644 --- a/python-avd/pyavd/_eos_designs/shared_utils/misc.py +++ b/python-avd/pyavd/_eos_designs/shared_utils/misc.py @@ -5,21 +5,23 @@ from copy import deepcopy from functools import cached_property -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any -from ..._errors import AristaAvdError, AristaAvdMissingVariableError -from ..._utils import default, get -from ...j2filters import convert_dicts, natural_sort, range_expand +from pyavd._errors import AristaAvdError, AristaAvdMissingVariableError +from pyavd._utils import default, get +from pyavd.j2filters import convert_dicts, natural_sort, range_expand if TYPE_CHECKING: - from ..eos_designs_facts import EosDesignsFacts + from pyavd._eos_designs.eos_designs_facts import EosDesignsFacts + from . import SharedUtils class MiscMixin: """ - Mixin Class providing a subset of SharedUtils - Class should only be used as Mixin to the SharedUtils class + Mixin Class providing a subset of SharedUtils. + + Class should only be used as Mixin to the SharedUtils class. Using type-hint on self to get proper type-hints on attributes across all Mixins. """ @@ -30,9 +32,7 @@ def all_fabric_devices(self: SharedUtils) -> list[str]: @cached_property def hostname(self: SharedUtils) -> str: - """ - hostname set based on inventory_hostname variable - """ + """Hostname set based on inventory_hostname variable.""" return get(self.hostvars, "inventory_hostname", required=True) @cached_property @@ -61,9 +61,7 @@ def filter_only_vlans_in_use(self: SharedUtils) -> bool: @cached_property def filter_tags(self: SharedUtils) -> list: - """ - Return filter.tags + group if defined - """ + """Return filter.tags + group if defined.""" filter_tags = get(self.switch_data_combined, "filter.tags", default=["all"]) if self.group is not None: filter_tags.append(self.group) @@ -97,9 +95,11 @@ def only_local_vlan_trunk_groups(self: SharedUtils) -> bool: @cached_property def system_mac_address(self: SharedUtils) -> str | None: """ + system_mac_address. + system_mac_address is inherited from Fabric Topology data model system_mac_address -> - Host variable var system_mac_address -> + Host variable var system_mac_address ->. """ return default(get(self.switch_data_combined, "system_mac_address"), get(self.hostvars, "system_mac_address")) @@ -119,7 +119,7 @@ def uplink_interfaces(self: SharedUtils) -> list: get(self.cv_topology_config, "uplink_interfaces"), get(self.default_interfaces, "uplink_interfaces"), [], - ) + ), ) @cached_property @@ -135,7 +135,8 @@ def uplink_switch_interfaces(self: SharedUtils) -> list: return [] if self.id is None: - raise AristaAvdMissingVariableError(f"'id' is not set on '{self.hostname}'") + msg = f"'id' is not set on '{self.hostname}'" + raise AristaAvdMissingVariableError(msg) uplink_switch_interfaces = [] uplink_switch_counter = {} @@ -153,10 +154,13 @@ def uplink_switch_interfaces(self: SharedUtils) -> list: if len(uplink_switch_facts._default_downlink_interfaces) > downlink_index: uplink_switch_interfaces.append(uplink_switch_facts._default_downlink_interfaces[downlink_index]) else: - raise AristaAvdError( + msg = ( f"'uplink_switch_interfaces' is not set on '{self.hostname}' and 'uplink_switch' '{uplink_switch}' " f"does not have 'downlink_interfaces[{downlink_index}]' set under 'default_interfaces'" ) + raise AristaAvdError( + msg, + ) return uplink_switch_interfaces @@ -167,24 +171,22 @@ def virtual_router_mac_address(self: SharedUtils) -> str | None: @cached_property def serial_number(self: SharedUtils) -> str | None: """ + serial_number. + serial_number is inherited from Fabric Topology data model serial_number -> - Host variable var serial_number + Host variable var serial_number. """ return default(get(self.switch_data_combined, "serial_number"), get(self.hostvars, "serial_number")) @cached_property def max_parallel_uplinks(self: SharedUtils) -> int: - """ - Exposed in avd_switch_facts - """ + """Exposed in avd_switch_facts.""" return get(self.switch_data_combined, "max_parallel_uplinks", default=1) @cached_property def max_uplink_switches(self: SharedUtils) -> int: - """ - max_uplink_switches will default to the length of uplink_switches - """ + """max_uplink_switches will default to the length of uplink_switches.""" return get(self.switch_data_combined, "max_uplink_switches", default=len(self.uplink_switches)) @cached_property @@ -208,12 +210,12 @@ def shutdown_bgp_towards_undeployed_peers(self: SharedUtils) -> bool: @cached_property def bfd_multihop(self: SharedUtils) -> dict: - DEFAULT_BFD_MULTIHOP = { + default_bfd_multihop = { "interval": 300, "min_rx": 300, "multiplier": 3, } - return get(self.hostvars, "bfd_multihop", default=DEFAULT_BFD_MULTIHOP) + return get(self.hostvars, "bfd_multihop", default=default_bfd_multihop) @cached_property def evpn_ebgp_multihop(self: SharedUtils) -> int: @@ -246,12 +248,12 @@ def rack(self: SharedUtils) -> str | None: @cached_property def network_services_keys(self: SharedUtils) -> list[dict]: """ - Return sorted network_services_keys filtered for invalid entries and unused keys + Return sorted network_services_keys filtered for invalid entries and unused keys. NOTE: This method is called _before_ any schema validation, since we need to resolve network_services_keys dynamically """ - DEFAULT_NETWORK_SERVICES_KEYS = [{"name": "tenants"}] - network_services_keys = get(self.hostvars, "network_services_keys", default=DEFAULT_NETWORK_SERVICES_KEYS) + default_network_services_keys = [{"name": "tenants"}] + network_services_keys = get(self.hostvars, "network_services_keys", default=default_network_services_keys) network_services_keys = [entry for entry in network_services_keys if entry.get("name") is not None and self.hostvars.get(entry["name"]) is not None] return natural_sort(network_services_keys, "name") @@ -305,6 +307,7 @@ def pod_name(self: SharedUtils) -> str | None: def fabric_ip_addressing_mlag_algorithm(self: SharedUtils) -> str: """ This method fetches the MLAG algorithm value from host variables. + It defaults to 'first_id' if the variable is not defined. """ return get(self.hostvars, "fabric_ip_addressing.mlag.algorithm", default="first_id") @@ -342,9 +345,10 @@ def default_interface_mtu(self: SharedUtils) -> int | None: default_default_interface_mtu = get(self.hostvars, "default_interface_mtu") return get(self.platform_settings, "default_interface_mtu", default=default_default_interface_mtu) - def get_switch_fact(self: SharedUtils, key, required=True): + def get_switch_fact(self: SharedUtils, key: str, required: bool = True) -> Any: """ Return facts from EosDesignsFacts. + We need to go via avd_switch_facts since PyAVD does not expose "switch.*" in get_avdfacts. """ return get(self.hostvars, f"avd_switch_facts..{self.hostname}..switch..{key}", required=required, org_key=f"switch.{key}", separator="..") @@ -356,8 +360,7 @@ def evpn_multicast(self: SharedUtils) -> bool: @cached_property def new_network_services_bgp_vrf_config(self: SharedUtils) -> bool: """ - Return whether or not to use the new behavior when generating - BGP VRF configuration + Return whether or not to use the new behavior when generating BGP VRF configuration. TODO: Change default to True in all cases in AVD 5.0.0 and remove in AVD 6.0.0 """ @@ -368,9 +371,10 @@ def new_network_services_bgp_vrf_config(self: SharedUtils) -> bool: def ipv4_acls(self: SharedUtils) -> dict: return {acl["name"]: acl for acl in get(self.hostvars, "ipv4_acls", default=[])} - def get_ipv4_acl(self: SharedUtils, name: str, interface_name: str, *, interface_ip: str | None = None, peer_ip: str | None = None): + def get_ipv4_acl(self: SharedUtils, name: str, interface_name: str, *, interface_ip: str | None = None, peer_ip: str | None = None) -> dict: """ Get one IPv4 ACL from "ipv4_acls" where fields have been substituted. + If any substitution is done, the ACL name will get "_" appended. """ org_ipv4_acl = get(self.ipv4_acls, name, required=True, org_key=f"ipv4_acls[name={name}]") @@ -398,9 +402,10 @@ def get_ipv4_acl(self: SharedUtils, name: str, interface_name: str, *, interface def _get_ipv4_acl_field_with_substitution(field_value: str, replacements: dict[str, str], field_context: str, interface_name: str) -> str: """ Checks one field if the value can be substituted. + The given "replacements" dict will be parsed as: key: substitution field to look for - value: replacement value to set + value: replacement value to set. If a replacement is done, but the value is None, an error will be raised. """ @@ -409,11 +414,14 @@ def _get_ipv4_acl_field_with_substitution(field_value: str, replacements: dict[s continue if value is None: - raise AristaAvdError( + msg = ( f"Unable to perform substitution of the value '{key}' defined under '{field_context}', " f"since no substitution value was found for interface '{interface_name}'. " "Make sure to set the appropriate fields on the interface." ) + raise AristaAvdError( + msg, + ) return value diff --git a/python-avd/pyavd/_eos_designs/shared_utils/mlag.py b/python-avd/pyavd/_eos_designs/shared_utils/mlag.py index 8ec7b307e4a..fa8746478d4 100644 --- a/python-avd/pyavd/_eos_designs/shared_utils/mlag.py +++ b/python-avd/pyavd/_eos_designs/shared_utils/mlag.py @@ -6,21 +6,23 @@ from functools import cached_property from ipaddress import ip_interface from re import findall -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any -from ..._errors import AristaAvdError, AristaAvdMissingVariableError -from ..._utils import default, get -from ...j2filters import range_expand +from pyavd._errors import AristaAvdError, AristaAvdMissingVariableError +from pyavd._utils import default, get +from pyavd.j2filters import range_expand if TYPE_CHECKING: - from ..eos_designs_facts import EosDesignsFacts + from pyavd._eos_designs.eos_designs_facts import EosDesignsFacts + from . import SharedUtils class MlagMixin: """ - Mixin Class providing a subset of SharedUtils - Class should only be used as Mixin to the SharedUtils class + Mixin Class providing a subset of SharedUtils. + + Class should only be used as Mixin to the SharedUtils class. Using type-hint on self to get proper type-hints on attributes across all Mixins. """ @@ -34,9 +36,7 @@ def mlag(self: SharedUtils) -> bool: @cached_property def group(self: SharedUtils) -> str | None: - """ - group set to "node_group" name or None - """ + """Group set to "node_group" name or None.""" return get(self.switch_data, "group") @cached_property @@ -55,7 +55,7 @@ def mlag_interfaces(self: SharedUtils) -> list: get(self.cv_topology_config, "mlag_interfaces"), get(self.default_interfaces, "mlag_interfaces"), [], - ) + ), ) @cached_property @@ -85,7 +85,8 @@ def mlag_role(self: SharedUtils) -> str | None: return "primary" if self.switch_data_node_group_nodes[1]["name"] == self.hostname: return "secondary" - raise AristaAvdError("Unable to detect MLAG role") + msg = "Unable to detect MLAG role" + raise AristaAvdError(msg) return None @cached_property @@ -94,7 +95,8 @@ def mlag_peer(self: SharedUtils) -> str: return self.switch_data_node_group_nodes[1]["name"] if self.mlag_role == "secondary": return self.switch_data_node_group_nodes[0]["name"] - raise AristaAvdError("Unable to find MLAG peer within same node group") + msg = "Unable to find MLAG peer within same node group" + raise AristaAvdError(msg) @cached_property def mlag_l3(self: SharedUtils) -> bool: @@ -123,7 +125,7 @@ def mlag_peer_l3_ip(self: SharedUtils) -> str | None: def mlag_peer_id(self: SharedUtils) -> int: return self.get_mlag_peer_fact("id") - def get_mlag_peer_fact(self: SharedUtils, key, required=True): + def get_mlag_peer_fact(self: SharedUtils, key: str, required: bool = True) -> Any: return get(self.mlag_peer_facts, key, required=required, org_key=f"avd_switch_facts.({self.mlag_peer}).switch.{key}") @cached_property @@ -139,45 +141,48 @@ def mlag_peer_mgmt_ip(self: SharedUtils) -> str | None: @cached_property def mlag_ip(self: SharedUtils) -> str | None: - """ - Render ipv4 address for mlag_ip using dynamically loaded python module. - """ + """Render ipv4 address for mlag_ip using dynamically loaded python module.""" if self.mlag_role == "primary": return self.ip_addressing.mlag_ip_primary() if self.mlag_role == "secondary": return self.ip_addressing.mlag_ip_secondary() + return None @cached_property def mlag_l3_ip(self: SharedUtils) -> str | None: - """ - Render ipv4 address for mlag_l3_ip using dynamically loaded python module. - """ + """Render ipv4 address for mlag_l3_ip using dynamically loaded python module.""" if self.mlag_peer_l3_vlan is None: return None if self.mlag_role == "primary": return self.ip_addressing.mlag_l3_ip_primary() if self.mlag_role == "secondary": return self.ip_addressing.mlag_l3_ip_secondary() + return None @cached_property def mlag_switch_ids(self: SharedUtils) -> dict | None: """ - Returns the switch id's of both primary and secondary switches for a given node group - {"primary": int, "secondary": int} + Returns the switch id's of both primary and secondary switches for a given node group. + + {"primary": int, "secondary": int}. """ if self.mlag_role == "primary": if self.id is None: - raise AristaAvdMissingVariableError(f"'id' is not set on '{self.hostname}' and is required to compute MLAG ids") + msg = f"'id' is not set on '{self.hostname}' and is required to compute MLAG ids" + raise AristaAvdMissingVariableError(msg) return {"primary": self.id, "secondary": self.mlag_peer_id} if self.mlag_role == "secondary": if self.id is None: - raise AristaAvdMissingVariableError(f"'id' is not set on '{self.hostname}' and is required to compute MLAG ids") + msg = f"'id' is not set on '{self.hostname}' and is required to compute MLAG ids" + raise AristaAvdMissingVariableError(msg) return {"primary": self.mlag_peer_id, "secondary": self.id} + return None @cached_property def mlag_port_channel_id(self: SharedUtils) -> int: if not self.mlag_interfaces: - raise AristaAvdMissingVariableError(f"'mlag_interfaces' not set on '{self.hostname}.") + msg = f"'mlag_interfaces' not set on '{self.hostname}." + raise AristaAvdMissingVariableError(msg) default_mlag_port_channel_id = int("".join(findall(r"\d", self.mlag_interfaces[0]))) return get(self.switch_data_combined, "mlag_port_channel_id", default_mlag_port_channel_id) diff --git a/python-avd/pyavd/_eos_designs/shared_utils/node_type.py b/python-avd/pyavd/_eos_designs/shared_utils/node_type.py index ed56b5dfdf3..28bb143f060 100644 --- a/python-avd/pyavd/_eos_designs/shared_utils/node_type.py +++ b/python-avd/pyavd/_eos_designs/shared_utils/node_type.py @@ -7,8 +7,8 @@ from re import search from typing import TYPE_CHECKING -from ..._errors import AristaAvdMissingVariableError -from ..._utils import get +from pyavd._errors import AristaAvdMissingVariableError +from pyavd._utils import get if TYPE_CHECKING: from . import SharedUtils @@ -16,29 +16,26 @@ class NodeTypeMixin: """ - Mixin Class providing a subset of SharedUtils - Class should only be used as Mixin to the SharedUtils class + Mixin Class providing a subset of SharedUtils. + + Class should only be used as Mixin to the SharedUtils class. Using type-hint on self to get proper type-hints on attributes across all Mixins. """ @cached_property def type(self: SharedUtils) -> str: - """ - type fact set based on type variable - """ + """Type fact set based on type variable.""" if (node_type := get(self.hostvars, "type")) is not None: return node_type if self.default_node_type: return self.default_node_type - raise AristaAvdMissingVariableError(f"'type' for host {self.hostname}") + msg = f"'type' for host {self.hostname}" + raise AristaAvdMissingVariableError(msg) @cached_property def default_node_type(self: SharedUtils) -> str | None: - """ - default_node_type set based on hostname, returning - first node type matching a regex in default_node_types - """ + """default_node_type set based on hostname, returning first node type matching a regex in default_node_types.""" default_node_types = get(self.hostvars, "default_node_types", default=[]) for default_node_type in default_node_types: @@ -50,34 +47,35 @@ def default_node_type(self: SharedUtils) -> str | None: @cached_property def cvp_tag_topology_hint_type(self: SharedUtils) -> str: - """ - topology_tag_type set based on - node_type_keys..cvp_tags.topology_hint_type - """ + """topology_tag_type set based on node_type_keys..cvp_tags.topology_hint_type.""" return get(self.node_type_key_data, "cvp_tags.topology_hint_type", default="endpoint") @cached_property def connected_endpoints(self: SharedUtils) -> bool: """ + Should we configure connected endpoints? + connected_endpoints set based on - node_type_keys..connected_endpoints + node_type_keys..connected_endpoints. """ return get(self.node_type_key_data, "connected_endpoints", default=False) @cached_property def underlay_router(self: SharedUtils) -> bool: """ + Is this an underlay router? + underlay_router set based on - node_type_keys..underlay_router + node_type_keys..underlay_router. """ return get(self.node_type_key_data, "underlay_router", default=True) @cached_property def uplink_type(self: SharedUtils) -> str: """ - uplink_type set based on - .nodes.[].uplink_type and - node_type_keys..uplink_type + Uplink type. + + uplink_type set based on .nodes.[].uplink_type and node_type_keys..uplink_type. """ default_uplink_type = get(self.node_type_key_data, "uplink_type", default="p2p") return get(self.switch_data_combined, "uplink_type", default=default_uplink_type) @@ -85,63 +83,67 @@ def uplink_type(self: SharedUtils) -> str: @cached_property def network_services_l1(self: SharedUtils) -> bool: """ - network_services_l1 set based on - node_type_keys..network_services.l1 + Should we configure L1 network services? + + network_services_l1 set based on node_type_keys..network_services.l1. """ return get(self.node_type_key_data, "network_services.l1", default=False) @cached_property def network_services_l2(self: SharedUtils) -> bool: """ - network_services_l2 set based on - node_type_keys..network_services.l2 + Should we configure L2 network services? + + network_services_l2 set based on node_type_keys..network_services.l2. """ return get(self.node_type_key_data, "network_services.l2", default=False) @cached_property def network_services_l3(self: SharedUtils) -> bool: """ - network_services_l3 set based on - node_type_keys..network_services.l3 and - . | nodes.<> >.evpn_services_l2_only + Should we configure L3 network services? + + network_services_l3 set based on node_type_keys..network_services.l3 + and . | nodes.<> >.evpn_services_l2_only. """ - if self.vtep is True: - # network_services_l3 override based on evpn_services_l2_only - if get(self.switch_data_combined, "evpn_services_l2_only") is True: - return False + # network_services_l3 override based on evpn_services_l2_only + if self.vtep is True and get(self.switch_data_combined, "evpn_services_l2_only") is True: + return False return get(self.node_type_key_data, "network_services.l3", default=False) @cached_property def network_services_l2_as_subint(self: SharedUtils) -> bool: """ - network_services_l2_as_subint set based on - node_type_keys..network_services.l3 for uplink_type "lan" or "lan-port-channel" + Should we deploy SVIs as subinterfaces? - This is used when deploying SVIs as subinterfaces. + network_services_l2_as_subint set based on + node_type_keys..network_services.l3 for uplink_type "lan" or "lan-port-channel". """ return self.network_services_l3 and self.uplink_type in ["lan", "lan-port-channel"] @cached_property def any_network_services(self: SharedUtils) -> bool: - """ - Returns True if either L1, L2 or L3 network_services are enabled - """ + """Returns True if either L1, L2 or L3 network_services are enabled.""" return self.network_services_l1 is True or self.network_services_l2 is True or self.network_services_l3 is True @cached_property def mpls_lsr(self: SharedUtils) -> bool: """ + Is this an MPLS LSR? + mpls_lsr set based on - node_type_keys..mpls_lsr + node_type_keys..mpls_lsr. """ return get(self.node_type_key_data, "mpls_lsr", default=False) @cached_property def vtep(self: SharedUtils) -> bool: """ + Is this a VTEP? + vtep set based on .nodes.[].vtep and - node_type_keys..vtep + node_type_keys..vtep. """ default_vtep = get(self.node_type_key_data, "vtep") return get(self.switch_data_combined, "vtep", default=default_vtep) is True diff --git a/python-avd/pyavd/_eos_designs/shared_utils/node_type_keys.py b/python-avd/pyavd/_eos_designs/shared_utils/node_type_keys.py index aa2f745940b..772c4a5d59d 100644 --- a/python-avd/pyavd/_eos_designs/shared_utils/node_type_keys.py +++ b/python-avd/pyavd/_eos_designs/shared_utils/node_type_keys.py @@ -6,9 +6,9 @@ from functools import cached_property from typing import TYPE_CHECKING -from ..._errors import AristaAvdMissingVariableError -from ..._utils import get -from ...j2filters import convert_dicts +from pyavd._errors import AristaAvdMissingVariableError +from pyavd._utils import get +from pyavd.j2filters import convert_dicts if TYPE_CHECKING: from . import SharedUtils @@ -166,31 +166,27 @@ class NodeTypeKeysMixin: """ - Mixin Class providing a subset of SharedUtils - Class should only be used as Mixin to the SharedUtils class + Mixin Class providing a subset of SharedUtils. + + Class should only be used as Mixin to the SharedUtils class. Using type-hint on self to get proper type-hints on attributes across all Mixins. """ @cached_property def node_type_keys(self: SharedUtils) -> list: - """ - NOTE: This method is called _before_ any schema validation, since we need to resolve node_type_keys dynamically - - """ + """NOTE: This method is called _before_ any schema validation, since we need to resolve node_type_keys dynamically.""" design_type = get(self.hostvars, "design.type", default="l3ls-evpn") default_node_type_keys_for_our_design = get(DEFAULT_NODE_TYPE_KEYS, design_type) node_type_keys = get(self.hostvars, "node_type_keys", default=default_node_type_keys_for_our_design) - node_type_keys = convert_dicts(node_type_keys, "key") - return node_type_keys + return convert_dicts(node_type_keys, "key") @cached_property def node_type_key_data(self: SharedUtils) -> dict: - """ - node_type_key_data containing settings for this node_type. - """ + """node_type_key_data containing settings for this node_type.""" for node_type_key in self.node_type_keys: if node_type_key["type"] == self.type: return node_type_key # Not found - raise AristaAvdMissingVariableError(f"node_type_keys.[type=={self.type}]") + msg = f"node_type_keys.[type=={self.type}]" + raise AristaAvdMissingVariableError(msg) diff --git a/python-avd/pyavd/_eos_designs/shared_utils/overlay.py b/python-avd/pyavd/_eos_designs/shared_utils/overlay.py index cf85ff09795..05bd5f5b8f8 100644 --- a/python-avd/pyavd/_eos_designs/shared_utils/overlay.py +++ b/python-avd/pyavd/_eos_designs/shared_utils/overlay.py @@ -8,8 +8,8 @@ from re import fullmatch from typing import TYPE_CHECKING -from ..._errors import AristaAvdError, AristaAvdMissingVariableError -from ..._utils import get +from pyavd._errors import AristaAvdError, AristaAvdMissingVariableError +from pyavd._utils import get if TYPE_CHECKING: from . import SharedUtils @@ -17,16 +17,15 @@ class OverlayMixin: """ - Mixin Class providing a subset of SharedUtils - Class should only be used as Mixin to the SharedUtils class + Mixin Class providing a subset of SharedUtils. + + Class should only be used as Mixin to the SharedUtils class. Using type-hint on self to get proper type-hints on attributes across all Mixins. """ @cached_property def vtep_loopback(self: SharedUtils) -> str: - """ - The default is Loopback1 except for WAN devices where the default is Dps1. - """ + """The default is Loopback1 except for WAN devices where the default is Dps1.""" default_vtep_loopback = "Dps1" if self.is_wan_router else "Loopback1" return get(self.switch_data_combined, "vtep_loopback", default=default_vtep_loopback) @@ -79,7 +78,7 @@ def overlay_rd_type_vrf_admin_subfield(self: SharedUtils) -> str: vrf_admin_subfield_offset = self.overlay_rd_type["vrf_admin_subfield_offset"] return self.get_rd_admin_subfield_value(vrf_admin_subfield, vrf_admin_subfield_offset) - def get_rd_admin_subfield_value(self: SharedUtils, admin_subfield, admin_subfield_offset): + def get_rd_admin_subfield_value(self: SharedUtils, admin_subfield: str, admin_subfield_offset: int) -> str: if admin_subfield == "overlay_loopback_ip": return self.router_id @@ -91,19 +90,19 @@ def get_rd_admin_subfield_value(self: SharedUtils, admin_subfield, admin_subfiel if admin_subfield == "switch_id": if self.id is None: - raise AristaAvdMissingVariableError(f"'id' is not set on '{self.hostname}' and 'overlay_rd_type_admin_subfield' is set to 'switch_id'") + msg = f"'id' is not set on '{self.hostname}' and 'overlay_rd_type_admin_subfield' is set to 'switch_id'" + raise AristaAvdMissingVariableError(msg) return self.id + admin_subfield_offset - if fullmatch(r"[0-9]+", str(admin_subfield)): + if fullmatch(r"\d+", str(admin_subfield)): return str(int(admin_subfield) + admin_subfield_offset) try: ip_address(admin_subfield) - return admin_subfield except ValueError: - pass + return self.router_id - return self.router_id + return admin_subfield @cached_property def evpn_gateway_vxlan_l2(self: SharedUtils) -> bool: @@ -120,27 +119,26 @@ def evpn_gateway_vxlan_l3_inter_domain(self: SharedUtils) -> bool: @cached_property def overlay_routing_protocol_address_family(self: SharedUtils) -> str: overlay_routing_protocol_address_family = get(self.hostvars, "overlay_routing_protocol_address_family", default="ipv4") - if overlay_routing_protocol_address_family == "ipv6": - if not (self.underlay_ipv6 is True and self.underlay_rfc5549): - raise AristaAvdError( - "'overlay_routing_protocol_address_family: ipv6' is only supported in combination with 'underlay_ipv6: True' and 'underlay_rfc5549: True'" - ) + if overlay_routing_protocol_address_family == "ipv6" and not (self.underlay_ipv6 is True and self.underlay_rfc5549): + msg = "'overlay_routing_protocol_address_family: ipv6' is only supported in combination with 'underlay_ipv6: True' and 'underlay_rfc5549: True'" + raise AristaAvdError( + msg, + ) return overlay_routing_protocol_address_family @cached_property def evpn_encapsulation(self: SharedUtils) -> str: - """ - EVPN encapsulation based on fabric_evpn_encapsulation and node default_evpn_encapsulation. - """ + """EVPN encapsulation based on fabric_evpn_encapsulation and node default_evpn_encapsulation.""" return get(self.hostvars, "fabric_evpn_encapsulation", default=get(self.node_type_key_data, "default_evpn_encapsulation", default="vxlan")) @cached_property def evpn_soo(self: SharedUtils) -> str: """ Site-Of-Origin used as BGP extended community. + - For regular VTEPs this is :1 - For WAN routers this is : - - Otherwise this is :1 + - Otherwise this is :1. TODO: Reconsider if suffix should just be :1 for all WAN routers. """ @@ -173,9 +171,7 @@ def overlay_evpn(self: SharedUtils) -> bool: @cached_property def overlay_mpls(self: SharedUtils) -> bool: - """ - Set overlay_mpls to enable MPLS as the primary overlay - """ + """Set overlay_mpls to enable MPLS as the primary overlay.""" return any([self.overlay_evpn_mpls, self.overlay_vpn_ipv4, self.overlay_vpn_ipv6]) and not self.overlay_evpn_vxlan @cached_property diff --git a/python-avd/pyavd/_eos_designs/shared_utils/platform.py b/python-avd/pyavd/_eos_designs/shared_utils/platform.py index 2351392dd26..13b51d3a6e4 100644 --- a/python-avd/pyavd/_eos_designs/shared_utils/platform.py +++ b/python-avd/pyavd/_eos_designs/shared_utils/platform.py @@ -7,7 +7,7 @@ from re import search from typing import TYPE_CHECKING -from ..._utils import default, get +from pyavd._utils import default, get if TYPE_CHECKING: from . import SharedUtils @@ -187,8 +187,9 @@ class PlatformMixin: """ - Mixin Class providing a subset of SharedUtils - Class should only be used as Mixin to the SharedUtils class + Mixin Class providing a subset of SharedUtils. + + Class should only be used as Mixin to the SharedUtils class. Using type-hint on self to get proper type-hints on attributes across all Mixins. """ @@ -214,9 +215,7 @@ def platform_settings(self: SharedUtils) -> dict: @cached_property def default_interfaces(self: SharedUtils) -> dict: - """ - default_interfaces set based on default_interfaces - """ + """default_interfaces set based on default_interfaces.""" default_interfaces = get(self.hostvars, "default_interfaces", default=[]) device_platform = default(self.platform, "default") diff --git a/python-avd/pyavd/_eos_designs/shared_utils/ptp.py b/python-avd/pyavd/_eos_designs/shared_utils/ptp.py index 39074139277..25ebb511724 100644 --- a/python-avd/pyavd/_eos_designs/shared_utils/ptp.py +++ b/python-avd/pyavd/_eos_designs/shared_utils/ptp.py @@ -6,7 +6,7 @@ from functools import cached_property from typing import TYPE_CHECKING -from ..._utils import default, get, get_item +from pyavd._utils import default, get, get_item if TYPE_CHECKING: from . import SharedUtils @@ -47,8 +47,9 @@ class PtpMixin: """ - Mixin Class providing a subset of SharedUtils - Class should only be used as Mixin to the SharedUtils class + Mixin Class providing a subset of SharedUtils. + + Class should only be used as Mixin to the SharedUtils class. Using type-hint on self to get proper type-hints on attributes across all Mixins. """ @@ -72,7 +73,5 @@ def ptp_profile(self: SharedUtils) -> dict: @cached_property def ptp_profiles(self: SharedUtils) -> list: - """ - Return ptp_profiles - """ + """Return ptp_profiles.""" return get(self.hostvars, "ptp_profiles", default=DEFAULT_PTP_PROFILES) diff --git a/python-avd/pyavd/_eos_designs/shared_utils/routing.py b/python-avd/pyavd/_eos_designs/shared_utils/routing.py index 9a066809cf4..eabaf6af2c9 100644 --- a/python-avd/pyavd/_eos_designs/shared_utils/routing.py +++ b/python-avd/pyavd/_eos_designs/shared_utils/routing.py @@ -6,9 +6,9 @@ from functools import cached_property from typing import TYPE_CHECKING -from ..._errors import AristaAvdError, AristaAvdMissingVariableError -from ..._utils import get -from ...j2filters import range_expand +from pyavd._errors import AristaAvdError, AristaAvdMissingVariableError +from pyavd._utils import get +from pyavd.j2filters import range_expand if TYPE_CHECKING: from . import SharedUtils @@ -16,22 +16,21 @@ class RoutingMixin: """ - Mixin Class providing a subset of SharedUtils - Class should only be used as Mixin to the SharedUtils class + Mixin Class providing a subset of SharedUtils. + + Class should only be used as Mixin to the SharedUtils class. Using type-hint on self to get proper type-hints on attributes across all Mixins. """ @cached_property def underlay_routing_protocol(self: SharedUtils) -> str: default_underlay_routing_protocol = get(self.node_type_key_data, "default_underlay_routing_protocol", default="ebgp") - underlay_routing_protocol = str(get(self.hostvars, "underlay_routing_protocol", default=default_underlay_routing_protocol)).lower() - return underlay_routing_protocol + return str(get(self.hostvars, "underlay_routing_protocol", default=default_underlay_routing_protocol)).lower() @cached_property def overlay_routing_protocol(self: SharedUtils) -> str: default_overlay_routing_protocol = get(self.node_type_key_data, "default_overlay_routing_protocol", default="ebgp") - overlay_routing_protocol = str(get(self.hostvars, "overlay_routing_protocol", default=default_overlay_routing_protocol)).lower() - return overlay_routing_protocol + return str(get(self.hostvars, "overlay_routing_protocol", default=default_overlay_routing_protocol)).lower() @cached_property def overlay_address_families(self: SharedUtils) -> list: @@ -42,9 +41,7 @@ def overlay_address_families(self: SharedUtils) -> list: @cached_property def bgp(self: SharedUtils) -> bool: - """ - Boolean telling if BGP Routing should be configured. - """ + """Boolean telling if BGP Routing should be configured.""" return ( self.underlay_router and self.uplink_type in ["p2p", "p2p-vrfs", "lan"] @@ -59,31 +56,23 @@ def bgp(self: SharedUtils) -> bool: @cached_property def router_id(self: SharedUtils) -> str | None: - """ - Render IP address for router_id - """ + """Render IP address for router_id.""" if self.underlay_router: return self.ip_addressing.router_id() return None @cached_property def ipv6_router_id(self: SharedUtils) -> str | None: - """ - Render IPv6 address for router_id - """ + """Render IPv6 address for router_id.""" if self.underlay_router and self.underlay_ipv6: return self.ip_addressing.ipv6_router_id() return None @cached_property def isis_instance_name(self: SharedUtils) -> str | None: - if self.underlay_router: - if self.underlay_routing_protocol in ["isis", "isis-ldp", "isis-sr", "isis-sr-ldp"]: - if self.mpls_lsr: - default_isis_instance_name = "CORE" - else: - default_isis_instance_name = "EVPN_UNDERLAY" - return get(self.hostvars, "underlay_isis_instance_name", default=default_isis_instance_name) + if self.underlay_router and self.underlay_routing_protocol in ["isis", "isis-ldp", "isis-sr", "isis-sr-ldp"]: + default_isis_instance_name = "CORE" if self.mpls_lsr else "EVPN_UNDERLAY" + return get(self.hostvars, "underlay_isis_instance_name", default=default_isis_instance_name) return None @cached_property @@ -112,10 +101,13 @@ def bgp_as(self: SharedUtils) -> str | None: return bgp_as_range_expanded[self.mlag_switch_ids["primary"] - 1] if self.id is None: - raise AristaAvdMissingVariableError(f"'id' is not set on '{self.hostname}' and is required when expanding 'bgp_as'") + msg = f"'id' is not set on '{self.hostname}' and is required when expanding 'bgp_as'" + raise AristaAvdMissingVariableError(msg) return bgp_as_range_expanded[self.id - 1] except IndexError as exc: - raise AristaAvdError(f"Unable to allocate BGP AS: bgp_as range is too small ({len(bgp_as_range_expanded)}) for the id of the device") from exc + msg = f"Unable to allocate BGP AS: bgp_as range is too small ({len(bgp_as_range_expanded)}) for the id of the device" + raise AristaAvdError(msg) from exc + return None @cached_property def always_configure_ip_routing(self: SharedUtils) -> bool: diff --git a/python-avd/pyavd/_eos_designs/shared_utils/switch_data.py b/python-avd/pyavd/_eos_designs/shared_utils/switch_data.py index dfc3498d80a..c23c9d6bd23 100644 --- a/python-avd/pyavd/_eos_designs/shared_utils/switch_data.py +++ b/python-avd/pyavd/_eos_designs/shared_utils/switch_data.py @@ -6,8 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ..._utils import get, merge -from ...j2filters import convert_dicts +from pyavd._utils import get, merge +from pyavd.j2filters import convert_dicts if TYPE_CHECKING: from . import SharedUtils @@ -15,15 +15,16 @@ class SwitchDataMixin: """ - Mixin Class providing a subset of SharedUtils - Class should only be used as Mixin to the SharedUtils class + Mixin Class providing a subset of SharedUtils. + + Class should only be used as Mixin to the SharedUtils class. Using type-hint on self to get proper type-hints on attributes across all Mixins. """ @cached_property def switch_data(self: SharedUtils) -> dict: """ - internal _switch_data containing inherited vars from fabric_topology data model + internal _switch_data containing inherited vars from fabric_topology data model. Vars are inherited like: .defaults -> @@ -31,7 +32,7 @@ def switch_data(self: SharedUtils) -> dict: .node_groups.[].nodes.[] -> .nodes.[] - Returns + Returns: ------- dict node_group : dict @@ -78,15 +79,10 @@ def switch_data(self: SharedUtils) -> dict: @property def switch_data_combined(self: SharedUtils) -> dict: - """ - switch_data_combined containing self._switch_data['combined'] for easier reference. - """ + """switch_data_combined containing self._switch_data['combined'] for easier reference.""" return self.switch_data["combined"] @cached_property def switch_data_node_group_nodes(self: SharedUtils) -> list: - """ - switch_data_node_group_nodes pointing to - self.switch_data['node_group']['nodes'] for easier reference. - """ + """switch_data_node_group_nodes pointing to self.switch_data['node_group']['nodes'] for easier reference.""" return get(self.switch_data, "node_group.nodes", default=[]) diff --git a/python-avd/pyavd/_eos_designs/shared_utils/underlay.py b/python-avd/pyavd/_eos_designs/shared_utils/underlay.py index 0212839db71..15422209adb 100644 --- a/python-avd/pyavd/_eos_designs/shared_utils/underlay.py +++ b/python-avd/pyavd/_eos_designs/shared_utils/underlay.py @@ -6,7 +6,7 @@ from functools import cached_property from typing import TYPE_CHECKING -from ..._utils import get, get_item +from pyavd._utils import get, get_item if TYPE_CHECKING: from . import SharedUtils @@ -14,8 +14,9 @@ class UnderlayMixin: """ - Mixin Class providing a subset of SharedUtils - Class should only be used as Mixin to the SharedUtils class + Mixin Class providing a subset of SharedUtils. + + Class should only be used as Mixin to the SharedUtils class. Using type-hint on self to get proper type-hints on attributes across all Mixins. """ @@ -89,7 +90,7 @@ def underlay_multicast_rp_interfaces(self: SharedUtils) -> list[dict] | None: "name": f"Loopback{node_entry['loopback_number']}", "description": get(node_entry, "description", default="PIM RP"), "ip_address": f"{rp_entry['rp']}/32", - } + }, ) if underlay_multicast_rp_interfaces: @@ -111,7 +112,7 @@ def underlay_ospf_process_id(self: SharedUtils) -> int: @cached_property def underlay_ospf_area(self: SharedUtils) -> str: - return get(self.hostvars, "underlay_ospf_area", default="0.0.0.0") + return get(self.hostvars, "underlay_ospf_area", default="0.0.0.0") # noqa: S104 @cached_property def underlay_filter_peer_as(self: SharedUtils) -> bool: diff --git a/python-avd/pyavd/_eos_designs/shared_utils/utils.py b/python-avd/pyavd/_eos_designs/shared_utils/utils.py index 55b3a13c97f..dec1f54cf86 100644 --- a/python-avd/pyavd/_eos_designs/shared_utils/utils.py +++ b/python-avd/pyavd/_eos_designs/shared_utils/utils.py @@ -6,24 +6,26 @@ from functools import lru_cache from typing import TYPE_CHECKING -from ..._errors import AristaAvdError -from ..._utils import get, get_item, merge, template_var +from pyavd._errors import AristaAvdError +from pyavd._utils import get, get_item, merge, template_var if TYPE_CHECKING: - from ...eos_designs_facts import EosDesignsFacts + from pyavd._eos_designs.eos_designs_facts import EosDesignsFacts + from . import SharedUtils class UtilsMixin: """ - Mixin Class providing a subset of SharedUtils - Class should only be used as Mixin to the SharedUtils class + Mixin Class providing a subset of SharedUtils. + + Class should only be used as Mixin to the SharedUtils class. Using type-hint on self to get proper type-hints on attributes across all Mixins. """ def get_peer_facts(self: SharedUtils, peer_name: str, required: bool = True) -> EosDesignsFacts | dict | None: """ - util function to retrieve peer_facts for peer_name + util function to retrieve peer_facts for peer_name. returns avd_switch_facts.{peer_name}.switch @@ -42,19 +44,16 @@ def get_peer_facts(self: SharedUtils, peer_name: str, required: bool = True) -> ) def template_var(self: SharedUtils, template_file: str, template_vars: dict) -> str: - """ - Run the simplified templater using the passed Ansible "templar" engine. - """ + """Run the simplified templater using the passed Ansible "templar" engine.""" try: return template_var(template_file, template_vars, self.templar) except Exception as e: - raise AristaAvdError(f"Error during templating of template: {template_file}") from e + msg = f"Error during templating of template: {template_file}" + raise AristaAvdError(msg) from e - @lru_cache + @lru_cache # noqa: B019 def get_merged_port_profile(self: SharedUtils, profile_name: str) -> list: - """ - Return list of merged "port_profiles" where "parent_profile" has been applied. - """ + """Return list of merged "port_profiles" where "parent_profile" has been applied.""" port_profile = get_item(self.port_profiles, "profile", profile_name, default={}) if "parent_profile" in port_profile: parent_profile = get_item(self.port_profiles, "profile", port_profile["parent_profile"], default={}) @@ -67,6 +66,7 @@ def get_merged_port_profile(self: SharedUtils, profile_name: str) -> list: def get_merged_adapter_settings(self: SharedUtils, adapter_or_network_port_settings: dict) -> dict: """ Applies port-profiles to the given adapter_or_network_port and returns the combined result. + adapter_or_network_port can either be an adapter of a connected endpoint or one item under network_ports. """ profile_name = adapter_or_network_port_settings.get("profile") diff --git a/python-avd/pyavd/_eos_designs/shared_utils/wan.py b/python-avd/pyavd/_eos_designs/shared_utils/wan.py index a9f6527f956..915dca0589c 100644 --- a/python-avd/pyavd/_eos_designs/shared_utils/wan.py +++ b/python-avd/pyavd/_eos_designs/shared_utils/wan.py @@ -6,9 +6,9 @@ from functools import cached_property from typing import TYPE_CHECKING, Literal -from ..._errors import AristaAvdError, AristaAvdMissingVariableError -from ..._utils import default, get, get_ip_from_pool, get_item, strip_empties_from_dict -from ...j2filters import natural_sort +from pyavd._errors import AristaAvdError, AristaAvdMissingVariableError +from pyavd._utils import default, get, get_ip_from_pool, get_item, strip_empties_from_dict +from pyavd.j2filters import natural_sort if TYPE_CHECKING: from . import SharedUtils @@ -16,8 +16,9 @@ class WanMixin: """ - Mixin Class providing a subset of SharedUtils - Class should only be used as Mixin to the SharedUtils class + Mixin Class providing a subset of SharedUtils. + + Class should only be used as Mixin to the SharedUtils class. Using type-hint on self to get proper type-hints on attributes across all Mixins. """ @@ -33,11 +34,14 @@ def wan_role(self: SharedUtils) -> str | None: default_wan_role = get(self.node_type_key_data, "default_wan_role", default=None) wan_role = get(self.switch_data_combined, "wan_role", default=default_wan_role) if wan_role is not None and self.overlay_routing_protocol != "ibgp": - raise AristaAvdError("Only 'ibgp' is supported as 'overlay_routing_protocol' for WAN nodes.") + msg = "Only 'ibgp' is supported as 'overlay_routing_protocol' for WAN nodes." + raise AristaAvdError(msg) if wan_role == "server" and self.evpn_role != "server": - raise AristaAvdError("'wan_role' server requires 'evpn_role' server.") + msg = "'wan_role' server requires 'evpn_role' server." + raise AristaAvdError(msg) if wan_role == "client" and self.evpn_role != "client": - raise AristaAvdError("'wan_role' client requires 'evpn_role' client.") + msg = "'wan_role' client requires 'evpn_role' client." + raise AristaAvdError(msg) return wan_role @cached_property @@ -63,9 +67,7 @@ def wan_listen_ranges(self: SharedUtils) -> list: @cached_property def cv_pathfinder_transit_mode(self: SharedUtils) -> Literal["region", "zone"] | None: - """ - When wan_mode is CV Pathfinder, return the transit mode "region", "zone" or None. - """ + """When wan_mode is CV Pathfinder, return the transit mode "region", "zone" or None.""" if not self.is_cv_pathfinder_client: return None @@ -74,22 +76,19 @@ def cv_pathfinder_transit_mode(self: SharedUtils) -> Literal["region", "zone"] | @cached_property def wan_interfaces(self: SharedUtils) -> list: """ - As a first approach, only interfaces under node config l3_interfaces can be considered - as WAN interfaces. + As a first approach, only interfaces under node config l3_interfaces can be considered as WAN interfaces. + This may need to be made wider. This also may require a different format for the dictionaries inside the list. """ if not self.is_wan_router: return [] - wan_interfaces = [] - for interface in self.l3_interfaces: - if get(interface, "wan_carrier") is not None: - wan_interfaces.append(interface) - + wan_interfaces = [interface for interface in self.l3_interfaces if get(interface, "wan_carrier") is not None] if not wan_interfaces: + msg = "At least one WAN interface must be configured on a WAN router. Add WAN interfaces under `l3_interfaces` node setting with `wan_carrier` set." raise AristaAvdError( - "At least one WAN interface must be configured on a WAN router. Add WAN interfaces under `l3_interfaces` node setting with `wan_carrier` set." + msg, ) return wan_interfaces @@ -100,10 +99,11 @@ def wan_carriers(self: SharedUtils) -> list: @cached_property def wan_local_carriers(self: SharedUtils) -> list: """ - List of carriers present on this router based on the wan_interfaces with the associated WAN interfaces + List of carriers present on this router based on the wan_interfaces with the associated WAN interfaces. + interfaces: - name: ... - ip: ... (for route-servers the IP may come from wan_route_servers) + ip: ... (for route-servers the IP may come from wan_route_servers). """ if not self.is_wan_router: return [] @@ -128,8 +128,8 @@ def wan_local_carriers(self: SharedUtils) -> list: "public_ip": self.get_public_ip_for_wan_interface(interface), "connected_to_pathfinder": get(interface, "connected_to_pathfinder", default=True), "wan_circuit_id": get(interface, "wan_circuit_id"), - } - ) + }, + ), ) return list(local_carriers_dict.values()) @@ -137,7 +137,8 @@ def wan_local_carriers(self: SharedUtils) -> list: @cached_property def wan_path_groups(self: SharedUtils) -> list: """ - List of path-groups defined in the top level key `wan_path_groups` + List of path-groups defined in the top level key `wan_path_groups`. + Updating default preference for each path-group to 'preferred' if not set. """ path_groups = get(self.hostvars, "wan_path_groups", required=True) @@ -149,6 +150,7 @@ def wan_path_groups(self: SharedUtils) -> list: def wan_local_path_groups(self: SharedUtils) -> list: """ List of path_groups present on this router based on the local carriers. + Also add for each path_groups the local interfaces in a data structure interfaces: - name: ... @@ -179,15 +181,14 @@ def wan_local_path_groups(self: SharedUtils) -> list: @cached_property def wan_local_path_group_names(self: SharedUtils) -> list: - """ - Return a list of wan_local_path_group names to be used by HA peer and in various places - """ + """Return a list of wan_local_path_group names to be used by HA peer and in various places.""" return [path_group["name"] for path_group in self.wan_local_path_groups] @cached_property def this_wan_route_server(self: SharedUtils) -> dict: """ Returns the instance for this wan_rs found under wan_route_servers. + Should only be called when the device is actually a wan_rs. """ wan_route_servers = get(self.hostvars, "wan_route_servers", default=[]) @@ -195,7 +196,7 @@ def this_wan_route_server(self: SharedUtils) -> dict: def get_public_ip_for_wan_interface(self: SharedUtils, interface: dict) -> str: """ - Takes a dict which looks like `l3_interface` from node config + Takes a dict which looks like `l3_interface` from node config. If not a WAN route-server this returns public IP and if not found then the interface IP without a mask. @@ -217,17 +218,20 @@ def get_public_ip_for_wan_interface(self: SharedUtils, interface: dict) -> str: return interface["public_ip"] if interface["ip_address"] == "dhcp": - raise AristaAvdError( + msg = ( f"The IP address for WAN interface '{interface['name']}' on Route Server '{self.hostname}' is set to 'dhcp'. " "Clients need to peer with a static IP which must be set under the 'wan_route_servers.path_groups.interfaces' key." ) + raise AristaAvdError( + msg, + ) return interface["ip_address"].split("/", maxsplit=1)[0] @cached_property def wan_site(self: SharedUtils) -> dict | None: """ - WAN site for CV Pathfinder + WAN site for CV Pathfinder. The site is required for edges, but optional for pathfinders """ @@ -266,7 +270,7 @@ def wan_site(self: SharedUtils) -> dict | None: @cached_property def wan_region(self: SharedUtils) -> dict | None: """ - WAN region for CV Pathfinder + WAN region for CV Pathfinder. The region is required for edges, but optional for pathfinders """ @@ -280,7 +284,10 @@ def wan_region(self: SharedUtils) -> dict | None: return None regions = get( - self.hostvars, "cv_pathfinder_regions", required=True, org_key="'cv_pathfinder_regions' key must be set when 'wan_mode' is 'cv-pathfinder'." + self.hostvars, + "cv_pathfinder_regions", + required=True, + org_key="'cv_pathfinder_regions' key must be set when 'wan_mode' is 'cv-pathfinder'.", ) return get_item( @@ -294,7 +301,7 @@ def wan_region(self: SharedUtils) -> dict | None: @property def wan_zone(self: SharedUtils) -> dict: """ - WAN zone for Pathfinder + WAN zone for Pathfinder. Currently, only one default zone with ID 1 is supported. """ @@ -328,22 +335,29 @@ def filtered_wan_route_servers(self: SharedUtils) -> dict: # Only ibgp is supported for WAN so raise if peer from peer_facts BGP AS is different from ours. if bgp_as != self.bgp_as: - raise AristaAvdError(f"Only iBGP is supported for WAN, the BGP AS {bgp_as} on {wan_rs} is different from our own: {self.bgp_as}.") + msg = f"Only iBGP is supported for WAN, the BGP AS {bgp_as} on {wan_rs} is different from our own: {self.bgp_as}." + raise AristaAvdError(msg) # Prefer values coming from the input variables over peer facts vtep_ip = get(wan_rs_dict, "vtep_ip", default=peer_facts.get("vtep_ip")) wan_path_groups = get(wan_rs_dict, "path_groups", default=peer_facts.get("wan_path_groups")) if vtep_ip is None: - raise AristaAvdMissingVariableError( + msg = ( f"'vtep_ip' is missing for peering with {wan_rs}, either set it in under 'wan_route_servers' or something is wrong with the peer" " facts." ) - if wan_path_groups is None: raise AristaAvdMissingVariableError( + msg, + ) + if wan_path_groups is None: + msg = ( f"'wan_path_groups' is missing for peering with {wan_rs}, either set it in under 'wan_route_servers'" " or something is wrong with the peer facts." ) + raise AristaAvdMissingVariableError( + msg, + ) else: # Retrieve the values from the dictionary, making them required if the peer_facts were not found @@ -365,7 +379,7 @@ def filtered_wan_route_servers(self: SharedUtils) -> dict: } # If no common path-group then skip - # TODO - this may need to change when `import` path-groups is available + # TODO: - this may need to change when `import` path-groups is available if len(wan_rs_result_dict["wan_path_groups"]) > 0: wan_route_servers[wan_rs] = strip_empties_from_dict(wan_rs_result_dict) @@ -386,23 +400,17 @@ def should_connect_to_wan_rs(self: SharedUtils, path_groups: list) -> bool: @cached_property def is_cv_pathfinder_router(self: SharedUtils) -> bool: - """ - Return True is the current wan_mode is cv-pathfinder and the device is a wan router. - """ + """Return True is the current wan_mode is cv-pathfinder and the device is a wan router.""" return self.wan_mode == "cv-pathfinder" and self.is_wan_router @cached_property def is_cv_pathfinder_client(self: SharedUtils) -> bool: - """ - Return True is the current wan_mode is cv-pathfinder and the device is either an edge or a transit device - """ + """Return True is the current wan_mode is cv-pathfinder and the device is either an edge or a transit device.""" return self.is_cv_pathfinder_router and self.is_wan_client @cached_property def is_cv_pathfinder_server(self: SharedUtils) -> bool: - """ - Return True is the current wan_mode is cv-pathfinder and the device is a pathfinder device - """ + """Return True is the current wan_mode is cv-pathfinder and the device is a pathfinder device.""" return self.is_cv_pathfinder_router and self.is_wan_server @cached_property @@ -422,20 +430,19 @@ def cv_pathfinder_role(self: SharedUtils) -> str | None: @cached_property def wan_ha(self: SharedUtils) -> bool: - """ - Only trigger HA if 2 cv_pathfinder clients are in the same group and wan_ha.enabled is true - """ + """Only trigger HA if 2 cv_pathfinder clients are in the same group and wan_ha.enabled is true.""" if not (self.is_cv_pathfinder_client and len(self.switch_data_node_group_nodes) == 2): return False if (ha_enabled := get(self.switch_data_combined, "wan_ha.enabled")) is None: + msg = ( + "Placing two WAN routers in a common node group will trigger WAN HA in a future AVD release. " + "Currently WAN HA is in preview, so it will not be automatically enabled. " + "To avoid unplanned configuration changes once the feature is released, " + "it is currently required to set 'wan_ha.enabled' to 'true' or 'false'." + ) raise AristaAvdError( - ( - "Placing two WAN routers in a common node group will trigger WAN HA in a future AVD release. " - "Currently WAN HA is in preview, so it will not be automatically enabled. " - "To avoid unplanned configuration changes once the feature is released, " - "it is currently required to set 'wan_ha.enabled' to 'true' or 'false'." - ) + msg, ) return ha_enabled @@ -447,6 +454,7 @@ def wan_ha_ipsec(self: SharedUtils) -> bool: def wan_ha_path_group_name(self: SharedUtils) -> str: """ Return HA path group name for the WAN design. + Used in both network services and overlay python modules. """ return get(self.hostvars, "wan_ha.lan_ha_path_group_name", default="LAN_HA") @@ -454,8 +462,7 @@ def wan_ha_path_group_name(self: SharedUtils) -> str: @cached_property def is_first_ha_peer(self: SharedUtils) -> bool: """ - Returns True if the device is the first device in the node_group, - false otherwise. + Returns True if the device is the first device in the node_group, false otherwise. This should be called only from functions which have checked that HA is enabled. """ @@ -463,43 +470,34 @@ def is_first_ha_peer(self: SharedUtils) -> bool: @cached_property def wan_ha_peer(self: SharedUtils) -> str | None: - """ - Return the name of the WAN HA peer. - """ + """Return the name of the WAN HA peer.""" if not self.wan_ha: return None if self.is_first_ha_peer: return self.switch_data_node_group_nodes[1]["name"] if self.switch_data_node_group_nodes[1]["name"] == self.hostname: return self.switch_data_node_group_nodes[0]["name"] - raise AristaAvdError("Unable to find WAN HA peer within same node group") + msg = "Unable to find WAN HA peer within same node group" + raise AristaAvdError(msg) @cached_property def configured_wan_ha_interfaces(self: SharedUtils) -> set: - """ - Read the device wan_ha.ha_interfaces node settings - """ + """Read the device wan_ha.ha_interfaces node settings.""" return get(self.switch_data_combined, "wan_ha.ha_interfaces", default=[]) @cached_property def vrf_default_uplinks(self: SharedUtils) -> list: - """ - Return the uplinkss in VRF default - """ + """Return the uplinkss in VRF default.""" return [uplink for uplink in self.get_switch_fact("uplinks") if get(uplink, "vrf") is None] @cached_property def vrf_default_uplink_interfaces(self: SharedUtils) -> list: - """ - Return the uplink interfaces in VRF default - """ + """Return the uplink interfaces in VRF default.""" return [uplink["interface"] for uplink in self.vrf_default_uplinks] @cached_property def use_uplinks_for_wan_ha(self: SharedUtils) -> bool: - """ - Return true or false - """ + """Return true or false.""" interfaces = set(self.configured_wan_ha_interfaces) uplink_interfaces = set(self.vrf_default_uplink_interfaces) @@ -507,14 +505,16 @@ def use_uplinks_for_wan_ha(self: SharedUtils) -> bool: return True if not interfaces.intersection(uplink_interfaces): if len(interfaces) > 1: - raise AristaAvdError("AVD does not support multiple HA interfaces when not using uplinks.") + msg = "AVD does not support multiple HA interfaces when not using uplinks." + raise AristaAvdError(msg) return False - raise AristaAvdError("Either all `wan_ha.ha_interfaces` must be uplink interfaces or all of them must not be uplinks.") + msg = "Either all `wan_ha.ha_interfaces` must be uplink interfaces or all of them must not be uplinks." + raise AristaAvdError(msg) @cached_property def wan_ha_interfaces(self: SharedUtils) -> list: """ - Return the list of interfaces for WAN HA + Return the list of interfaces for WAN HA. If using uplinks for WAN HA, returns the filtered uplinks if self.configured_wan_ha_interfaces is not empty else returns all of them. @@ -527,8 +527,9 @@ def wan_ha_interfaces(self: SharedUtils) -> list: @cached_property def wan_ha_peer_ip_addresses(self: SharedUtils) -> list: """ - Read the IP addresses/prefix length from HA peer uplinks - Used also to generate the prefix list of the PEER HA prefixes + Read the IP addresses/prefix length from HA peer uplinks. + + Used also to generate the prefix list of the PEER HA prefixes. """ interfaces = set(self.configured_wan_ha_interfaces) ip_addresses = [] @@ -556,6 +557,7 @@ def wan_ha_peer_ip_addresses(self: SharedUtils) -> list: def wan_ha_ip_addresses(self: SharedUtils) -> list: """ Read the IP addresses/prefix length from this device uplinks used for HA. + Used to generate the prefix list. """ interfaces = set(self.configured_wan_ha_interfaces) @@ -605,9 +607,7 @@ def get_wan_ha_ip_address(self: SharedUtils, local: bool) -> str | None: return f"{ip_address}/31" def generate_lb_policy_name(self: SharedUtils, name: str) -> str: - """ - Returns LB-{name} - """ + """Returns LB-{name}.""" return f"LB-{name}" @cached_property diff --git a/python-avd/pyavd/_eos_designs/structured_config/__init__.py b/python-avd/pyavd/_eos_designs/structured_config/__init__.py index 0474733a553..ca73ad6b09a 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/__init__.py +++ b/python-avd/pyavd/_eos_designs/structured_config/__init__.py @@ -4,11 +4,11 @@ from __future__ import annotations from collections import ChainMap +from typing import TYPE_CHECKING + +from pyavd._eos_designs.shared_utils import SharedUtils +from pyavd._utils import get, merge -from ..._utils import get, merge -from ...avd_schema_tools import AvdSchemaTools -from ..avdfacts import AvdFacts -from ..shared_utils import SharedUtils from .base import AvdStructuredConfigBase from .connected_endpoints import AvdStructuredConfigConnectedEndpoints from .core_interfaces_and_l3_edge import AvdStructuredConfigCoreInterfacesAndL3Edge @@ -21,6 +21,10 @@ from .overlay import AvdStructuredConfigOverlay from .underlay import AvdStructuredConfigUnderlay +if TYPE_CHECKING: + from pyavd._eos_designs.avdfacts import AvdFacts + from pyavd.avd_schema_tools import AvdSchemaTools + AVD_STRUCTURED_CONFIG_CLASSES = [ AvdStructuredConfigBase, AvdStructuredConfigMlag, @@ -46,7 +50,7 @@ def get_structured_config( - vars: dict, + vars: dict, # noqa: A002 input_schema_tools: AvdSchemaTools, output_schema_tools: AvdSchemaTools, result: dict, diff --git a/python-avd/pyavd/_eos_designs/structured_config/base/__init__.py b/python-avd/pyavd/_eos_designs/structured_config/base/__init__.py index 0095b78f150..0ff38f0e1de 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/base/__init__.py +++ b/python-avd/pyavd/_eos_designs/structured_config/base/__init__.py @@ -5,10 +5,11 @@ from functools import cached_property -from ...._errors import AristaAvdMissingVariableError -from ...._utils import default, get, strip_null_from_data -from ....j2filters import convert_dicts, natural_sort -from ...avdfacts import AvdFacts +from pyavd._eos_designs.avdfacts import AvdFacts +from pyavd._errors import AristaAvdMissingVariableError +from pyavd._utils import default, get, strip_null_from_data +from pyavd.j2filters import convert_dicts, natural_sort + from .ntp import NtpMixin from .snmp_server import SnmpServerMixin @@ -37,16 +38,15 @@ def is_deployed(self) -> bool: @cached_property def serial_number(self) -> str | None: - """ - serial_number variable set based on serial_number fact - """ + """serial_number variable set based on serial_number fact.""" return self.shared_utils.serial_number @cached_property def router_bgp(self) -> dict | None: """ - router_bgp set based on switch.bgp_as, switch.bgp_defaults, router_id facts - and aggregating the values of bgp_maximum_paths and bgp_ecmp variables + Structured config for router_bgp. + + router_bgp set based on switch.bgp_as, switch.bgp_defaults, router_id facts and aggregating the values of bgp_maximum_paths and bgp_ecmp variables. """ if self.shared_utils.bgp_as is None: return None @@ -99,73 +99,62 @@ def router_bgp(self) -> dict | None: @cached_property def static_routes(self) -> list | None: - """ - static_routes set based on mgmt_gateway, mgmt_destination_networks and mgmt_interface_vrf - """ + """static_routes set based on mgmt_gateway, mgmt_destination_networks and mgmt_interface_vrf.""" if self.shared_utils.mgmt_gateway is None: return None - static_routes = [] if (mgmt_destination_networks := get(self._hostvars, "mgmt_destination_networks")) is not None: - for mgmt_destination_network in mgmt_destination_networks: - static_routes.append( - { - "vrf": self.shared_utils.mgmt_interface_vrf, - "destination_address_prefix": mgmt_destination_network, - "gateway": self.shared_utils.mgmt_gateway, - } - ) - else: - static_routes.append( + return [ { "vrf": self.shared_utils.mgmt_interface_vrf, - "destination_address_prefix": "0.0.0.0/0", + "destination_address_prefix": mgmt_destination_network, "gateway": self.shared_utils.mgmt_gateway, } - ) + for mgmt_destination_network in mgmt_destination_networks + ] - return static_routes + return [ + { + "vrf": self.shared_utils.mgmt_interface_vrf, + "destination_address_prefix": "0.0.0.0/0", + "gateway": self.shared_utils.mgmt_gateway, + } + ] @cached_property def ipv6_static_routes(self) -> list | None: - """ - ipv6_static_routes set based on ipv6_mgmt_gateway, ipv6_mgmt_destination_networks and mgmt_interface_vrf - """ + """ipv6_static_routes set based on ipv6_mgmt_gateway, ipv6_mgmt_destination_networks and mgmt_interface_vrf.""" if self.shared_utils.ipv6_mgmt_gateway is None or self.shared_utils.ipv6_mgmt_ip is None: return None - ipv6_static_routes = [] if (ipv6_mgmt_destination_networks := get(self._hostvars, "ipv6_mgmt_destination_networks")) is not None: - for mgmt_destination_network in ipv6_mgmt_destination_networks: - ipv6_static_routes.append( - { - "vrf": self.shared_utils.mgmt_interface_vrf, - "destination_address_prefix": mgmt_destination_network, - "gateway": self.shared_utils.ipv6_mgmt_gateway, - } - ) - else: - ipv6_static_routes.append( + return [ { "vrf": self.shared_utils.mgmt_interface_vrf, - "destination_address_prefix": "::/0", + "destination_address_prefix": mgmt_destination_network, "gateway": self.shared_utils.ipv6_mgmt_gateway, } - ) + for mgmt_destination_network in ipv6_mgmt_destination_networks + ] - return ipv6_static_routes + return [ + { + "vrf": self.shared_utils.mgmt_interface_vrf, + "destination_address_prefix": "::/0", + "gateway": self.shared_utils.ipv6_mgmt_gateway, + }, + ] @cached_property def service_routing_protocols_model(self) -> str: - """ - service_routing_protocols_model set to 'multi-agent' - """ + """service_routing_protocols_model set to 'multi-agent'.""" return "multi-agent" @cached_property def ip_routing(self) -> bool | None: """ For l3 devices, configure ip routing unless ip_routing_ipv6_interfaces is True. + For other devices only configure if "always_configure_ip_routing" is True. """ if not self.shared_utils.underlay_router and not self.shared_utils.always_configure_ip_routing: @@ -177,9 +166,7 @@ def ip_routing(self) -> bool | None: @cached_property def ipv6_unicast_routing(self) -> bool | None: - """ - ipv6_unicast_routing set based on underlay_rfc5549 and underlay_ipv6 - """ + """ipv6_unicast_routing set based on underlay_rfc5549 and underlay_ipv6.""" if not self.shared_utils.underlay_router and not self.shared_utils.always_configure_ip_routing: return None @@ -189,9 +176,7 @@ def ipv6_unicast_routing(self) -> bool | None: @cached_property def ip_routing_ipv6_interfaces(self) -> bool | None: - """ - ip_routing_ipv6_interfaces set based on underlay_rfc5549 variable - """ + """ip_routing_ipv6_interfaces set based on underlay_rfc5549 variable.""" if not self.shared_utils.underlay_router and not self.shared_utils.always_configure_ip_routing: return None @@ -201,10 +186,7 @@ def ip_routing_ipv6_interfaces(self) -> bool | None: @cached_property def router_multicast(self) -> dict | None: - """ - router_multicast set based on underlay_multicast, underlay_router - and switch.evpn_multicast facts - """ + """router_multicast set based on underlay_multicast, underlay_router and switch.evpn_multicast facts.""" if not self.shared_utils.underlay_multicast: return None @@ -216,16 +198,15 @@ def router_multicast(self) -> dict | None: @cached_property def hardware_counters(self) -> dict | None: - """ - hardware_counters set based on hardware_counters.features variable - """ + """hardware_counters set based on hardware_counters.features variable.""" return get(self._hostvars, "hardware_counters") @cached_property def hardware(self) -> dict | None: """ hardware set based on platform_speed_groups variable and switch.platform fact. - Converting nested dict to list of dict to support avd_v4.0 + + Converting nested dict to list of dict to support avd_v4.0. """ platform_speed_groups = get(self._hostvars, "platform_speed_groups") switch_platform = self.shared_utils.platform @@ -248,11 +229,12 @@ def hardware(self) -> dict | None: for speed_group in natural_sort(tmp_speed_groups): hardware["speed_groups"].append({"speed_group": speed_group, "serdes": tmp_speed_groups[speed_group]}) return hardware + return None @cached_property def daemon_terminattr(self) -> dict | None: """ - daemon_terminattr set based on cvp_instance_ip and cvp_instance_ips variables + daemon_terminattr set based on cvp_instance_ip and cvp_instance_ips variables. Updating cvaddrs and cvauth considering conditions for cvaas and cvp_on_prem IPs @@ -277,7 +259,7 @@ def daemon_terminattr(self) -> dict | None: daemon_terminattr["cvauth"] = { "method": "token-secure", # Ignoring sonar-lint false positive for tmp path since this is config for EOS - "token_file": get(self._hostvars, "cvp_token_file", "/tmp/cv-onboarding-token"), # NOSONAR + "token_file": get(self._hostvars, "cvp_token_file", "/tmp/cv-onboarding-token"), # NOSONAR # noqa: S108 } else: # updating for cvp_on_prem_ips @@ -292,7 +274,7 @@ def daemon_terminattr(self) -> dict | None: daemon_terminattr["cvauth"] = { "method": "token", # Ignoring sonar-lint false positive for tmp path since this is config for EOS - "token_file": get(self._hostvars, "cvp_token_file", "/tmp/token"), # NOSONAR + "token_file": get(self._hostvars, "cvp_token_file", "/tmp/token"), # NOSONAR # noqa: S108 } daemon_terminattr["cvvrf"] = self.shared_utils.mgmt_interface_vrf @@ -304,20 +286,18 @@ def daemon_terminattr(self) -> dict | None: @cached_property def vlan_internal_order(self) -> dict | None: - """ - vlan_internal_order set based on internal_vlan_order data-model - """ + """vlan_internal_order set based on internal_vlan_order data-model.""" if self.shared_utils.wan_role: return None - DEFAULT_INTERNAL_VLAN_ORDER = { + default_internal_vlan_order = { "allocation": "ascending", "range": { "beginning": 1006, "ending": 1199, }, } - return get(self._hostvars, "internal_vlan_order", default=DEFAULT_INTERNAL_VLAN_ORDER) + return get(self._hostvars, "internal_vlan_order", default=default_internal_vlan_order) @cached_property def transceiver_qsfp_default_mode_4x10(self) -> bool | None: @@ -331,35 +311,26 @@ def transceiver_qsfp_default_mode_4x10(self) -> bool | None: @cached_property def event_monitor(self) -> dict | None: - """ - event_monitor set based on event_monitor data-model - """ + """event_monitor set based on event_monitor data-model.""" if get(self._hostvars, "event_monitor") is True: return {"enabled": "true"} return None @cached_property def event_handlers(self) -> list | None: - """ - event_handlers set based on event_handlers data-model - """ + """event_handlers set based on event_handlers data-model.""" return get(self._hostvars, "event_handlers") @cached_property def load_interval(self) -> dict | None: - """ - load_interval set based on load_interval_default variable - """ + """load_interval set based on load_interval_default variable.""" if (load_interval_default := get(self._hostvars, "load_interval_default")) is not None: return {"default": load_interval_default} return None @cached_property def queue_monitor_length(self) -> dict | None: - """ - queue_monitor_length set based on queue_monitor_length data-model and - platform_settings.feature_support.queue_monitor_length_notify fact - """ + """queue_monitor_length set based on queue_monitor_length data-model and platform_settings.feature_support.queue_monitor_length_notify fact.""" if (queue_monitor_length := get(self._hostvars, "queue_monitor_length")) is None: return None @@ -371,9 +342,7 @@ def queue_monitor_length(self) -> dict | None: @cached_property def ip_name_servers(self) -> list | None: - """ - ip_name_servers set based on name_servers data-model and mgmt_interface_vrf - """ + """ip_name_servers set based on name_servers data-model and mgmt_interface_vrf.""" ip_name_servers = [ { "ip_address": name_server, @@ -388,18 +357,14 @@ def ip_name_servers(self) -> list | None: @cached_property def redundancy(self) -> dict | None: - """ - redundancy set based on redundancy data-model - """ + """Redundancy set based on redundancy data-model.""" if get(self._hostvars, "redundancy") is not None: return {"protocol": get(self._hostvars, "redundancy.protocol")} return None @cached_property def interface_defaults(self) -> dict | None: - """ - interface_defaults set based on default_interface_mtu - """ + """interface_defaults set based on default_interface_mtu.""" if self.shared_utils.default_interface_mtu is not None: return { "mtu": self.shared_utils.default_interface_mtu, @@ -408,10 +373,7 @@ def interface_defaults(self) -> dict | None: @cached_property def spanning_tree(self) -> dict | None: - """ - spanning_tree set based on spanning_tree_root_super, spanning_tree_mode - and spanning_tree_priority - """ + """spanning_tree set based on spanning_tree_root_super, spanning_tree_mode and spanning_tree_priority.""" if not self.shared_utils.network_services_l2: return {"mode": "none"} @@ -437,9 +399,7 @@ def spanning_tree(self) -> dict | None: @cached_property def service_unsupported_transceiver(self) -> dict | None: - """ - service_unsupported_transceiver based on unsupported_transceiver data-model - """ + """service_unsupported_transceiver based on unsupported_transceiver data-model.""" if (unsupported_transceiver := get(self._hostvars, "unsupported_transceiver")) is not None: return {"license_name": unsupported_transceiver.get("license_name"), "license_key": unsupported_transceiver.get("license_key")} @@ -447,9 +407,7 @@ def service_unsupported_transceiver(self) -> dict | None: @cached_property def local_users(self) -> list | None: - """ - local_users set based on local_users data model - """ + """local_users set based on local_users data model.""" if (local_users := get(self._hostvars, "local_users")) is None: return None @@ -457,18 +415,14 @@ def local_users(self) -> list | None: @cached_property def clock(self) -> dict | None: - """ - clock set based on timezone variable - """ + """Clock set based on timezone variable.""" if (timezone := get(self._hostvars, "timezone")) is not None: return {"timezone": timezone} return None @cached_property def vrfs(self) -> list: - """ - vrfs set based on mgmt_interface_vrf variable - """ + """Vrfs set based on mgmt_interface_vrf variable.""" mgmt_vrf_routing = get(self._hostvars, "mgmt_vrf_routing", default=False) vrf_settings = { "name": self.shared_utils.mgmt_interface_vrf, @@ -480,10 +434,7 @@ def vrfs(self) -> list: @cached_property def management_interfaces(self) -> list | None: - """ - management_interfaces set based on mgmt_interface, mgmt_ip, ipv6_mgmt_ip facts, - mgmt_gateway, ipv6_mgmt_gateway and mgmt_interface_vrf variables - """ + """management_interfaces set based on mgmt_interface, mgmt_ip, ipv6_mgmt_ip facts, mgmt_gateway, ipv6_mgmt_gateway and mgmt_interface_vrf variables.""" mgmt_interface = self.shared_utils.mgmt_interface if ( mgmt_interface is not None @@ -508,7 +459,7 @@ def management_interfaces(self) -> list | None: "ipv6_enable": True, "ipv6_address": self.shared_utils.ipv6_mgmt_ip, "ipv6_gateway": self.shared_utils.ipv6_mgmt_gateway, - } + }, ) return [interface_settings] @@ -517,9 +468,7 @@ def management_interfaces(self) -> list | None: @cached_property def management_security(self) -> dict | None: - """ - Return structured config for management_security. - """ + """Return structured config for management_security.""" if (entropy_sources := get(self.shared_utils.platform_settings, "security_entropy_sources")) is not None: return {"entropy_sources": entropy_sources} @@ -527,9 +476,7 @@ def management_security(self) -> dict | None: @cached_property def tcam_profile(self) -> dict | None: - """ - tcam_profile set based on platform_settings.tcam_profile fact - """ + """tcam_profile set based on platform_settings.tcam_profile fact.""" if (tcam_profile := get(self.shared_utils.platform_settings, "tcam_profile")) is not None: return {"system": tcam_profile} return None @@ -537,10 +484,11 @@ def tcam_profile(self) -> dict | None: @cached_property def platform(self) -> dict | None: """ - platform set based on: + platform set based on. + * platform_settings.lag_hardware_only, * platform_settings.trident_forwarding_table_partition and switch.evpn_multicast facts - * data_plane_cpu_allocation_max + * data_plane_cpu_allocation_max. """ platform = {} if (lag_hardware_only := get(self.shared_utils.platform_settings, "lag_hardware_only")) is not None: @@ -555,7 +503,8 @@ def platform(self) -> dict | None: elif self.shared_utils.is_wan_server: # For AutoVPN Route Reflectors and Pathfinders, running on CloudEOS, setting # this value is required for the solution to work. - raise AristaAvdMissingVariableError("For AutoVPN RRs and Pathfinders, 'data_plane_cpu_allocation_max' must be set") + msg = "For AutoVPN RRs and Pathfinders, 'data_plane_cpu_allocation_max' must be set" + raise AristaAvdMissingVariableError(msg) if platform: return platform @@ -563,19 +512,14 @@ def platform(self) -> dict | None: @cached_property def mac_address_table(self) -> dict | None: - """ - mac_address_table set based on mac_address_table data-model - """ + """mac_address_table set based on mac_address_table data-model.""" if (aging_time := get(self._hostvars, "mac_address_table.aging_time")) is not None: return {"aging_time": aging_time} return None @cached_property def queue_monitor_streaming(self) -> dict | None: - """ - queue_monitor_streaming set based on queue_monitor_streaming data-model - - """ + """queue_monitor_streaming set based on queue_monitor_streaming data-model.""" enable = get(self._hostvars, "queue_monitor_streaming.enable") vrf = get(self._hostvars, "queue_monitor_streaming.vrf") if enable is not True or vrf is None: @@ -593,15 +537,13 @@ def queue_monitor_streaming(self) -> dict | None: @cached_property def management_api_http(self) -> dict | None: - """ - management_api_http set based on management_eapi data-model - """ + """management_api_http set based on management_eapi data-model.""" if (management_eapi := get(self._hostvars, "management_eapi", default={"enable_https": True})) is None: return None management_api_http = {"enable_vrfs": [{"name": self.shared_utils.mgmt_interface_vrf}]} management_api = management_eapi.fromkeys(["enable_http", "enable_https", "default_services"]) - for key in dict(management_api).keys(): + for key in dict(management_api): if (value := management_eapi.get(key)) is not None: management_api[key] = value else: @@ -612,22 +554,19 @@ def management_api_http(self) -> dict | None: @cached_property def link_tracking_groups(self) -> list | None: - """ - link_tracking_groups - """ + """link_tracking_groups.""" return self.shared_utils.link_tracking_groups @cached_property def lacp(self) -> dict | None: - """ - lacp set based on lacp_port_id_range - """ + """Lacp set based on lacp_port_id_range.""" lacp_port_id_range = get(self.shared_utils.switch_data_combined, "lacp_port_id_range", default={}) if lacp_port_id_range.get("enabled") is not True: return None if (switch_id := self.shared_utils.id) is None: - raise AristaAvdMissingVariableError(f"'id' is not set on '{self.shared_utils.hostname}' to set LACP port ID ranges") + msg = f"'id' is not set on '{self.shared_utils.hostname}' to set LACP port ID ranges" + raise AristaAvdMissingVariableError(msg) node_group_length = max(len(self.shared_utils.switch_data_node_group_nodes), 1) port_range = int(get(lacp_port_id_range, "size", default=128)) @@ -641,21 +580,22 @@ def lacp(self) -> dict | None: "range": { "begin": begin, "end": end, - } - } + }, + }, } @cached_property def ptp(self) -> dict | None: """ Generates PTP config on node level as well as for interfaces, using various defaults. + - The following are set in default node_type_keys for design "l3ls-evpn": spine: default_ptp_priority1: 20 l3leaf: default_ptp_priority1: 30 PTP priority2 is set in the code below, calculated based on the node id: - default_priority2 = self.id % 256 + default_priority2 = self.id % 256. """ if not self.shared_utils.ptp_enabled: # Since we have overlapping data model "ptp" between eos_designs and eos_cli_config_gen, @@ -672,10 +612,15 @@ def ptp(self) -> dict | None: priority2 = get(self.shared_utils.switch_data_combined, "ptp.priority2") if priority2 is None: if self.shared_utils.id is None: - raise AristaAvdMissingVariableError(f"'id' must be set on '{self.shared_utils.hostname}' to set ptp priority2") + msg = f"'id' must be set on '{self.shared_utils.hostname}' to set ptp priority2" + raise AristaAvdMissingVariableError(msg) priority2 = self.shared_utils.id % 256 - default_auto_clock_identity = default(get(self._hostvars, "ptp_settings.auto_clock_identity"), get(self._hostvars, "ptp.auto_clock_identity"), True) + default_auto_clock_identity = default( + get(self._hostvars, "ptp_settings.auto_clock_identity"), + get(self._hostvars, "ptp.auto_clock_identity"), + True, # noqa: FBT003 + ) if get(self.shared_utils.switch_data_combined, "ptp.auto_clock_identity", default=default_auto_clock_identity) is True: clock_identity_prefix = get(self.shared_utils.switch_data_combined, "ptp.clock_identity_prefix", default="00:1C:73") default_clock_identity = f"{clock_identity_prefix}:{priority1:02x}:00:{priority2:02x}" @@ -724,14 +669,11 @@ def ptp(self) -> dict | None: }, }, } - ptp = strip_null_from_data(ptp, (None, {})) - return ptp + return strip_null_from_data(ptp, (None, {})) @cached_property def eos_cli(self) -> str | None: - """ - Aggregate the values of raw_eos_cli and platform_settings.platform_raw_eos_cli facts - """ + """Aggregate the values of raw_eos_cli and platform_settings.platform_raw_eos_cli facts.""" raw_eos_cli = get(self.shared_utils.switch_data_combined, "raw_eos_cli") platform_raw_eos_cli = get(self.shared_utils.platform_settings, "raw_eos_cli") if raw_eos_cli is not None or platform_raw_eos_cli is not None: @@ -740,9 +682,7 @@ def eos_cli(self) -> str | None: @cached_property def ip_radius_source_interfaces(self) -> list | None: - """ - Parse source_interfaces.radius and return list of source_interfaces. - """ + """Parse source_interfaces.radius and return list of source_interfaces.""" if (inputs := self._source_interfaces.get("radius")) is None: return None @@ -753,9 +693,7 @@ def ip_radius_source_interfaces(self) -> list | None: @cached_property def ip_tacacs_source_interfaces(self) -> list | None: - """ - Parse source_interfaces.tacacs and return list of source_interfaces. - """ + """Parse source_interfaces.tacacs and return list of source_interfaces.""" if (inputs := self._source_interfaces.get("tacacs")) is None: return None @@ -766,9 +704,7 @@ def ip_tacacs_source_interfaces(self) -> list | None: @cached_property def ip_ssh_client_source_interfaces(self) -> list | None: - """ - Parse source_interfaces.ssh_client and return list of source_interfaces. - """ + """Parse source_interfaces.ssh_client and return list of source_interfaces.""" if (inputs := self._source_interfaces.get("ssh_client")) is None: return None @@ -779,14 +715,14 @@ def ip_ssh_client_source_interfaces(self) -> list | None: @cached_property def ip_domain_lookup(self) -> dict | None: - """ - Parse source_interfaces.domain_lookup and return dict with nested source_interfaces list. - """ + """Parse source_interfaces.domain_lookup and return dict with nested source_interfaces list.""" if (inputs := self._source_interfaces.get("domain_lookup")) is None: return None if source_interfaces := self._build_source_interfaces( - inputs.get("mgmt_interface", False), inputs.get("inband_mgmt_interface", False), "IP Domain Lookup" + inputs.get("mgmt_interface", False), + inputs.get("inband_mgmt_interface", False), + "IP Domain Lookup", ): return {"source_interfaces": source_interfaces} @@ -794,14 +730,14 @@ def ip_domain_lookup(self) -> dict | None: @cached_property def ip_http_client_source_interfaces(self) -> list | None: - """ - Parse source_interfaces.http_client and return list of source_interfaces. - """ + """Parse source_interfaces.http_client and return list of source_interfaces.""" if (inputs := self._source_interfaces.get("http_client")) is None: return None if source_interfaces := self._build_source_interfaces( - inputs.get("mgmt_interface", False), inputs.get("inband_mgmt_interface", False), "IP HTTP Client" + inputs.get("mgmt_interface", False), + inputs.get("inband_mgmt_interface", False), + "IP HTTP Client", ): return source_interfaces diff --git a/python-avd/pyavd/_eos_designs/structured_config/base/ntp.py b/python-avd/pyavd/_eos_designs/structured_config/base/ntp.py index a4ec8516cf9..f181e244ad0 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/base/ntp.py +++ b/python-avd/pyavd/_eos_designs/structured_config/base/ntp.py @@ -6,8 +6,9 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._errors import AristaAvdError -from ...._utils import get, strip_null_from_data +from pyavd._errors import AristaAvdError +from pyavd._utils import get, strip_null_from_data + from .utils import UtilsMixin if TYPE_CHECKING: @@ -17,14 +18,13 @@ class NtpMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def ntp(self: AvdStructuredConfigBase) -> dict | None: - """ - ntp set based on "ntp_settings" data-model. - """ + """Ntp set based on "ntp_settings" data-model.""" ntp_settings = get(self._hostvars, "ntp_settings") if not ntp_settings: return None @@ -36,7 +36,7 @@ def ntp(self: AvdStructuredConfigBase) -> dict | None: "authenticate_servers_only": ntp_settings.get("authenticate_servers_only"), "authentication_keys": ntp_settings.get("authentication_keys"), "trusted_keys": ntp_settings.get("trusted_keys"), - } + }, ) if "servers" not in ntp_settings: @@ -56,7 +56,8 @@ def ntp(self: AvdStructuredConfigBase) -> dict | None: if server_vrf == "use_mgmt_interface_vrf": has_mgmt_ip = (self.shared_utils.mgmt_ip is not None) or (self.shared_utils.ipv6_mgmt_ip is not None) if not has_mgmt_ip: - raise AristaAvdError("'ntp_settings.server_vrf' is set to 'use_mgmt_interface_vrf' but this node is missing an 'mgmt_ip'") + msg = "'ntp_settings.server_vrf' is set to 'use_mgmt_interface_vrf' but this node is missing an 'mgmt_ip'" + raise AristaAvdError(msg) # Replacing server_vrf with mgmt_interface_vrf server_vrf = self.shared_utils.mgmt_interface_vrf ntp["local_interface"] = { @@ -65,7 +66,8 @@ def ntp(self: AvdStructuredConfigBase) -> dict | None: } elif server_vrf == "use_inband_mgmt_vrf": if self.shared_utils.inband_mgmt_interface is None: - raise AristaAvdError("'ntp_settings.server_vrf' is set to 'use_inband_mgmt_vrf' but this node is missing configuration for inband management") + msg = "'ntp_settings.server_vrf' is set to 'use_inband_mgmt_vrf' but this node is missing configuration for inband management" + raise AristaAvdError(msg) # self.shared_utils.inband_mgmt_vrf returns None for the default VRF. # Replacing server_vrf with inband_mgmt_vrf or "default" server_vrf = self.shared_utils.inband_mgmt_vrf or "default" diff --git a/python-avd/pyavd/_eos_designs/structured_config/base/snmp_server.py b/python-avd/pyavd/_eos_designs/structured_config/base/snmp_server.py index 2052da7d198..d6c4e6a2ed4 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/base/snmp_server.py +++ b/python-avd/pyavd/_eos_designs/structured_config/base/snmp_server.py @@ -7,9 +7,10 @@ from hashlib import sha1 from typing import TYPE_CHECKING -from ...._errors import AristaAvdError, AristaAvdMissingVariableError -from ...._utils import get, replace_or_append_item, strip_null_from_data -from ....j2filters import natural_sort, snmp_hash +from pyavd._errors import AristaAvdError, AristaAvdMissingVariableError +from pyavd._utils import get, replace_or_append_item, strip_null_from_data +from pyavd.j2filters import natural_sort, snmp_hash + from .utils import UtilsMixin if TYPE_CHECKING: @@ -19,7 +20,8 @@ class SnmpServerMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property @@ -60,13 +62,14 @@ def snmp_server(self: AvdStructuredConfigBase) -> dict | None: "views": snmp_settings.get("views"), "groups": snmp_settings.get("groups"), "traps": snmp_settings.get("traps"), - } + }, ) def _snmp_engine_ids(self: AvdStructuredConfigBase, snmp_settings: dict) -> dict | None: """ Return dict of engine ids if "snmp_settings.compute_local_engineid" is True. - Otherwise return None + + Otherwise return None. """ if snmp_settings.get("compute_local_engineid") is not True: return None @@ -74,22 +77,25 @@ def _snmp_engine_ids(self: AvdStructuredConfigBase, snmp_settings: dict) -> dict compute_source = get(snmp_settings, "compute_local_engineid_source", default="hostname_and_ip") if compute_source == "hostname_and_ip": # Accepting SonarLint issue: The weak sha1 is not used for encryption. Just to create a unique engine id. - local_engine_id = sha1(f"{self.shared_utils.hostname}{self.shared_utils.mgmt_ip}".encode("utf-8")).hexdigest() # NOSONAR + local_engine_id = sha1(f"{self.shared_utils.hostname}{self.shared_utils.mgmt_ip}".encode()).hexdigest() # NOSONAR # noqa: S324 elif compute_source == "system_mac": if self.shared_utils.system_mac_address is None: - raise AristaAvdMissingVariableError("default_engine_id_from_system_mac: true requires system_mac_address to be set!") + msg = "default_engine_id_from_system_mac: true requires system_mac_address to be set!" + raise AristaAvdMissingVariableError(msg) # the default engine id on switches is derived as per the following formula local_engine_id = f"f5717f{str(self.shared_utils.system_mac_address).replace(':', '').lower()}00" else: # Unknown mode - raise AristaAvdError(f"'{compute_source}' is not a valid value to compute the engine ID, accepted values are 'hostname_and_ip' and 'system_mac'") + msg = f"'{compute_source}' is not a valid value to compute the engine ID, accepted values are 'hostname_and_ip' and 'system_mac'" + raise AristaAvdError(msg) return {"local": local_engine_id} def _snmp_location(self: AvdStructuredConfigBase, snmp_settings: dict) -> str | None: """ Return location if "snmp_settings.location" is True. - Otherwise return None + + Otherwise return None. """ if snmp_settings.get("location") is not True: return None @@ -107,7 +113,8 @@ def _snmp_location(self: AvdStructuredConfigBase, snmp_settings: dict) -> str | def _snmp_users(self: AvdStructuredConfigBase, snmp_settings: dict, engine_ids: dict | None) -> list | None: """ Return users if "snmp_settings.users" is set. - Otherwise return None + + Otherwise return None. Users will have computed localized keys if configured. """ @@ -154,10 +161,11 @@ def _snmp_users(self: AvdStructuredConfigBase, snmp_settings: dict, engine_ids: return snmp_users or None - def _snmp_hosts(self: AvdStructuredConfigBase, snmp_settings) -> list | None: + def _snmp_hosts(self: AvdStructuredConfigBase, snmp_settings: dict) -> list | None: """ Return hosts if "snmp_settings.hosts" is set. - Otherwise return None + + Otherwise return None. Hosts may have management VRFs dynamically set. """ @@ -191,35 +199,37 @@ def _snmp_hosts(self: AvdStructuredConfigBase, snmp_settings) -> list | None: # Add host without VRF field snmp_hosts.append(host) - for vrf in natural_sort(vrfs): - # Add host with VRF field. - snmp_hosts.append({**host, "vrf": vrf}) + # Add host with VRF field. + snmp_hosts.extend({**host, "vrf": vrf} for vrf in natural_sort(vrfs)) return snmp_hosts or None def _snmp_local_interfaces(self: AvdStructuredConfigBase, source_interfaces_inputs: dict | None) -> list | None: """ Return local_interfaces if "source_interfaces.snmp" is set. - Otherwise return None - """ + Otherwise return None. + """ if not source_interfaces_inputs: # Empty dict or None return None local_interfaces = self._build_source_interfaces( - source_interfaces_inputs.get("mgmt_interface", False), source_interfaces_inputs.get("inband_mgmt_interface", False), "SNMP" + source_interfaces_inputs.get("mgmt_interface", False), + source_interfaces_inputs.get("inband_mgmt_interface", False), + "SNMP", ) return local_interfaces or None def _snmp_vrfs(self: AvdStructuredConfigBase, snmp_settings: dict | None) -> list | None: """ - Return list of dicts for enabling/disabling SNMP for VRFs + Return list of dicts for enabling/disabling SNMP for VRFs. + Requires one of the following options to be set under snmp_settings: - vrfs - enable_mgmt_interface_vrf - enable_inband_mgmt_vrf - Otherwise return None + Otherwise return None. """ if snmp_settings is None: return None diff --git a/python-avd/pyavd/_eos_designs/structured_config/base/utils.py b/python-avd/pyavd/_eos_designs/structured_config/base/utils.py index c1ac8fff4ec..41a625cc860 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/base/utils.py +++ b/python-avd/pyavd/_eos_designs/structured_config/base/utils.py @@ -6,8 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._errors import AristaAvdError, AristaAvdMissingVariableError -from ...._utils import get +from pyavd._errors import AristaAvdError, AristaAvdMissingVariableError +from pyavd._utils import get if TYPE_CHECKING: from . import AvdStructuredConfigBase @@ -16,6 +16,7 @@ class UtilsMixin: """ Mixin Class with internal functions. + Class should only be used as Mixin to a AvdStructuredConfig class or other Mixins. """ @@ -35,7 +36,8 @@ def _build_source_interfaces(self: AvdStructuredConfigBase, include_mgmt_interfa if include_mgmt_interface: if (self.shared_utils.mgmt_ip is None) and (self.shared_utils.ipv6_mgmt_ip is None): - raise AristaAvdMissingVariableError(f"Unable to configure {error_context} source-interface since 'mgmt_ip' or 'ipv6_mgmt_ip' are not set.") + msg = f"Unable to configure {error_context} source-interface since 'mgmt_ip' or 'ipv6_mgmt_ip' are not set." + raise AristaAvdMissingVariableError(msg) # mgmt_interface is always set (defaults to "Management1") so no need for error handling missing interface. source_interface = {"name": self.shared_utils.mgmt_interface} @@ -46,13 +48,15 @@ def _build_source_interfaces(self: AvdStructuredConfigBase, include_mgmt_interfa if include_inband_mgmt_interface: # Check for missing interface if self.shared_utils.inband_mgmt_interface is None: - raise AristaAvdMissingVariableError(f"Unable to configure {error_context} source-interface since 'inband_mgmt_interface' is not set.") + msg = f"Unable to configure {error_context} source-interface since 'inband_mgmt_interface' is not set." + raise AristaAvdMissingVariableError(msg) # Check for duplicate VRF # inband_mgmt_vrf returns None in case of VRF "default", but here we want the "default" VRF name to have proper duplicate detection. inband_mgmt_vrf = self.shared_utils.inband_mgmt_vrf or "default" if [source_interface for source_interface in source_interfaces if source_interface.get("vrf", "default") == inband_mgmt_vrf]: - raise AristaAvdError(f"Unable to configure multiple {error_context} source-interfaces for the same VRF '{inband_mgmt_vrf}'.") + msg = f"Unable to configure multiple {error_context} source-interfaces for the same VRF '{inband_mgmt_vrf}'." + raise AristaAvdError(msg) source_interface = {"name": self.shared_utils.inband_mgmt_interface} if self.shared_utils.inband_mgmt_vrf not in [None, "default"]: diff --git a/python-avd/pyavd/_eos_designs/structured_config/connected_endpoints/__init__.py b/python-avd/pyavd/_eos_designs/structured_config/connected_endpoints/__init__.py index a29e935318b..ddd32b18af1 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/connected_endpoints/__init__.py +++ b/python-avd/pyavd/_eos_designs/structured_config/connected_endpoints/__init__.py @@ -1,7 +1,8 @@ # Copyright (c) 2023-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -from ...avdfacts import AvdFacts +from pyavd._eos_designs.avdfacts import AvdFacts + from .ethernet_interfaces import EthernetInterfacesMixin from .monitor_sessions import MonitorSessionsMixin from .port_channel_interfaces import PortChannelInterfacesMixin @@ -27,9 +28,7 @@ class AvdStructuredConfigConnectedEndpoints( """ def render(self) -> dict: - """ - Wrap class render function with a check if connected_endpoints feature is enabled - """ + """Wrap class render function with a check if connected_endpoints feature is enabled.""" if self.shared_utils.connected_endpoints: return super().render() return {} diff --git a/python-avd/pyavd/_eos_designs/structured_config/connected_endpoints/ethernet_interfaces.py b/python-avd/pyavd/_eos_designs/structured_config/connected_endpoints/ethernet_interfaces.py index f2414192462..118a064891e 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/connected_endpoints/ethernet_interfaces.py +++ b/python-avd/pyavd/_eos_designs/structured_config/connected_endpoints/ethernet_interfaces.py @@ -8,10 +8,11 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._errors import AristaAvdError, AristaAvdMissingVariableError -from ...._utils import append_if_not_duplicate, default, get, replace_or_append_item, strip_null_from_data -from ....j2filters import range_expand -from ...interface_descriptions import InterfaceDescriptionData +from pyavd._eos_designs.interface_descriptions.models import InterfaceDescriptionData +from pyavd._errors import AristaAvdError, AristaAvdMissingVariableError +from pyavd._utils import append_if_not_duplicate, default, get, replace_or_append_item, strip_null_from_data +from pyavd.j2filters import range_expand + from .utils import UtilsMixin if TYPE_CHECKING: @@ -21,20 +22,20 @@ class EthernetInterfacesMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def ethernet_interfaces(self: AvdStructuredConfigConnectedEndpoints) -> list | None: """ - Return structured config for ethernet_interfaces + Return structured config for ethernet_interfaces. Duplicate checks following these rules: - Silently overwrite duplicate network_ports with other network_ports. - Silently overwrite duplicate network_ports with connected_endpoints. - Do NOT overwrite connected_endpoints with other connected_endpoints. Instead we raise a duplicate error. """ - ethernet_interfaces = [] # List of ethernet_interfaces used for duplicate checks. @@ -104,14 +105,12 @@ def _update_ethernet_interface_cfg(self: AvdStructuredConfigConnectedEndpoints, "sflow": self._get_adapter_sflow(adapter), "flow_tracker": self._get_adapter_flow_tracking(adapter), "link_tracking_groups": self._get_adapter_link_tracking_groups(adapter), - } + }, ) return ethernet_interface def _get_ethernet_interface_cfg(self: AvdStructuredConfigConnectedEndpoints, adapter: dict, node_index: int, connected_endpoint: dict) -> dict: - """ - Return structured_config for one ethernet_interface - """ + """Return structured_config for one ethernet_interface.""" peer = connected_endpoint["name"] endpoint_ports: list = default( adapter.get("endpoint_ports"), @@ -125,10 +124,13 @@ def _get_ethernet_interface_cfg(self: AvdStructuredConfigConnectedEndpoints, ada # check lengths of lists nodes_length = len(adapter["switches"]) if len(adapter["switch_ports"]) != nodes_length or ("descriptions" in adapter and len(adapter["descriptions"]) != nodes_length): - raise AristaAvdError( + msg = ( f"Length of lists 'switches', 'switch_ports', and 'descriptions' (if used) must match for adapter. Check configuration for {peer}, adapter" f" switch_ports {adapter['switch_ports']}." ) + raise AristaAvdError( + msg, + ) # if 'descriptions' is set, it is preferred if (interface_descriptions := adapter.get("descriptions")) is not None: @@ -150,7 +152,7 @@ def _get_ethernet_interface_cfg(self: AvdStructuredConfigConnectedEndpoints, ada peer=peer, peer_interface=peer_interface, description=interface_description, - ) + ), ), "speed": adapter.get("speed"), "shutdown": not adapter.get("enabled", True), @@ -168,7 +170,7 @@ def _get_ethernet_interface_cfg(self: AvdStructuredConfigConnectedEndpoints, ada "id": channel_group_id, "mode": port_channel_mode, }, - } + }, ) if get(adapter, "port_channel.lacp_fallback.mode") == "static": ethernet_interface["lacp_port_priority"] = 8192 if node_index == 0 else 32768 @@ -176,18 +178,24 @@ def _get_ethernet_interface_cfg(self: AvdStructuredConfigConnectedEndpoints, ada elif get(adapter, "port_channel.lacp_fallback.mode") == "individual": # if fallback is set to individual a profile has to be defined if (profile_name := get(adapter, "port_channel.lacp_fallback.individual.profile")) is None: - raise AristaAvdMissingVariableError( + msg = ( "A Port-channel which is set to lacp fallback mode 'individual' must have a 'profile' defined. Profile definition is missing for" f" the connected endpoint with the name '{connected_endpoint['name']}'." ) + raise AristaAvdMissingVariableError( + msg, + ) # Verify that the referred profile exists under port_profiles if not (profile := self.shared_utils.get_merged_port_profile(profile_name)): - raise AristaAvdMissingVariableError( + msg = ( "The 'profile' of every port-channel lacp fallback individual setting must be defined in the 'port_profiles'. First occurrence seen" f" of a missing profile is '{get(adapter, 'port_channel.lacp_fallback.individual.profile')}' for the connected endpoint with the" f" name '{connected_endpoint['name']}'." ) + raise AristaAvdMissingVariableError( + msg, + ) ethernet_interface = self._update_ethernet_interface_cfg(profile, ethernet_interface, connected_endpoint) @@ -201,7 +209,12 @@ def _get_ethernet_interface_cfg(self: AvdStructuredConfigConnectedEndpoints, ada else: ethernet_interface = self._update_ethernet_interface_cfg(adapter, ethernet_interface, connected_endpoint) ethernet_interface["evpn_ethernet_segment"] = self._get_adapter_evpn_ethernet_segment_cfg( - adapter, short_esi, node_index, connected_endpoint, "auto", "single-active" + adapter, + short_esi, + node_index, + connected_endpoint, + "auto", + "single-active", ) # More common ethernet_interface settings diff --git a/python-avd/pyavd/_eos_designs/structured_config/connected_endpoints/monitor_sessions.py b/python-avd/pyavd/_eos_designs/structured_config/connected_endpoints/monitor_sessions.py index 7a97bdb527a..34e0f6fdd62 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/connected_endpoints/monitor_sessions.py +++ b/python-avd/pyavd/_eos_designs/structured_config/connected_endpoints/monitor_sessions.py @@ -7,8 +7,9 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import append_if_not_duplicate, get, groupby, merge, strip_null_from_data -from ....j2filters import range_expand +from pyavd._utils import append_if_not_duplicate, get, groupby, merge, strip_null_from_data +from pyavd.j2filters import range_expand + from .utils import UtilsMixin if TYPE_CHECKING: @@ -18,14 +19,13 @@ class MonitorSessionsMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def monitor_sessions(self: AvdStructuredConfigConnectedEndpoints) -> list | None: - """ - Return structured_config for monitor_sessions - """ + """Return structured_config for monitor_sessions.""" if not self._monitor_session_configs: return None @@ -33,14 +33,14 @@ def monitor_sessions(self: AvdStructuredConfigConnectedEndpoints) -> list | None for session_name, session_configs in groupby(self._monitor_session_configs, "name"): # Convert iterator to list since we can only access it once. - session_configs = list(session_configs) - merged_settings = merge({}, session_configs, destructive_merge=False) + session_configs_list = list(session_configs) + merged_settings = merge({}, session_configs_list, destructive_merge=False) monitor_session = { "name": session_name, "sources": [], - "destinations": [session["interface"] for session in session_configs if session.get("role") == "destination"], + "destinations": [session["interface"] for session in session_configs_list if session.get("role") == "destination"], } - source_sessions = [session for session in session_configs if session.get("role") == "source"] + source_sessions = [session for session in session_configs_list if session.get("role") == "source"] for session in source_sessions: source = { "name": session["interface"], @@ -72,9 +72,7 @@ def monitor_sessions(self: AvdStructuredConfigConnectedEndpoints) -> list | None @cached_property def _monitor_session_configs(self: AvdStructuredConfigConnectedEndpoints) -> list: - """ - Return list of monitor session configs extracted from every interface - """ + """Return list of monitor session configs extracted from every interface.""" monitor_session_configs = [] for connected_endpoint in self._filtered_connected_endpoints: for adapter in connected_endpoint["adapters"]: diff --git a/python-avd/pyavd/_eos_designs/structured_config/connected_endpoints/port_channel_interfaces.py b/python-avd/pyavd/_eos_designs/structured_config/connected_endpoints/port_channel_interfaces.py index 8bf6833ef81..2646a491228 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/connected_endpoints/port_channel_interfaces.py +++ b/python-avd/pyavd/_eos_designs/structured_config/connected_endpoints/port_channel_interfaces.py @@ -8,9 +8,10 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import append_if_not_duplicate, get, short_esi_to_route_target, strip_null_from_data -from ....j2filters import range_expand -from ...interface_descriptions import InterfaceDescriptionData +from pyavd._eos_designs.interface_descriptions.models import InterfaceDescriptionData +from pyavd._utils import append_if_not_duplicate, get, short_esi_to_route_target, strip_null_from_data +from pyavd.j2filters import range_expand + from .utils import UtilsMixin if TYPE_CHECKING: @@ -20,13 +21,14 @@ class PortChannelInterfacesMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def port_channel_interfaces(self: AvdStructuredConfigConnectedEndpoints) -> list | None: """ - Return structured config for port_channel_interfaces + Return structured config for port_channel_interfaces. Duplicate checks following these rules: - Silently ignore duplicate port-channels if they contain _exactly_ the same configuration @@ -60,7 +62,10 @@ def port_channel_interfaces(self: AvdStructuredConfigConnectedEndpoints) -> list port_channel_subinterface_name = f"Port-Channel{channel_group_id}.{subinterface['number']}" port_channel_subinterface_config = self._get_port_channel_subinterface_cfg( - subinterface, adapter, port_channel_subinterface_name, channel_group_id + subinterface, + adapter, + port_channel_subinterface_name, + channel_group_id, ) append_if_not_duplicate( list_of_dicts=port_channel_interfaces, @@ -110,12 +115,13 @@ def port_channel_interfaces(self: AvdStructuredConfigConnectedEndpoints) -> list return None def _get_port_channel_interface_cfg( - self: AvdStructuredConfigConnectedEndpoints, adapter: dict, port_channel_interface_name: str, channel_group_id: int, connected_endpoint: dict + self: AvdStructuredConfigConnectedEndpoints, + adapter: dict, + port_channel_interface_name: str, + channel_group_id: int, + connected_endpoint: dict, ) -> dict: - """ - Return structured_config for one port_channel_interface - """ - + """Return structured_config for one port_channel_interface.""" peer = connected_endpoint["name"] adapter_description = get(adapter, "description") adapter_port_channel_description = get(adapter, "port_channel.description") @@ -133,7 +139,7 @@ def _get_port_channel_interface_cfg( peer=peer, description=adapter_description, port_channel_description=adapter_port_channel_description, - ) + ), ), "type": port_channel_type, "shutdown": not get(adapter, "port_channel.enabled", default=True), @@ -164,7 +170,7 @@ def _get_port_channel_interface_cfg( "spanning_tree_bpdufilter": adapter.get("spanning_tree_bpdufilter"), "spanning_tree_bpduguard": adapter.get("spanning_tree_bpduguard"), "storm_control": self._get_adapter_storm_control(adapter), - } + }, ) # EVPN A/A @@ -185,17 +191,19 @@ def _get_port_channel_interface_cfg( { "lacp_fallback_mode": lacp_fallback_mode, "lacp_fallback_timeout": get(adapter, "port_channel.lacp_fallback.timeout", default=90), - } + }, ) return strip_null_from_data(port_channel_interface, strip_values_tuple=(None, "")) def _get_port_channel_subinterface_cfg( - self: AvdStructuredConfigConnectedEndpoints, subinterface: dict, adapter: dict, port_channel_subinterface_name: str, channel_group_id: int + self: AvdStructuredConfigConnectedEndpoints, + subinterface: dict, + adapter: dict, + port_channel_subinterface_name: str, + channel_group_id: int, ) -> dict: - """ - Return structured_config for one port_channel_interface (subinterface) - """ + """Return structured_config for one port_channel_interface (subinterface).""" # Common port_channel_interface settings port_channel_interface = { "name": port_channel_subinterface_name, @@ -205,7 +213,7 @@ def _get_port_channel_subinterface_cfg( "client": { "dot1q": { "vlan": get(subinterface, "encapsulation_vlan.client_dot1q", default=subinterface["number"]), - } + }, }, "network": { "client": True, diff --git a/python-avd/pyavd/_eos_designs/structured_config/connected_endpoints/utils.py b/python-avd/pyavd/_eos_designs/structured_config/connected_endpoints/utils.py index d414d8a75d2..cfdf3806def 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/connected_endpoints/utils.py +++ b/python-avd/pyavd/_eos_designs/structured_config/connected_endpoints/utils.py @@ -8,9 +8,9 @@ from hashlib import sha256 from typing import TYPE_CHECKING -from ...._errors import AristaAvdError -from ...._utils import get, get_item, short_esi_to_route_target -from ....j2filters import convert_dicts +from pyavd._errors import AristaAvdError +from pyavd._utils import get, get_item, short_esi_to_route_target +from pyavd.j2filters import convert_dicts if TYPE_CHECKING: from . import AvdStructuredConfigConnectedEndpoints @@ -19,14 +19,14 @@ class UtilsMixin: """ Mixin Class with internal functions. + Class should only be used as Mixin to a AvdStructuredConfig class or other Mixins. """ @cached_property def _filtered_connected_endpoints(self: AvdStructuredConfigConnectedEndpoints) -> list: """ - Return list of endpoints defined under one of the keys in "connected_endpoints_keys" - which are connected to this switch. + Return list of endpoints defined under one of the keys in "connected_endpoints_keys" which are connected to this switch. Adapters are filtered to contain only the ones connected to this switch. """ @@ -48,11 +48,12 @@ def _filtered_connected_endpoints(self: AvdStructuredConfigConnectedEndpoints) - nodes_length = len(adapter_settings["switches"]) endpoint_ports = adapter_settings.get("endpoint_ports") if len(adapter_settings["switch_ports"]) != nodes_length or (endpoint_ports is not None and len(endpoint_ports) != nodes_length): - raise AristaAvdError( + msg = ( f"Length of lists 'switches', 'switch_ports', 'endpoint_ports' (if used) did not match on adapter {adapter_index} on" f" connected_endpoint '{connected_endpoint['name']}' under '{connected_endpoints_key['key']}'." " Notice that some or all of these variables could be inherited from 'port_profiles'" ) + raise AristaAvdError(msg) filtered_adapters.append(adapter_settings) @@ -62,17 +63,14 @@ def _filtered_connected_endpoints(self: AvdStructuredConfigConnectedEndpoints) - **connected_endpoint, "adapters": filtered_adapters, "type": connected_endpoints_key["type"], - } + }, ) return filtered_connected_endpoints @cached_property def _filtered_network_ports(self: AvdStructuredConfigConnectedEndpoints) -> list: - """ - Return list of endpoints defined under "network_ports" - which are connected to this switch. - """ + """Return list of endpoints defined under "network_ports" which are connected to this switch.""" filtered_network_ports = [] for network_port in get(self._hostvars, "network_ports", default=[]): network_port_settings = self.shared_utils.get_merged_adapter_settings(network_port) @@ -86,16 +84,19 @@ def _filtered_network_ports(self: AvdStructuredConfigConnectedEndpoints) -> list def _match_regexes(self: AvdStructuredConfigConnectedEndpoints, regexes: list, value: str) -> bool: """ Match a list of regexes with the supplied value. - Regex must match the full value to pass, so regex is wrapped in ^$ + + Regex must match the full value to pass, so regex is wrapped in ^$. """ return any(re.match(rf"^{regex}$", value) for regex in regexes) def _get_short_esi( - self: AvdStructuredConfigConnectedEndpoints, adapter: dict, channel_group_id: int, short_esi: str = None, hash_extra_value: str = "" + self: AvdStructuredConfigConnectedEndpoints, + adapter: dict, + channel_group_id: int, + short_esi: str | None = None, + hash_extra_value: str = "", ) -> str | None: - """ - Return short_esi for one adapter - """ + """Return short_esi for one adapter.""" if len(set(adapter["switches"])) < 2 or not self.shared_utils.overlay_evpn or not self.shared_utils.overlay_vtep: # Only configure ESI for multi-homing. return None @@ -118,23 +119,20 @@ def _get_short_esi( short_esi = re.sub(r"([0-9a-f]{4})", "\\1:", esi_hash)[:14] if len(short_esi.split(":")) != 3: - raise AristaAvdError(f"Invalid 'short_esi': '{short_esi}' on connected endpoints adapter. Must be in the format xxxx:xxxx:xxxx") + msg = f"Invalid 'short_esi': '{short_esi}' on connected endpoints adapter. Must be in the format xxxx:xxxx:xxxx" + raise AristaAvdError(msg) return short_esi def _get_adapter_trunk_groups(self: AvdStructuredConfigConnectedEndpoints, adapter: dict, connected_endpoint: dict) -> dict | None: - """ - Return trunk_groups for one adapter - """ + """Return trunk_groups for one adapter.""" if self.shared_utils.enable_trunk_groups and "trunk" in adapter.get("mode", ""): return get(adapter, "trunk_groups", required=True, org_key=f"'trunk_groups' for the connected_endpoint {connected_endpoint['name']}") return None def _get_adapter_storm_control(self: AvdStructuredConfigConnectedEndpoints, adapter: dict) -> dict | None: - """ - Return storm_control for one adapter - """ + """Return storm_control for one adapter.""" if self.shared_utils.platform_settings_feature_support_interface_storm_control: return get(adapter, "storm_control") @@ -146,12 +144,10 @@ def _get_adapter_evpn_ethernet_segment_cfg( short_esi: str, node_index: int, connected_endpoint: dict, - default_df_algo: str = None, - default_redundancy: str = None, + default_df_algo: str | None = None, + default_redundancy: str | None = None, ) -> dict | None: - """ - Return evpn_ethernet_segment_cfg for one adapter - """ + """Return evpn_ethernet_segment_cfg for one adapter.""" if short_esi is None: return None @@ -192,9 +188,7 @@ def _get_adapter_evpn_ethernet_segment_cfg( return evpn_ethernet_segment def _get_adapter_link_tracking_groups(self: AvdStructuredConfigConnectedEndpoints, adapter: dict) -> list | None: - """ - Return link_tracking_groups for one adapter - """ + """Return link_tracking_groups for one adapter.""" if self.shared_utils.link_tracking_groups is None or get(adapter, "link_tracking.enabled") is not True: return None @@ -202,13 +196,11 @@ def _get_adapter_link_tracking_groups(self: AvdStructuredConfigConnectedEndpoint { "name": get(adapter, "link_tracking.name", default=self.shared_utils.link_tracking_groups[0]["name"]), "direction": "downstream", - } + }, ] def _get_adapter_ptp(self: AvdStructuredConfigConnectedEndpoints, adapter: dict) -> dict | None: - """ - Return ptp for one adapter - """ + """Return ptp for one adapter.""" if get(adapter, "ptp.enabled") is not True: return None @@ -228,31 +220,31 @@ def _get_adapter_ptp(self: AvdStructuredConfigConnectedEndpoints, adapter: dict) return ptp_config def _get_adapter_poe(self: AvdStructuredConfigConnectedEndpoints, adapter: dict) -> dict | None: - """ - Return poe settings for one adapter - """ + """Return poe settings for one adapter.""" if self.shared_utils.platform_settings_feature_support_poe: return get(adapter, "poe") return None def _get_adapter_phone(self: AvdStructuredConfigConnectedEndpoints, adapter: dict, connected_endpoint: dict) -> dict | None: - """ - Return phone settings for one adapter - """ + """Return phone settings for one adapter.""" if (adapter_phone_vlan := get(adapter, "phone_vlan")) is None: return None # Verify that "mode" is set to "trunk phone" if get(adapter, "mode") != "trunk phone": - raise AristaAvdError(f"Setting 'phone_vlan' requires 'mode: trunk phone' to be set on connected endpoint '{connected_endpoint['name']}'.") + msg = f"Setting 'phone_vlan' requires 'mode: trunk phone' to be set on connected endpoint '{connected_endpoint['name']}'." + raise AristaAvdError(msg) # Verify that "vlans" is not set, since data vlan is picked up from 'native_vlan'. if get(adapter, "vlans") is not None: - raise AristaAvdError( + msg = ( "With 'phone_vlan' and 'mode: trunk phone' the data VLAN is set via 'native_vlan' instead of 'vlans'. Found 'vlans' on connected endpoint" f" '{connected_endpoint['name']}'." ) + raise AristaAvdError( + msg, + ) return { "vlan": adapter_phone_vlan, diff --git a/python-avd/pyavd/_eos_designs/structured_config/core_interfaces_and_l3_edge/__init__.py b/python-avd/pyavd/_eos_designs/structured_config/core_interfaces_and_l3_edge/__init__.py index 3c7da0f0276..ba4f82211bd 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/core_interfaces_and_l3_edge/__init__.py +++ b/python-avd/pyavd/_eos_designs/structured_config/core_interfaces_and_l3_edge/__init__.py @@ -1,7 +1,8 @@ # Copyright (c) 2023-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -from ...avdfacts import AvdFacts +from pyavd._eos_designs.avdfacts import AvdFacts + from .ethernet_interfaces import EthernetInterfacesMixin from .port_channel_interfaces import PortChannelInterfacesMixin from .router_bgp import RouterBgpMixin @@ -31,9 +32,7 @@ class AvdStructuredConfigCoreInterfacesAndL3Edge( """ def render(self) -> dict: - """ - Render structured configs for core_interfaces and l3_Edge - """ + """Render structured configs for core_interfaces and l3_Edge.""" result_list = [] for data_model in DATA_MODELS: diff --git a/python-avd/pyavd/_eos_designs/structured_config/core_interfaces_and_l3_edge/ethernet_interfaces.py b/python-avd/pyavd/_eos_designs/structured_config/core_interfaces_and_l3_edge/ethernet_interfaces.py index fe9c3b884be..3f65c2c663d 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/core_interfaces_and_l3_edge/ethernet_interfaces.py +++ b/python-avd/pyavd/_eos_designs/structured_config/core_interfaces_and_l3_edge/ethernet_interfaces.py @@ -6,7 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import append_if_not_duplicate +from pyavd._utils import append_if_not_duplicate + from .utils import UtilsMixin if TYPE_CHECKING: @@ -16,14 +17,13 @@ class EthernetInterfacesMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def ethernet_interfaces(self: AvdStructuredConfigCoreInterfacesAndL3Edge) -> list | None: - """ - Return structured config for ethernet_interfaces - """ + """Return structured config for ethernet_interfaces.""" ethernet_interfaces = [] for p2p_link in self._filtered_p2p_links: diff --git a/python-avd/pyavd/_eos_designs/structured_config/core_interfaces_and_l3_edge/port_channel_interfaces.py b/python-avd/pyavd/_eos_designs/structured_config/core_interfaces_and_l3_edge/port_channel_interfaces.py index 3e2f36227ac..eaee2094b74 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/core_interfaces_and_l3_edge/port_channel_interfaces.py +++ b/python-avd/pyavd/_eos_designs/structured_config/core_interfaces_and_l3_edge/port_channel_interfaces.py @@ -15,14 +15,13 @@ class PortChannelInterfacesMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def port_channel_interfaces(self: AvdStructuredConfigCoreInterfacesAndL3Edge) -> list | None: - """ - Return structured config for port_channel_interfaces - """ + """Return structured config for port_channel_interfaces.""" port_channel_interfaces = [] for p2p_link in self._filtered_p2p_links: if p2p_link["data"]["port_channel_id"] is None: diff --git a/python-avd/pyavd/_eos_designs/structured_config/core_interfaces_and_l3_edge/router_bgp.py b/python-avd/pyavd/_eos_designs/structured_config/core_interfaces_and_l3_edge/router_bgp.py index 9b88c143b69..a246b8aee10 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/core_interfaces_and_l3_edge/router_bgp.py +++ b/python-avd/pyavd/_eos_designs/structured_config/core_interfaces_and_l3_edge/router_bgp.py @@ -6,7 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._errors import AristaAvdMissingVariableError +from pyavd._errors import AristaAvdMissingVariableError + from .utils import UtilsMixin if TYPE_CHECKING: @@ -16,15 +17,13 @@ class RouterBgpMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def router_bgp(self: AvdStructuredConfigCoreInterfacesAndL3Edge) -> dict | None: - """ - Return structured config for router_bgp - """ - + """Return structured config for router_bgp.""" if not self.shared_utils.underlay_bgp: return None @@ -35,7 +34,8 @@ def router_bgp(self: AvdStructuredConfigCoreInterfacesAndL3Edge) -> dict | None: continue if p2p_link["data"]["bgp_as"] is None or p2p_link["data"]["peer_bgp_as"] is None: - raise AristaAvdMissingVariableError(f"{self.data_model}.p2p_links.[].as or {self.data_model}.p2p_links_profiles.[].as") + msg = f"{self.data_model}.p2p_links.[].as or {self.data_model}.p2p_links_profiles.[].as" + raise AristaAvdMissingVariableError(msg) neighbor = { "remote_as": p2p_link["data"]["peer_bgp_as"], @@ -51,7 +51,8 @@ def router_bgp(self: AvdStructuredConfigCoreInterfacesAndL3Edge) -> dict | None: # Regular BGP Neighbors if p2p_link["data"]["ip"] is None or p2p_link["data"]["peer_ip"] is None: - raise AristaAvdMissingVariableError(f"{self.data_model}.p2p_links.[].ip, .subnet or .ip_pool") + msg = f"{self.data_model}.p2p_links.[].ip, .subnet or .ip_pool" + raise AristaAvdMissingVariableError(msg) neighbor["bfd"] = p2p_link.get("bfd") if p2p_link["data"]["bgp_as"] != self.shared_utils.bgp_as: diff --git a/python-avd/pyavd/_eos_designs/structured_config/core_interfaces_and_l3_edge/router_ospf.py b/python-avd/pyavd/_eos_designs/structured_config/core_interfaces_and_l3_edge/router_ospf.py index e8abeeb74b7..d2db164ca24 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/core_interfaces_and_l3_edge/router_ospf.py +++ b/python-avd/pyavd/_eos_designs/structured_config/core_interfaces_and_l3_edge/router_ospf.py @@ -15,15 +15,13 @@ class RouterOspfMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def router_ospf(self: AvdStructuredConfigCoreInterfacesAndL3Edge) -> dict | None: - """ - Return structured config for router_ospf - """ - + """Return structured config for router_ospf.""" if not self.shared_utils.underlay_ospf: return None @@ -36,8 +34,8 @@ def router_ospf(self: AvdStructuredConfigCoreInterfacesAndL3Edge) -> dict | None { "id": self.shared_utils.underlay_ospf_process_id, "no_passive_interfaces": no_passive_interfaces, - } - ] + }, + ], } return None diff --git a/python-avd/pyavd/_eos_designs/structured_config/core_interfaces_and_l3_edge/utils.py b/python-avd/pyavd/_eos_designs/structured_config/core_interfaces_and_l3_edge/utils.py index c856b7eaf55..cffd8db20e5 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/core_interfaces_and_l3_edge/utils.py +++ b/python-avd/pyavd/_eos_designs/structured_config/core_interfaces_and_l3_edge/utils.py @@ -10,9 +10,9 @@ from itertools import islice from typing import TYPE_CHECKING -from ...._errors import AristaAvdMissingVariableError -from ...._utils import default, get, get_item, merge -from ....j2filters import convert_dicts +from pyavd._errors import AristaAvdMissingVariableError +from pyavd._utils import default, get, get_item, merge +from pyavd.j2filters import convert_dicts if TYPE_CHECKING: from . import AvdStructuredConfigCoreInterfacesAndL3Edge @@ -21,7 +21,8 @@ class UtilsMixin: """ Mixin Class with internal functions. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property @@ -44,10 +45,10 @@ def _p2p_links_sflow(self: AvdStructuredConfigCoreInterfacesAndL3Edge) -> bool | def _filtered_p2p_links(self: AvdStructuredConfigCoreInterfacesAndL3Edge) -> list: """ Returns a filtered list of p2p_links, which only contains links with our hostname. + For each links any referenced profiles are applied and IP addresses are resolved from pools or subnets. """ - if not (p2p_links := self._p2p_links): return [] @@ -69,9 +70,7 @@ def _filtered_p2p_links(self: AvdStructuredConfigCoreInterfacesAndL3Edge) -> lis return p2p_links def _apply_p2p_links_profile(self: AvdStructuredConfigCoreInterfacesAndL3Edge, target_dict: dict) -> dict: - """ - Apply a profile to a p2p_link - """ + """Apply a profile to a p2p_link.""" if "profile" not in target_dict: # Nothing to do return target_dict @@ -101,7 +100,7 @@ def _resolve_p2p_ips(self: AvdStructuredConfigCoreInterfacesAndL3Edge, p2p_link: return p2p_link prefix_size = int(ip_pool.get("prefix_size", 31)) link_id = int(p2p_link["id"]) - subnet = list(islice(ip_network(ip_pool_subnet).subnets(new_prefix=prefix_size), link_id - 1, link_id))[0] + subnet = next(iter(islice(ip_network(ip_pool_subnet).subnets(new_prefix=prefix_size), link_id - 1, link_id))) # hosts() return an iterator of all hosts in subnet. # islice() return a generator with only the first two iterations of hosts. @@ -112,6 +111,7 @@ def _resolve_p2p_ips(self: AvdStructuredConfigCoreInterfacesAndL3Edge, p2p_link: def _get_p2p_data(self: AvdStructuredConfigCoreInterfacesAndL3Edge, p2p_link: dict) -> dict: """ Parses p2p_link data model and extracts information which is easier to parse. + Returns: { peer: @@ -132,10 +132,7 @@ def _get_p2p_data(self: AvdStructuredConfigCoreInterfacesAndL3Edge, p2p_link: di peer_index = (index + 1) % 2 peer = p2p_link["nodes"][peer_index] peer_facts = self.shared_utils.get_peer_facts(peer, required=False) - if peer_facts is None: - peer_type = "other" - else: - peer_type = peer_facts.get("type", "other") + peer_type = "other" if peer_facts is None else peer_facts.get("type", "other") # Set ip or fallback to list with None values ip = get(p2p_link, "ip", default=[None, None]) @@ -186,7 +183,7 @@ def _get_p2p_data(self: AvdStructuredConfigCoreInterfacesAndL3Edge, p2p_link: di } for index, interface in enumerate(member_interfaces) ], - } + }, ) return data @@ -198,19 +195,20 @@ def _get_p2p_data(self: AvdStructuredConfigCoreInterfacesAndL3Edge, p2p_link: di "peer_interface": p2p_link["interfaces"][peer_index], "port_channel_id": None, "port_channel_members": [], - } + }, ) return data - raise AristaAvdMissingVariableError(f"{self.data_model}.p2p_links must have either 'interfaces' or 'port_channel' with correct members set.") + msg = f"{self.data_model}.p2p_links must have either 'interfaces' or 'port_channel' with correct members set." + raise AristaAvdMissingVariableError(msg) def _get_common_interface_cfg(self: AvdStructuredConfigCoreInterfacesAndL3Edge, p2p_link: dict) -> dict: """ Return partial structured_config for one p2p_link. + Covers common config that is applicable to both port-channels and ethernet interfaces. This config will only be used on the main interface - so not port-channel members. """ - index = p2p_link["nodes"].index(self.shared_utils.hostname) peer = p2p_link["data"]["peer"] peer_interface = p2p_link["data"]["peer_interface"] @@ -244,7 +242,7 @@ def _get_common_interface_cfg(self: AvdStructuredConfigCoreInterfacesAndL3Edge, { "ospf_network_point_to_point": True, "ospf_area": self.shared_utils.underlay_ospf_area, - } + }, ) if self.shared_utils.underlay_isis: @@ -258,7 +256,7 @@ def _get_common_interface_cfg(self: AvdStructuredConfigCoreInterfacesAndL3Edge, "isis_circuit_type": default(p2p_link.get("isis_circuit_type"), self.shared_utils.isis_default_circuit_type), "isis_authentication_mode": p2p_link.get("isis_authentication_mode"), "isis_authentication_key": p2p_link.get("isis_authentication_key"), - } + }, ) if p2p_link.get("macsec_profile"): @@ -280,8 +278,8 @@ def _get_common_interface_cfg(self: AvdStructuredConfigCoreInterfacesAndL3Edge, "ldp": { "interface": True, "igp_sync": True, - } - } + }, + }, ) return interface_cfg @@ -289,10 +287,10 @@ def _get_common_interface_cfg(self: AvdStructuredConfigCoreInterfacesAndL3Edge, def _get_ethernet_cfg(self: AvdStructuredConfigCoreInterfacesAndL3Edge, p2p_link: dict) -> dict: """ Return partial structured_config for one p2p_link. + Covers config that is only applicable to ethernet interfaces. This config will only be used on both main interfaces and port-channel members. """ - ethernet_cfg = {"speed": p2p_link.get("speed")} if get(p2p_link, "ptp.enabled") is not True: @@ -313,6 +311,7 @@ def _get_ethernet_cfg(self: AvdStructuredConfigCoreInterfacesAndL3Edge, p2p_link def _get_port_channel_member_cfg(self: AvdStructuredConfigCoreInterfacesAndL3Edge, p2p_link: dict, member: dict) -> dict: """ Return partial structured_config for one p2p_link. + Covers config for ethernet interfaces that are port-channel members. TODO: Change description for members to be the physical peer interface instead of port-channel diff --git a/python-avd/pyavd/_eos_designs/structured_config/custom_structured_configuration/__init__.py b/python-avd/pyavd/_eos_designs/structured_config/custom_structured_configuration/__init__.py index d22c9a0219d..cf67bd4c83c 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/custom_structured_configuration/__init__.py +++ b/python-avd/pyavd/_eos_designs/structured_config/custom_structured_configuration/__init__.py @@ -5,8 +5,8 @@ from functools import cached_property -from ...._utils import get -from ...avdfacts import AvdFacts +from pyavd._eos_designs.avdfacts import AvdFacts +from pyavd._utils import get CUSTOM_STRUCTURED_CONFIGURATION_EXEMPT_KEYS = ["custom_structured_configuration_prefix", "custom_structured_configuration_list_merge"] @@ -22,9 +22,7 @@ class AvdStructuredConfigCustomStructuredConfiguration(AvdFacts): @cached_property def _custom_structured_configuration_prefix(self) -> list: - """ - Reads custom_structured_configuration_prefix from hostvars and converts to list if necessary - """ + """Reads custom_structured_configuration_prefix from hostvars and converts to list if necessary.""" custom_structured_configuration_prefix = get(self._hostvars, "custom_structured_configuration_prefix", default=["custom_structured_configuration_"]) if not isinstance(custom_structured_configuration_prefix, list): return [custom_structured_configuration_prefix] @@ -90,8 +88,8 @@ def _router_bgp_peer_groups(self) -> list: { "router_bgp": { "peer_groups": struct_cfgs, - } - } + }, + }, ] def _router_bgp_vrfs(self) -> list: @@ -105,8 +103,8 @@ def _router_bgp_vrfs(self) -> list: { "router_bgp": { "vrfs": struct_cfgs, - } - } + }, + }, ] def _router_bgp_vlans(self) -> list: @@ -120,8 +118,8 @@ def _router_bgp_vlans(self) -> list: { "router_bgp": { "vlans": struct_cfgs, - } - } + }, + }, ] def _custom_structured_configurations(self) -> list[dict]: @@ -132,7 +130,7 @@ def _custom_structured_configurations(self) -> list[dict]: { # Disable black to prevent whitespace before colon PEP8 E203 # fmt: off - str(key)[len(prefix):]: self._hostvars[key] + str(key)[len(prefix) :]: self._hostvars[key] # fmt: on for key in self._hostvars if str(key).startswith(prefix) and key not in CUSTOM_STRUCTURED_CONFIGURATION_EXEMPT_KEYS @@ -148,7 +146,6 @@ def render(self) -> list[dict]: get_structured_config will merge this list into a single dict. """ - struct_cfgs = self._struct_cfg() struct_cfgs.extend(self._struct_cfgs()) struct_cfgs.extend(self._ethernet_interfaces()) @@ -157,8 +154,6 @@ def render(self) -> list[dict]: struct_cfgs.extend(self._router_bgp_peer_groups()) struct_cfgs.extend(self._router_bgp_vrfs()) struct_cfgs.extend(self._router_bgp_vlans()) - # struct_cfgs = [struct_cfg for struct_cfg in struct_cfgs if struct_cfg is not None] struct_cfgs.extend(self._custom_structured_configurations()) - # raise Exception(struct_cfgs) return struct_cfgs diff --git a/python-avd/pyavd/_eos_designs/structured_config/flows/__init__.py b/python-avd/pyavd/_eos_designs/structured_config/flows/__init__.py index 007d1b2c108..3b2ad62e422 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/flows/__init__.py +++ b/python-avd/pyavd/_eos_designs/structured_config/flows/__init__.py @@ -5,14 +5,16 @@ from functools import cached_property -from ...._errors import AristaAvdMissingVariableError -from ...._utils import get, get_item, strip_null_from_data -from ....j2filters import natural_sort -from ...avdfacts import AvdFacts +from pyavd._eos_designs.avdfacts import AvdFacts +from pyavd._errors import AristaAvdMissingVariableError +from pyavd._utils import get, get_item, strip_null_from_data +from pyavd.j2filters import natural_sort class AvdStructuredConfigFlows(AvdFacts): """ + Structured config for sflow and flow_tracker. + This class must be rendered after all other eos_designs modules since it relies on detecting sflow from the interface structured config generated by the other modules. @@ -35,8 +37,7 @@ def sflow(self) -> dict | None: destinations = get(self._hostvars, "sflow_settings.destinations") if destinations is None: - # TODO: - # AVD5.0.0 raise an error if sflow is enabled on an interface but there are no destinations configured. + # TODO: AVD5.0.0 raise an error if sflow is enabled on an interface but there are no destinations configured. # This cannot be implemented today since it would be breaking for already released support for sflow on interfaces. return None @@ -60,8 +61,9 @@ def sflow(self) -> dict | None: elif vrf == "use_mgmt_interface_vrf": if (self.shared_utils.mgmt_ip is None) and (self.shared_utils.ipv6_mgmt_ip is None): + msg = "Unable to configure sFlow source-interface with 'use_mgmt_interface_vrf' since 'mgmt_ip' or 'ipv6_mgmt_ip' are not set." raise AristaAvdMissingVariableError( - "Unable to configure sFlow source-interface with 'use_mgmt_interface_vrf' since 'mgmt_ip' or 'ipv6_mgmt_ip' are not set." + msg, ) vrf = self.shared_utils.mgmt_interface_vrf @@ -70,14 +72,17 @@ def sflow(self) -> dict | None: elif vrf == "use_inband_mgmt_vrf": # Check for missing interface if self.shared_utils.inband_mgmt_interface is None: + msg = "Unable to configure sFlow source-interface with 'use_inband_mgmt_vrf' since 'inband_mgmt_interface' is not set." raise AristaAvdMissingVariableError( - "Unable to configure sFlow source-interface with 'use_inband_mgmt_vrf' since 'inband_mgmt_interface' is not set." + msg, ) # self.shared_utils.inband_mgmt_vrf returns None for the default VRF, but here we need "default" to avoid duplicates. vrf = self.shared_utils.inband_mgmt_vrf or "default" source_interface = get( - get_item(sflow_settings_vrfs, "name", vrf, default={}), "source_interface", default=self.shared_utils.inband_mgmt_interface + get_item(sflow_settings_vrfs, "name", vrf, default={}), + "source_interface", + default=self.shared_utils.inband_mgmt_interface, ) else: @@ -90,7 +95,7 @@ def sflow(self) -> dict | None: { "destination": destination.get("destination"), "port": destination.get("port"), - } + }, ) sflow["source_interface"] = source_interface @@ -100,7 +105,7 @@ def sflow(self) -> dict | None: { "destination": destination.get("destination"), "port": destination.get("port"), - } + }, ) sflow_vrfs[vrf]["source_interface"] = source_interface @@ -121,23 +126,20 @@ def _enable_sflow(self) -> bool: if get(interface, "sflow.enable") is True: return True - for interface in get(self._hostvars, "port_channel_interfaces", default=[]): - if get(interface, "sflow.enable") is True: - return True - - return False + return any(get(interface, "sflow.enable") is True for interface in get(self._hostvars, "port_channel_interfaces", default=[])) @cached_property def _default_flow_tracker(self) -> dict: """ - Following configuration will be rendered based on the inputs: + Following configuration will be rendered based on the inputs. + tracker FLOW-TRACKER record export on inactive timeout 70000 record export on interval 300000 exporter ayush_exporter collector 127.0.0.1 local interface Loopback0 - template interval 3600000 + template interval 3600000. Depending on the flow tracker type, some other default values like sample, no shutdown will be added in further method @@ -165,9 +167,7 @@ def resolve_flow_tracker_by_type(self, tracker_settings: dict) -> dict: @cached_property def flow_tracking(self) -> dict | None: - """ - Return structured config for flow_tracking - """ + """Return structured config for flow_tracking.""" configured_trackers = self._get_enabled_flow_trackers() if not configured_trackers: return None @@ -221,7 +221,7 @@ def _get_enabled_flow_trackers(self) -> bool: for interface_type in ["ethernet_interfaces", "port_channel_interfaces", "dps_interfaces"]: for interface in get(self._hostvars, interface_type, default=[]): if tracker := get(interface, "flow_tracker"): - for trackerType, trackerName in tracker.items(): - trackers[trackerType][trackerName] = True + for tracker_type, tracker_name in tracker.items(): + trackers[tracker_type][tracker_name] = True return trackers[self.shared_utils.flow_tracking_type] diff --git a/python-avd/pyavd/_eos_designs/structured_config/inband_management/__init__.py b/python-avd/pyavd/_eos_designs/structured_config/inband_management/__init__.py index 3d1c67abe82..91a45f49d88 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/inband_management/__init__.py +++ b/python-avd/pyavd/_eos_designs/structured_config/inband_management/__init__.py @@ -6,10 +6,10 @@ from functools import cached_property from ipaddress import ip_network -from ...._errors import AristaAvdMissingVariableError -from ...._utils import get, strip_empties_from_dict -from ....j2filters import natural_sort -from ...avdfacts import AvdFacts +from pyavd._eos_designs.avdfacts import AvdFacts +from pyavd._errors import AristaAvdMissingVariableError +from pyavd._utils import get, strip_empties_from_dict +from pyavd.j2filters import natural_sort class AvdStructuredConfigInbandManagement(AvdFacts): @@ -33,9 +33,7 @@ def vlans(self) -> list | None: @cached_property def vlan_interfaces(self) -> list | None: - """ - VLAN interfaces can be our own management interface and/or SVIs created on behalf of child switches using us as uplink_switch. - """ + """VLAN interfaces can be our own management interface and/or SVIs created on behalf of child switches using us as uplink_switch.""" if not self._inband_management_parent_vlans and not (self.shared_utils.configure_inband_mgmt or self.shared_utils.configure_inband_mgmt_ipv6): return None @@ -44,6 +42,7 @@ def vlan_interfaces(self) -> list | None: if self._inband_management_parent_vlans: return [self.get_parent_svi_cfg(vlan, subnet["ipv4"], subnet["ipv6"]) for vlan, subnet in self._inband_management_parent_vlans.items()] + return None @cached_property def _inband_mgmt_ipv6_parent(self) -> bool: @@ -72,8 +71,8 @@ def static_routes(self) -> list | None: "destination_address_prefix": "0.0.0.0/0", "gateway": self.shared_utils.inband_mgmt_gateway, "vrf": self.shared_utils.inband_mgmt_vrf, - } - ) + }, + ), ] @cached_property @@ -87,8 +86,8 @@ def ipv6_static_routes(self) -> list | None: "destination_address_prefix": "::/0", "gateway": self.shared_utils.inband_mgmt_ipv6_gateway, "vrf": self.shared_utils.inband_mgmt_vrf, - } - ) + }, + ), ] @cached_property @@ -107,7 +106,8 @@ def ip_virtual_router_mac_address(self) -> str | None: return None if self.shared_utils.virtual_router_mac_address is None: - raise AristaAvdMissingVariableError("'virtual_router_mac_address' must be set for inband management parent.") + msg = "'virtual_router_mac_address' must be set for inband management parent." + raise AristaAvdMissingVariableError(msg) return str(self.shared_utils.virtual_router_mac_address).lower() @cached_property @@ -154,7 +154,7 @@ def prefix_lists(self) -> list | None: { "name": "PL-L2LEAF-INBAND-MGMT", "sequence_numbers": sequence_numbers, - } + }, ] @cached_property @@ -185,7 +185,7 @@ def ipv6_prefix_lists(self) -> list | None: { "name": "IPv6-PL-L2LEAF-INBAND-MGMT", "sequence_numbers": sequence_numbers, - } + }, ] @cached_property @@ -256,7 +256,7 @@ def get_local_inband_mgmt_interface_cfg(self) -> dict: "ipv6_enable": None if not self.shared_utils.configure_inband_mgmt_ipv6 else True, "ipv6_address": self.shared_utils.inband_mgmt_ipv6_address, "type": "inband_mgmt", - } + }, ) def get_parent_svi_cfg(self, vlan: int, subnet: str | None, ipv6_subnet: str | None) -> dict: diff --git a/python-avd/pyavd/_eos_designs/structured_config/metadata/__init__.py b/python-avd/pyavd/_eos_designs/structured_config/metadata/__init__.py index 61685b0588e..f0c41aa03c0 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/metadata/__init__.py +++ b/python-avd/pyavd/_eos_designs/structured_config/metadata/__init__.py @@ -5,15 +5,17 @@ from functools import cached_property -from ...._utils import strip_empties_from_dict -from ...avdfacts import AvdFacts +from pyavd._eos_designs.avdfacts import AvdFacts +from pyavd._utils import strip_empties_from_dict + from .cv_pathfinder import CvPathfinderMixin from .cv_tags import CvTagsMixin class AvdStructuredConfigMetadata(AvdFacts, CvTagsMixin, CvPathfinderMixin): """ - This returns the metadata data structure as per the below example + This returns the metadata data structure as per the below example. + { "metadata": { "platform": "7050X3", @@ -46,7 +48,7 @@ class AvdStructuredConfigMetadata(AvdFacts, CvTagsMixin, CvPathfinderMixin): }, "cv_pathfinder": {} } - } + }. """ @cached_property diff --git a/python-avd/pyavd/_eos_designs/structured_config/metadata/cv_pathfinder.py b/python-avd/pyavd/_eos_designs/structured_config/metadata/cv_pathfinder.py index f21957332d6..3d8f64bdbff 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/metadata/cv_pathfinder.py +++ b/python-avd/pyavd/_eos_designs/structured_config/metadata/cv_pathfinder.py @@ -6,8 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._errors import AristaAvdError -from ...._utils import get, get_all, get_item, strip_empties_from_list +from pyavd._errors import AristaAvdError +from pyavd._utils import get, get_all, get_item, strip_empties_from_list if TYPE_CHECKING: from . import AvdStructuredConfigMetadata @@ -16,13 +16,15 @@ class CvPathfinderMixin: """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ def _cv_pathfinder(self: AvdStructuredConfigMetadata) -> dict | None: """ Generate metadata for CV Pathfinder feature. - Only relevant for cv_pathfinder routers + + Only relevant for cv_pathfinder routers. Metadata for "applications" and "internet_exit_policies" is generated in the network services module, since all the required data was readily available in there. @@ -94,7 +96,10 @@ def _metadata_pathgroups(self: AvdStructuredConfigMetadata) -> list: def _metadata_regions(self: AvdStructuredConfigMetadata) -> list: regions = get( - self._hostvars, "cv_pathfinder_regions", required=True, org_key="'cv_pathfinder_regions' key must be set when 'wan_mode' is 'cv-pathfinder'." + self._hostvars, + "cv_pathfinder_regions", + required=True, + org_key="'cv_pathfinder_regions' key must be set when 'wan_mode' is 'cv-pathfinder'.", ) return [ { @@ -119,7 +124,7 @@ def _metadata_regions(self: AvdStructuredConfigMetadata) -> list: } for site in region["sites"] ], - } + }, ], } for region in regions @@ -134,10 +139,7 @@ def _metadata_pathfinder_vtep_ips(self: AvdStructuredConfigMetadata) -> list: ] def _metadata_vrfs(self: AvdStructuredConfigMetadata) -> list: - """ - Extracting metadata for VRFs by parsing the generated structured config - and flatten it a bit (like hiding load-balance policies) - """ + """Extracting metadata for VRFs by parsing the generated structured config and flatten it a bit (like hiding load-balance policies).""" if (avt_vrfs := get(self._hostvars, "router_adaptive_virtual_topology.vrfs")) is None: return [] @@ -154,12 +156,15 @@ def _metadata_vrfs(self: AvdStructuredConfigMetadata) -> list: for path_group in lb_policy["path_groups"] if path_group["name"] != self.shared_utils.wan_ha_path_group_name ): - raise AristaAvdError( + msg = ( "At least one path-group must be configured with preference '1' or 'preferred' for " f"load-balance policy {lb_policy['name']}' to use CloudVision integration. " "If this is an auto-generated policy, ensure that at least one default_preference " "for a non excluded path-group is set to 'preferred' (or unset as this is the default)." ) + raise AristaAvdError( + msg, + ) return strip_empties_from_list( [ @@ -196,22 +201,24 @@ def _metadata_vrfs(self: AvdStructuredConfigMetadata) -> list: } for vrf in avt_vrfs for avt_policy in [get_item(avt_policies, "name", vrf["policy"], required=True)] - ] + ], ) @cached_property def _wan_virtual_topologies_vrfs(self: AvdStructuredConfigMetadata) -> list[dict]: """ Unfiltered list of VRFs found under wan_virtual_topologies. + Used to find VNI for each VRF used in cv_pathfinder. """ return get(self._hostvars, "wan_virtual_topologies.vrfs", default=[]) - def _get_vni_for_vrf_name(self: AvdStructuredConfigMetadata, vrf_name: str): + def _get_vni_for_vrf_name(self: AvdStructuredConfigMetadata, vrf_name: str) -> int: if (vrf := get_item(self._wan_virtual_topologies_vrfs, "name", vrf_name)) is None or (wan_vni := vrf.get("wan_vni")) is None: if vrf_name == "default": return 1 - raise AristaAvdError(f"Unable to find the WAN VNI for VRF {vrf_name} during generation of cv_pathfinder metadata.") + msg = f"Unable to find the WAN VNI for VRF {vrf_name} during generation of cv_pathfinder metadata." + raise AristaAvdError(msg) return wan_vni diff --git a/python-avd/pyavd/_eos_designs/structured_config/metadata/cv_tags.py b/python-avd/pyavd/_eos_designs/structured_config/metadata/cv_tags.py index c440521a51d..d48b14937ed 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/metadata/cv_tags.py +++ b/python-avd/pyavd/_eos_designs/structured_config/metadata/cv_tags.py @@ -4,10 +4,10 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any -from ...._errors import AristaAvdError -from ...._utils import default, get, get_item, strip_empties_from_dict, strip_empties_from_list +from pyavd._errors import AristaAvdError +from pyavd._utils import default, get, get_item, strip_empties_from_dict, strip_empties_from_list if TYPE_CHECKING: from . import AvdStructuredConfigMetadata @@ -41,7 +41,8 @@ class CvTagsMixin: """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property @@ -49,9 +50,7 @@ def _generate_cv_tags(self: AvdStructuredConfigMetadata) -> dict: return get(self._hostvars, "generate_cv_tags", default={}) def _cv_tags(self: AvdStructuredConfigMetadata) -> dict | None: - """ - Generate the data structure `metadata.cv_tags`. - """ + """Generate the data structure `metadata.cv_tags`.""" if not self._generate_cv_tags and not self.shared_utils.is_cv_pathfinder_router: return None @@ -64,15 +63,13 @@ def _cv_tags(self: AvdStructuredConfigMetadata) -> dict | None: return strip_empties_from_dict(cv_tags) or None @staticmethod - def _tag_dict(name: str, value) -> dict | None: + def _tag_dict(name: str, value: Any) -> dict | None: if value is None: return None return {"name": name, "value": str(value)} def _get_topology_hints(self: AvdStructuredConfigMetadata) -> list: - """ - Return list of topology_hint tags. - """ + """Return list of topology_hint tags.""" if get(self._generate_cv_tags, "topology_hints") is not True: return [] @@ -84,19 +81,20 @@ def _get_topology_hints(self: AvdStructuredConfigMetadata) -> list: self._tag_dict("topology_hint_pod", self.shared_utils.pod_name), self._tag_dict("topology_hint_type", get(self._hostvars, "cv_tags_topology_type", default=default_type_hint)), self._tag_dict("topology_hint_rack", default(self.shared_utils.rack, self.shared_utils.group)), - ] + ], ) def _get_cv_pathfinder_device_tags(self: AvdStructuredConfigMetadata) -> list: """ - Return list of device_tags for cv_pathfinder solution + Return list of device_tags for cv_pathfinder solution. + Example: [ {"name": "Region", "value": }, {"name": "Zone", "value": <"-ZONE" for pathfinder clients>}, {"name": "Site", "value": }, {"name": "PathfinderSet", "value": }, {"name": "Role", "value": <'pathfinder', 'edge', 'transit region' or 'transit zone'>} - ] + ]. """ if not self.shared_utils.is_cv_pathfinder_router: return [] @@ -112,25 +110,26 @@ def _get_cv_pathfinder_device_tags(self: AvdStructuredConfigMetadata) -> list: [ self._tag_dict("Zone", self.shared_utils.wan_zone["name"]), self._tag_dict("Site", self.shared_utils.wan_site["name"]), - ] + ], ) return strip_empties_from_list(device_tags) def _get_device_tags(self: AvdStructuredConfigMetadata) -> list: - """ - Return list of device_tags - """ + """Return list of device_tags.""" if not (tags_to_generate := get(self._generate_cv_tags, "device_tags")): return [] device_tags = [] for generate_tag in tags_to_generate: if generate_tag["name"] in INVALID_CUSTOM_DEVICE_TAGS: - raise AristaAvdError( + msg = ( f"The CloudVision tag name 'generate_cv_tags.device_tags[name={generate_tag['name']}] is invalid. " "System Tags cannot be overridden. Try using a different name for this tag." ) + raise AristaAvdError( + msg, + ) # Get value from either 'value' key, structured config based on the 'data_path' key or raise. if get(generate_tag, "value") is not None: @@ -138,12 +137,16 @@ def _get_device_tags(self: AvdStructuredConfigMetadata) -> list: elif get(generate_tag, "data_path") is not None: value = get(self._hostvars, generate_tag["data_path"]) if type(value) in [list, dict]: - raise AristaAvdError( + msg = ( f"'generate_cv_tags.device_tags[name={generate_tag['name']}].data_path' ({generate_tag['data_path']}) " f"points to a variable of type {type(value).__name__}. This is not supported for cloudvision tag data_paths." ) + raise AristaAvdError( + msg, + ) else: - raise AristaAvdError(f"'generate_cv_tags.device_tags[name={generate_tag['name']}]' is missing either a static 'value' or a dynamic 'data_path'") + msg = f"'generate_cv_tags.device_tags[name={generate_tag['name']}]' is missing either a static 'value' or a dynamic 'data_path'" + raise AristaAvdError(msg) # Silently ignoring empty values since structured config may vary between devices. if value: @@ -152,9 +155,7 @@ def _get_device_tags(self: AvdStructuredConfigMetadata) -> list: return device_tags def _get_interface_tags(self: AvdStructuredConfigMetadata) -> list: - """ - Return list of interface_tags - """ + """Return list of interface_tags.""" if not (tags_to_generate := get(self._generate_cv_tags, "interface_tags", default=[])) and not self.shared_utils.is_cv_pathfinder_router: return [] @@ -168,13 +169,17 @@ def _get_interface_tags(self: AvdStructuredConfigMetadata) -> list: elif get(generate_tag, "data_path") is not None: value = get(ethernet_interface, generate_tag["data_path"]) if type(value) in [list, dict]: - raise AristaAvdError( + msg = ( f"'generate_cv_tags.interface_tags[name={generate_tag['name']}].data_path' ({generate_tag['data_path']}) " f"points to a variable of type {type(value).__name__}. This is not supported for cloudvision tag data_paths." ) + raise AristaAvdError( + msg, + ) else: + msg = f"'generate_cv_tags.interface_tags[name={generate_tag['name']}]' is missing either a static 'value' or a dynamic 'data_path'" raise AristaAvdError( - f"'generate_cv_tags.interface_tags[name={generate_tag['name']}]' is missing either a static 'value' or a dynamic 'data_path'" + msg, ) # Silently ignoring empty values since structured config may vary between devices. @@ -191,12 +196,13 @@ def _get_interface_tags(self: AvdStructuredConfigMetadata) -> list: def _get_cv_pathfinder_interface_tags(self: AvdStructuredConfigMetadata, ethernet_interface: dict) -> list: """ - Return list of device_tags for cv_pathfinder solution + Return list of device_tags for cv_pathfinder solution. + Example: [ {"name": "Type", <"lan" or "wan">}, {"name": "Carrier", }, {"name": "Circuit", } - ] + ]. """ if ethernet_interface["name"] in self._wan_interface_names: wan_interface = get_item(self.shared_utils.wan_interfaces, "name", ethernet_interface["name"], required=True) @@ -205,11 +211,11 @@ def _get_cv_pathfinder_interface_tags(self: AvdStructuredConfigMetadata, etherne self._tag_dict("Type", "wan"), self._tag_dict("Carrier", get(wan_interface, "wan_carrier")), self._tag_dict("Circuit", get(wan_interface, "wan_circuit_id")), - ] + ], ) return [self._tag_dict("Type", "lan")] @cached_property - def _wan_interface_names(self: AvdStructuredConfigMetadata): + def _wan_interface_names(self: AvdStructuredConfigMetadata) -> list: return [wan_interface["name"] for wan_interface in self.shared_utils.wan_interfaces] diff --git a/python-avd/pyavd/_eos_designs/structured_config/mlag/__init__.py b/python-avd/pyavd/_eos_designs/structured_config/mlag/__init__.py index 0f67b7bb831..1eda14b66f6 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/mlag/__init__.py +++ b/python-avd/pyavd/_eos_designs/structured_config/mlag/__init__.py @@ -5,27 +5,25 @@ from functools import cached_property -from ...._utils import default, get, strip_empties_from_dict -from ....j2filters import list_compress -from ...avdfacts import AvdFacts -from ...interface_descriptions import InterfaceDescriptionData +from pyavd._eos_designs.avdfacts import AvdFacts +from pyavd._eos_designs.interface_descriptions.models import InterfaceDescriptionData +from pyavd._utils import default, get, strip_empties_from_dict +from pyavd.j2filters import list_compress class AvdStructuredConfigMlag(AvdFacts): - def render(self): - """ - Wrap class render function with a check for mlag is True - """ + def render(self) -> dict: + """Wrap class render function with a check for mlag is True.""" if self.shared_utils.mlag is True: return super().render() return {} @cached_property - def _trunk_groups_mlag_name(self): + def _trunk_groups_mlag_name(self) -> str: return get(self.shared_utils.trunk_groups, "mlag.name", required=True) @cached_property - def _trunk_groups_mlag_l3_name(self): + def _trunk_groups_mlag_l3_name(self) -> str: return get(self.shared_utils.trunk_groups, "mlag_l3.name", required=True) @cached_property @@ -46,7 +44,7 @@ def vlans(self) -> list: "tenant": "system", "name": "LEAF_PEER_L3", "trunk_groups": [self._trunk_groups_mlag_l3_name], - } + }, ) vlans.append( @@ -55,19 +53,18 @@ def vlans(self) -> list: "tenant": "system", "name": "MLAG_PEER", "trunk_groups": [self._trunk_groups_mlag_name], - } + }, ) return vlans @cached_property def vlan_interfaces(self) -> list | None: """ - Return list with VLAN Interfaces used for MLAG + Return list with VLAN Interfaces used for MLAG. May return both the main MLAG VLAN as well as a dedicated L3 VLAN Can also combine L3 configuration on the main MLAG VLAN """ - # Create Main MLAG VLAN Interface main_vlan_interface_name = f"Vlan{self.shared_utils.mlag_peer_vlan}" main_vlan_interface = { @@ -95,7 +92,7 @@ def vlan_interfaces(self) -> list | None: { "ospf_network_point_to_point": True, "ospf_area": self.shared_utils.underlay_ospf_area, - } + }, ) elif self.shared_utils.underlay_routing_protocol == "isis": @@ -105,7 +102,7 @@ def vlan_interfaces(self) -> list | None: "isis_bfd": get(self._hostvars, "underlay_isis_bfd"), "isis_metric": 50, "isis_network_point_to_point": True, - } + }, ) if self.shared_utils.underlay_multicast: @@ -142,16 +139,13 @@ def vlan_interfaces(self) -> list | None: ] @cached_property - def port_channel_interfaces(self): - """ - Return dict with one Port Channel Interface used for MLAG Peer Link - """ - + def port_channel_interfaces(self) -> list: + """Return dict with one Port Channel Interface used for MLAG Peer Link.""" port_channel_interface_name = f"Port-Channel{self.shared_utils.mlag_port_channel_id}" port_channel_interface = { "name": port_channel_interface_name, "description": self.shared_utils.interface_descriptions.mlag_port_channel_interface( - InterfaceDescriptionData(shared_utils=self.shared_utils, interface=port_channel_interface_name) + InterfaceDescriptionData(shared_utils=self.shared_utils, interface=port_channel_interface_name), ), "type": "switched", "shutdown": False, @@ -189,11 +183,8 @@ def port_channel_interfaces(self): return [strip_empties_from_dict(port_channel_interface)] @cached_property - def ethernet_interfaces(self): - """ - Return dict with Ethernet Interfaces used for MLAG Peer Link - """ - + def ethernet_interfaces(self) -> list: + """Return dict with Ethernet Interfaces used for MLAG Peer Link.""" if not (mlag_interfaces := self.shared_utils.mlag_interfaces): return None @@ -205,7 +196,7 @@ def ethernet_interfaces(self): "peer_interface": mlag_interface, "peer_type": "mlag_peer", "description": self.shared_utils.interface_descriptions.mlag_ethernet_interface( - InterfaceDescriptionData(shared_utils=self.shared_utils, interface=mlag_interface, peer_interface=mlag_interface) + InterfaceDescriptionData(shared_utils=self.shared_utils, interface=mlag_interface, peer_interface=mlag_interface), ), "type": "port-channel-member", "shutdown": False, @@ -222,10 +213,8 @@ def ethernet_interfaces(self): return ethernet_interfaces @cached_property - def mlag_configuration(self): - """ - Return Structured Config for MLAG Configuration - """ + def mlag_configuration(self) -> dict: + """Return Structured Config for MLAG Configuration.""" mlag_configuration = { "domain_id": get(self.shared_utils.switch_data_combined, "mlag_domain_id", default=self.shared_utils.group), "local_interface": f"Vlan{self.shared_utils.mlag_peer_vlan}", @@ -246,20 +235,20 @@ def mlag_configuration(self): "vrf": self.shared_utils.mgmt_interface_vrf, }, "dual_primary_detection_delay": 5, - } + }, ) return strip_empties_from_dict(mlag_configuration) @cached_property - def route_maps(self): + def route_maps(self) -> list[dict] | None: """ - Return dict with one route-map - Origin Incomplete for MLAG iBGP learned routes + Return list of route-maps. + + Origin Incomplete for MLAG iBGP learned routes. TODO: Partially duplicated in network_services. Should be moved to a common class """ - if not (self.shared_utils.mlag_l3 is True and self.shared_utils.mlag_ibgp_origin_incomplete is True and self.shared_utils.underlay_bgp): return None @@ -272,20 +261,19 @@ def route_maps(self): "type": "permit", "set": ["origin incomplete"], "description": "Make routes learned over MLAG Peer-link less preferred on spines to ensure optimal routing", - } + }, ], - } + }, ] @cached_property - def router_bgp(self): + def router_bgp(self) -> dict | None: """ - Return structured config for router bgp + Return structured config for router bgp. Peer group and underlay MLAG iBGP peering is created only for BGP underlay. For other underlay protocols the MLAG peer-group may be created as part of the network services logic. """ - if not (self.shared_utils.mlag_l3 is True and self.shared_utils.underlay_bgp): return None @@ -304,7 +292,7 @@ def router_bgp(self): "peer": self.shared_utils.mlag_peer, "remote_as": self.shared_utils.bgp_as, "description": self.shared_utils.mlag_peer, - } + }, ] else: @@ -315,15 +303,14 @@ def router_bgp(self): "peer_group": peer_group_name, "peer": self.shared_utils.mlag_peer, "description": self.shared_utils.mlag_peer, - } + }, ] return strip_empties_from_dict(router_bgp) def _router_bgp_mlag_peer_group(self) -> dict: """ - Return a partial router_bgp structured_config covering the MLAG peer_group - and associated address_family activations + Return a partial router_bgp structured_config covering the MLAG peer_group and associated address_family activations. TODO: Duplicated in network_services. Should be moved to a common class """ @@ -352,8 +339,8 @@ def _router_bgp_mlag_peer_group(self) -> dict: { "name": peer_group_name, "activate": True, - } - ] + }, + ], } address_family_ipv4_peer_group = {"name": peer_group_name, "activate": True} diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/__init__.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/__init__.py index 8ae81886bc5..cf3e33e879f 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/__init__.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/__init__.py @@ -1,7 +1,8 @@ # Copyright (c) 2023-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -from ...avdfacts import AvdFacts +from pyavd._eos_designs.avdfacts import AvdFacts + from .application_traffic_recognition import ApplicationTrafficRecognitionMixin from .dps_interfaces import DpsInterfacesMixin from .eos_cli import EosCliMixin @@ -94,10 +95,11 @@ class AvdStructuredConfigNetworkServices( def render(self) -> dict: """ - Wrap class render function with a check if one of the following vars are True + Wrap class render function with a check if one of the following vars are True. + - node_type_keys.[].network_services_l2 - node_type_keys.[].network_services_l3 - - node_type_keys.[].network_services_l1 + - node_type_keys.[].network_services_l1. """ if self.shared_utils.any_network_services: return super().render() diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/application_traffic_recognition.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/application_traffic_recognition.py index b71fb29c449..4b96cd7ebea 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/application_traffic_recognition.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/application_traffic_recognition.py @@ -6,7 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import append_if_not_duplicate, get, get_item, strip_empties_from_dict +from pyavd._utils import append_if_not_duplicate, get, get_item, strip_empties_from_dict + from .utils import UtilsMixin if TYPE_CHECKING: @@ -16,14 +17,13 @@ class ApplicationTrafficRecognitionMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def application_traffic_recognition(self: AvdStructuredConfigNetworkServices) -> dict | None: - """ - Return structured config for application_traffic_recognition if wan router - """ + """Return structured config for application_traffic_recognition if wan router.""" if not self.shared_utils.is_wan_router: return None @@ -48,9 +48,10 @@ def _wan_cp_app_src_prefix(self: AvdStructuredConfigNetworkServices) -> str: def _generate_control_plane_application_profile(self: AvdStructuredConfigNetworkServices, app_dict: dict) -> None: """ - Generate an application profile using a single application matching: + Generate an application profile using a single application matching. + * the device Pathfinders vtep_ips as destination for non Pathfinders. - * the device Pathfinder vtep_ip as source + * the device Pathfinder vtep_ip as source. Create a structure as follow. If any object already exist, it is kept as defined by user and override the defaults. @@ -96,9 +97,9 @@ def _generate_control_plane_application_profile(self: AvdStructuredConfigNetwork "applications": [ { "name": self._wan_control_plane_application, - } + }, ], - } + }, ) # Adding the application ipv4_applications = get(app_dict, "applications.ipv4_applications", []) @@ -109,7 +110,7 @@ def _generate_control_plane_application_profile(self: AvdStructuredConfigNetwork { "name": self._wan_control_plane_application, "dest_prefix_set_name": self._wan_cp_app_dst_prefix, - } + }, ) # Adding the field-set based on the connected Pathfinder router-ids ipv4_prefixes_field_sets = get(app_dict, "field_sets.ipv4_prefixes", []) @@ -120,17 +121,17 @@ def _generate_control_plane_application_profile(self: AvdStructuredConfigNetwork { "name": self._wan_cp_app_dst_prefix, "prefix_values": pathfinder_vtep_ips, - } + }, ) elif self.shared_utils.is_wan_server: app_dict.setdefault("applications", {}).setdefault("ipv4_applications", []).append( { "name": self._wan_control_plane_application, "src_prefix_set_name": self._wan_cp_app_src_prefix, - } + }, ) app_dict.setdefault("field_sets", {}).setdefault("ipv4_prefixes", []).append( - {"name": self._wan_cp_app_src_prefix, "prefix_values": [f"{self.shared_utils.vtep_ip}/32"]} + {"name": self._wan_cp_app_src_prefix, "prefix_values": [f"{self.shared_utils.vtep_ip}/32"]}, ) def _filtered_application_classification(self: AvdStructuredConfigNetworkServices) -> dict: @@ -145,10 +146,11 @@ def _filtered_application_classification(self: AvdStructuredConfigNetworkService # Application profiles first application_profiles = [] - def _append_object_to_list_of_dicts(path: str, obj_name: str, list_of_dicts: list, message: str | None = None, required=True) -> None: + def _append_object_to_list_of_dicts(path: str, obj_name: str, list_of_dicts: list, message: str | None = None, *, required: bool = True) -> None: """ - Helper function - Technically impossible to get a duplicate, just reusing the method when the same application is used in multiple places + Helper function. + + Technically impossible to get a duplicate, just reusing the method when the same application is used in multiple places. """ if ( obj := get_item( diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/dps_interfaces.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/dps_interfaces.py index 5c7cf749e69..0c3d4fa129f 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/dps_interfaces.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/dps_interfaces.py @@ -15,13 +15,14 @@ class DpsInterfacesMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def dps_interfaces(self: AvdStructuredConfigNetworkServices) -> list | None: """ - Returns structured config for dps_interfaces + Returns structured config for dps_interfaces. Only used for WAN devices """ @@ -38,7 +39,7 @@ def dps_interfaces(self: AvdStructuredConfigNetworkServices) -> list | None: if self.shared_utils.vtep_loopback.lower().startswith("dps"): dps1["ip_address"] = f"{self.shared_utils.vtep_ip}/32" - # TODO do IPv6 when needed - for now no easy way in AVD to detect if this is needed + # TODO: do IPv6 when needed - for now no easy way in AVD to detect if this is needed # When needed - need a default value if different than IPv4 if (dps_flow := self.shared_utils.get_flow_tracker(None, "dps_interfaces")) is not None: diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/eos_cli.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/eos_cli.py index df0fe1a2d86..619bef247ec 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/eos_cli.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/eos_cli.py @@ -6,7 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import get +from pyavd._utils import get + from .utils import UtilsMixin if TYPE_CHECKING: @@ -16,15 +17,13 @@ class EosCliMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def eos_cli(self: AvdStructuredConfigNetworkServices) -> str | None: - """ - Return existing eos_cli plus any eos_cli from VRFs - """ - + """Return existing eos_cli plus any eos_cli from VRFs.""" if not self.shared_utils.network_services_l3: return None @@ -32,10 +31,7 @@ def eos_cli(self: AvdStructuredConfigNetworkServices) -> str | None: if (eos_cli := get(self._hostvars, "eos_cli")) is not None: eos_clis.append(eos_cli) - for tenant in self.shared_utils.filtered_tenants: - for vrf in tenant["vrfs"]: - if (eos_cli := vrf.get("raw_eos_cli")) is not None: - eos_clis.append(eos_cli) + eos_clis.extend(vrf["raw_eos_cli"] for tenant in self.shared_utils.filtered_tenants for vrf in tenant["vrfs"] if vrf.get("raw_eos_cli") is not None) if eos_clis: return "\n".join(eos_clis) diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/ethernet_interfaces.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/ethernet_interfaces.py index dfb6251ef4a..32a04eeb9f1 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/ethernet_interfaces.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/ethernet_interfaces.py @@ -7,9 +7,10 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._errors import AristaAvdError -from ...._utils import append_if_not_duplicate, get -from ....j2filters import natural_sort +from pyavd._errors import AristaAvdError +from pyavd._utils import append_if_not_duplicate, get +from pyavd.j2filters import natural_sort + from .utils import UtilsMixin if TYPE_CHECKING: @@ -19,17 +20,17 @@ class EthernetInterfacesMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def ethernet_interfaces(self: AvdStructuredConfigNetworkServices) -> list | None: """ - Return structured config for ethernet_interfaces + Return structured config for ethernet_interfaces. Only used with L3 or L1 network services """ - if not (self.shared_utils.network_services_l3 or self.shared_utils.network_services_l1 or self.shared_utils.l3_interfaces): return None @@ -48,10 +49,11 @@ def ethernet_interfaces(self: AvdStructuredConfigNetworkServices) -> list | None or len(l3_interface["ip_addresses"]) != nodes_length or ("descriptions" in l3_interface and "description" not in l3_interface and len(l3_interface["descriptions"]) != nodes_length) ): - raise AristaAvdError( + msg = ( "Length of lists 'interfaces', 'nodes', 'ip_addresses' and 'descriptions' (if used) must match for l3_interfaces for" f" {vrf['name']} in {tenant['name']}" ) + raise AristaAvdError(msg) for node_index, node_name in enumerate(l3_interface["nodes"]): if node_name != self.shared_utils.hostname: @@ -80,7 +82,7 @@ def ethernet_interfaces(self: AvdStructuredConfigNetworkServices) -> list | None { "access_group_in": get(self._l3_interface_acls, f"{interface_name}..ipv4_acl_in..name", separator=".."), "access_group_out": get(self._l3_interface_acls, f"{interface_name}..ipv4_acl_out..name", separator=".."), - } + }, ) if "." in interface_name: @@ -123,7 +125,7 @@ def ethernet_interfaces(self: AvdStructuredConfigNetworkServices) -> list | None "id": ospf_key["id"], "hash_algorithm": ospf_key.get("hash_algorithm", "sha512"), "key": ospf_key["key"], - } + }, ) if ospf_keys: @@ -132,16 +134,22 @@ def ethernet_interfaces(self: AvdStructuredConfigNetworkServices) -> list | None if get(l3_interface, "pim.enabled"): if not vrf.get("_evpn_l3_multicast_enabled"): - raise AristaAvdError( + msg = ( f"'pim: enabled' set on l3_interface '{interface_name}' on '{self.shared_utils.hostname}' requires evpn_l3_multicast:" f" enabled: true under VRF '{vrf['name']}' or Tenant '{tenant['name']}'" ) + raise AristaAvdError( + msg, + ) if not vrf.get("_pim_rp_addresses"): - raise AristaAvdError( + msg = ( f"'pim: enabled' set on l3_interface '{interface_name}' on '{self.shared_utils.hostname}' requires at least one RP" f" defined in pim_rp_addresses under VRF '{vrf['name']}' or Tenant '{tenant['name']}'" ) + raise AristaAvdError( + msg, + ) interface["pim"] = {"ipv4": {"sparse_mode": True}} @@ -252,27 +260,27 @@ def ethernet_interfaces(self: AvdStructuredConfigNetworkServices) -> list | None subif_parent_interface_names = subif_parent_interface_names.difference(eth_int["name"] for eth_int in ethernet_interfaces) if subif_parent_interface_names: - for interface_name in natural_sort(subif_parent_interface_names): - ethernet_interfaces.append( - { - "name": interface_name, - "type": "routed", - "peer_type": "l3_interface", - "shutdown": False, - } - ) - - for internet_exit_policy in self._filtered_internet_exit_policies: - for connection in internet_exit_policy.get("connections", []): - if connection["type"] == "ethernet": - ethernet_interfaces.append( - { - "name": connection["source_interface"], - "ip_nat": { - "service_profile": self.get_internet_exit_nat_profile_name(internet_exit_policy["type"]), - }, - } - ) + ethernet_interfaces.extend( + { + "name": interface_name, + "type": "routed", + "peer_type": "l3_interface", + "shutdown": False, + } + for interface_name in natural_sort(subif_parent_interface_names) + ) + + ethernet_interfaces.extend( + { + "name": connection["source_interface"], + "ip_nat": { + "service_profile": self.get_internet_exit_nat_profile_name(internet_exit_policy["type"]), + }, + } + for internet_exit_policy in self._filtered_internet_exit_policies + for connection in internet_exit_policy.get("connections", []) + if connection["type"] == "ethernet" + ) if ethernet_interfaces: return ethernet_interfaces diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/ip_access_lists.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/ip_access_lists.py index 2dfe0bac99c..65cdcf5ed31 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/ip_access_lists.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/ip_access_lists.py @@ -6,9 +6,10 @@ from functools import cached_property from typing import TYPE_CHECKING, Literal -from ...._errors import AristaAvdError -from ...._utils import append_if_not_duplicate, get -from ....j2filters import natural_sort +from pyavd._errors import AristaAvdError +from pyavd._utils import append_if_not_duplicate, get +from pyavd.j2filters import natural_sort + from .utils import UtilsMixin if TYPE_CHECKING: @@ -18,7 +19,8 @@ class IpAccesslistsMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property @@ -32,7 +34,7 @@ def _acl_internet_exit_zscaler(self: AvdStructuredConfigNetworkServices) -> dict "protocol": "ip", "source": "any", "destination": "any", - } + }, ], } @@ -56,7 +58,7 @@ def _acl_internet_exit_direct(self: AvdStructuredConfigNetworkServices) -> dict "protocol": "ip", "source": interface_ip.split("/", maxsplit=1)[0], "destination": "any", - } + }, ) entries.append( { @@ -65,13 +67,14 @@ def _acl_internet_exit_direct(self: AvdStructuredConfigNetworkServices) -> dict "protocol": "ip", "source": "any", "destination": "any", - } + }, ) return { "name": self.get_internet_exit_nat_acl_name("direct"), "entries": entries, } + return None def _acl_internet_exit_user_defined(self: AvdStructuredConfigNetworkServices, internet_exit_policy_type: Literal["zscaler", "direct"]) -> dict | None: acl_name = self.get_internet_exit_nat_acl_name(internet_exit_policy_type) @@ -87,7 +90,8 @@ def _acl_internet_exit_user_defined(self: AvdStructuredConfigNetworkServices, in # TODO: We still have one nat for all interfaces, need to also add logic to make nat per interface # if acl needs substitution - raise AristaAvdError(f"ipv4_acls[name={acl_name}] field substitution is not supported for internet exit access lists") + msg = f"ipv4_acls[name={acl_name}] field substitution is not supported for internet exit access lists" + raise AristaAvdError(msg) def _acl_internet_exit(self: AvdStructuredConfigNetworkServices, internet_exit_policy_type: Literal["zscaler", "direct"]) -> dict | None: acls = self._acl_internet_exit_user_defined(internet_exit_policy_type) @@ -102,9 +106,7 @@ def _acl_internet_exit(self: AvdStructuredConfigNetworkServices, internet_exit_p @cached_property def ip_access_lists(self: AvdStructuredConfigNetworkServices) -> list | None: - """ - Return structured config for ip_access_lists. - """ + """Return structured config for ip_access_lists.""" ip_access_lists = [] if self._svi_acls: for interface_acls in self._svi_acls.values(): diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/ip_igmp_snooping.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/ip_igmp_snooping.py index ce0b14f51ef..adf8b8eec0d 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/ip_igmp_snooping.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/ip_igmp_snooping.py @@ -6,7 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import append_if_not_duplicate, default, get +from pyavd._utils import append_if_not_duplicate, default, get + from .utils import UtilsMixin if TYPE_CHECKING: @@ -16,15 +17,13 @@ class IpIgmpSnoopingMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def ip_igmp_snooping(self: AvdStructuredConfigNetworkServices) -> dict | None: - """ - Return structured config for ip_igmp_snooping - """ - + """Return structured config for ip_igmp_snooping.""" if not self.shared_utils.network_services_l2: return None @@ -62,9 +61,9 @@ def ip_igmp_snooping(self: AvdStructuredConfigNetworkServices) -> dict | None: return ip_igmp_snooping - def _ip_igmp_snooping_vlan(self: AvdStructuredConfigNetworkServices, vlan, tenant) -> dict: + def _ip_igmp_snooping_vlan(self: AvdStructuredConfigNetworkServices, vlan: dict, tenant: dict) -> dict: """ - ip_igmp_snooping logic for one vlan + ip_igmp_snooping logic for one vlan. Can be used for both svis and l2vlans """ diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/ip_nat.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/ip_nat.py index f648390d681..8c682ef7595 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/ip_nat.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/ip_nat.py @@ -16,14 +16,13 @@ class IpNatMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def ip_nat(self: AvdStructuredConfigNetworkServices) -> dict | None: - """ - Returns structured config for ip_nat - """ + """Returns structured config for ip_nat.""" if not self.shared_utils.is_cv_pathfinder_client: return None diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/ip_security.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/ip_security.py index 613282a2d30..0a7644f52d5 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/ip_security.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/ip_security.py @@ -6,7 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import get, strip_null_from_data +from pyavd._utils import get, strip_null_from_data + from .utils import UtilsMixin if TYPE_CHECKING: @@ -16,15 +17,13 @@ class IpSecurityMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def ip_security(self: AvdStructuredConfigNetworkServices) -> dict | None: - """ - ip_security set based on cv_pathfinder_internet_exit_policies - """ - + """ip_security set based on cv_pathfinder_internet_exit_policies.""" if not self._filtered_internet_exit_policies: return None @@ -48,7 +47,7 @@ def ip_security(self: AvdStructuredConfigNetworkServices) -> dict | None: "ike_lifetime": 24, "encryption": "aes256", "dh_group": 24, - } + }, ) ip_security["sa_policies"].append( { @@ -59,7 +58,7 @@ def ip_security(self: AvdStructuredConfigNetworkServices) -> dict | None: "integrity": "sha256", "encryption": "aes256" if encrypt_traffic else "disabled", }, - } + }, ) ip_security["profiles"].append( { @@ -73,7 +72,7 @@ def ip_security(self: AvdStructuredConfigNetworkServices) -> dict | None: "action": "clear", }, "connection": "start", - } + }, ) return strip_null_from_data(ip_security) or None diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/ip_virtual_router_mac_address.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/ip_virtual_router_mac_address.py index 315efa3d595..93b653e9339 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/ip_virtual_router_mac_address.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/ip_virtual_router_mac_address.py @@ -15,14 +15,13 @@ class IpVirtualRouterMacAddressMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def ip_virtual_router_mac_address(self: AvdStructuredConfigNetworkServices) -> str | None: - """ - Return structured config for ip_virtual_router_mac_address - """ + """Return structured config for ip_virtual_router_mac_address.""" if self.shared_utils.network_services_l2 and self.shared_utils.network_services_l3 and self.shared_utils.virtual_router_mac_address is not None: return str(self.shared_utils.virtual_router_mac_address).lower() diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/ipv6_static_routes.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/ipv6_static_routes.py index 1fe4e3ced9a..bb27ab6d4fa 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/ipv6_static_routes.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/ipv6_static_routes.py @@ -15,19 +15,19 @@ class Ipv6StaticRoutesMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def ipv6_static_routes(self: AvdStructuredConfigNetworkServices) -> list[dict] | None: """ - Returns structured config for ipv6_static_routes + Returns structured config for ipv6_static_routes. Consist of - ipv6 static_routes defined under the vrfs - static routes added automatically for VARPv6 with prefixes """ - if not self.shared_utils.network_services_l3: return None diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/loopback_interfaces.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/loopback_interfaces.py index bb9b02d8272..ea749eb2733 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/loopback_interfaces.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/loopback_interfaces.py @@ -6,7 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import append_if_not_duplicate, get, get_item +from pyavd._utils import append_if_not_duplicate, get, get_item + from .utils import UtilsMixin if TYPE_CHECKING: @@ -16,18 +17,18 @@ class LoopbackInterfacesMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def loopback_interfaces(self: AvdStructuredConfigNetworkServices) -> list | None: """ - Return structured config for loopback_interfaces + Return structured config for loopback_interfaces. Used for Tenant vrf loopback interfaces This function is also called from virtual_source_nat_vrfs to avoid duplicate logic """ - if not self.shared_utils.network_services_l3: return None diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/metadata.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/metadata.py index 48c7cba033f..1b826d988d2 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/metadata.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/metadata.py @@ -6,7 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import get, get_all, strip_empties_from_list, strip_null_from_data +from pyavd._utils import get, get_all, strip_empties_from_list, strip_null_from_data + from .utils import UtilsMixin if TYPE_CHECKING: @@ -16,13 +17,14 @@ class MetadataMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def metadata(self: AvdStructuredConfigNetworkServices) -> dict | None: """ - Generate metadata.cv_pathfinder for CV Pathfinder routers + Generate metadata.cv_pathfinder for CV Pathfinder routers. Pathfinders will always have applications since we have the default control plane apps. Edge routers may have internet_exit_policies but not applications. @@ -34,7 +36,7 @@ def metadata(self: AvdStructuredConfigNetworkServices) -> dict | None: { "internet_exit_policies": self.get_cv_pathfinder_metadata_internet_exit_policies(), "applications": self.get_cv_pathfinder_metadata_applications(), - } + }, ) if not cv_pathfinder_metadata: return None @@ -42,9 +44,7 @@ def metadata(self: AvdStructuredConfigNetworkServices) -> dict | None: return {"cv_pathfinder": cv_pathfinder_metadata} def get_cv_pathfinder_metadata_internet_exit_policies(self: AvdStructuredConfigNetworkServices) -> dict | None: - """ - Generate metadata.cv_pathfinder.internet_exit_policies if available. - """ + """Generate metadata.cv_pathfinder.internet_exit_policies if available.""" if not self._filtered_internet_exit_policies: return None @@ -71,7 +71,7 @@ def get_cv_pathfinder_metadata_internet_exit_policies(self: AvdStructuredConfigN "fqdn": ufqdn, "vpn_type": "UFQDN", "pre_shared_key": ipsec_key, - } + }, ], "tunnels": [ { @@ -80,15 +80,13 @@ def get_cv_pathfinder_metadata_internet_exit_policies(self: AvdStructuredConfigN } for connection in internet_exit_policy["connections"] ], - } + }, ) return strip_empties_from_list(internet_exit_polices, (None, [], {})) def get_cv_pathfinder_metadata_applications(self: AvdStructuredConfigNetworkServices) -> dict | None: - """ - Generate metadata.cv_pathfinder.applications if available. - """ + """Generate metadata.cv_pathfinder.applications if available.""" if not self.shared_utils.is_cv_pathfinder_server or self.application_traffic_recognition is None: return None diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/monitor_connectivity.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/monitor_connectivity.py index dce1dce62c4..dbd446c9d98 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/monitor_connectivity.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/monitor_connectivity.py @@ -6,7 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import append_if_not_duplicate, strip_empties_from_dict +from pyavd._utils import append_if_not_duplicate, strip_empties_from_dict + from .utils import UtilsMixin if TYPE_CHECKING: @@ -16,13 +17,14 @@ class MonitorConnectivityMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def monitor_connectivity(self: AvdStructuredConfigNetworkServices) -> dict | None: """ - Return structured config for monitor_connectivity + Return structured config for monitor_connectivity. Only used for CV Pathfinder edge routers today """ @@ -35,17 +37,14 @@ def monitor_connectivity(self: AvdStructuredConfigNetworkServices) -> dict | Non for policy in self._filtered_internet_exit_policies: for connection in policy["connections"]: - if connection["type"] == "tunnel": - interface_name = f"Tunnel{connection['tunnel_id']}" - else: - interface_name = connection["source_interface"] + interface_name = f"Tunnel{connection['tunnel_id']}" if connection["type"] == "tunnel" else connection["source_interface"] interface_set_name = f"SET-{self.shared_utils.sanitize_interface_name(interface_name)}" interface_sets.append( { "name": interface_set_name, "interfaces": interface_name, - } + }, ) host = { diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/patch_panel.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/patch_panel.py index 2ca30e71727..6333e5d60b5 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/patch_panel.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/patch_panel.py @@ -7,8 +7,9 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import append_if_not_duplicate, get -from ....j2filters import natural_sort +from pyavd._utils import append_if_not_duplicate, get +from pyavd.j2filters import natural_sort + from .utils import UtilsMixin if TYPE_CHECKING: @@ -18,15 +19,13 @@ class PatchPanelMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def patch_panel(self: AvdStructuredConfigNetworkServices) -> dict | None: - """ - Return structured config for patch_panel - """ - + """Return structured config for patch_panel.""" if not self.shared_utils.network_services_l1: return None @@ -67,7 +66,7 @@ def patch_panel(self: AvdStructuredConfigNetworkServices) -> dict | None: "id": "2", "type": "pseudowire", "endpoint": f"bgp vpws {tenant['name']} pseudowire {point_to_point_service['name']}_{subif['number']}", - } + }, ) append_if_not_duplicate( list_of_dicts=patches, @@ -85,7 +84,7 @@ def patch_panel(self: AvdStructuredConfigNetworkServices) -> dict | None: "id": "1", "type": "interface", "endpoint": f"{interface}", - } + }, ], } if point_to_point_service.get("type") == "vpws-pseudowire": @@ -94,7 +93,7 @@ def patch_panel(self: AvdStructuredConfigNetworkServices) -> dict | None: "id": "2", "type": "pseudowire", "endpoint": f"bgp vpws {tenant['name']} pseudowire {point_to_point_service['name']}", - } + }, ) append_if_not_duplicate( list_of_dicts=patches, diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/port_channel_interfaces.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/port_channel_interfaces.py index 1a9d3878421..de4323c7a83 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/port_channel_interfaces.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/port_channel_interfaces.py @@ -7,8 +7,9 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import append_if_not_duplicate, get, short_esi_to_route_target -from ....j2filters import natural_sort +from pyavd._utils import append_if_not_duplicate, get, short_esi_to_route_target +from pyavd.j2filters import natural_sort + from .utils import UtilsMixin if TYPE_CHECKING: @@ -18,17 +19,17 @@ class PortChannelInterfacesMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def port_channel_interfaces(self: AvdStructuredConfigNetworkServices) -> list | None: """ - Return structured config for port_channel_interfaces + Return structured config for port_channel_interfaces. Only used with L1 network services """ - if not self.shared_utils.network_services_l1: return None @@ -62,18 +63,17 @@ def port_channel_interfaces(self: AvdStructuredConfigNetworkServices) -> list | "peer_type": "system", "shutdown": False, } - if (short_esi := get(endpoint, "port_channel.short_esi")) is not None: - if len(short_esi.split(":")) == 3: - parent_interface.update( - { - "evpn_ethernet_segment": { - "identifier": f"{self.shared_utils.evpn_short_esi_prefix}{short_esi}", - "route_target": short_esi_to_route_target(short_esi), - } - } - ) - if port_channel_mode == "active": - parent_interface["lacp_id"] = short_esi.replace(":", ".") + if (short_esi := get(endpoint, "port_channel.short_esi")) is not None and len(short_esi.split(":")) == 3: + parent_interface.update( + { + "evpn_ethernet_segment": { + "identifier": f"{self.shared_utils.evpn_short_esi_prefix}{short_esi}", + "route_target": short_esi_to_route_target(short_esi), + }, + }, + ) + if port_channel_mode == "active": + parent_interface["lacp_id"] = short_esi.replace(":", ".") subif_parent_interfaces.append(parent_interface) @@ -118,18 +118,17 @@ def port_channel_interfaces(self: AvdStructuredConfigNetworkServices) -> list | "receive": False, } - if (short_esi := get(endpoint, "port_channel.short_esi")) is not None: - if len(short_esi.split(":")) == 3: - interface.update( - { - "evpn_ethernet_segment": { - "identifier": f"{self.shared_utils.evpn_short_esi_prefix}{short_esi}", - "route_target": short_esi_to_route_target(short_esi), - } - } - ) - if port_channel_mode == "active": - interface["lacp_id"] = short_esi.replace(":", ".") + if (short_esi := get(endpoint, "port_channel.short_esi")) is not None and len(short_esi.split(":")) == 3: + interface.update( + { + "evpn_ethernet_segment": { + "identifier": f"{self.shared_utils.evpn_short_esi_prefix}{short_esi}", + "route_target": short_esi_to_route_target(short_esi), + }, + }, + ) + if port_channel_mode == "active": + interface["lacp_id"] = short_esi.replace(":", ".") append_if_not_duplicate( list_of_dicts=port_channel_interfaces, diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/prefix_lists.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/prefix_lists.py index df1463a221a..a4bf84e500a 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/prefix_lists.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/prefix_lists.py @@ -7,7 +7,8 @@ from ipaddress import IPv4Network from typing import TYPE_CHECKING -from ....j2filters import natural_sort +from pyavd.j2filters import natural_sort + from .utils import UtilsMixin if TYPE_CHECKING: @@ -17,17 +18,17 @@ class PrefixListsMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def prefix_lists(self: AvdStructuredConfigNetworkServices) -> list | None: """ - Return structured config for prefix_lists + Return structured config for prefix_lists. Covers EVPN services in VRF "default" and redistribution of connected to BGP """ - # Get prefix-lists from EVPN services in VRF "default" (if any) prefix_lists = self._prefix_lists_vrf_default() @@ -46,9 +47,7 @@ def prefix_lists(self: AvdStructuredConfigNetworkServices) -> list | None: return None def _prefix_lists_vrf_default(self: AvdStructuredConfigNetworkServices) -> list: - """ - prefix_lists for EVPN services in VRF "default" - """ + """prefix_lists for EVPN services in VRF "default".""" if not self._vrf_default_evpn: return [] @@ -75,9 +74,7 @@ def _prefix_lists_vrf_default(self: AvdStructuredConfigNetworkServices) -> list: @cached_property def _mlag_ibgp_peering_subnets_without_redistribution(self: AvdStructuredConfigNetworkServices) -> list: - """ - Return sorted list of MLAG peerings for VRFs where MLAG iBGP peering should not be redistributed - """ + """Return sorted list of MLAG peerings for VRFs where MLAG iBGP peering should not be redistributed.""" mlag_prefixes = set() for tenant in self.shared_utils.filtered_tenants: for vrf in tenant["vrfs"]: diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/route_maps.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/route_maps.py index e50894caf2d..008af9685c5 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/route_maps.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/route_maps.py @@ -6,7 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import append_if_not_duplicate, strip_empties_from_list +from pyavd._utils import append_if_not_duplicate, strip_empties_from_list + from .utils import UtilsMixin if TYPE_CHECKING: @@ -16,13 +17,14 @@ class RouteMapsMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def route_maps(self: AvdStructuredConfigNetworkServices) -> list | None: """ - Return structured config for route_maps + Return structured config for route_maps. Contains two parts. - Route-maps for tenant bgp peers set_ipv4_next_hop parameter @@ -45,10 +47,7 @@ def route_maps(self: AvdStructuredConfigNetworkServices) -> list | None: continue route_map_name = f"RM-{vrf['name']}-{bgp_peer['ip_address']}-SET-NEXT-HOP-OUT" - if ipv4_next_hop is not None: - set_action = f"ip next-hop {ipv4_next_hop}" - else: - set_action = f"ipv6 next-hop {ipv6_next_hop}" + set_action = f"ip next-hop {ipv4_next_hop}" if ipv4_next_hop is not None else f"ipv6 next-hop {ipv6_next_hop}" route_map = { "name": route_map_name, @@ -85,7 +84,7 @@ def route_maps(self: AvdStructuredConfigNetworkServices) -> list | None: @cached_property def _route_maps_vrf_default(self: AvdStructuredConfigNetworkServices) -> list | None: """ - Route-maps for EVPN services in VRF "default" + Route-maps for EVPN services in VRF "default". Called from main route_maps function @@ -103,15 +102,16 @@ def _route_maps_vrf_default(self: AvdStructuredConfigNetworkServices) -> list | self._bgp_underlay_peers_route_map(), self._redistribute_connected_to_bgp_route_map(), self._redistribute_static_to_bgp_route_map(), - ] + ], ) return route_maps or None def _bgp_mlag_peer_group_route_map(self: AvdStructuredConfigNetworkServices) -> dict: """ - Return dict with one route-map - Origin Incomplete for MLAG iBGP learned routes + Return dict with one route-map. + + Origin Incomplete for MLAG iBGP learned routes. TODO: Partially duplicated from mlag. Should be moved to a common class """ @@ -123,14 +123,15 @@ def _bgp_mlag_peer_group_route_map(self: AvdStructuredConfigNetworkServices) -> "type": "permit", "set": ["origin incomplete"], "description": "Make routes learned over MLAG Peer-link less preferred on spines to ensure optimal routing", - } + }, ], } def _connected_to_bgp_vrfs_route_map(self: AvdStructuredConfigNetworkServices) -> dict: """ - Return dict with one route-map - Filter MLAG peer subnets for redistribute connected for overlay VRFs + Return dict with one route-map. + + Filter MLAG peer subnets for redistribute connected for overlay VRFs. """ return { "name": "RM-CONN-2-BGP-VRFS", @@ -149,9 +150,10 @@ def _connected_to_bgp_vrfs_route_map(self: AvdStructuredConfigNetworkServices) - def _evpn_export_vrf_default_route_map(self: AvdStructuredConfigNetworkServices) -> dict | None: """ - Match the following prefixes to be exported in EVPN for VRF default: + Match the following prefixes to be exported in EVPN for VRF default. + * SVI subnets in VRF default - * Static routes subnets in VRF default + * Static routes subnets in VRF default. * for WAN routers, all the routes matching the SOO (which includes the two above) """ @@ -162,17 +164,17 @@ def _evpn_export_vrf_default_route_map(self: AvdStructuredConfigNetworkServices) "sequence": 10, "type": "permit", "match": ["extcommunity ECL-EVPN-SOO"], - } + }, ) else: - # TODO refactor existing behavior to SoO? + # TODO: refactor existing behavior to SoO? if self._vrf_default_ipv4_subnets: sequence_numbers.append( { "sequence": 10, "type": "permit", "match": ["ip address prefix-list PL-SVI-VRF-DEFAULT"], - } + }, ) if self._vrf_default_ipv4_static_routes["static_routes"]: @@ -181,7 +183,7 @@ def _evpn_export_vrf_default_route_map(self: AvdStructuredConfigNetworkServices) "sequence": 20, "type": "permit", "match": ["ip address prefix-list PL-STATIC-VRF-DEFAULT"], - } + }, ) if not sequence_numbers: @@ -207,7 +209,7 @@ def _bgp_underlay_peers_route_map(self: AvdStructuredConfigNetworkServices) -> d "sequence": 10, "type": "deny", "match": ["ip address prefix-list PL-SVI-VRF-DEFAULT"], - } + }, ) if self._vrf_default_ipv4_static_routes["static_routes"]: @@ -216,7 +218,7 @@ def _bgp_underlay_peers_route_map(self: AvdStructuredConfigNetworkServices) -> d "sequence": 15, "type": "deny", "match": ["ip address prefix-list PL-STATIC-VRF-DEFAULT"], - } + }, ) if not sequence_numbers: @@ -233,7 +235,7 @@ def _bgp_underlay_peers_route_map(self: AvdStructuredConfigNetworkServices) -> d def _redistribute_connected_to_bgp_route_map(self: AvdStructuredConfigNetworkServices) -> dict | None: """ - Append network services relevant entries to the route-map used to redistribute connected subnets in BGP + Append network services relevant entries to the route-map used to redistribute connected subnets in BGP. sequence 10 is set in underlay and sequence 20 in inband management, so avoid setting those here """ @@ -260,9 +262,7 @@ def _redistribute_connected_to_bgp_route_map(self: AvdStructuredConfigNetworkSer return {"name": "RM-CONN-2-BGP", "sequence_numbers": sequence_numbers} def _redistribute_static_to_bgp_route_map(self: AvdStructuredConfigNetworkServices) -> dict | None: - """ - Append network services relevant entries to the route-map used to redistribute static routes to BGP - """ + """Append network services relevant entries to the route-map used to redistribute static routes to BGP.""" if not (self.shared_utils.wan_role and self._vrf_default_ipv4_static_routes["redistribute_in_overlay"]): return None @@ -274,6 +274,6 @@ def _redistribute_static_to_bgp_route_map(self: AvdStructuredConfigNetworkServic "type": "permit", "match": ["ip address prefix-list PL-STATIC-VRF-DEFAULT"], "set": [f"extcommunity soo {self.shared_utils.evpn_soo} additive"], - } + }, ], } diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/router_adaptive_virtual_topology.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/router_adaptive_virtual_topology.py index a2ec4868658..ffed157b70e 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/router_adaptive_virtual_topology.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/router_adaptive_virtual_topology.py @@ -6,7 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import append_if_not_duplicate, get, get_item, strip_empties_from_dict +from pyavd._utils import append_if_not_duplicate, get, get_item, strip_empties_from_dict + from .utils import UtilsMixin if TYPE_CHECKING: @@ -16,14 +17,13 @@ class RouterAdaptiveVirtualTopologyMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def router_adaptive_virtual_topology(self: AvdStructuredConfigNetworkServices) -> dict | None: - """ - Return structured config for profiles, policies and VRFs for router adaptive-virtual-topology (AVT) - """ + """Return structured config for profiles, policies and VRFs for router adaptive-virtual-topology (AVT).""" if not self.shared_utils.is_cv_pathfinder_router: return None @@ -36,9 +36,7 @@ def router_adaptive_virtual_topology(self: AvdStructuredConfigNetworkServices) - return strip_empties_from_dict(router_adaptive_virtual_topology) def _cv_pathfinder_wan_vrfs(self: AvdStructuredConfigNetworkServices) -> list: - """ - Return a list of WAN VRFs based on filtered tenants and the AVT. - """ + """Return a list of WAN VRFs based on filtered tenants and the AVT.""" # For CV Pathfinder, it is required to go through all the AVT profiles in the policy to assign an ID. wan_vrfs = [] @@ -59,14 +57,14 @@ def _cv_pathfinder_wan_vrfs(self: AvdStructuredConfigNetworkServices) -> list: { "name": get(match, "avt_profile", required=True), "id": get(match, "id", required=True), - } + }, ) if (default_match := policy.get("default_match")) is not None: wan_vrf["profiles"].append( { "name": get(default_match, "avt_profile", required=True), "id": get(default_match, "id", required=True), - } + }, ) wan_vrfs.append(wan_vrf) @@ -75,8 +73,7 @@ def _cv_pathfinder_wan_vrfs(self: AvdStructuredConfigNetworkServices) -> list: def _cv_pathfinder_policies(self: AvdStructuredConfigNetworkServices) -> list: """ - Build and return the CV Pathfinder policies based on the computed - _filtered_wan_policies. + Build and return the CV Pathfinder policies based on the computed _filtered_wan_policies. It loops though the different match statements to build the appropriate entries by popping the load_balance_policy and id keys. @@ -105,9 +102,7 @@ def _cv_pathfinder_policies(self: AvdStructuredConfigNetworkServices) -> list: return policies def _cv_pathfinder_profiles(self: AvdStructuredConfigNetworkServices) -> list: - """ - Return a list of router adaptive-virtual-topology profiles for this router. - """ + """Return a list of router adaptive-virtual-topology profiles for this router.""" profiles = [] for policy in self._filtered_wan_policies: for match in policy.get("matches", []): @@ -116,7 +111,9 @@ def _cv_pathfinder_profiles(self: AvdStructuredConfigNetworkServices) -> list: "load_balance_policy": match["load_balance_policy"]["name"], } if (internet_exit_policy_name := match["internet_exit_policy_name"]) is not None and get_item( - self._filtered_internet_exit_policies, "name", internet_exit_policy_name + self._filtered_internet_exit_policies, + "name", + internet_exit_policy_name, ) is not None: profile["internet_exit_policy"] = internet_exit_policy_name @@ -133,7 +130,9 @@ def _cv_pathfinder_profiles(self: AvdStructuredConfigNetworkServices) -> list: "load_balance_policy": default_match["load_balance_policy"]["name"], } if (internet_exit_policy_name := default_match["internet_exit_policy_name"]) is not None and get_item( - self._filtered_internet_exit_policies, "name", internet_exit_policy_name + self._filtered_internet_exit_policies, + "name", + internet_exit_policy_name, ) is not None: profile["internet_exit_policy"] = internet_exit_policy_name diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/router_bgp.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/router_bgp.py index ef826600f83..e545deda6c3 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/router_bgp.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/router_bgp.py @@ -9,9 +9,10 @@ from re import fullmatch as re_fullmatch from typing import TYPE_CHECKING -from ...._errors import AristaAvdMissingVariableError -from ...._utils import append_if_not_duplicate, default, get, get_item, merge, strip_empties_from_dict -from ....j2filters import list_compress, natural_sort +from pyavd._errors import AristaAvdMissingVariableError +from pyavd._utils import append_if_not_duplicate, default, get, get_item, merge, strip_empties_from_dict +from pyavd.j2filters import list_compress, natural_sort + from .utils import UtilsMixin if TYPE_CHECKING: @@ -21,19 +22,19 @@ class RouterBgpMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def router_bgp(self: AvdStructuredConfigNetworkServices) -> dict | None: """ - Return the structured config for router_bgp + Return the structured config for router_bgp. Changing legacy behavior is to only render this on vtep or mpls_ler by instead skipping vlans/bundles if not vtep or mpls_ler TODO: Fix so this also works for L2LS with VRFs """ - if not self.shared_utils.bgp: return None @@ -52,18 +53,16 @@ def router_bgp(self: AvdStructuredConfigNetworkServices) -> dict | None: merge(router_bgp, self._router_bgp_mlag_peer_group()) # Strip None values from vlan before returning - router_bgp = {key: value for key, value in router_bgp.items() if value is not None} - return router_bgp + return {key: value for key, value in router_bgp.items() if value is not None} def _router_bgp_peer_groups(self: AvdStructuredConfigNetworkServices) -> dict: """ - Return the structured config for router_bgp.peer_groups + Return the structured config for router_bgp.peer_groups. Covers two areas: - bgp_peer_groups defined under the vrf including ipv4/ipv6 address_families. - adding route-map to the underlay peer-group in case of services in vrf default """ - if not self.shared_utils.network_services_l3: return {} @@ -75,13 +74,13 @@ def _router_bgp_peer_groups(self: AvdStructuredConfigNetworkServices) -> dict: if not (vrf["bgp_peers"] or vrf.get("bgp_peer_groups")): continue - vrf_peer_peergroups = set(peer["peer_group"] for peer in vrf["bgp_peers"] if "peer_group" in peer) + vrf_peer_peergroups = {peer["peer_group"] for peer in vrf["bgp_peers"] if "peer_group" in peer} peer_groups.extend( [ peer_group for peer_group in vrf.get("bgp_peer_groups", []) if (self.shared_utils.hostname in peer_group.get("nodes", []) or peer_group["name"] in vrf_peer_peergroups) - ] + ], ) peer_peergroups.update(vrf_peer_peergroups) @@ -90,13 +89,12 @@ def _router_bgp_peer_groups(self: AvdStructuredConfigNetworkServices) -> dict: peer_group for peer_group in tenant.get("bgp_peer_groups", []) if (self.shared_utils.hostname in peer_group.get("nodes", []) or peer_group["name"] in peer_peergroups) - ] + ], ) router_bgp = {"peer_groups": []} if peer_groups: for peer_group in peer_groups: - peer_group.pop("nodes", None) for af in ["address_family_ipv4", "address_family_ipv6"]: if not (af_peer_group := peer_group.pop(af, None)): @@ -124,7 +122,7 @@ def _router_bgp_peer_groups(self: AvdStructuredConfigNetworkServices) -> dict: "name": self.shared_utils.bgp_peer_groups["ipv4_underlay_peers"]["name"], "type": "ipv4", "route_map_out": "RM-BGP-UNDERLAY-PEERS-OUT", - } + }, ) if router_bgp["peer_groups"]: @@ -135,7 +133,7 @@ def _router_bgp_peer_groups(self: AvdStructuredConfigNetworkServices) -> dict: @cached_property def _router_bgp_vrfs(self: AvdStructuredConfigNetworkServices) -> list | None: """ - Return structured config for router_bgp.vrfs + Return structured config for router_bgp.vrfs. TODO: Optimize this to allow bgp VRF config without overlays (vtep or mpls) """ @@ -194,7 +192,7 @@ def _router_bgp_vrfs(self: AvdStructuredConfigNetworkServices) -> list | None: "activate": True, "prefix_list_in": bgp_peer.pop("prefix_list_in", None), "prefix_list_out": bgp_peer.pop("prefix_list_out", None), - } + }, ) append_if_not_duplicate( @@ -253,9 +251,7 @@ def _router_bgp_vrfs(self: AvdStructuredConfigNetworkServices) -> list | None: return vrfs or None def _update_router_bgp_vrf_evpn_or_mpls_cfg(self: AvdStructuredConfigNetworkServices, bgp_vrf: dict, vrf: dict, vrf_address_families: list) -> None: - """ - In-place update EVPN/MPLS part of structured config for *one* VRF under router_bgp.vrfs - """ + """In-place update EVPN/MPLS part of structured config for *one* VRF under router_bgp.vrfs.""" vrf_name = vrf["name"] bgp_vrf["rd"] = self.get_vrf_rd(vrf) @@ -300,9 +296,7 @@ def _update_router_bgp_vrf_evpn_or_mpls_cfg(self: AvdStructuredConfigNetworkServ bgp_vrf["evpn_multicast_address_family"] = {"ipv4": {"transit": evpn_multicast_transit_mode}} def _update_router_bgp_vrf_mlag_neighbor_cfg(self: AvdStructuredConfigNetworkServices, bgp_vrf: dict, vrf: dict, tenant: dict, vlan_id: int) -> None: - """ - In-place update MLAG neighbor part of structured config for *one* VRF under router_bgp.vrfs - """ + """In-place update MLAG neighbor part of structured config for *one* VRF under router_bgp.vrfs.""" if not self._mlag_ibgp_peering_redistribute(vrf, tenant): bgp_vrf["redistribute_routes"][0]["route_map"] = "RM-CONN-2-BGP-VRFS" @@ -314,7 +308,7 @@ def _update_router_bgp_vrf_mlag_neighbor_cfg(self: AvdStructuredConfigNetworkSer "peer_group": self.shared_utils.bgp_peer_groups["mlag_ipv4_underlay_peer"]["name"], "remote_as": self.shared_utils.bgp_as, "description": self.shared_utils.mlag_peer, - } + }, ) else: if (mlag_ibgp_peering_ipv4_pool := vrf.get("mlag_ibgp_peering_ipv4_pool")) is None: @@ -328,7 +322,7 @@ def _update_router_bgp_vrf_mlag_neighbor_cfg(self: AvdStructuredConfigNetworkSer { "ip_address": ip_address, "peer_group": self.shared_utils.bgp_peer_groups["mlag_ipv4_underlay_peer"]["name"], - } + }, ) if self.shared_utils.underlay_rfc5549: bgp_vrf.setdefault("address_family_ipv4", {}).setdefault("neighbors", []).append( @@ -337,9 +331,9 @@ def _update_router_bgp_vrf_mlag_neighbor_cfg(self: AvdStructuredConfigNetworkSer "next_hop": { "address_family_ipv6": { "enabled": False, - } + }, }, - } + }, ) def _router_bgp_sorted_vlans_and_svis_lists(self: AvdStructuredConfigNetworkServices) -> dict: @@ -354,12 +348,12 @@ def _router_bgp_sorted_vlans_and_svis_lists(self: AvdStructuredConfigNetworkServ bundle_groups = itertools_groupby(sorted_vlan_list, self._get_vlan_aware_bundle_name_tuple_for_l2vlans) for vlan_aware_bundle_name_tuple, l2vlans in bundle_groups: bundle_name, is_evpn_vlan_bundle = vlan_aware_bundle_name_tuple - l2vlans = list(l2vlans) + l2vlans_list = list(l2vlans) if is_evpn_vlan_bundle: - l2vlans_bundle_dict[bundle_name] = l2vlans + l2vlans_bundle_dict[bundle_name] = l2vlans_list else: - l2vlans_non_bundle_list[bundle_name] = l2vlans + l2vlans_non_bundle_list[bundle_name] = l2vlans_list # For SVIs vrf_svis_bundle_dict = {} @@ -371,12 +365,12 @@ def _router_bgp_sorted_vlans_and_svis_lists(self: AvdStructuredConfigNetworkServ bundle_groups_svis = itertools_groupby(sorted_svi_list, self._get_vlan_aware_bundle_name_tuple_for_svis) for vlan_aware_bundle_name_tuple, svis in bundle_groups_svis: bundle_name, is_evpn_vlan_bundle = vlan_aware_bundle_name_tuple - svis = list(svis) + svis_list = list(svis) if is_evpn_vlan_bundle: - vrf_svis_bundle_dict[vrf["name"]][bundle_name] = svis + vrf_svis_bundle_dict[vrf["name"]][bundle_name] = svis_list else: - vrf_svis_non_bundle_dict[vrf["name"]] = svis + vrf_svis_non_bundle_dict[vrf["name"]] = svis_list tenant_svis_l2vlans_dict[tenant["name"]]["svi_bundle"] = vrf_svis_bundle_dict tenant_svis_l2vlans_dict[tenant["name"]]["svi_non_bundle"] = vrf_svis_non_bundle_dict @@ -385,10 +379,8 @@ def _router_bgp_sorted_vlans_and_svis_lists(self: AvdStructuredConfigNetworkServ return tenant_svis_l2vlans_dict - def _router_bgp_vlans(self: AvdStructuredConfigNetworkServices, tenant_svis_l2vlans_dict) -> list | None: - """ - Return structured config for router_bgp.vlans - """ + def _router_bgp_vlans(self: AvdStructuredConfigNetworkServices, tenant_svis_l2vlans_dict: dict) -> list | None: + """Return structured config for router_bgp.vlans.""" if not ( self.shared_utils.network_services_l2 and "evpn" in self.shared_utils.overlay_address_families @@ -431,10 +423,8 @@ def _router_bgp_vlans(self: AvdStructuredConfigNetworkServices, tenant_svis_l2vl return vlans or None - def _router_bgp_vlans_vlan(self: AvdStructuredConfigNetworkServices, vlan, tenant, vrf) -> dict | None: - """ - Return structured config for one given vlan under router_bgp.vlans - """ + def _router_bgp_vlans_vlan(self: AvdStructuredConfigNetworkServices, vlan: dict, tenant: dict, vrf: dict) -> dict | None: + """Return structured config for one given vlan under router_bgp.vlans.""" if vlan.get("vxlan") is False: return None @@ -451,7 +441,7 @@ def _router_bgp_vlans_vlan(self: AvdStructuredConfigNetworkServices, vlan, tenan } if ( self.shared_utils.evpn_gateway_vxlan_l2 - and default(vlan.get("evpn_l2_multi_domain"), vrf.get("evpn_l2_multi_domain"), tenant.get("evpn_l2_multi_domain"), True) is True + and default(vlan.get("evpn_l2_multi_domain"), vrf.get("evpn_l2_multi_domain"), tenant.get("evpn_l2_multi_domain"), True) is True # noqa: FBT003 ): bgp_vlan["rd_evpn_domain"] = {"domain": "remote", "rd": vlan_rd} bgp_vlan["route_targets"]["import_export_evpn_domains"] = [{"domain": "remote", "route_target": vlan_rt}] @@ -461,8 +451,7 @@ def _router_bgp_vlans_vlan(self: AvdStructuredConfigNetworkServices, vlan, tenan bgp_vlan["redistribute_routes"].append("igmp") # Strip None values from vlan before returning - bgp_vlan = {key: value for key, value in bgp_vlan.items() if value is not None} - return bgp_vlan + return {key: value for key, value in bgp_vlan.items() if value is not None} @cached_property def _evpn_vlan_bundles(self) -> list: @@ -473,9 +462,7 @@ def _evpn_vlan_aware_bundles(self) -> bool: return get(self._hostvars, "evpn_vlan_aware_bundles", default=False) def _get_vlan_aware_bundle_name_tuple_for_l2vlans(self: AvdStructuredConfigNetworkServices, vlan: dict) -> tuple[str, bool] | None: - """ - Return a tuple with string with the vlan-aware-bundle name for one VLAN and a boolean saying if this is a evpn_vlan_bundle. - """ + """Return a tuple with string with the vlan-aware-bundle name for one VLAN and a boolean saying if this is a evpn_vlan_bundle.""" if vlan.get("evpn_vlan_bundle") is not None: return (str(vlan.get("evpn_vlan_bundle")), True) return (str(vlan.get("name")), False) @@ -483,6 +470,7 @@ def _get_vlan_aware_bundle_name_tuple_for_l2vlans(self: AvdStructuredConfigNetwo def _get_vlan_aware_bundle_name_tuple_for_svis(self: AvdStructuredConfigNetworkServices, vlan: dict) -> tuple[str, bool] | None: """ Return a tuple with string with the vlan-aware-bundle name for one VLAN and a boolean saying if this is a evpn_vlan_bundle. + If no bundle is configured, it will return an empty string as name, since the calling function will then get all svis without bundle grouped under "". """ @@ -491,20 +479,19 @@ def _get_vlan_aware_bundle_name_tuple_for_svis(self: AvdStructuredConfigNetworkS return ("", False) def _get_evpn_vlan_bundle(self: AvdStructuredConfigNetworkServices, vlan: dict, bundle_name: str) -> dict: - """ - Return an evpn_vlan_bundle dict if it exists, else raise an exception. - """ + """Return an evpn_vlan_bundle dict if it exists, else raise an exception.""" if (evpn_vlan_bundle := get_item(self._evpn_vlan_bundles, "name", bundle_name)) is None: - raise AristaAvdMissingVariableError( + msg = ( "The 'evpn_vlan_bundle' of the svis/l2vlans must be defined in the common 'evpn_vlan_bundles' setting. First occurrence seen for svi/l2vlan" f" {vlan['id']} in Tenant '{vlan['tenant']}' and evpn_vlan_bundle '{vlan['evpn_vlan_bundle']}'." ) + raise AristaAvdMissingVariableError( + msg, + ) return evpn_vlan_bundle def _get_svi_l2vlan_bundle(self: AvdStructuredConfigNetworkServices, evpn_vlan_bundle: dict, tenant: dict, vlans: list) -> dict | None: - """ - Return an bundle config for a svi or l2vlan. - """ + """Return an bundle config for a svi or l2vlan.""" bundle = self._router_bgp_vlan_aware_bundle( name=evpn_vlan_bundle["name"], vlans=vlans, @@ -527,11 +514,8 @@ def _get_svi_l2vlan_bundle(self: AvdStructuredConfigNetworkServices, evpn_vlan_b return None - def _router_bgp_vlan_aware_bundles(self: AvdStructuredConfigNetworkServices, tenant_svis_l2vlans_dict) -> list | None: - """ - Return structured config for router_bgp.vlan_aware_bundles - """ - + def _router_bgp_vlan_aware_bundles(self: AvdStructuredConfigNetworkServices, tenant_svis_l2vlans_dict: dict) -> list | None: + """Return structured config for router_bgp.vlan_aware_bundles.""" if not self.shared_utils.network_services_l2 or not self.shared_utils.overlay_evpn: return None @@ -592,7 +576,7 @@ def _router_bgp_vlan_aware_bundles(self: AvdStructuredConfigNetworkServices, ten ) # L2VLANs and SVIs which have an evpn_vlan_bundle defined - for bundle_name, bundle_dict in l2vlan_svi_vlan_aware_bundles.items(): + for bundle_dict in l2vlan_svi_vlan_aware_bundles.values(): evpn_vlan_bundle = bundle_dict["evpn_vlan_bundle"] l2vlans_svis = bundle_dict["l2vlan_svis"] @@ -614,9 +598,7 @@ def _router_bgp_vlan_aware_bundles(self: AvdStructuredConfigNetworkServices, ten return bundles or None def _router_bgp_vlan_aware_bundles_vrf(self: AvdStructuredConfigNetworkServices, vrf: dict, tenant: dict, vlans: list[dict]) -> dict | None: - """ - Return structured config for one vrf under router_bgp.vlan_aware_bundles - """ + """Return structured config for one vrf under router_bgp.vlan_aware_bundles.""" return self._router_bgp_vlan_aware_bundle( name=vrf["name"], vlans=vlans, @@ -627,11 +609,18 @@ def _router_bgp_vlan_aware_bundles_vrf(self: AvdStructuredConfigNetworkServices, ) def _router_bgp_vlan_aware_bundle( - self: AvdStructuredConfigNetworkServices, name: str, vlans: list, rd: str, rt: str, evpn_l2_multi_domain: bool, tenant: dict + self: AvdStructuredConfigNetworkServices, + name: str, + vlans: list, + rd: str, + rt: str, + evpn_l2_multi_domain: bool, + tenant: dict, ) -> dict | None: """ Return structured config for one vlan-aware-bundle. - Used for VRFs and bundles defined under "evpn_vlan_bundles" referred by l2vlans and SVIs + + Used for VRFs and bundles defined under "evpn_vlan_bundles" referred by l2vlans and SVIs. """ vlans = [vlan for vlan in vlans if vlan.get("vxlan") is not False] if not vlans: @@ -659,6 +648,7 @@ def _router_bgp_vlan_aware_bundle( def _rt_admin_subfield(self: AvdStructuredConfigNetworkServices) -> str | None: """ Return a string with the route-target admin subfield unless set to "vrf_id" or "vrf_vni" or "id". + Returns None if not set, since the calling functions will use per-vlan numbers by default. """ @@ -674,28 +664,32 @@ def _rt_admin_subfield(self: AvdStructuredConfigNetworkServices) -> str | None: return None - def get_vlan_mac_vrf_id(self: AvdStructuredConfigNetworkServices, vlan, tenant) -> int: + def get_vlan_mac_vrf_id(self: AvdStructuredConfigNetworkServices, vlan: dict, tenant: dict) -> int: mac_vrf_id_base = default(tenant.get("mac_vrf_id_base"), tenant.get("mac_vrf_vni_base")) if mac_vrf_id_base is None: - raise AristaAvdMissingVariableError( + msg = ( "'rt_override' or 'vni_override' or 'mac_vrf_id_base' or 'mac_vrf_vni_base' must be set. " f"Unable to set EVPN RD/RT for vlan {vlan['id']} in Tenant '{vlan['tenant']}'" ) + raise AristaAvdMissingVariableError( + msg, + ) return mac_vrf_id_base + int(vlan["id"]) - def get_vlan_mac_vrf_vni(self: AvdStructuredConfigNetworkServices, vlan, tenant) -> int: + def get_vlan_mac_vrf_vni(self: AvdStructuredConfigNetworkServices, vlan: dict, tenant: dict) -> int: mac_vrf_vni_base = default(tenant.get("mac_vrf_vni_base"), tenant.get("mac_vrf_id_base")) if mac_vrf_vni_base is None: - raise AristaAvdMissingVariableError( + msg = ( "'rt_override' or 'vni_override' or 'mac_vrf_id_base' or 'mac_vrf_vni_base' must be set. " f"Unable to set EVPN RD/RT for vlan {vlan['id']} in Tenant '{vlan['tenant']}'" ) + raise AristaAvdMissingVariableError( + msg, + ) return mac_vrf_vni_base + int(vlan["id"]) - def get_vlan_rd(self: AvdStructuredConfigNetworkServices, vlan, tenant) -> str: - """ - Return a string with the route-destinguisher for one VLAN - """ + def get_vlan_rd(self: AvdStructuredConfigNetworkServices, vlan: dict, tenant: dict) -> str: + """Return a string with the route-destinguisher for one VLAN.""" rd_override = default(vlan.get("rd_override"), vlan.get("rt_override"), vlan.get("vni_override")) if ":" in str(rd_override): @@ -713,9 +707,7 @@ def get_vlan_rd(self: AvdStructuredConfigNetworkServices, vlan, tenant) -> str: return f"{self.shared_utils.overlay_rd_type_admin_subfield}:{assigned_number_subfield}" def get_vlan_rt(self: AvdStructuredConfigNetworkServices, vlan: dict, tenant: dict) -> str: - """ - Return a string with the route-target for one VLAN - """ + """Return a string with the route-target for one VLAN.""" rt_override = default(vlan.get("rt_override"), vlan.get("vni_override")) if ":" in str(rt_override): @@ -747,6 +739,7 @@ def get_vlan_rt(self: AvdStructuredConfigNetworkServices, vlan: dict, tenant: di def _vrf_rt_admin_subfield(self: AvdStructuredConfigNetworkServices) -> str | None: """ Return a string with the VRF route-target admin subfield unless set to "vrf_id" or "vrf_vni" or "id". + Returns None if not set, since the calling functions will use per-vrf numbers by default. """ @@ -762,10 +755,8 @@ def _vrf_rt_admin_subfield(self: AvdStructuredConfigNetworkServices) -> str | No return None - def get_vrf_rd(self: AvdStructuredConfigNetworkServices, vrf) -> str: - """ - Return a string with the route-destinguisher for one VRF - """ + def get_vrf_rd(self: AvdStructuredConfigNetworkServices, vrf: dict) -> str: + """Return a string with the route-destinguisher for one VRF.""" rd_override = default(vrf.get("rd_override")) if ":" in str(rd_override): @@ -777,9 +768,7 @@ def get_vrf_rd(self: AvdStructuredConfigNetworkServices, vrf) -> str: return f"{self.shared_utils.overlay_rd_type_vrf_admin_subfield}:{self.shared_utils.get_vrf_id(vrf)}" def get_vrf_rt(self: AvdStructuredConfigNetworkServices, vrf: dict) -> str: - """ - Return a string with the route-target for one VRF - """ + """Return a string with the route-target for one VRF.""" rt_override = default(vrf.get("rt_override")) if ":" in str(rt_override): @@ -798,14 +787,15 @@ def get_vrf_rt(self: AvdStructuredConfigNetworkServices, vrf: dict) -> str: return f"{admin_subfield}:{self.shared_utils.get_vrf_id(vrf)}" - def get_vlan_aware_bundle_rd(self: AvdStructuredConfigNetworkServices, id: int, tenant: dict, is_vrf: bool, rd_override: str | None = None) -> str: - """ - Return a string with the route-destinguisher for one VLAN Aware Bundle - """ - if is_vrf: - admin_subfield = self.shared_utils.overlay_rd_type_vrf_admin_subfield - else: - admin_subfield = self.shared_utils.overlay_rd_type_admin_subfield + def get_vlan_aware_bundle_rd( + self: AvdStructuredConfigNetworkServices, + id: int, # noqa: A002 + tenant: dict, + is_vrf: bool, + rd_override: str | None = None, + ) -> str: + """Return a string with the route-destinguisher for one VLAN Aware Bundle.""" + admin_subfield = self.shared_utils.overlay_rd_type_vrf_admin_subfield if is_vrf else self.shared_utils.overlay_rd_type_admin_subfield if rd_override is not None: if ":" in str(rd_override): @@ -817,11 +807,14 @@ def get_vlan_aware_bundle_rd(self: AvdStructuredConfigNetworkServices, id: int, return f"{admin_subfield}:{bundle_number}" def get_vlan_aware_bundle_rt( - self: AvdStructuredConfigNetworkServices, id: int, vni: int, tenant: dict, is_vrf: bool, rt_override: str | None = None + self: AvdStructuredConfigNetworkServices, + id: int, # noqa: A002 + vni: int, + tenant: dict, + is_vrf: bool, + rt_override: str | None = None, ) -> str: - """ - Return a string with the route-target for one VLAN Aware Bundle - """ + """Return a string with the route-target for one VLAN Aware Bundle.""" if rt_override is not None and ":" in str(rt_override): return rt_override @@ -843,7 +836,7 @@ def get_vlan_aware_bundle_rt( @cached_property def _router_bgp_redistribute_routes(self: AvdStructuredConfigNetworkServices) -> list | None: """ - Return structured config for router_bgp.redistribute_routes + Return structured config for router_bgp.redistribute_routes. Add redistribute static to default if either "redistribute_in_overlay" is set or "redistribute_in_underlay" and underlay protocol is BGP. @@ -864,10 +857,7 @@ def _router_bgp_redistribute_routes(self: AvdStructuredConfigNetworkServices) -> @cached_property def _router_bgp_vpws(self: AvdStructuredConfigNetworkServices) -> list[dict] | None: - """ - Return structured config for router_bgp.vpws - """ - + """Return structured config for router_bgp.vpws.""" if not (self.shared_utils.network_services_l1 and self.shared_utils.overlay_ler and self.shared_utils.overlay_evpn_mpls): return None @@ -901,7 +891,7 @@ def _router_bgp_vpws(self: AvdStructuredConfigNetworkServices) -> list[dict] | N "name": f"{point_to_point_service['name']}_{subif_number}", "id_local": int(endpoint["id"]) + subif_number, "id_remote": int(remote_endpoint["id"]) + subif_number, - } + }, ) else: @@ -910,7 +900,7 @@ def _router_bgp_vpws(self: AvdStructuredConfigNetworkServices) -> list[dict] | N "name": f"{point_to_point_service['name']}", "id_local": int(endpoint["id"]), "id_remote": int(remote_endpoint["id"]), - } + }, ) if pseudowires: @@ -923,7 +913,7 @@ def _router_bgp_vpws(self: AvdStructuredConfigNetworkServices) -> list[dict] | N "rd": rd, "route_targets": {"import_export": rt}, "pseudowires": pseudowires, - } + }, ) if vpws: @@ -933,8 +923,7 @@ def _router_bgp_vpws(self: AvdStructuredConfigNetworkServices) -> list[dict] | N def _router_bgp_mlag_peer_group(self: AvdStructuredConfigNetworkServices) -> dict: """ - Return a partial router_bgp structured_config covering the MLAG peer_group - and associated address_family activations + Return a partial router_bgp structured_config covering the MLAG peer_group and associated address_family activations. TODO: Partially duplicated from mlag. Should be moved to a common class """ @@ -963,8 +952,8 @@ def _router_bgp_mlag_peer_group(self: AvdStructuredConfigNetworkServices) -> dic { "name": peer_group_name, "activate": True, - } - ] + }, + ], } address_family_ipv4_peer_group = {"name": peer_group_name, "activate": True} diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/router_internet_exit.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/router_internet_exit.py index 05d87ebc8da..edf1959178c 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/router_internet_exit.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/router_internet_exit.py @@ -7,7 +7,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import get +from pyavd._utils import get + from .utils import UtilsMixin if TYPE_CHECKING: @@ -17,13 +18,14 @@ class RouterInternetExitMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def router_internet_exit(self: AvdStructuredConfigNetworkServices) -> dict | None: """ - Return structured config for router_internet_exit + Return structured config for router_internet_exit. Only used for CV Pathfinder edge routers today """ diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/router_isis.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/router_isis.py index c65f9734686..e33114449c6 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/router_isis.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/router_isis.py @@ -15,19 +15,19 @@ class RouterIsisMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def router_isis(self: AvdStructuredConfigNetworkServices) -> dict | None: """ - return structured config for router_isis + return structured config for router_isis. Used for non-EVPN where underlay_routing_protocol is ISIS, static routes in VRF "default" should be redistributed into ISIS unless specifically disabled under the vrf. """ - if ( self.shared_utils.network_services_l3 and self._vrf_default_ipv4_static_routes["redistribute_in_underlay"] diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/router_multicast.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/router_multicast.py index c528193aae5..a674af68839 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/router_multicast.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/router_multicast.py @@ -6,7 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import append_if_not_duplicate, get +from pyavd._utils import append_if_not_duplicate, get + from .utils import UtilsMixin if TYPE_CHECKING: @@ -16,17 +17,17 @@ class RouterMulticastMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def router_multicast(self: AvdStructuredConfigNetworkServices) -> dict | None: """ - return structured config for router_multicast + return structured config for router_multicast. Used to enable multicast routing on the VRF. """ - if not self.shared_utils.network_services_l3: return None @@ -36,7 +37,11 @@ def router_multicast(self: AvdStructuredConfigNetworkServices) -> dict | None: if get(vrf, "_evpn_l3_multicast_enabled"): vrf_config = {"name": vrf["name"], "ipv4": {"routing": True}} append_if_not_duplicate( - list_of_dicts=vrfs, primary_key="name", new_dict=vrf_config, context="Router Multicast for VRFs", context_keys=["name"] + list_of_dicts=vrfs, + primary_key="name", + new_dict=vrf_config, + context="Router Multicast for VRFs", + context_keys=["name"], ) if vrfs: diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/router_ospf.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/router_ospf.py index 92f40946ec8..0607474e92b 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/router_ospf.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/router_ospf.py @@ -6,8 +6,9 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._errors import AristaAvdMissingVariableError -from ...._utils import append_if_not_duplicate, default, get +from pyavd._errors import AristaAvdMissingVariableError +from pyavd._utils import append_if_not_duplicate, default, get + from .utils import UtilsMixin if TYPE_CHECKING: @@ -17,18 +18,18 @@ class RouterOspfMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def router_ospf(self: AvdStructuredConfigNetworkServices) -> dict | None: """ - return structured config for router_ospf + return structured config for router_ospf. If we have static_routes in default VRF and not EPVN, and underlay is OSPF Then add redistribute static to the underlay OSPF process. """ - if not self.shared_utils.network_services_l3: return None @@ -57,7 +58,8 @@ def router_ospf(self: AvdStructuredConfigNetworkServices) -> dict | None: process_id = default(get(vrf, "ospf.process_id"), vrf.get("vrf_id")) if not process_id: - raise AristaAvdMissingVariableError(f"'ospf.process_id' or 'vrf_id' under vrf '{vrf['name']}") + msg = f"'ospf.process_id' or 'vrf_id' under vrf '{vrf['name']}" + raise AristaAvdMissingVariableError(msg) process = { "id": process_id, @@ -87,7 +89,11 @@ def router_ospf(self: AvdStructuredConfigNetworkServices) -> dict | None: process = {key: value for key, value in process.items() if value is not None} append_if_not_duplicate( - list_of_dicts=ospf_processes, primary_key="id", new_dict=process, context="OSPF Processes defined under network services", context_keys="id" + list_of_dicts=ospf_processes, + primary_key="id", + new_dict=process, + context="OSPF Processes defined under network services", + context_keys="id", ) # If we have static_routes in default VRF and not EPVN, and underlay is OSPF diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/router_path_selection.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/router_path_selection.py index 3032db940e0..7c08ecdfaea 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/router_path_selection.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/router_path_selection.py @@ -6,7 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import append_if_not_duplicate, get, strip_empties_from_dict +from pyavd._utils import append_if_not_duplicate, get, strip_empties_from_dict + from .utils import UtilsMixin if TYPE_CHECKING: @@ -16,15 +17,13 @@ class RouterPathSelectionMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def router_path_selection(self: AvdStructuredConfigNetworkServices) -> dict | None: - """ - Return structured config for router path-selection (DPS) - """ - + """Return structured config for router path-selection (DPS).""" if not self.shared_utils.is_wan_router: return None @@ -41,15 +40,13 @@ def router_path_selection(self: AvdStructuredConfigNetworkServices) -> dict | No { "policies": self._autovpn_policies(), "vrfs": vrfs, - } + }, ) return strip_empties_from_dict(router_path_selection) def _wan_load_balance_policies(self: AvdStructuredConfigNetworkServices) -> list: - """ - Return a list of load balance policies - """ + """Return a list of load balance policies.""" load_balance_policies = [] for policy in self._filtered_wan_policies: for match in policy.get("matches", []): @@ -72,9 +69,7 @@ def _wan_load_balance_policies(self: AvdStructuredConfigNetworkServices) -> list return load_balance_policies def _autovpn_policies(self: AvdStructuredConfigNetworkServices) -> list: - """ - Return a list of policies for AutoVPN - """ + """Return a list of policies for AutoVPN.""" policies = [] for policy in self._filtered_wan_policies: autovpn_policy = {"name": policy["name"], "rules": []} @@ -84,7 +79,7 @@ def _autovpn_policies(self: AvdStructuredConfigNetworkServices) -> list: "id": 10 * index, "application_profile": match["application_profile"], "load_balance": match["load_balance_policy"]["name"], - } + }, ) if (default_match := policy.get("default_match")) is not None: autovpn_policy["default_match"] = {"load_balance": default_match["load_balance_policy"]["name"]} diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/router_pim_sparse_mode.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/router_pim_sparse_mode.py index 1a4be6e597f..a86c850fa4d 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/router_pim_sparse_mode.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/router_pim_sparse_mode.py @@ -6,7 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import append_if_not_duplicate, get +from pyavd._utils import append_if_not_duplicate, get + from .utils import UtilsMixin if TYPE_CHECKING: @@ -16,17 +17,17 @@ class RouterPimSparseModeMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def router_pim_sparse_mode(self: AvdStructuredConfigNetworkServices) -> dict | None: """ - return structured config for router_pim + return structured config for router_pim. Used for to configure RPs on the VRF """ - if not self.shared_utils.network_services_l3: return None @@ -41,7 +42,11 @@ def router_pim_sparse_mode(self: AvdStructuredConfigNetworkServices) -> dict | N }, } append_if_not_duplicate( - list_of_dicts=vrfs, primary_key="name", new_dict=vrf_config, context="Router PIM Sparse-Mode for VRFs", context_keys=["name"] + list_of_dicts=vrfs, + primary_key="name", + new_dict=vrf_config, + context="Router PIM Sparse-Mode for VRFs", + context_keys=["name"], ) if vrfs: return {"vrfs": vrfs} diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/router_service_insertion.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/router_service_insertion.py index adb9e86127a..004eecd07d2 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/router_service_insertion.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/router_service_insertion.py @@ -15,13 +15,14 @@ class RouterServiceInsertionMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def router_service_insertion(self: AvdStructuredConfigNetworkServices) -> dict | None: """ - Return structured config for router_service_insertion + Return structured config for router_service_insertion. Only used for CV Pathfinder edge routers today """ diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/spanning_tree.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/spanning_tree.py index a1d8a36c8b5..74c08520ab5 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/spanning_tree.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/spanning_tree.py @@ -6,8 +6,9 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import get -from ....j2filters import list_compress +from pyavd._utils import get +from pyavd.j2filters import list_compress + from .utils import UtilsMixin if TYPE_CHECKING: @@ -17,14 +18,13 @@ class SpanningTreeMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def spanning_tree(self: AvdStructuredConfigNetworkServices) -> dict | None: - """ - spanning_tree priorities set per VLAN if spanning_tree mode is "rapid-pvst" - """ + """spanning_tree priorities set per VLAN if spanning_tree mode is "rapid-pvst".""" if not self.shared_utils.network_services_l2: return None @@ -62,5 +62,5 @@ def spanning_tree(self: AvdStructuredConfigNetworkServices) -> dict | None: "priority": priority, } for priority, vlans in vlan_stp_priorities.items() - ] + ], } diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/standard_access_lists.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/standard_access_lists.py index 82f0737d267..6c2932e766b 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/standard_access_lists.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/standard_access_lists.py @@ -6,7 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import append_if_not_duplicate, default, get +from pyavd._utils import append_if_not_duplicate, default, get + from .utils import UtilsMixin if TYPE_CHECKING: @@ -16,17 +17,17 @@ class StandardAccessListsMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def standard_access_lists(self: AvdStructuredConfigNetworkServices) -> list | None: """ - return structured config for standard_access_lists + return structured config for standard_access_lists. Used for to configure ACLs used by multicast RPs in each VRF """ - if not self.shared_utils.network_services_l3: return None diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/static_routes.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/static_routes.py index 9ad652f9b8d..75e430f233f 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/static_routes.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/static_routes.py @@ -16,19 +16,19 @@ class StaticRoutesMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def static_routes(self: AvdStructuredConfigNetworkServices) -> list[dict] | None: """ - Returns structured config for static_routes + Returns structured config for static_routes. Consist of - static_routes defined under the vrfs - static routes added automatically for VARP with prefixes """ - if not self.shared_utils.network_services_l3: return None diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/struct_cfgs.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/struct_cfgs.py index ad0438ca3d5..2976a5711d5 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/struct_cfgs.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/struct_cfgs.py @@ -6,7 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import append_if_not_duplicate +from pyavd._utils import append_if_not_duplicate + from .utils import UtilsMixin if TYPE_CHECKING: @@ -16,15 +17,13 @@ class StructCfgsMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def struct_cfgs(self: AvdStructuredConfigNetworkServices) -> list | None: - """ - Return the combined structured config from VRFs - """ - + """Return the combined structured config from VRFs.""" if not self.shared_utils.network_services_l3: return None diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/tunnel_interfaces.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/tunnel_interfaces.py index 17f35ac9747..e1121b91773 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/tunnel_interfaces.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/tunnel_interfaces.py @@ -6,7 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import append_if_not_duplicate +from pyavd._utils import append_if_not_duplicate + from .utils import UtilsMixin if TYPE_CHECKING: @@ -16,13 +17,14 @@ class TunnelInterfacesMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def tunnel_interfaces(self: AvdStructuredConfigNetworkServices) -> list | None: """ - Return structured config for tunnel_interfaces + Return structured config for tunnel_interfaces. Only used for CV Pathfinder edge routers today """ diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/utils.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/utils.py index 826e77f0059..509bc38e488 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/utils.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/utils.py @@ -5,12 +5,13 @@ import ipaddress from functools import cached_property -from typing import TYPE_CHECKING, Literal, Tuple +from typing import TYPE_CHECKING, Literal + +from pyavd._errors import AristaAvdError, AristaAvdMissingVariableError +from pyavd._utils import default, get, get_item +from pyavd._utils.password_utils.password import simple_7_encrypt +from pyavd.j2filters import natural_sort, range_expand -from ...._errors import AristaAvdError, AristaAvdMissingVariableError -from ...._utils import default, get, get_item -from ...._utils.password_utils.password import simple_7_encrypt -from ....j2filters import natural_sort, range_expand from .utils_zscaler import UtilsZscalerMixin if TYPE_CHECKING: @@ -20,7 +21,8 @@ class UtilsMixin(UtilsZscalerMixin): """ Mixin Class with internal functions. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property @@ -41,9 +43,7 @@ def _local_endpoint_trunk_groups(self: AvdStructuredConfigNetworkServices) -> se @cached_property def _vrf_default_evpn(self: AvdStructuredConfigNetworkServices) -> bool: - """ - Return boolean telling if VRF "default" is running EVPN or not. - """ + """Return boolean telling if VRF "default" is running EVPN or not.""" if not (self.shared_utils.network_services_l3 and self.shared_utils.overlay_vtep and self.shared_utils.overlay_evpn): return False @@ -53,16 +53,15 @@ def _vrf_default_evpn(self: AvdStructuredConfigNetworkServices) -> bool: if "evpn" in vrf_default.get("address_families", ["evpn"]): if self.shared_utils.underlay_filter_peer_as: - raise AristaAvdError("'underlay_filter_peer_as' cannot be used while there are EVPN services in the default VRF.") + msg = "'underlay_filter_peer_as' cannot be used while there are EVPN services in the default VRF." + raise AristaAvdError(msg) return True return False @cached_property def _vrf_default_ipv4_subnets(self: AvdStructuredConfigNetworkServices) -> list[str]: - """ - Return list of ipv4 subnets in VRF "default" - """ + """Return list of ipv4 subnets in VRF "default".""" subnets = [] for tenant in self.shared_utils.filtered_tenants: if (vrf_default := get_item(tenant["vrfs"], "name", "default")) is None: @@ -84,7 +83,7 @@ def _vrf_default_ipv4_static_routes(self: AvdStructuredConfigNetworkServices) -> """ Finds static routes defined under VRF "default" and find out if they should be redistributed in underlay and/or overlay. - Returns + Returns: ------- dict static_routes: [] @@ -127,20 +126,21 @@ def _vrf_default_ipv4_static_routes(self: AvdStructuredConfigNetworkServices) -> "redistribute_in_overlay": redistribute_in_overlay, } - def _mlag_ibgp_peering_enabled(self: AvdStructuredConfigNetworkServices, vrf, tenant) -> bool: + def _mlag_ibgp_peering_enabled(self: AvdStructuredConfigNetworkServices, vrf: dict, tenant: dict) -> bool: """ - Returns True if mlag ibgp_peering is enabled - False otherwise + Returns True if mlag ibgp_peering is enabled. + + False otherwise. """ if not self.shared_utils.mlag_l3 or not self.shared_utils.network_services_l3: return False - mlag_ibgp_peering: bool = default(vrf.get("enable_mlag_ibgp_peering_vrfs"), tenant.get("enable_mlag_ibgp_peering_vrfs"), True) + mlag_ibgp_peering: bool = default(vrf.get("enable_mlag_ibgp_peering_vrfs"), tenant.get("enable_mlag_ibgp_peering_vrfs"), True) # noqa: FBT003 return vrf["name"] != "default" and mlag_ibgp_peering - def _mlag_ibgp_peering_vlan_vrf(self: AvdStructuredConfigNetworkServices, vrf, tenant) -> int | None: + def _mlag_ibgp_peering_vlan_vrf(self: AvdStructuredConfigNetworkServices, vrf: dict, tenant: dict) -> int | None: """ - MLAG IBGP Peering VLANs per VRF + MLAG IBGP Peering VLANs per VRF. Performs all relevant checks if MLAG IBGP Peering is enabled Returns None if peering is not enabled @@ -154,28 +154,31 @@ def _mlag_ibgp_peering_vlan_vrf(self: AvdStructuredConfigNetworkServices, vrf, t base_vlan = self.shared_utils.mlag_ibgp_peering_vrfs_base_vlan vrf_id = vrf.get("vrf_id", vrf.get("vrf_vni")) if vrf_id is None: + msg = f"Unable to assign MLAG VRF Peering VLAN for vrf {vrf['name']}.Set either 'mlag_ibgp_peering_vlan' or 'vrf_id' or 'vrf_vni' on the VRF" raise AristaAvdMissingVariableError( - f"Unable to assign MLAG VRF Peering VLAN for vrf {vrf['name']}.Set either 'mlag_ibgp_peering_vlan' or 'vrf_id' or 'vrf_vni' on the VRF" + msg, ) vlan_id = base_vlan + int(vrf_id) - 1 return vlan_id - def _mlag_ibgp_peering_redistribute(self: AvdStructuredConfigNetworkServices, vrf, tenant) -> bool: + def _mlag_ibgp_peering_redistribute(self: AvdStructuredConfigNetworkServices, vrf: dict, tenant: dict) -> bool: """ Returns True if MLAG IBGP Peering subnet should be redistributed for the given vrf/tenant. + False otherwise. Does _not_ include checks if the peering is enabled at all, so that should be checked first. """ - return default(vrf.get("redistribute_mlag_ibgp_peering_vrfs"), tenant.get("redistribute_mlag_ibgp_peering_vrfs"), True) is True + return default(vrf.get("redistribute_mlag_ibgp_peering_vrfs"), tenant.get("redistribute_mlag_ibgp_peering_vrfs"), True) is True # noqa: FBT003 @cached_property def _configure_bgp_mlag_peer_group(self: AvdStructuredConfigNetworkServices) -> bool: """ Flag set during creating of BGP VRFs if an MLAG peering is needed. + Decides if MLAG BGP peer-group should be configured. - Catches cases where underlay is not BGP but we still need MLAG iBGP peering + Catches cases where underlay is not BGP but we still need MLAG iBGP peering. """ if self.shared_utils.underlay_bgp or (bgp_vrfs := self._router_bgp_vrfs) is None: return False @@ -191,9 +194,7 @@ def _configure_bgp_mlag_peer_group(self: AvdStructuredConfigNetworkServices) -> @cached_property def _filtered_wan_vrfs(self: AvdStructuredConfigNetworkServices) -> list: - """ - Loop through all the VRFs defined under `wan_virtual_topologies.vrfs` and returns a list of mode - """ + """Loop through all the VRFs defined under `wan_virtual_topologies.vrfs` and returns a list of mode.""" wan_vrfs = [] for vrf in get(self._hostvars, "wan_virtual_topologies.vrfs", []): @@ -203,7 +204,10 @@ def _filtered_wan_vrfs(self: AvdStructuredConfigNetworkServices) -> list: "name": vrf_name, "policy": get(vrf, "policy", default=self._default_wan_policy_name), "wan_vni": get( - vrf, "wan_vni", required=True, org_key=f"Required `wan_vni` is missing for VRF {vrf_name} under `wan_virtual_topologies.vrfs`." + vrf, + "wan_vni", + required=True, + org_key=f"Required `wan_vni` is missing for VRF {vrf_name} under `wan_virtual_topologies.vrfs`.", ), } @@ -217,7 +221,7 @@ def _filtered_wan_vrfs(self: AvdStructuredConfigNetworkServices) -> list: "policy": f"{self._default_wan_policy_name}-WITH-CP", "wan_vni": 1, "original_policy": self._default_wan_policy_name, - } + }, ) else: vrf_default["original_policy"] = vrf_default["policy"] @@ -227,9 +231,7 @@ def _filtered_wan_vrfs(self: AvdStructuredConfigNetworkServices) -> list: @cached_property def _wan_virtual_topologies_policies(self: AvdStructuredConfigNetworkServices) -> list: - """ - This function parses the input data and append the default-policy if not already present - """ + """This function parses the input data and append the default-policy if not already present.""" policies = get(self._hostvars, "wan_virtual_topologies.policies", default=[]) # If not overwritten, inject the default policy in case it is required for one of the VRFs if get_item(policies, "name", self._default_wan_policy_name) is None: @@ -284,7 +286,8 @@ def _filtered_wan_policies(self: AvdStructuredConfigNetworkServices) -> list: def _update_policy_match_statements(self: AvdStructuredConfigNetworkServices, policy: dict) -> None: """ - Update the policy dict with two keys: `matches` and `default_match` + Update the policy dict with two keys: `matches` and `default_match`. + For each match (or default_match), the load_balancing policy is resolved and if it is empty the match statement is not included. """ @@ -301,7 +304,8 @@ def _update_policy_match_statements(self: AvdStructuredConfigNetworkServices, po policy["name"], ) ) is None: - raise AristaAvdError("The WAN control-plane load-balance policy is empty. Make sure at least one path-group can be used in the policy") + msg = "The WAN control-plane load-balance policy is empty. Make sure at least one path-group can be used in the policy" + raise AristaAvdError(msg) matches.append( { "application_profile": self._wan_control_plane_application_profile_name, @@ -311,7 +315,7 @@ def _update_policy_match_statements(self: AvdStructuredConfigNetworkServices, po "dscp": get(control_plane_virtual_topology, "dscp"), "load_balance_policy": load_balance_policy, "id": 254, - } + }, ) for application_virtual_topology in get(policy, "application_virtual_topologies", []): @@ -354,11 +358,14 @@ def _update_policy_match_statements(self: AvdStructuredConfigNetworkServices, po "dscp": get(application_virtual_topology, "dscp"), "load_balance_policy": load_balance_policy, "id": profile_id, - } + }, ) default_virtual_topology = get( - policy, "default_virtual_topology", required=True, org_key=f"wan_virtual_topologies.policies[{policy['profile_prefix']}].default_virtual_toplogy" + policy, + "default_virtual_topology", + required=True, + org_key=f"wan_virtual_topologies.policies[{policy['profile_prefix']}].default_virtual_toplogy", ) # Separating default_match as it is used differently default_match = None @@ -379,11 +386,14 @@ def _update_policy_match_statements(self: AvdStructuredConfigNetworkServices, po load_balance_policy_name = self.shared_utils.generate_lb_policy_name(name) load_balance_policy = self._generate_wan_load_balance_policy(load_balance_policy_name, default_virtual_topology, context_path) if not load_balance_policy: - raise AristaAvdError( + msg = ( f"The `default_virtual_topology` path-groups configuration for `wan_virtual_toplogies.policies[{policy['name']}]` produces " "an empty load-balancing policy. Make sure at least one path-group present on the device is allowed in the " "`default_virtual_topology` path-groups." ) + raise AristaAvdError( + msg, + ) application_profile = get(default_virtual_topology, "application_profile", default="default") default_match = { @@ -398,10 +408,13 @@ def _update_policy_match_statements(self: AvdStructuredConfigNetworkServices, po if not matches and not default_match: # The policy is empty but should be assigned to a VRF - raise AristaAvdError( + msg = ( f"The policy `wan_virtual_toplogies.policies[{policy['name']}]` cannot match any traffic but is assigned to a VRF. " "Make sure at least one path-group present on the device is used in the policy." ) + raise AristaAvdError( + msg, + ) policy["matches"] = matches policy["default_match"] = default_match @@ -409,6 +422,7 @@ def _update_policy_match_statements(self: AvdStructuredConfigNetworkServices, po def _generate_wan_load_balance_policy(self: AvdStructuredConfigNetworkServices, name: str, input_dict: dict, context_path: str) -> dict | None: """ Generate and return a router path-selection load-balance policy. + If HA is enabled, inject the HA path-group with priority 1. Attrs: @@ -464,7 +478,7 @@ def _generate_wan_load_balance_policy(self: AvdStructuredConfigNetworkServices, # The policy is empty return None - # TODO for now adding LAN_HA only if the path-groups list is not empty + # TODO: for now adding LAN_HA only if the path-groups list is not empty # then need to add the logic to check HA peer path-group and maybe return a policy with LAN_HA only. if self.shared_utils.wan_ha or self.shared_utils.is_cv_pathfinder_server: # Adding HA path-group with priority 1 - it does not count as an entry with priority 1 @@ -493,23 +507,26 @@ def _path_group_preference_to_eos_priority(self: AvdStructuredConfigNetworkServi failed_conversion = True if failed_conversion or not 1 <= priority <= 65535: - raise AristaAvdError( + msg = ( f"Invalid value '{path_group_preference}' for Path-Group preference - should be either 'preferred', " f"'alternate' or an integer[1-65535] for {context_path}." ) + raise AristaAvdError( + msg, + ) return priority @cached_property def _default_wan_policy_name(self: AvdStructuredConfigNetworkServices) -> str: - """ - TODO make this configurable - """ + """TODO: make this configurable.""" return "DEFAULT-POLICY" @cached_property def _default_policy_path_group_names(self: AvdStructuredConfigNetworkServices) -> list: """ + Return a list of path group names for the default policy. + Return the list of path-groups to consider when generating a default policy with AVD whether for the default policy or the special Control-plane policy. """ @@ -518,22 +535,24 @@ def _default_policy_path_group_names(self: AvdStructuredConfigNetworkServices) - } if not path_group_names.intersection(self.shared_utils.wan_local_path_group_names): # No common path-group between this device local path-groups and the available path-group for the default policy - raise AristaAvdError( + msg = ( f"Unable to generate the default WAN policy as none of the device local path-groups {self.shared_utils.wan_local_path_group_names} " "is eligible to be included. Make sure that at least one path-group for the device is not configured with " "`excluded_from_default_policy: true` under `wan_path_groups`." ) + raise AristaAvdError( + msg, + ) return natural_sort(path_group_names) @cached_property def _default_wan_policy(self: AvdStructuredConfigNetworkServices) -> dict: """ + Returning policy containing all path groups not excluded from default policy. + If no policy is defined for a VRF under 'wan_virtual_topologies.vrfs', a default policy named DEFAULT-POLICY is used where all traffic is matched in the default category and distributed amongst all path-groups. - - Returning policy containing all path groups not excluded from default policy. """ - return { "name": self._default_wan_policy_name, "default_virtual_topology": {"path_groups": [{"names": self._default_policy_path_group_names}]}, @@ -541,7 +560,7 @@ def _default_wan_policy(self: AvdStructuredConfigNetworkServices) -> dict: def _default_profile_name(self: AvdStructuredConfigNetworkServices, profile_name: str, application_profile: str) -> str: """ - Helper function to consistently return the default name of a profile + Helper function to consistently return the default name of a profile. Returns {profile_name}-{application_profile} """ @@ -564,24 +583,18 @@ def _wan_control_plane_virtual_topology(self: AvdStructuredConfigNetworkServices @cached_property def _wan_control_plane_profile_name(self: AvdStructuredConfigNetworkServices) -> str: - """ - Control plane profile name - """ + """Control plane profile name.""" vrf_default_policy_name = get(get_item(self._filtered_wan_vrfs, "name", "default"), "original_policy") return get(self._wan_control_plane_virtual_topology, "name", default=f"{vrf_default_policy_name}-CONTROL-PLANE") @cached_property def _wan_control_plane_application_profile_name(self: AvdStructuredConfigNetworkServices) -> str: - """ - Control plane application profile name - """ + """Control plane application profile name.""" return get(self._hostvars, "wan_virtual_topologies.control_plane_virtual_topology.application_profile", default="APP-PROFILE-CONTROL-PLANE") @cached_property def _local_path_groups_connected_to_pathfinder(self: AvdStructuredConfigNetworkServices) -> list: - """ - Return list of names of local path_groups connected to pathfinder - """ + """Return list of names of local path_groups connected to pathfinder.""" return [ path_group["name"] for path_group in self.shared_utils.wan_local_path_groups @@ -591,11 +604,13 @@ def _local_path_groups_connected_to_pathfinder(self: AvdStructuredConfigNetworkS @cached_property def _svi_acls(self: AvdStructuredConfigNetworkServices) -> dict[str, dict[str, dict]] | None: """ - Returns a dict of - : { - "ipv4_acl_in": , - "ipv4_acl_out": , - } + Returns a dict of SVI ACLs. + + : { + "ipv4_acl_in": , + "ipv4_acl_out": , + } + Only contains interfaces with ACLs and only the ACLs that are set, so use `get(self._svi_acls, f"{interface_name}.ipv4_acl_in")` to get the value. """ @@ -640,8 +655,9 @@ def get_internet_exit_nat_acl_name(self: AvdStructuredConfigNetworkServices, int return f"ACL-{self.get_internet_exit_nat_profile_name(internet_exit_policy_type)}" def get_internet_exit_nat_pool_and_profile( - self: AvdStructuredConfigNetworkServices, internet_exit_policy_type: Literal["zscaler", "direct"] - ) -> Tuple[dict | None, dict | None]: + self: AvdStructuredConfigNetworkServices, + internet_exit_policy_type: Literal["zscaler", "direct"], + ) -> tuple[dict | None, dict | None]: if internet_exit_policy_type == "zscaler": pool = { "name": "PORT-ONLY-POOL", @@ -650,7 +666,7 @@ def get_internet_exit_nat_pool_and_profile( { "first_port": 1500, "last_port": 65535, - } + }, ], } @@ -662,8 +678,8 @@ def get_internet_exit_nat_pool_and_profile( "access_list": self.get_internet_exit_nat_acl_name(internet_exit_policy_type), "pool_name": "PORT-ONLY-POOL", "nat_type": "pool", - } - ] + }, + ], }, } return pool, profile @@ -676,30 +692,31 @@ def get_internet_exit_nat_pool_and_profile( { "access_list": self.get_internet_exit_nat_acl_name(internet_exit_policy_type), "nat_type": "overload", - } - ] + }, + ], }, } return None, profile + return None @cached_property def _filtered_internet_exit_policy_types(self: AvdStructuredConfigNetworkServices) -> list: - return sorted(set(internet_exit_policy["type"] for internet_exit_policy in self._filtered_internet_exit_policies)) + return sorted({internet_exit_policy["type"] for internet_exit_policy in self._filtered_internet_exit_policies}) @cached_property def _l3_interface_acls(self: AvdStructuredConfigNetworkServices) -> dict | None: """ Returns a dict of interfaces and ACLs set on the interfaces. - { - : { - "ipv4_acl_in": , - "ipv4_acl_out": , - } + + { + : { + "ipv4_acl_in": , + "ipv4_acl_out": , } + } Only contains interfaces with ACLs and only the ACLs that are set, so use `get(self._l3_interface_acls, f"{interface_name}..ipv4_acl_in", separator="..")` to get the value. """ - if not self.shared_utils.network_services_l3: return None @@ -816,12 +833,11 @@ def get_internet_exit_connections(self: AvdStructuredConfigNetworkServices, inte if policy_type == "zscaler": return self.get_zscaler_internet_exit_connections(internet_exit_policy) - raise AristaAvdError(f"Unsupported type '{policy_type}' found in cv_pathfinder_internet_exit[name={policy_name}].") + msg = f"Unsupported type '{policy_type}' found in cv_pathfinder_internet_exit[name={policy_name}]." + raise AristaAvdError(msg) def get_direct_internet_exit_connections(self: AvdStructuredConfigNetworkServices, internet_exit_policy: dict) -> list: - """ - Return a list of connections (dicts) for the given internet_exit_policy of type direct. - """ + """Return a list of connections (dicts) for the given internet_exit_policy of type direct.""" if get(internet_exit_policy, "type") != "direct": return [] @@ -834,18 +850,23 @@ def get_direct_internet_exit_connections(self: AvdStructuredConfigNetworkService continue if not wan_interface.get("peer_ip"): - raise AristaAvdMissingVariableError( + msg = ( f"{wan_interface['name']} peer_ip needs to be set. When using wan interface " "for direct type internet exit, peer_ip is used for nexthop, and connectivity monitoring." ) + raise AristaAvdMissingVariableError( + msg, + ) # wan interface ip will be used for acl, hence raise error if ip is not available - if (ip_address := wan_interface.get("ip_address")) == "dhcp": - if not (ip_address := wan_interface.get("dhcp_ip")): - raise AristaAvdMissingVariableError( - f"{wan_interface['name']} 'dhcp_ip' needs to be set. When using WAN interface for 'direct' type Internet exit, " - "'dhcp_ip' is used in the NAT ACL." - ) + if (ip_address := wan_interface.get("ip_address")) == "dhcp" and not (ip_address := wan_interface.get("dhcp_ip")): + msg = ( + f"{wan_interface['name']} 'dhcp_ip' needs to be set. When using WAN interface for 'direct' type Internet exit, " + "'dhcp_ip' is used in the NAT ACL." + ) + raise AristaAvdMissingVariableError( + msg, + ) sanitized_interface_name = self.shared_utils.sanitize_interface_name(wan_interface["name"]) connections.append( @@ -859,15 +880,13 @@ def get_direct_internet_exit_connections(self: AvdStructuredConfigNetworkService "source_interface": wan_interface["name"], "description": f"Internet Exit {internet_exit_policy['name']}", "exit_group": f"{internet_exit_policy['name']}", - } + }, ) return connections def get_zscaler_internet_exit_connections(self: AvdStructuredConfigNetworkServices, internet_exit_policy: dict) -> list: - """ - Return a list of connections (dicts) for the given internet_exit_policy of type zscaler. - """ + """Return a list of connections (dicts) for the given internet_exit_policy of type zscaler.""" if get(internet_exit_policy, "type") != "zscaler": return [] @@ -897,10 +916,13 @@ def get_zscaler_internet_exit_connections(self: AvdStructuredConfigNetworkServic tunnel_interface_numbers = get(interface_policy_config, "tunnel_interface_numbers") if tunnel_interface_numbers is None: - raise AristaAvdMissingVariableError( + msg = ( f"{wan_interface['name']}.cv_pathfinder_internet_exit.policies[{internet_exit_policy['name']}]." "tunnel_interface_numbers needs to be set, when using wan interface for zscaler type internet exit." ) + raise AristaAvdMissingVariableError( + msg, + ) tunnel_id_range = range_expand(tunnel_interface_numbers) @@ -932,15 +954,13 @@ def get_zscaler_internet_exit_connections(self: AvdStructuredConfigNetworkServic "exit_group": f"{policy_name}_{suffix}", "preference": zscaler_endpoint_key, "suffix": suffix, - } + }, ) return connections def _get_ipsec_credentials(self: AvdStructuredConfigNetworkServices, internet_exit_policy: dict) -> tuple[str, str]: - """ - Returns ufqdn, shared_key based on various details from the given internet_exit_policy. - """ + """Returns ufqdn, shared_key based on various details from the given internet_exit_policy.""" policy_name = internet_exit_policy["name"] domain_name = get(internet_exit_policy, "zscaler.domain_name", required=True) ipsec_key_salt = get(internet_exit_policy, "zscaler.ipsec_key_salt", required=True) @@ -951,10 +971,11 @@ def _get_ipsec_credentials(self: AvdStructuredConfigNetworkServices, internet_ex def _generate_ipsec_key(self: AvdStructuredConfigNetworkServices, name: str, salt: str) -> str: """ Build a secret containing various components for this policy and device. + Run type-7 obfuscation using a algorithmic salt so we ensure the same key every time. TODO: Maybe introduce some formatting with max length of each element, since the keys can be come very very long. """ - secret = "_".join((self.shared_utils.hostname, name, salt)) + secret = f"{self.shared_utils.hostname}_{name}_{salt}" type_7_salt = sum(salt.encode("utf-8")) % 16 return simple_7_encrypt(secret, type_7_salt) diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/utils_zscaler.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/utils_zscaler.py index 983bbdb03a3..de50545ab65 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/utils_zscaler.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/utils_zscaler.py @@ -8,14 +8,15 @@ from logging import getLogger from typing import TYPE_CHECKING -from ...._cv.client import CVClient -from ...._cv.workflows.models import CVDevice -from ...._cv.workflows.verify_devices_on_cv import verify_devices_in_cloudvision_inventory -from ...._errors import AristaAvdError -from ...._utils import get +from pyavd._cv.client import CVClient +from pyavd._cv.workflows.models import CVDevice +from pyavd._cv.workflows.verify_devices_on_cv import verify_devices_in_cloudvision_inventory +from pyavd._errors import AristaAvdError +from pyavd._utils import get if TYPE_CHECKING: - from ...._cv.api.arista.swg.v1 import Location, VpnEndpoint + from pyavd._cv.api.arista.swg.v1 import Location, VpnEndpoint + from . import AvdStructuredConfigNetworkServices LOGGER = getLogger(__name__) @@ -24,7 +25,8 @@ class UtilsZscalerMixin: """ Mixin Class with internal functions. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property @@ -40,7 +42,7 @@ def _zscaler_endpoints(self: AvdStructuredConfigNetworkServices) -> dict: return asyncio.run(self._generate_zscaler_endpoints()) or {} - async def _generate_zscaler_endpoints(self: AvdStructuredConfigNetworkServices): + async def _generate_zscaler_endpoints(self: AvdStructuredConfigNetworkServices) -> dict: """ Call CloudVision SWG APIs to generate the zscaler_endpoints model. @@ -64,16 +66,21 @@ async def _generate_zscaler_endpoints(self: AvdStructuredConfigNetworkServices): async with CVClient(servers=[cv_server], token=cv_token) as cv_client: cv_device = CVDevice(self.shared_utils.hostname, self.shared_utils.serial_number, self.shared_utils.system_mac_address) cv_inventory_devices: list[CVDevice] = await verify_devices_in_cloudvision_inventory( - devices=[cv_device], skip_missing_devices=True, warnings=[], cv_client=cv_client + devices=[cv_device], + skip_missing_devices=True, + warnings=[], + cv_client=cv_client, ) if not cv_inventory_devices: - raise AristaAvdError(f"{context} but could not find '{self.shared_utils.hostname}' on the server '{cv_server}'.") + msg = f"{context} but could not find '{self.shared_utils.hostname}' on the server '{cv_server}'." + raise AristaAvdError(msg) if len(cv_inventory_devices) > 1: + msg = ( + f"{context} but found more than one device named '{self.shared_utils.hostname}' on the server '{cv_server}'. " + "Set 'serial_number' for the device in AVD vars, to ensure a unique match." + ) raise AristaAvdError( - ( - f"{context} but found more than one device named '{self.shared_utils.hostname}' on the server '{cv_server}'. " - "Set 'serial_number' for the device in AVD vars, to ensure a unique match." - ) + msg, ) device_id: str = cv_inventory_devices[0].serial_number request_time, _ = await cv_client.set_swg_device(device_id=device_id, service="zscaler", location=wan_site_location) @@ -89,7 +96,8 @@ async def _generate_zscaler_endpoints(self: AvdStructuredConfigNetworkServices): }, } if not getattr(cv_endpoint_status, "vpn_endpoints", None) or not getattr(cv_endpoint_status.vpn_endpoints, "values", None): - raise AristaAvdError(f"{context} but did not get any IPsec Tunnel endpoints back from the Zscaler API.") + msg = f"{context} but did not get any IPsec Tunnel endpoints back from the Zscaler API." + raise AristaAvdError(msg) for key in ("primary", "secondary", "tertiary"): if key in cv_endpoint_status.vpn_endpoints.values: diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/virtual_source_nat_vrfs.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/virtual_source_nat_vrfs.py index f339276b836..f5615fbc82c 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/virtual_source_nat_vrfs.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/virtual_source_nat_vrfs.py @@ -6,7 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import append_if_not_duplicate +from pyavd._utils import append_if_not_duplicate + from .utils import UtilsMixin if TYPE_CHECKING: @@ -16,13 +17,14 @@ class VirtualSourceNatVrfsMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def virtual_source_nat_vrfs(self: AvdStructuredConfigNetworkServices) -> list | None: """ - Return structured config for virtual_source_nat_vrfs + Return structured config for virtual_source_nat_vrfs. Only used by VTEPs with L2 and L3 services Using data from loopback_interfaces to avoid duplicating logic diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/vlan_interfaces.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/vlan_interfaces.py index 9881b7914ae..268b5238ce7 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/vlan_interfaces.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/vlan_interfaces.py @@ -6,8 +6,9 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._errors import AristaAvdMissingVariableError -from ...._utils import append_if_not_duplicate, default, get, strip_empties_from_dict +from pyavd._errors import AristaAvdMissingVariableError +from pyavd._utils import append_if_not_duplicate, default, get, strip_empties_from_dict + from .utils import UtilsMixin if TYPE_CHECKING: @@ -17,17 +18,17 @@ class VlanInterfacesMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def vlan_interfaces(self: AvdStructuredConfigNetworkServices) -> list | None: """ - Return structured config for vlan_interfaces + Return structured config for vlan_interfaces. Consist of svis and mlag peering vlans from filtered tenants """ - if not (self.shared_utils.network_services_l2 and self.shared_utils.network_services_l3): return None @@ -65,20 +66,21 @@ def vlan_interfaces(self: AvdStructuredConfigNetworkServices) -> list | None: return None - def _get_vlan_interface_config_for_svi(self: AvdStructuredConfigNetworkServices, svi, vrf) -> dict: - def _check_virtual_router_mac_address(vlan_interface_config: dict, variables: list): + def _get_vlan_interface_config_for_svi(self: AvdStructuredConfigNetworkServices, svi: dict, vrf: dict) -> dict: + def _check_virtual_router_mac_address(vlan_interface_config: dict, variables: list) -> None: """ + Error if virtual router mac address is required but missing. + Check if any variable in the list of variables is not None in vlan_interface_config - and if it is the case, raise an Exception if virtual_router_mac_address is None + and if it is the case, raise an Exception if virtual_router_mac_address is None. NOTE: SVI settings are also used for subinterfaces for uplink_type: 'lan'. So any changes here may also be needed in underlay.utils.UtilsMixin._get_l2_as_subint(). """ if any(vlan_interface_config.get(var) for var in variables) and self.shared_utils.virtual_router_mac_address is None: quoted_vars = [f"'{var}'" for var in variables] - raise AristaAvdMissingVariableError( - f"'virtual_router_mac_address' must be set for node '{self.shared_utils.hostname}' when using {' or '.join(quoted_vars)} under 'svi'" - ) + msg = f"'virtual_router_mac_address' must be set for node '{self.shared_utils.hostname}' when using {' or '.join(quoted_vars)} under 'svi'" + raise AristaAvdMissingVariableError(msg) interface_name = f"Vlan{svi['id']}" vlan_interface_config = { @@ -116,10 +118,11 @@ def _check_virtual_router_mac_address(vlan_interface_config: dict, variables: li if "ip_address_virtual" in vlan_interface_config: if (vrf_diagnostic_loopback := get(vrf, "vtep_diagnostic.loopback")) is None: - raise AristaAvdMissingVariableError( + msg = ( f"No vtep_diagnostic loopback defined on VRF '{vrf['name']}' in Tenant '{svi['tenant']}'." "This is required when 'l3_multicast' is enabled on the VRF and ip_address_virtual is used on an SVI in that VRF." ) + raise AristaAvdMissingVariableError(msg) pim_config_ipv4["local_interface"] = f"Loopback{vrf_diagnostic_loopback}" if pim_config_ipv4: @@ -143,7 +146,7 @@ def _check_virtual_router_mac_address(vlan_interface_config: dict, variables: li if vlan_interface_config.get("ipv6_address_virtuals"): # If any anycast IPs are set, we also enable link-local IPv6 per best practice, unless specifically disabled with 'ipv6_enable: false' - vlan_interface_config["ipv6_enable"] = default(vlan_interface_config["ipv6_enable"], True) + vlan_interface_config["ipv6_enable"] = default(vlan_interface_config["ipv6_enable"], True) # noqa: FBT003 if vrf["name"] != "default": vlan_interface_config["vrf"] = vrf["name"] @@ -153,10 +156,11 @@ def _check_virtual_router_mac_address(vlan_interface_config: dict, variables: li return strip_empties_from_dict(vlan_interface_config) - def _get_vlan_interface_config_for_mlag_peering(self: AvdStructuredConfigNetworkServices, vrf) -> dict: + def _get_vlan_interface_config_for_mlag_peering(self: AvdStructuredConfigNetworkServices, vrf: dict) -> dict: """ Build config for MLAG peering SVI for the given SVI. - Called from vlan_interfaces and prefix_lists + + Called from vlan_interfaces and prefix_lists. """ vlan_interface_config = { "tenant": vrf["tenant"], diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/vlans.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/vlans.py index 80da7e8c09f..1e74c6623d7 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/vlans.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/vlans.py @@ -6,8 +6,9 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import append_if_not_duplicate -from ....j2filters import natural_sort +from pyavd._utils import append_if_not_duplicate +from pyavd.j2filters import natural_sort + from .utils import UtilsMixin if TYPE_CHECKING: @@ -17,7 +18,8 @@ class VlansMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property @@ -30,7 +32,6 @@ def vlans(self: AvdStructuredConfigNetworkServices) -> list | None: This function also detects duplicate vlans and raise an error in case of duplicates between SVIs in all VRFs and L2VLANs deployed on this device. """ - if not self.shared_utils.network_services_l2: return None @@ -72,7 +73,12 @@ def vlans(self: AvdStructuredConfigNetworkServices) -> list | None: for l2vlan in tenant["l2vlans"]: vlan = self._get_vlan_config(l2vlan) append_if_not_duplicate( - list_of_dicts=vlans, primary_key="id", new_dict=vlan, context="L2VLANs", context_keys=["id", "name", "tenant"], ignore_keys={"tenant"} + list_of_dicts=vlans, + primary_key="id", + new_dict=vlan, + context="L2VLANs", + context_keys=["id", "name", "tenant"], + ignore_keys={"tenant"}, ) if vlans: @@ -80,9 +86,9 @@ def vlans(self: AvdStructuredConfigNetworkServices) -> list | None: return None - def _get_vlan_config(self: AvdStructuredConfigNetworkServices, vlan) -> dict: + def _get_vlan_config(self: AvdStructuredConfigNetworkServices, vlan: dict) -> dict: """ - Return structured config for one given vlan + Return structured config for one given vlan. Can be used for svis and l2vlans """ diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/vrfs.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/vrfs.py index 9f02180eadf..11d269c353f 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/vrfs.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/vrfs.py @@ -6,7 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import append_if_not_duplicate +from pyavd._utils import append_if_not_duplicate + from .utils import UtilsMixin if TYPE_CHECKING: @@ -16,7 +17,8 @@ class VrfsMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property @@ -29,7 +31,6 @@ def vrfs(self: AvdStructuredConfigNetworkServices) -> list | None: This function also detects duplicate vrfs and raise an error in case of duplicates between all Tenants deployed on this device. """ - if not self.shared_utils.network_services_l3: return None @@ -72,7 +73,7 @@ def vrfs(self: AvdStructuredConfigNetworkServices) -> list | None: return None - def _has_ipv6(self: AvdStructuredConfigNetworkServices, vrf) -> bool: + def _has_ipv6(self: AvdStructuredConfigNetworkServices, vrf: dict) -> bool: """ Return bool if IPv6 is configured in the given VRF. diff --git a/python-avd/pyavd/_eos_designs/structured_config/network_services/vxlan_interface.py b/python-avd/pyavd/_eos_designs/structured_config/network_services/vxlan_interface.py index f70c552f647..f11ea2fe992 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/network_services/vxlan_interface.py +++ b/python-avd/pyavd/_eos_designs/structured_config/network_services/vxlan_interface.py @@ -6,9 +6,10 @@ from functools import cached_property from typing import TYPE_CHECKING, NoReturn -from ...._errors import AristaAvdError, AristaAvdMissingVariableError -from ...._utils import append_if_not_duplicate, default, get, get_item, unique -from ....j2filters import natural_sort, range_expand +from pyavd._errors import AristaAvdError, AristaAvdMissingVariableError +from pyavd._utils import append_if_not_duplicate, default, get, get_item, unique +from pyavd.j2filters import natural_sort, range_expand + from .utils import UtilsMixin if TYPE_CHECKING: @@ -18,13 +19,14 @@ class VxlanInterfaceMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def vxlan_interface(self: AvdStructuredConfigNetworkServices) -> dict | None: """ - Returns structured config for vxlan_interface + Returns structured config for vxlan_interface. Only used for VTEPs and for WAN @@ -117,13 +119,11 @@ def vxlan_interface(self: AvdStructuredConfigNetworkServices) -> dict | None: "Vxlan1": { "description": f"{self.shared_utils.hostname}_VTEP", "vxlan": vxlan, - } + }, } def _get_vxlan_interface_config_for_vrf(self: AvdStructuredConfigNetworkServices, vrf: dict, tenant: dict, vrfs: list, vlans: list, vnis: list) -> None: - """ - In place updates of the vlans, vnis and vrfs list - """ + """In place updates of the vlans, vnis and vrfs list.""" if self.shared_utils.network_services_l2: for svi in vrf["svis"]: if vlan := self._get_vxlan_interface_config_for_vlan(svi, tenant): @@ -169,7 +169,7 @@ def _get_vxlan_interface_config_for_vrf(self: AvdStructuredConfigNetworkServices ) # NOTE: this can never be None here, it would be caught previously in the code - id = default( + vrf_id = default( vrf.get("vrf_id"), vrf.get("vrf_vni"), ) @@ -187,7 +187,10 @@ def _get_vxlan_interface_config_for_vrf(self: AvdStructuredConfigNetworkServices ) underlay_l3_mcast_group_ipv4_pool_offset = get(tenant, "evpn_l3_multicast.evpn_underlay_l3_multicast_group_ipv4_pool_offset", default=0) vrf_data["multicast_group"] = self.shared_utils.ip_addressing.evpn_underlay_l3_multicast_group( - underlay_l3_multicast_group_ipv4_pool, vni, id, underlay_l3_mcast_group_ipv4_pool_offset + underlay_l3_multicast_group_ipv4_pool, + vni, + vrf_id, + underlay_l3_mcast_group_ipv4_pool_offset, ) # Duplicate check is not done on the actual list of vlans, but instead on our local "vnis" list. @@ -208,9 +211,9 @@ def _get_vxlan_interface_config_for_vrf(self: AvdStructuredConfigNetworkServices context_keys=["name", "vni"], ) - def _get_vxlan_interface_config_for_vlan(self: AvdStructuredConfigNetworkServices, vlan, tenant) -> dict: + def _get_vxlan_interface_config_for_vlan(self: AvdStructuredConfigNetworkServices, vlan: dict, tenant: dict) -> dict: """ - vxlan_interface logic for one vlan + vxlan_interface logic for one vlan. Can be used for both svis and l2vlans """ @@ -237,7 +240,9 @@ def _get_vxlan_interface_config_for_vlan(self: AvdStructuredConfigNetworkService ) underlay_l2_multicast_group_ipv4_pool_offset = get(tenant, "evpn_l2_multicast.underlay_l2_multicast_group_ipv4_pool_offset", default=0) vxlan_interface_vlan["multicast_group"] = self.shared_utils.ip_addressing.evpn_underlay_l2_multicast_group( - underlay_l2_multicast_group_ipv4_pool, vlan_id, underlay_l2_multicast_group_ipv4_pool_offset + underlay_l2_multicast_group_ipv4_pool, + vlan_id, + underlay_l2_multicast_group_ipv4_pool_offset, ) if self.shared_utils.overlay_her and self._overlay_her_flood_list_per_vni: @@ -270,7 +275,8 @@ def _overlay_her_flood_lists(self: AvdStructuredConfigNetworkServices) -> dict[l overlay_her_flood_list_scope = get(self._hostvars, "overlay_her_flood_list_scope") if overlay_her_flood_list_scope == "dc" and self.shared_utils.dc_name is None: - raise AristaAvdMissingVariableError("'dc_name' is required with 'overlay_her_flood_list_scope: dc'") + msg = "'dc_name' is required with 'overlay_her_flood_list_scope: dc'" + raise AristaAvdMissingVariableError(msg) for peer in self.shared_utils.all_fabric_devices: if peer == self.shared_utils.hostname: diff --git a/python-avd/pyavd/_eos_designs/structured_config/overlay/__init__.py b/python-avd/pyavd/_eos_designs/structured_config/overlay/__init__.py index f21e7ddd076..ae56a650d21 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/overlay/__init__.py +++ b/python-avd/pyavd/_eos_designs/structured_config/overlay/__init__.py @@ -1,7 +1,8 @@ # Copyright (c) 2023-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -from ...avdfacts import AvdFacts +from pyavd._eos_designs.avdfacts import AvdFacts + from .cvx import CvxMixin from .ip_extcommunity_lists import IpExtCommunityListsMixin from .ip_security import IpSecurityMixin @@ -46,11 +47,12 @@ class AvdStructuredConfigOverlay( def render(self) -> dict: """ - Wrap class render function with a check if one of the following vars are True: + Wrap class render function with a check if one of the following vars are True. + - overlay_cvx - overlay_evpn - overlay_vpn_ipv4 - - overlay_vpn_ipv6 + - overlay_vpn_ipv6. """ if any( [ @@ -59,7 +61,7 @@ def render(self) -> dict: self.shared_utils.overlay_vpn_ipv4, self.shared_utils.overlay_vpn_ipv6, self.shared_utils.is_wan_router, - ] + ], ): return super().render() return {} diff --git a/python-avd/pyavd/_eos_designs/structured_config/overlay/cvx.py b/python-avd/pyavd/_eos_designs/structured_config/overlay/cvx.py index c32a8e77649..daf9ecb5cc4 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/overlay/cvx.py +++ b/python-avd/pyavd/_eos_designs/structured_config/overlay/cvx.py @@ -7,7 +7,8 @@ from ipaddress import ip_interface from typing import TYPE_CHECKING -from ...._utils import get +from pyavd._utils import get + from .utils import UtilsMixin if TYPE_CHECKING: @@ -17,14 +18,13 @@ class CvxMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def cvx(self: AvdStructuredConfigOverlay) -> dict | None: - """ - Detect if this is a CVX server for overlay and configure service & peer hosts accordingly. - """ + """Detect if this is a CVX server for overlay and configure service & peer hosts accordingly.""" if not self.shared_utils.overlay_cvx: return None diff --git a/python-avd/pyavd/_eos_designs/structured_config/overlay/ip_extcommunity_lists.py b/python-avd/pyavd/_eos_designs/structured_config/overlay/ip_extcommunity_lists.py index 7111bee46d4..26b3855858e 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/overlay/ip_extcommunity_lists.py +++ b/python-avd/pyavd/_eos_designs/structured_config/overlay/ip_extcommunity_lists.py @@ -15,14 +15,13 @@ class IpExtCommunityListsMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def ip_extcommunity_lists(self: AvdStructuredConfigOverlay) -> list | None: - """ - Return structured config for ip_extcommunity_lists - """ + """Return structured config for ip_extcommunity_lists.""" if self.shared_utils.overlay_routing_protocol != "ibgp": return None @@ -39,7 +38,7 @@ def ip_extcommunity_lists(self: AvdStructuredConfigOverlay) -> list | None: "extcommunities": f"soo {self.shared_utils.evpn_soo}", }, ], - } + }, ] return None diff --git a/python-avd/pyavd/_eos_designs/structured_config/overlay/ip_security.py b/python-avd/pyavd/_eos_designs/structured_config/overlay/ip_security.py index 247a92301dd..7b6584d31ae 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/overlay/ip_security.py +++ b/python-avd/pyavd/_eos_designs/structured_config/overlay/ip_security.py @@ -6,7 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import get, strip_null_from_data +from pyavd._utils import get, strip_null_from_data + from .utils import UtilsMixin if TYPE_CHECKING: @@ -16,18 +17,19 @@ class IpSecurityMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def ip_security(self: AvdStructuredConfigOverlay) -> dict | None: """ - ip_security set based on wan_ipsec_profiles data_model + ip_security set based on wan_ipsec_profiles data_model. If `data_plane` is not configured, `control_plane` data is used for both Data Plane and Control Plane. """ - # TODO - in future, the default algo/dh groups value must be clarified + # TODO: - in future, the default algo/dh groups value must be clarified if not self.shared_utils.is_wan_router: return None @@ -45,13 +47,8 @@ def ip_security(self: AvdStructuredConfigOverlay) -> dict | None: return strip_null_from_data(ip_security) def _append_data_plane(self: AvdStructuredConfigOverlay, ip_security: dict, data_plane_config: dict) -> None: - """ - In place update of ip_security - """ - if self.shared_utils.wan_ha_ipsec: - ike_policy_name = get(data_plane_config, "ike_policy_name", default="DP-IKE-POLICY") - else: - ike_policy_name = None + """In place update of ip_security.""" + ike_policy_name = get(data_plane_config, "ike_policy_name", default="DP-IKE-POLICY") if self.shared_utils.wan_ha_ipsec else None sa_policy_name = get(data_plane_config, "sa_policy_name", default="DP-SA-POLICY") profile_name = get(data_plane_config, "profile_name", default="DP-PROFILE") key = get(data_plane_config, "shared_key", required=True) @@ -67,7 +64,7 @@ def _append_data_plane(self: AvdStructuredConfigOverlay, ip_security: dict, data def _append_control_plane(self: AvdStructuredConfigOverlay, ip_security: dict, control_plane_config: dict) -> None: """ - In place update of ip_security for control plane data + In place update of ip_security for control plane data. expected to be called AFTER _append_data_plane """ @@ -85,9 +82,7 @@ def _append_control_plane(self: AvdStructuredConfigOverlay, ip_security: dict, c ip_security["key_controller"] = self._key_controller(profile_name) def _ike_policy(self: AvdStructuredConfigOverlay, name: str) -> dict | None: - """ - Return an IKE policy - """ + """Return an IKE policy.""" return { "name": name, "local_id": self.shared_utils.vtep_ip, @@ -95,20 +90,20 @@ def _ike_policy(self: AvdStructuredConfigOverlay, name: str) -> dict | None: def _sa_policy(self: AvdStructuredConfigOverlay, name: str) -> dict | None: """ - Return an SA policy + Return an SA policy. By default using aes256gcm128 as GCM variants give higher performance. """ sa_policy = {"name": name} if self.shared_utils.is_cv_pathfinder_router: - # TODO, provide options to change this cv_pathfinder_wide + # TODO: provide options to change this cv_pathfinder_wide sa_policy["esp"] = {"encryption": "aes256gcm128"} sa_policy["pfs_dh_group"] = 14 return sa_policy def _profile(self: AvdStructuredConfigOverlay, profile_name: str, ike_policy_name: str | None, sa_policy_name: str, key: str) -> dict | None: """ - Return one IPsec Profile + Return one IPsec Profile. The expectation is that potential None values are stripped later. @@ -130,9 +125,7 @@ def _profile(self: AvdStructuredConfigOverlay, profile_name: str, ike_policy_nam } def _key_controller(self: AvdStructuredConfigOverlay, profile_name: str) -> dict | None: - """ - Return a key_controller structure if the device is not a RR or pathfinder - """ + """Return a key_controller structure if the device is not a RR or pathfinder.""" if self.shared_utils.is_wan_server: return None diff --git a/python-avd/pyavd/_eos_designs/structured_config/overlay/management_cvx.py b/python-avd/pyavd/_eos_designs/structured_config/overlay/management_cvx.py index 525228122d7..a54fc734e9a 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/overlay/management_cvx.py +++ b/python-avd/pyavd/_eos_designs/structured_config/overlay/management_cvx.py @@ -7,7 +7,8 @@ from ipaddress import ip_interface from typing import TYPE_CHECKING -from ...._utils import get +from pyavd._utils import get + from .utils import UtilsMixin if TYPE_CHECKING: @@ -17,7 +18,8 @@ class ManagementCvxMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property diff --git a/python-avd/pyavd/_eos_designs/structured_config/overlay/management_security.py b/python-avd/pyavd/_eos_designs/structured_config/overlay/management_security.py index e309c209f2b..54293c7f8ce 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/overlay/management_security.py +++ b/python-avd/pyavd/_eos_designs/structured_config/overlay/management_security.py @@ -15,7 +15,8 @@ class ManagementSecurityMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property @@ -38,6 +39,6 @@ def management_security(self: AvdStructuredConfigOverlay) -> dict | None: }, "trust_certificate": {"certificates": ["aristaDeviceCertProvisionerDefaultRootCA.crt"]}, "tls_versions": "1.2", - } - ] + }, + ], } diff --git a/python-avd/pyavd/_eos_designs/structured_config/overlay/route_maps.py b/python-avd/pyavd/_eos_designs/structured_config/overlay/route_maps.py index b28dd9b7f0f..786a6544e60 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/overlay/route_maps.py +++ b/python-avd/pyavd/_eos_designs/structured_config/overlay/route_maps.py @@ -6,7 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ....j2filters import natural_sort +from pyavd.j2filters import natural_sort + from .utils import UtilsMixin if TYPE_CHECKING: @@ -16,15 +17,13 @@ class RouteMapsMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def route_maps(self: AvdStructuredConfigOverlay) -> list | None: - """ - Return structured config for route_maps - """ - + """Return structured config for route_maps.""" if self.shared_utils.overlay_cvx: return None @@ -32,7 +31,7 @@ def route_maps(self: AvdStructuredConfigOverlay) -> list | None: if self.shared_utils.overlay_routing_protocol == "ebgp": if self.shared_utils.evpn_prevent_readvertise_to_server is True: - remote_asns = natural_sort(set(rs_dict.get("bgp_as") for route_server, rs_dict in self._evpn_route_servers.items())) + remote_asns = natural_sort({rs_dict.get("bgp_as") for route_server, rs_dict in self._evpn_route_servers.items()}) for remote_asn in remote_asns: route_map_name = f"RM-EVPN-FILTER-AS{remote_asn}" route_maps.append( @@ -49,7 +48,7 @@ def route_maps(self: AvdStructuredConfigOverlay) -> list | None: "type": "permit", }, ], - } + }, ) elif self.shared_utils.overlay_routing_protocol == "ibgp" and self.shared_utils.overlay_vtep and self.shared_utils.evpn_role != "server": @@ -68,7 +67,7 @@ def route_maps(self: AvdStructuredConfigOverlay) -> list | None: "type": "permit", }, ], - } + }, ) route_maps.append( @@ -81,7 +80,7 @@ def route_maps(self: AvdStructuredConfigOverlay) -> list | None: "set": [f"extcommunity soo {self.shared_utils.evpn_soo} additive"], }, ], - } + }, ) if self.shared_utils.wan_ha: @@ -96,7 +95,7 @@ def route_maps(self: AvdStructuredConfigOverlay) -> list | None: "set": ["tag 50"], }, ], - } + }, ) route_maps.append( { @@ -116,7 +115,7 @@ def route_maps(self: AvdStructuredConfigOverlay) -> list | None: "set": ["local-preference 75"], }, ], - } + }, ) if route_maps: diff --git a/python-avd/pyavd/_eos_designs/structured_config/overlay/router_adaptive_virtual_topology.py b/python-avd/pyavd/_eos_designs/structured_config/overlay/router_adaptive_virtual_topology.py index fcaffb3db9b..9d712f7272c 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/overlay/router_adaptive_virtual_topology.py +++ b/python-avd/pyavd/_eos_designs/structured_config/overlay/router_adaptive_virtual_topology.py @@ -15,14 +15,13 @@ class RouterAdaptiveVirtualTopologyMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def router_adaptive_virtual_topology(self: AvdStructuredConfigOverlay) -> dict | None: - """ - Return structured config for router adaptive-virtual-topology (AVT) - """ + """Return structured config for router adaptive-virtual-topology (AVT).""" if not self.shared_utils.is_cv_pathfinder_router: return None diff --git a/python-avd/pyavd/_eos_designs/structured_config/overlay/router_bfd.py b/python-avd/pyavd/_eos_designs/structured_config/overlay/router_bfd.py index 9fc0150d081..45ab29b62cb 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/overlay/router_bfd.py +++ b/python-avd/pyavd/_eos_designs/structured_config/overlay/router_bfd.py @@ -6,7 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import strip_empties_from_dict +from pyavd._utils import strip_empties_from_dict + from .utils import UtilsMixin if TYPE_CHECKING: @@ -16,14 +17,13 @@ class RouterBfdMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def router_bfd(self: AvdStructuredConfigOverlay) -> dict | None: - """ - return structured config for router_bfd - """ + """Return structured config for router_bfd.""" if self.shared_utils.bfd_multihop is None: return None diff --git a/python-avd/pyavd/_eos_designs/structured_config/overlay/router_bgp.py b/python-avd/pyavd/_eos_designs/structured_config/overlay/router_bgp.py index 3ffcf68427e..620baac3058 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/overlay/router_bgp.py +++ b/python-avd/pyavd/_eos_designs/structured_config/overlay/router_bgp.py @@ -7,9 +7,10 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._errors import AristaAvdError -from ...._utils import default, get, get_item, strip_empties_from_dict -from ....j2filters import natural_sort +from pyavd._errors import AristaAvdError +from pyavd._utils import default, get, get_item, strip_empties_from_dict +from pyavd.j2filters import natural_sort + from .utils import UtilsMixin if TYPE_CHECKING: @@ -19,14 +20,13 @@ class RouterBgpMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def router_bgp(self: AvdStructuredConfigOverlay) -> dict | None: - """ - Return the structured config for router_bgp - """ + """Return the structured config for router_bgp.""" if self.shared_utils.overlay_cvx: return None @@ -50,15 +50,14 @@ def router_bgp(self: AvdStructuredConfigOverlay) -> dict | None: return strip_empties_from_dict(router_bgp, strip_values_tuple=(None, "")) def _bgp_cluster_id(self: AvdStructuredConfigOverlay) -> str | None: - if self.shared_utils.overlay_routing_protocol == "ibgp": - if self.shared_utils.evpn_role == "server" or self.shared_utils.mpls_overlay_role == "server": - return get(self.shared_utils.switch_data_combined, "bgp_cluster_id", default=self.shared_utils.router_id) + if self.shared_utils.overlay_routing_protocol == "ibgp" and ( + self.shared_utils.evpn_role == "server" or self.shared_utils.mpls_overlay_role == "server" + ): + return get(self.shared_utils.switch_data_combined, "bgp_cluster_id", default=self.shared_utils.router_id) return None def _bgp_listen_ranges(self: AvdStructuredConfigOverlay) -> list | None: - """ - Generate listen-ranges. Currently only supported for WAN RR. - """ + """Generate listen-ranges. Currently only supported for WAN RR.""" if not self.shared_utils.is_wan_server: return None @@ -72,7 +71,11 @@ def _bgp_listen_ranges(self: AvdStructuredConfigOverlay) -> list | None: ] or None def _generate_base_peer_group( - self: AvdStructuredConfigOverlay, pg_type: str, pg_name: str, maximum_routes: int = 0, update_source: str = "Loopback0" + self: AvdStructuredConfigOverlay, + pg_type: str, + pg_name: str, + maximum_routes: int = 0, + update_source: str = "Loopback0", ) -> dict: return { "name": self.shared_utils.bgp_peer_groups[pg_name]["name"], @@ -86,7 +89,6 @@ def _generate_base_peer_group( } def _peer_groups(self: AvdStructuredConfigOverlay) -> list | None: - """ """ peer_groups = [] if self.shared_utils.overlay_routing_protocol == "ebgp": @@ -106,7 +108,7 @@ def _peer_groups(self: AvdStructuredConfigOverlay) -> list | None: { **self._generate_base_peer_group("evpn", "evpn_overlay_core"), "ebgp_multihop": self.shared_utils.evpn_ebgp_gateway_multihop, - } + }, ) elif self.shared_utils.overlay_routing_protocol == "ibgp": @@ -134,7 +136,7 @@ def _peer_groups(self: AvdStructuredConfigOverlay) -> list | None: { **self._generate_base_peer_group("wan", "wan_overlay_peers", update_source=self.shared_utils.vtep_loopback), **peer_group_config, - } + }, ) else: # EVPN OVERLAY peer group - also in EBGP.. @@ -144,7 +146,7 @@ def _peer_groups(self: AvdStructuredConfigOverlay) -> list | None: { **self._generate_base_peer_group("evpn", "evpn_overlay_peers"), **peer_group_config, - } + }, ) # RR Overlay peer group rendered either for MPLS route servers @@ -159,7 +161,7 @@ def _peer_groups(self: AvdStructuredConfigOverlay) -> list | None: "ttl_maximum_hops": self.shared_utils.bgp_peer_groups["wan_rr_overlay_peers"]["ttl_maximum_hops"], "bfd_timers": get(self.shared_utils.bgp_peer_groups["wan_rr_overlay_peers"], "bfd_timers"), "route_reflector_client": True, - } + }, ) peer_groups.append(wan_rr_overlay_peer_group) @@ -170,27 +172,26 @@ def _peer_groups(self: AvdStructuredConfigOverlay) -> list | None: **self._generate_base_peer_group("mpls", "ipvpn_gateway_peers"), "local_as": self._ipvpn_gateway_local_as, "maximum_routes": get(self.shared_utils.switch_data_combined, "ipvpn_gateway.maximum_routes", default=0), - } + }, ) return peer_groups def _address_family_ipv4(self: AvdStructuredConfigOverlay) -> dict: - """ - deactivate the relevant peer_groups in address_family_ipv4 - """ + """Deactivate the relevant peer_groups in address_family_ipv4.""" peer_groups = [] if self.shared_utils.is_wan_router: peer_groups.append({"name": self.shared_utils.bgp_peer_groups["wan_overlay_peers"]["name"], "activate": False}) - # TODO no elif + # TODO: no elif elif self.shared_utils.overlay_evpn_vxlan is True: peer_groups.append({"name": self.shared_utils.bgp_peer_groups["evpn_overlay_peers"]["name"], "activate": False}) - if self.shared_utils.overlay_routing_protocol == "ebgp": - if self.shared_utils.evpn_gateway_vxlan_l2 is True or self.shared_utils.evpn_gateway_vxlan_l3 is True: - peer_groups.append({"name": self.shared_utils.bgp_peer_groups["evpn_overlay_core"]["name"], "activate": False}) + if self.shared_utils.overlay_routing_protocol == "ebgp" and ( + self.shared_utils.evpn_gateway_vxlan_l2 is True or self.shared_utils.evpn_gateway_vxlan_l3 is True + ): + peer_groups.append({"name": self.shared_utils.bgp_peer_groups["evpn_overlay_core"]["name"], "activate": False}) if self.shared_utils.overlay_routing_protocol == "ibgp": if self.shared_utils.overlay_mpls is True: @@ -208,7 +209,6 @@ def _address_family_ipv4(self: AvdStructuredConfigOverlay) -> dict: return {"peer_groups": peer_groups} def _address_family_evpn(self: AvdStructuredConfigOverlay) -> dict: - """ """ address_family_evpn = {} peer_groups = [] @@ -227,7 +227,7 @@ def _address_family_evpn(self: AvdStructuredConfigOverlay) -> dict: "name": self.shared_utils.bgp_peer_groups["evpn_overlay_core"]["name"], "domain_remote": True, "activate": True, - } + }, ) if self.shared_utils.evpn_gateway_vxlan_l3 is True: @@ -235,11 +235,11 @@ def _address_family_evpn(self: AvdStructuredConfigOverlay) -> dict: "next_hop_self_received_evpn_routes": { "enable": True, "inter_domain": self.shared_utils.evpn_gateway_vxlan_l3_inter_domain, - } + }, } if self.shared_utils.overlay_routing_protocol == "ibgp": - # TODO - assess this condition + # TODO: - assess this condition if self.shared_utils.overlay_evpn_mpls is True and self.shared_utils.overlay_evpn_vxlan is not True: overlay_peer_group_name = self.shared_utils.bgp_peer_groups["mpls_overlay_peers"]["name"] peer_groups.append({"name": overlay_peer_group_name, "activate": True}) @@ -248,14 +248,17 @@ def _address_family_evpn(self: AvdStructuredConfigOverlay) -> dict: address_family_evpn["neighbor_default"]["next_hop_self_source_interface"] = "Loopback0" # partly duplicate with ebgp - if self.shared_utils.overlay_vtep is True and self.shared_utils.evpn_role != "server": - if (peer_group := get_item(peer_groups, "name", overlay_peer_group_name)) is not None: - peer_group.update( - { - "route_map_in": "RM-EVPN-SOO-IN", - "route_map_out": "RM-EVPN-SOO-OUT", - } - ) + if ( + self.shared_utils.overlay_vtep is True + and self.shared_utils.evpn_role != "server" + and (peer_group := get_item(peer_groups, "name", overlay_peer_group_name)) is not None + ): + peer_group.update( + { + "route_map_in": "RM-EVPN-SOO-IN", + "route_map_out": "RM-EVPN-SOO-OUT", + }, + ) if self._is_mpls_server is True: peer_groups.append({"name": self.shared_utils.bgp_peer_groups["rr_overlay_peers"]["name"], "activate": True}) @@ -290,14 +293,14 @@ def _address_family_evpn(self: AvdStructuredConfigOverlay) -> dict: address_family_evpn["neighbor_default"] = { "next_hop_self_received_evpn_routes": { "enable": True, - } + }, } address_family_evpn["neighbors"] = [{"ip_address": self._wan_ha_peer_vtep_ip(), "activate": True}] return address_family_evpn def _address_family_ipv4_sr_te(self: AvdStructuredConfigOverlay) -> dict | None: - """Generate structured config for IPv4 SR-TE address family""" + """Generate structured config for IPv4 SR-TE address family.""" if not self.shared_utils.is_cv_pathfinder_router: return None @@ -306,7 +309,7 @@ def _address_family_ipv4_sr_te(self: AvdStructuredConfigOverlay) -> dict | None: { "name": self.shared_utils.bgp_peer_groups["wan_overlay_peers"]["name"], "activate": True, - } + }, ], } @@ -316,7 +319,7 @@ def _address_family_ipv4_sr_te(self: AvdStructuredConfigOverlay) -> dict | None: return address_family_ipv4_sr_te def _address_family_link_state(self: AvdStructuredConfigOverlay) -> dict | None: - """Generate structured config for link-state address family""" + """Generate structured config for link-state address family.""" if not self.shared_utils.is_cv_pathfinder_router: return None @@ -325,7 +328,7 @@ def _address_family_link_state(self: AvdStructuredConfigOverlay) -> dict | None: { "name": self.shared_utils.bgp_peer_groups["wan_overlay_peers"]["name"], "activate": True, - } + }, ], } @@ -335,8 +338,8 @@ def _address_family_link_state(self: AvdStructuredConfigOverlay) -> dict | None: { "missing_policy": { "direction_out_action": "deny", - } - } + }, + }, ) else: # other roles are transit / edge address_family_link_state["path_selection"] = {"roles": {"producer": True}} @@ -347,7 +350,6 @@ def _address_family_link_state(self: AvdStructuredConfigOverlay) -> dict | None: return address_family_link_state def _address_family_path_selection(self: AvdStructuredConfigOverlay) -> dict | None: - """ """ if not self.shared_utils.is_wan_router: return None @@ -356,7 +358,7 @@ def _address_family_path_selection(self: AvdStructuredConfigOverlay) -> dict | N { "name": self.shared_utils.bgp_peer_groups["wan_overlay_peers"]["name"], "activate": True, - } + }, ], "bgp": {"additional_paths": {"receive": True, "send": {"any": True}}}, } @@ -368,8 +370,7 @@ def _address_family_path_selection(self: AvdStructuredConfigOverlay) -> dict | N def _address_family_rtc(self: AvdStructuredConfigOverlay) -> dict | None: """ - Activate EVPN OVERLAY peer group and EVPN OVERLAY CORE peer group (if present) - in address_family_rtc + Activate EVPN OVERLAY peer group and EVPN OVERLAY CORE peer group (if present) in address_family_rtc. if the evpn_role is server, enable default_route_target only """ @@ -386,7 +387,7 @@ def _address_family_rtc(self: AvdStructuredConfigOverlay) -> dict | None: if self.shared_utils.overlay_routing_protocol == "ebgp": if self.shared_utils.evpn_gateway_vxlan_l2 is True or self.shared_utils.evpn_gateway_vxlan_l3 is True: core_peer_group = {"name": self.shared_utils.bgp_peer_groups["evpn_overlay_core"]["name"], "activate": True} - # TODO (@Claus) told me to remove this + # TODO: (@Claus) told me to remove this if self.shared_utils.evpn_role == "server": core_peer_group["default_route_target"] = {"only": True} peer_groups.append(core_peer_group) @@ -403,9 +404,8 @@ def _address_family_rtc(self: AvdStructuredConfigOverlay) -> dict | None: mpls_peer_group["default_route_target"] = {"only": True} peer_groups.append(mpls_peer_group) - if self.shared_utils.overlay_evpn_vxlan is True: - if self.shared_utils.evpn_role == "server" or self.shared_utils.mpls_overlay_role == "server": - evpn_overlay_peers["default_route_target"] = {"only": True} + if self.shared_utils.overlay_evpn_vxlan is True and (self.shared_utils.evpn_role == "server" or self.shared_utils.mpls_overlay_role == "server"): + evpn_overlay_peers["default_route_target"] = {"only": True} peer_groups.append(evpn_overlay_peers) address_family_rtc["peer_groups"] = peer_groups @@ -414,7 +414,8 @@ def _address_family_rtc(self: AvdStructuredConfigOverlay) -> dict | None: def _address_family_vpn_ipvx(self: AvdStructuredConfigOverlay, version: int) -> dict | None: if version not in [4, 6]: - raise AristaAvdError("_address_family_vpn_ipvx should be called with version 4 or 6 only") + msg = "_address_family_vpn_ipvx should be called with version 4 or 6 only" + raise AristaAvdError(msg) if (version == 4 and self.shared_utils.overlay_vpn_ipv4 is not True) or (version == 6 and self.shared_utils.overlay_vpn_ipv6 is not True): return None @@ -445,12 +446,12 @@ def _address_family_vpn_ipvx(self: AvdStructuredConfigOverlay, version: int) -> return address_family_vpn_ipvx def _create_neighbor(self: AvdStructuredConfigOverlay, ip_address: str, name: str, peer_group: str, remote_as: str | None = None) -> dict: - """ """ neighbor = {"ip_address": ip_address, "peer_group": peer_group, "peer": name, "description": name} if self.shared_utils.overlay_routing_protocol == "ebgp": if remote_as is None: - raise AristaAvdError("Configuring eBGP neighbor without a remote_as") + msg = "Configuring eBGP neighbor without a remote_as" + raise AristaAvdError(msg) neighbor["remote_as"] = remote_as @@ -462,13 +463,15 @@ def _create_neighbor(self: AvdStructuredConfigOverlay, ip_address: str, name: st return neighbor def _neighbors(self: AvdStructuredConfigOverlay) -> list | None: - """ """ neighbors = [] if self.shared_utils.overlay_routing_protocol == "ebgp": for route_server, data in natural_sort(self._evpn_route_servers.items()): neighbor = self._create_neighbor( - data["ip_address"], route_server, self.shared_utils.bgp_peer_groups["evpn_overlay_peers"]["name"], remote_as=data["bgp_as"] + data["ip_address"], + route_server, + self.shared_utils.bgp_peer_groups["evpn_overlay_peers"]["name"], + remote_as=data["bgp_as"], ) if self.shared_utils.evpn_prevent_readvertise_to_server is True: @@ -477,13 +480,19 @@ def _neighbors(self: AvdStructuredConfigOverlay) -> list | None: for route_client, data in natural_sort(self._evpn_route_clients.items()): neighbor = self._create_neighbor( - data["ip_address"], route_client, self.shared_utils.bgp_peer_groups["evpn_overlay_peers"]["name"], remote_as=data["bgp_as"] + data["ip_address"], + route_client, + self.shared_utils.bgp_peer_groups["evpn_overlay_peers"]["name"], + remote_as=data["bgp_as"], ) neighbors.append(neighbor) for gw_remote_peer, data in natural_sort(self._evpn_gateway_remote_peers.items()): neighbor = self._create_neighbor( - data["ip_address"], gw_remote_peer, self.shared_utils.bgp_peer_groups["evpn_overlay_core"]["name"], remote_as=data["bgp_as"] + data["ip_address"], + gw_remote_peer, + self.shared_utils.bgp_peer_groups["evpn_overlay_core"]["name"], + remote_as=data["bgp_as"], ) neighbors.append(neighbor) @@ -517,10 +526,11 @@ def _neighbors(self: AvdStructuredConfigOverlay) -> list | None: if self.shared_utils.is_wan_client: if not self._ip_in_listen_ranges(self.shared_utils.vtep_ip, self.shared_utils.wan_listen_ranges): - raise AristaAvdError( + msg = ( f"{self.shared_utils.vtep_loopback} IP {self.shared_utils.vtep_ip} is not in the Route Reflector listen range prefixes" " 'bgp_peer_groups.wan_overlay_peers.listen_range_prefixes'." ) + raise AristaAvdError(msg) for wan_route_server, data in self.shared_utils.filtered_wan_route_servers.items(): neighbor = self._create_neighbor(data["vtep_ip"], wan_route_server, self.shared_utils.bgp_peer_groups["wan_overlay_peers"]["name"]) neighbors.append(neighbor) @@ -547,7 +557,10 @@ def _neighbors(self: AvdStructuredConfigOverlay) -> list | None: for ipvpn_gw_peer, data in natural_sort(self._ipvpn_gateway_remote_peers.items()): neighbor = self._create_neighbor( - data["ip_address"], ipvpn_gw_peer, self.shared_utils.bgp_peer_groups["ipvpn_gateway_peers"]["name"], remote_as=data["bgp_as"] + data["ip_address"], + ipvpn_gw_peer, + self.shared_utils.bgp_peer_groups["ipvpn_gateway_peers"]["name"], + remote_as=data["bgp_as"], ) # Add ebgp_multihop if the gw peer is an ebgp peer. if data["bgp_as"] != default(self._ipvpn_gateway_local_as, self.shared_utils.bgp_as): @@ -561,9 +574,7 @@ def _neighbors(self: AvdStructuredConfigOverlay) -> list | None: return None def _ip_in_listen_ranges(self: AvdStructuredConfigOverlay, source_ip: str, listen_range_prefixes: list) -> bool: - """ - Check if our source IP is in any of the listen range prefixes - """ + """Check if our source IP is in any of the listen range prefixes.""" source_ip = ipaddress.ip_address(source_ip) return any(source_ip in ipaddress.ip_network(prefix) for prefix in listen_range_prefixes) @@ -572,6 +583,6 @@ def _bgp_overlay_dpath(self: AvdStructuredConfigOverlay) -> dict | None: return { "bestpath": { "d_path": True, - } + }, } return None diff --git a/python-avd/pyavd/_eos_designs/structured_config/overlay/router_path_selection.py b/python-avd/pyavd/_eos_designs/structured_config/overlay/router_path_selection.py index b865fb00885..fec0c716147 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/overlay/router_path_selection.py +++ b/python-avd/pyavd/_eos_designs/structured_config/overlay/router_path_selection.py @@ -6,8 +6,9 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._errors import AristaAvdError -from ...._utils import get, get_item, strip_empties_from_dict +from pyavd._errors import AristaAvdError +from pyavd._utils import get, get_item, strip_empties_from_dict + from .utils import UtilsMixin if TYPE_CHECKING: @@ -17,15 +18,13 @@ class RouterPathSelectionMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def router_path_selection(self: AvdStructuredConfigOverlay) -> dict | None: - """ - Return structured config for router path-selection (DPS) - """ - + """Return structured config for router path-selection (DPS).""" if not self.shared_utils.is_wan_router: return None @@ -41,30 +40,21 @@ def router_path_selection(self: AvdStructuredConfigOverlay) -> dict | None: @cached_property def _cp_ipsec_profile_name(self: AvdStructuredConfigOverlay) -> str: - """ - Returns the IPsec profile name to use for Control-Plane - """ + """Returns the IPsec profile name to use for Control-Plane.""" return get(self._hostvars, "wan_ipsec_profiles.control_plane.profile_name", default="CP-PROFILE") @cached_property def _dp_ipsec_profile_name(self: AvdStructuredConfigOverlay) -> str: - """ - Returns the IPsec profile name to use for Data-Plane - """ - # TODO need to use CP one if 'wan_ipsec_profiles.data_plane' not present + """Returns the IPsec profile name to use for Data-Plane.""" + # TODO: need to use CP one if 'wan_ipsec_profiles.data_plane' not present return get(self._hostvars, "wan_ipsec_profiles.data_plane.profile_name", default="DP-PROFILE") def _get_path_groups(self: AvdStructuredConfigOverlay) -> list: - """ - Generate the required path-groups locally - """ + """Generate the required path-groups locally.""" path_groups = [] - if self.shared_utils.is_wan_server: - # Configure all path-groups on Pathfinders and AutoVPN RRs - path_groups_to_configure = self.shared_utils.wan_path_groups - else: - path_groups_to_configure = self.shared_utils.wan_local_path_groups + # Configure all path-groups on Pathfinders and AutoVPN RRs. Otherwise only configure the local path-groups + path_groups_to_configure = self.shared_utils.wan_path_groups if self.shared_utils.is_wan_server else self.shared_utils.wan_local_path_groups local_path_groups_names = [path_group["name"] for path_group in self.shared_utils.wan_local_path_groups] @@ -94,10 +84,11 @@ def _get_path_groups(self: AvdStructuredConfigOverlay) -> list: path_group_data["keepalive"] = {"auto": True} else: if not (interval.isdigit() and 50 <= int(interval) <= 60000): - raise AristaAvdError( + msg = ( f"Invalid value '{interval}' for dps_keepalive.interval - " f"should be either 'auto', or an integer[50-60000] for wan_path_groups[{pg_name}]" ) + raise AristaAvdError(msg) path_group_data["keepalive"] = { "interval": int(interval), "failure_threshold": get(keepalive, "failure_threshold", default=5), @@ -111,9 +102,7 @@ def _get_path_groups(self: AvdStructuredConfigOverlay) -> list: return path_groups def _generate_ha_path_group(self: AvdStructuredConfigOverlay) -> dict: - """ - Called only when self.shared_utils.wan_ha is True or on Pathfinders - """ + """Called only when self.shared_utils.wan_ha is True or on Pathfinders.""" ha_path_group = { "name": self.shared_utils.wan_ha_path_group_name, "id": self._get_path_group_id(self.shared_utils.wan_ha_path_group_name), @@ -132,9 +121,9 @@ def _generate_ha_path_group(self: AvdStructuredConfigOverlay) -> dict: "router_ip": self._wan_ha_peer_vtep_ip(), "name": self.shared_utils.wan_ha_peer, "ipv4_addresses": [ip_address.split("/")[0] for ip_address in self.shared_utils.wan_ha_peer_ip_addresses], - } + }, ], - } + }, ) if self.shared_utils.wan_ha_ipsec: ha_path_group["ipsec_profile"] = self._dp_ipsec_profile_name @@ -142,20 +131,19 @@ def _generate_ha_path_group(self: AvdStructuredConfigOverlay) -> dict: return ha_path_group def _wan_ha_interfaces(self: AvdStructuredConfigOverlay) -> list: - """ - Return list of interfaces for HA - """ + """Return list of interfaces for HA.""" return [uplink for uplink in self.shared_utils.get_switch_fact("uplinks") if get(uplink, "vrf") is None] def _wan_ha_peer_vtep_ip(self: AvdStructuredConfigOverlay) -> str: - """ """ peer_facts = self.shared_utils.get_peer_facts(self.shared_utils.wan_ha_peer, required=True) return get(peer_facts, "vtep_ip", required=True) def _get_path_group_id(self: AvdStructuredConfigOverlay, path_group_name: str, config_id: int | None = None) -> int: """ - TODO - implement algorithm to auto assign IDs - cf internal documentation - TODO - also implement algorithm for cross connects on public path_groups + Get path group id. + + TODO: - implement algorithm to auto assign IDs - cf internal documentation + TODO: - also implement algorithm for cross connects on public path_groups. """ if path_group_name == self.shared_utils.wan_ha_path_group_name: return 65535 @@ -165,7 +153,7 @@ def _get_path_group_id(self: AvdStructuredConfigOverlay, path_group_name: str, c def _get_local_interfaces_for_path_group(self: AvdStructuredConfigOverlay, path_group_name: str) -> list | None: """ - Generate the router_path_selection.local_interfaces list + Generate the router_path_selection.local_interfaces list. For AUTOVPN clients, configure the stun server profiles as appropriate """ @@ -184,9 +172,7 @@ def _get_local_interfaces_for_path_group(self: AvdStructuredConfigOverlay, path_ return local_interfaces def _get_dynamic_peers(self: AvdStructuredConfigOverlay, disable_ipsec: bool) -> dict | None: - """ - TODO support ip_local ? - """ + """TODO: support ip_local ?""" if not self.shared_utils.is_wan_client: return None @@ -196,27 +182,26 @@ def _get_dynamic_peers(self: AvdStructuredConfigOverlay, disable_ipsec: bool) -> return dynamic_peers def _get_static_peers_for_path_group(self: AvdStructuredConfigOverlay, path_group_name: str) -> list | None: - """ - Retrieves the static peers to configure for a given path-group based on the connected nodes. - """ + """Retrieves the static peers to configure for a given path-group based on the connected nodes.""" if not self.shared_utils.is_wan_router: return None static_peers = [] for wan_route_server_name, wan_route_server in self.shared_utils.filtered_wan_route_servers.items(): if (path_group := get_item(get(wan_route_server, "wan_path_groups", default=[]), "name", path_group_name)) is not None: - ipv4_addresses = [] + ipv4_addresses = [ + # TODO: - removing mask using split but maybe a helper is clearer + public_ip.split("/")[0] + for interface_dict in get(path_group, "interfaces", required=True) + if (public_ip := interface_dict.get("public_ip")) is not None + ] - for interface_dict in get(path_group, "interfaces", required=True): - if (public_ip := interface_dict.get("public_ip")) is not None: - # TODO - removing mask using split but maybe a helper is clearer - ipv4_addresses.append(public_ip.split("/")[0]) static_peers.append( { "router_ip": get(wan_route_server, "vtep_ip", required=True), "name": wan_route_server_name, "ipv4_addresses": ipv4_addresses, - } + }, ) return static_peers diff --git a/python-avd/pyavd/_eos_designs/structured_config/overlay/router_traffic_engineering.py b/python-avd/pyavd/_eos_designs/structured_config/overlay/router_traffic_engineering.py index ea55c939a7b..91afd0551ef 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/overlay/router_traffic_engineering.py +++ b/python-avd/pyavd/_eos_designs/structured_config/overlay/router_traffic_engineering.py @@ -15,15 +15,13 @@ class RouterTrafficEngineering(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def router_traffic_engineering(self: AvdStructuredConfigOverlay) -> dict | None: - """ - Return structured config for router traffic-engineering - """ - + """Return structured config for router traffic-engineering.""" if not self.shared_utils.is_cv_pathfinder_router: return None diff --git a/python-avd/pyavd/_eos_designs/structured_config/overlay/stun.py b/python-avd/pyavd/_eos_designs/structured_config/overlay/stun.py index bbd3bd3de9c..b05c950db23 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/overlay/stun.py +++ b/python-avd/pyavd/_eos_designs/structured_config/overlay/stun.py @@ -7,7 +7,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import strip_empties_from_dict +from pyavd._utils import strip_empties_from_dict + from .utils import UtilsMixin if TYPE_CHECKING: @@ -17,15 +18,13 @@ class StunMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def stun(self: AvdStructuredConfigOverlay) -> dict | None: - """ - Return structured config for stun - """ - + """Return structured config for stun.""" if not self.shared_utils.is_wan_router: return None @@ -37,7 +36,6 @@ def stun(self: AvdStructuredConfigOverlay) -> dict | None: "ssl_profile": self.shared_utils.wan_stun_dtls_profile_name, } - if self.shared_utils.is_wan_client: - if server_profiles := list(itertools.chain.from_iterable(self._stun_server_profiles.values())): - stun["client"] = {"server_profiles": server_profiles} + if self.shared_utils.is_wan_client and (server_profiles := list(itertools.chain.from_iterable(self._stun_server_profiles.values()))): + stun["client"] = {"server_profiles": server_profiles} return strip_empties_from_dict(stun) or None diff --git a/python-avd/pyavd/_eos_designs/structured_config/overlay/utils.py b/python-avd/pyavd/_eos_designs/structured_config/overlay/utils.py index e87a6d9c60c..e6a7d4ea551 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/overlay/utils.py +++ b/python-avd/pyavd/_eos_designs/structured_config/overlay/utils.py @@ -6,8 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import get -from ....j2filters import natural_sort +from pyavd._utils import get +from pyavd.j2filters import natural_sort if TYPE_CHECKING: from . import AvdStructuredConfigOverlay @@ -16,13 +16,14 @@ class UtilsMixin: """ Mixin Class with internal functions. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def _avd_overlay_peers(self: AvdStructuredConfigOverlay) -> list: """ - Returns a list of overlay peers for the device + Returns a list of overlay peers for the device. This cannot be loaded in shared_utils since it will not be calculated until EosDesignsFacts has been rendered and shared_utils are shared between EosDesignsFacts and AvdStructuredConfig classes like this one. @@ -49,7 +50,7 @@ def _evpn_gateway_remote_peers(self: AvdStructuredConfigOverlay) -> dict: else: # Server not found in inventory, adding manually - # TODO - what if the values are None - this is not handled by the template today + # TODO: - what if the values are None - this is not handled by the template today bgp_as = str(_as) if (_as := gw_remote_peer_dict.get("bgp_as")) else None ip_address = gw_remote_peer_dict.get("ip_address") @@ -107,14 +108,10 @@ def _is_mpls_server(self: AvdStructuredConfigOverlay) -> bool: return self.shared_utils.mpls_overlay_role == "server" or (self.shared_utils.evpn_role == "server" and self.shared_utils.overlay_evpn_mpls) def _is_peer_mpls_client(self: AvdStructuredConfigOverlay, peer_facts: dict) -> bool: - return peer_facts.get("mpls_overlay_role", None) == "client" or ( - peer_facts.get("evpn_role", None) == "client" and get(peer_facts, "overlay.evpn_mpls") is True - ) + return peer_facts.get("mpls_overlay_role") == "client" or (peer_facts.get("evpn_role") == "client" and get(peer_facts, "overlay.evpn_mpls") is True) def _is_peer_mpls_server(self: AvdStructuredConfigOverlay, peer_facts: dict) -> bool: - return peer_facts.get("mpls_overlay_role", None) == "server" or ( - peer_facts.get("evpn_role", None) == "server" and get(peer_facts, "overlay.evpn_mpls") is True - ) + return peer_facts.get("mpls_overlay_role") == "server" or (peer_facts.get("evpn_role") == "server" and get(peer_facts, "overlay.evpn_mpls") is True) @cached_property def _ipvpn_gateway_local_as(self: AvdStructuredConfigOverlay) -> str | None: @@ -229,7 +226,9 @@ def _mpls_rr_peers(self: AvdStructuredConfigOverlay) -> dict: continue if self.shared_utils.hostname in peer_facts.get("mpls_route_reflectors", []) and avd_peer not in get( - self._hostvars, "switch.mpls_route_reflectors", default=[] + self._hostvars, + "switch.mpls_route_reflectors", + default=[], ): self._append_peer(mpls_rr_peers, avd_peer, peer_facts) @@ -237,14 +236,14 @@ def _mpls_rr_peers(self: AvdStructuredConfigOverlay) -> dict: def _append_peer(self: AvdStructuredConfigOverlay, peers_dict: dict, peer_name: str, peer_facts: dict) -> None: """ - Retrieve bgp_as and "overlay.peering_address" from peer_facts and append - a new peer to peers_dict + Retrieve bgp_as and "overlay.peering_address" from peer_facts and append a new peer to peers_dict. + { peer_name: { "bgp_as": bgp_as, "ip_address": overlay.peering_address, } - } + }. """ bgp_as = peer_facts.get("bgp_as") peers_dict[peer_name] = { @@ -263,7 +262,7 @@ def _is_wan_server_with_peers(self: AvdStructuredConfigOverlay) -> bool: def _stun_server_profile_name(self: AvdStructuredConfigOverlay, wan_route_server_name: str, path_group_name: str, interface_name: str) -> str: """ - Return a string to use as the name of the stun server_profile + Return a string to use as the name of the stun server_profile. `/` are not allowed, `.` are allowed so Ethernet1/1.1 is transformed into Ethernet1_1.1 @@ -273,9 +272,7 @@ def _stun_server_profile_name(self: AvdStructuredConfigOverlay, wan_route_server @cached_property def _stun_server_profiles(self: AvdStructuredConfigOverlay) -> dict: - """ - Return a dictionary of _stun_server_profiles with ip_address per local path_group - """ + """Return a dictionary of _stun_server_profiles with ip_address per local path_group.""" stun_server_profiles = {} for wan_route_server, data in self.shared_utils.filtered_wan_route_servers.items(): for path_group in data.get("wan_path_groups", []): @@ -290,6 +287,5 @@ def _stun_server_profiles(self: AvdStructuredConfigOverlay) -> dict: return stun_server_profiles def _wan_ha_peer_vtep_ip(self) -> str: - """ """ peer_facts = self.shared_utils.get_peer_facts(self.shared_utils.wan_ha_peer, required=True) return get(peer_facts, "vtep_ip", required=True) diff --git a/python-avd/pyavd/_eos_designs/structured_config/underlay/__init__.py b/python-avd/pyavd/_eos_designs/structured_config/underlay/__init__.py index a2b2ee029e3..a03a27b87b4 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/underlay/__init__.py +++ b/python-avd/pyavd/_eos_designs/structured_config/underlay/__init__.py @@ -1,7 +1,8 @@ # Copyright (c) 2023-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -from ...avdfacts import AvdFacts +from pyavd._eos_designs.avdfacts import AvdFacts + from .agents import AgentsMixin from .as_path import AsPathMixin from .dhcp_server import DhcpServerMixin diff --git a/python-avd/pyavd/_eos_designs/structured_config/underlay/agents.py b/python-avd/pyavd/_eos_designs/structured_config/underlay/agents.py index 5b317cd3171..6f91162087f 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/underlay/agents.py +++ b/python-avd/pyavd/_eos_designs/structured_config/underlay/agents.py @@ -15,20 +15,16 @@ class AgentsMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def agents(self: AvdStructuredConfigUnderlay) -> list | None: - """ - Return structured config for agents - """ - + """Return structured config for agents.""" if not self.shared_utils.is_wan_router: return None - agents = [ + return [ {"name": "KernelFib", "environment_variables": [{"name": "KERNELFIB_PROGRAM_ALL_ECMP", "value": "1"}]}, ] - - return agents diff --git a/python-avd/pyavd/_eos_designs/structured_config/underlay/as_path.py b/python-avd/pyavd/_eos_designs/structured_config/underlay/as_path.py index 7b195afd924..c3ad8772756 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/underlay/as_path.py +++ b/python-avd/pyavd/_eos_designs/structured_config/underlay/as_path.py @@ -15,14 +15,13 @@ class AsPathMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def as_path(self: AvdStructuredConfigUnderlay) -> dict | None: - """ - Return structured config for as_path. - """ + """Return structured config for as_path.""" if self.shared_utils.underlay_routing_protocol != "ebgp": return None @@ -37,7 +36,7 @@ def as_path(self: AvdStructuredConfigUnderlay) -> dict | None: "match": self.shared_utils.bgp_as, }, ], - } + }, ) if access_lists: diff --git a/python-avd/pyavd/_eos_designs/structured_config/underlay/ethernet_interfaces.py b/python-avd/pyavd/_eos_designs/structured_config/underlay/ethernet_interfaces.py index a874e075827..f69b61b41aa 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/underlay/ethernet_interfaces.py +++ b/python-avd/pyavd/_eos_designs/structured_config/underlay/ethernet_interfaces.py @@ -6,10 +6,11 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._errors import AristaAvdError -from ...._utils import append_if_not_duplicate, get -from ....j2filters import encrypt, natural_sort -from ...interface_descriptions import InterfaceDescriptionData +from pyavd._eos_designs.interface_descriptions.models import InterfaceDescriptionData +from pyavd._errors import AristaAvdError +from pyavd._utils import append_if_not_duplicate, get +from pyavd.j2filters import encrypt, natural_sort + from .utils import UtilsMixin if TYPE_CHECKING: @@ -19,14 +20,13 @@ class EthernetInterfacesMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def ethernet_interfaces(self: AvdStructuredConfigUnderlay) -> list | None: - """ - Return structured config for ethernet_interfaces - """ + """Return structured config for ethernet_interfaces.""" ethernet_interfaces = [] for link in self._underlay_links: @@ -38,7 +38,7 @@ def ethernet_interfaces(self: AvdStructuredConfigUnderlay) -> list | None: link_type=link["type"], peer=link["peer"], peer_interface=link["peer_interface"], - ) + ), ) ethernet_interface = { "name": link["interface"], @@ -63,7 +63,7 @@ def ethernet_interfaces(self: AvdStructuredConfigUnderlay) -> list | None: "link_tracking_groups": link.get("link_tracking_groups"), "sflow": link.get("sflow"), "flow_tracker": link.get("flow_tracker"), - } + }, ) # PTP @@ -114,19 +114,20 @@ def ethernet_interfaces(self: AvdStructuredConfigUnderlay) -> list | None: "hash_algorithm": ospf_key.get("hash_algorithm", "sha512"), "key": encrypt( ospf_key["key"], - passwd_type="ospf_message_digest", # NOSONAR + passwd_type="ospf_message_digest", # NOSONAR # noqa: S106 key=ethernet_interface["name"], hash_algorithm=ospf_key.get("hash_algorithm", "sha512"), key_id=ospf_key["id"], ), - } + }, ) if len(ospf_keys) > 0: ethernet_interface["ospf_authentication"] = "message-digest" ethernet_interface["ospf_message_digest_keys"] = ospf_keys else: - raise AristaAvdError("'underlay_ospf_authentication.enabled' is True but no message-digest keys with both key and ID are defined.") + msg = "'underlay_ospf_authentication.enabled' is True but no message-digest keys with both key and ID are defined." + raise AristaAvdError(msg) if self.shared_utils.underlay_isis is True: ethernet_interface.update( @@ -136,7 +137,7 @@ def ethernet_interfaces(self: AvdStructuredConfigUnderlay) -> list | None: "isis_metric": self.shared_utils.isis_default_metric, "isis_network_point_to_point": True, "isis_circuit_type": self.shared_utils.isis_default_circuit_type, - } + }, ) if link.get("underlay_multicast") is True: @@ -166,7 +167,7 @@ def ethernet_interfaces(self: AvdStructuredConfigUnderlay) -> list | None: "id": int(channel_group_id), "mode": "active", }, - } + }, ) if get(link, "inband_ztp_vlan"): ethernet_interface.update({"mode": "access", "vlans": link["inband_ztp_vlan"]}) @@ -182,7 +183,7 @@ def ethernet_interfaces(self: AvdStructuredConfigUnderlay) -> list | None: "link_tracking_groups": link.get("link_tracking_groups"), "spanning_tree_portfast": link.get("spanning_tree_portfast"), "flow_tracker": link.get("flow_tracker"), - } + }, ) # Remove None values @@ -206,7 +207,7 @@ def ethernet_interfaces(self: AvdStructuredConfigUnderlay) -> list | None: peer=link["peer"], peer_interface=subinterface["peer_interface"], vrf=subinterface["vrf"], - ) + ), ) ethernet_subinterface = { "name": subinterface["interface"], @@ -214,7 +215,7 @@ def ethernet_interfaces(self: AvdStructuredConfigUnderlay) -> list | None: "peer_interface": subinterface["peer_interface"], "peer_type": link["peer_type"], "vrf": subinterface["vrf"], - # TODO - for now reusing the encapsulation as it is hardcoded to the VRF ID which is used as + # TODO: - for now reusing the encapsulation as it is hardcoded to the VRF ID which is used as # subinterface name "description": description, "shutdown": self.shared_utils.shutdown_interfaces_towards_undeployed_peers and not link["peer_is_deployed"], @@ -288,7 +289,9 @@ def ethernet_interfaces(self: AvdStructuredConfigUnderlay) -> list | None: # WAN HA interfaces for direct connection if self.shared_utils.use_uplinks_for_wan_ha is False: direct_wan_ha_links_flow_tracker = get( - self.shared_utils.switch_data_combined, "wan_ha.flow_tracker", default=self.shared_utils.get_flow_tracker(None, "direct_wan_ha_links") + self.shared_utils.switch_data_combined, + "wan_ha.flow_tracker", + default=self.shared_utils.get_flow_tracker(None, "direct_wan_ha_links"), ) for index, interface in enumerate(get(self.shared_utils.switch_data_combined, "wan_ha.ha_interfaces", required=True)): ha_interface = { diff --git a/python-avd/pyavd/_eos_designs/structured_config/underlay/ip_access_lists.py b/python-avd/pyavd/_eos_designs/structured_config/underlay/ip_access_lists.py index 6d33e3d92f9..549b4855b52 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/underlay/ip_access_lists.py +++ b/python-avd/pyavd/_eos_designs/structured_config/underlay/ip_access_lists.py @@ -6,8 +6,9 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import append_if_not_duplicate -from ....j2filters import natural_sort +from pyavd._utils import append_if_not_duplicate +from pyavd.j2filters import natural_sort + from .utils import UtilsMixin if TYPE_CHECKING: @@ -17,7 +18,8 @@ class IpAccesslistsMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property diff --git a/python-avd/pyavd/_eos_designs/structured_config/underlay/loopback_interfaces.py b/python-avd/pyavd/_eos_designs/structured_config/underlay/loopback_interfaces.py index 6b5ea9c3fbc..3e2c0eb2a6a 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/underlay/loopback_interfaces.py +++ b/python-avd/pyavd/_eos_designs/structured_config/underlay/loopback_interfaces.py @@ -6,9 +6,10 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._errors import AristaAvdMissingVariableError -from ...._utils import get -from ...interface_descriptions.models import InterfaceDescriptionData +from pyavd._eos_designs.interface_descriptions.models import InterfaceDescriptionData +from pyavd._errors import AristaAvdMissingVariableError +from pyavd._utils import get + from .utils import UtilsMixin if TYPE_CHECKING: @@ -18,14 +19,13 @@ class LoopbackInterfacesMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def loopback_interfaces(self: AvdStructuredConfigUnderlay) -> list | None: - """ - Return structured config for loopback_interfaces - """ + """Return structured config for loopback_interfaces.""" if not self.shared_utils.underlay_router: return None @@ -34,7 +34,9 @@ def loopback_interfaces(self: AvdStructuredConfigUnderlay) -> list | None: loopback0 = { "name": "Loopback0", "description": self.shared_utils.interface_descriptions.router_id_loopback_interface( - InterfaceDescriptionData(shared_utils=self.shared_utils, interface="Loopback0", description=get(self._hostvars, "overlay_loopback_description")) + InterfaceDescriptionData( + shared_utils=self.shared_utils, interface="Loopback0", description=get(self._hostvars, "overlay_loopback_description") + ), ), "shutdown": False, "ip_address": f"{self.shared_utils.router_id}/32", @@ -90,14 +92,14 @@ def loopback_interfaces(self: AvdStructuredConfigUnderlay) -> list | None: # Underlay Multicast RP Loopbacks if self.shared_utils.underlay_multicast_rp_interfaces is not None: - for underlay_multicast_rp_interface in self.shared_utils.underlay_multicast_rp_interfaces: - loopback_interfaces.append(underlay_multicast_rp_interface) + loopback_interfaces.extend(self.shared_utils.underlay_multicast_rp_interfaces) return loopback_interfaces @cached_property - def _node_sid(self: AvdStructuredConfigUnderlay): + def _node_sid(self: AvdStructuredConfigUnderlay) -> str: if self.shared_utils.id is None: - raise AristaAvdMissingVariableError(f"'id' is not set on '{self.shared_utils.hostname}' and is required to set node SID") + msg = f"'id' is not set on '{self.shared_utils.hostname}' and is required to set node SID" + raise AristaAvdMissingVariableError(msg) node_sid_base = int(get(self.shared_utils.switch_data_combined, "node_sid_base", 0)) return self.shared_utils.id + node_sid_base diff --git a/python-avd/pyavd/_eos_designs/structured_config/underlay/mpls.py b/python-avd/pyavd/_eos_designs/structured_config/underlay/mpls.py index 390d4dd2433..2f90eb76fa3 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/underlay/mpls.py +++ b/python-avd/pyavd/_eos_designs/structured_config/underlay/mpls.py @@ -15,14 +15,13 @@ class MplsMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def mpls(self: AvdStructuredConfigUnderlay) -> dict | None: - """ - Return structured config for mpls - """ + """Return structured config for mpls.""" if self.shared_utils.underlay_mpls is not True: return None diff --git a/python-avd/pyavd/_eos_designs/structured_config/underlay/port_channel_interfaces.py b/python-avd/pyavd/_eos_designs/structured_config/underlay/port_channel_interfaces.py index b75a76cbca4..32d739b44dc 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/underlay/port_channel_interfaces.py +++ b/python-avd/pyavd/_eos_designs/structured_config/underlay/port_channel_interfaces.py @@ -6,8 +6,9 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import get, short_esi_to_route_target -from ...interface_descriptions.models import InterfaceDescriptionData +from pyavd._eos_designs.interface_descriptions.models import InterfaceDescriptionData +from pyavd._utils import get, short_esi_to_route_target + from .utils import UtilsMixin if TYPE_CHECKING: @@ -17,14 +18,13 @@ class PortChannelInterfacesMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def port_channel_interfaces(self: AvdStructuredConfigUnderlay) -> list | None: - """ - Return structured config for port_channel_interfaces - """ + """Return structured config for port_channel_interfaces.""" port_channel_interfaces = [] port_channel_list = [] for link in self._underlay_links: @@ -47,7 +47,7 @@ def port_channel_interfaces(self: AvdStructuredConfigUnderlay) -> list | None: peer=link["peer"], peer_channel_group_id=link["peer_channel_group_id"], port_channel_description=link.get("channel_description"), - ) + ), ), "type": "switched", "shutdown": False, diff --git a/python-avd/pyavd/_eos_designs/structured_config/underlay/prefix_lists.py b/python-avd/pyavd/_eos_designs/structured_config/underlay/prefix_lists.py index 5952d18bd3b..d85893ded8c 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/underlay/prefix_lists.py +++ b/python-avd/pyavd/_eos_designs/structured_config/underlay/prefix_lists.py @@ -8,7 +8,8 @@ from ipaddress import ip_network from typing import TYPE_CHECKING -from ...._utils import get, get_item +from pyavd._utils import get, get_item + from .utils import UtilsMixin if TYPE_CHECKING: @@ -18,14 +19,13 @@ class PrefixListsMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def prefix_lists(self: AvdStructuredConfigUnderlay) -> list | None: - """ - Return structured config for prefix_lists - """ + """Return structured config for prefix_lists.""" if self.shared_utils.underlay_bgp is not True and not self.shared_utils.is_wan_router: return None @@ -73,17 +73,15 @@ def prefix_lists(self: AvdStructuredConfigUnderlay) -> list | None: prefix_lists_in_use = set() for neighbor in self.shared_utils.l3_interfaces_bgp_neighbors: - if prefix_list_in := get(neighbor, "ipv4_prefix_list_in"): - if prefix_list_in not in prefix_lists_in_use: - pfx_list = self._get_prefix_list(prefix_list_in) - prefix_lists.append(pfx_list) - prefix_lists_in_use.add(prefix_list_in) - - if prefix_list_out := get(neighbor, "ipv4_prefix_list_out"): - if prefix_list_out not in prefix_lists_in_use: - pfx_list = self._get_prefix_list(prefix_list_out) - prefix_lists.append(pfx_list) - prefix_lists_in_use.add(prefix_list_out) + if (prefix_list_in := get(neighbor, "ipv4_prefix_list_in")) and prefix_list_in not in prefix_lists_in_use: + pfx_list = self._get_prefix_list(prefix_list_in) + prefix_lists.append(pfx_list) + prefix_lists_in_use.add(prefix_list_in) + + if (prefix_list_out := get(neighbor, "ipv4_prefix_list_out")) and prefix_list_out not in prefix_lists_in_use: + pfx_list = self._get_prefix_list(prefix_list_out) + prefix_lists.append(pfx_list) + prefix_lists_in_use.add(prefix_list_out) # P2P-LINKS needed for L3 inband ZTP p2p_links_sequence_numbers = [] @@ -106,14 +104,12 @@ def prefix_lists(self: AvdStructuredConfigUnderlay) -> list | None: return prefix_lists - def _get_prefix_list(self, name: str): + def _get_prefix_list(self, name: str) -> dict: return get_item(self.shared_utils.ipv4_prefix_list_catalog, "name", name, required=True, var_name=f"ipv4_prefix_list_catalog[name={name}]") @cached_property def ipv6_prefix_lists(self: AvdStructuredConfigUnderlay) -> list | None: - """ - Return structured config for IPv6 prefix_lists - """ + """Return structured config for IPv6 prefix_lists.""" if self.shared_utils.underlay_bgp is not True: return None @@ -128,5 +124,5 @@ def ipv6_prefix_lists(self: AvdStructuredConfigUnderlay) -> list | None: # IPv6 - PL-LOOPBACKS-EVPN-OVERLAY-V6 return [ - {"name": "PL-LOOPBACKS-EVPN-OVERLAY-V6", "sequence_numbers": [{"sequence": 10, "action": f"permit {self.shared_utils.loopback_ipv6_pool} eq 128"}]} + {"name": "PL-LOOPBACKS-EVPN-OVERLAY-V6", "sequence_numbers": [{"sequence": 10, "action": f"permit {self.shared_utils.loopback_ipv6_pool} eq 128"}]}, ] diff --git a/python-avd/pyavd/_eos_designs/structured_config/underlay/route_maps.py b/python-avd/pyavd/_eos_designs/structured_config/underlay/route_maps.py index fd43d39925a..8ec5aa4547b 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/underlay/route_maps.py +++ b/python-avd/pyavd/_eos_designs/structured_config/underlay/route_maps.py @@ -6,7 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import get +from pyavd._utils import get + from .utils import UtilsMixin if TYPE_CHECKING: @@ -16,13 +17,14 @@ class RouteMapsMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def route_maps(self: AvdStructuredConfigUnderlay) -> list | None: """ - Return structured config for route_maps + Return structured config for route_maps. Contains two parts. - Route map for connected routes redistribution in BGP @@ -52,7 +54,7 @@ def route_maps(self: AvdStructuredConfigUnderlay) -> list | None: "sequence": 30, "type": "permit", "match": ["ipv6 address prefix-list PL-LOOPBACKS-EVPN-OVERLAY-V6"], - } + }, ) if self.shared_utils.underlay_multicast_rp_interfaces is not None: @@ -61,7 +63,7 @@ def route_maps(self: AvdStructuredConfigUnderlay) -> list | None: "sequence": 40, "type": "permit", "match": ["ip address prefix-list PL-LOOPBACKS-PIM-RP"], - } + }, ) if self.shared_utils.wan_ha and self.shared_utils.use_uplinks_for_wan_ha: @@ -70,7 +72,7 @@ def route_maps(self: AvdStructuredConfigUnderlay) -> list | None: "sequence": 50, "type": "permit", "match": ["ip address prefix-list PL-WAN-HA-PREFIXES"], - } + }, ) add_p2p_links = False @@ -116,7 +118,7 @@ def route_maps(self: AvdStructuredConfigUnderlay) -> list | None: "type": "permit", }, ], - } + }, ) # Route-map IN and OUT for SOO, rendered for WAN routers @@ -145,7 +147,7 @@ def route_maps(self: AvdStructuredConfigUnderlay) -> list | None: "description": "Deny other routes from the HA peer", "match": ["as-path ASPATH-WAN"], }, - ] + ], ) route_maps.append({"name": "RM-BGP-UNDERLAY-PEERS-IN", "sequence_numbers": sequence_numbers}) diff --git a/python-avd/pyavd/_eos_designs/structured_config/underlay/router_bgp.py b/python-avd/pyavd/_eos_designs/structured_config/underlay/router_bgp.py index 3877292aec5..50b26c660e3 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/underlay/router_bgp.py +++ b/python-avd/pyavd/_eos_designs/structured_config/underlay/router_bgp.py @@ -6,7 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import append_if_not_duplicate, get, strip_empties_from_dict +from pyavd._utils import append_if_not_duplicate, get, strip_empties_from_dict + from .utils import UtilsMixin if TYPE_CHECKING: @@ -16,15 +17,13 @@ class RouterBgpMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def router_bgp(self: AvdStructuredConfigUnderlay) -> dict | None: - """ - Return the structured config for router_bgp - """ - + """Return the structured config for router_bgp.""" if not self.shared_utils.underlay_bgp: if self.shared_utils.is_wan_router: # Configure redistribute connected with or without route-map in case it the underlay is not BGP. @@ -57,14 +56,14 @@ def router_bgp(self: AvdStructuredConfigUnderlay) -> dict | None: router_bgp["peer_groups"] = [strip_empties_from_dict(peer_group)] # Address Families - # TODO - see if it makes sense to extract logic in method + # TODO: - see if it makes sense to extract logic in method address_family_ipv4_peer_group = {"activate": True} if self.shared_utils.underlay_rfc5549 is True: address_family_ipv4_peer_group["next_hop"] = {"address_family_ipv6": {"enabled": True, "originate": True}} router_bgp["address_family_ipv4"] = { - "peer_groups": [{"name": self.shared_utils.bgp_peer_groups["ipv4_underlay_peers"]["name"], **address_family_ipv4_peer_group}] + "peer_groups": [{"name": self.shared_utils.bgp_peer_groups["ipv4_underlay_peers"]["name"], **address_family_ipv4_peer_group}], } if self.shared_utils.underlay_ipv6 is True: @@ -89,7 +88,7 @@ def router_bgp(self: AvdStructuredConfigUnderlay) -> dict | None: "remote_as": link["peer_bgp_as"], "peer": link["peer"], "description": "_".join([link["peer"], link["peer_interface"]]), - } + }, ) if "subinterfaces" in link: @@ -107,9 +106,9 @@ def router_bgp(self: AvdStructuredConfigUnderlay) -> dict | None: "name": subinterface["interface"], "peer_group": self.shared_utils.bgp_peer_groups["ipv4_underlay_peers"]["name"], "remote_as": link["peer_bgp_as"], - # TODO - implement some centralized way to generate these descriptions + # TODO: - implement some centralized way to generate these descriptions "description": f"{'_'.join([link['peer'], subinterface['peer_interface']])}_vrf_{subinterface['vrf']}", - } + }, ) if neighbor_interfaces: @@ -193,9 +192,7 @@ def router_bgp(self: AvdStructuredConfigUnderlay) -> dict | None: @cached_property def _router_bgp_redistribute_routes(self: AvdStructuredConfigUnderlay) -> list: - """ - Return structured config for router_bgp.redistribute_routes - """ + """Return structured config for router_bgp.redistribute_routes.""" if self.shared_utils.overlay_routing_protocol == "none" or not self.shared_utils.underlay_filter_redistribute_connected: return [{"source_protocol": "connected"}] diff --git a/python-avd/pyavd/_eos_designs/structured_config/underlay/router_isis.py b/python-avd/pyavd/_eos_designs/structured_config/underlay/router_isis.py index d558f53e9f1..73cc274c8bd 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/underlay/router_isis.py +++ b/python-avd/pyavd/_eos_designs/structured_config/underlay/router_isis.py @@ -6,8 +6,9 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._errors import AristaAvdMissingVariableError -from ...._utils import get +from pyavd._errors import AristaAvdMissingVariableError +from pyavd._utils import get + from .utils import UtilsMixin if TYPE_CHECKING: @@ -17,14 +18,13 @@ class RouterIsisMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def router_isis(self: AvdStructuredConfigUnderlay) -> dict | None: - """ - return structured config for router_isis - """ + """Return structured config for router_isis.""" if self.shared_utils.underlay_isis is not True: return None @@ -46,7 +46,7 @@ def router_isis(self: AvdStructuredConfigUnderlay) -> dict | None: "local_convergence": { "delay": get(self._hostvars, "isis_ti_lfa.local_convergence_delay", default="10000"), "protected_prefixes": True, - } + }, } ti_lfa_protection = get(self._hostvars, "isis_ti_lfa.protection") if ti_lfa_protection == "link": @@ -62,7 +62,7 @@ def router_isis(self: AvdStructuredConfigUnderlay) -> dict | None: router_isis["advertise"] = { "passive_only": get(self._hostvars, "isis_advertise_passive_only", default=False), } - # TODO - enabling IPv6 only in SR cases as per existing behavior + # TODO: - enabling IPv6 only in SR cases as per existing behavior # but this could probably be taken out if self.shared_utils.underlay_ipv6 is True: router_isis["address_family_ipv6"] = {"enabled": True, "maximum_paths": get(self._hostvars, "isis_maximum_paths", default=4)} @@ -83,9 +83,8 @@ def _isis_net(self: AvdStructuredConfigUnderlay) -> str | None: return None if self.shared_utils.id is None: - raise AristaAvdMissingVariableError( - f"'id' is not set on '{self.shared_utils.hostname}' and is required to set ISIS NET address using the node ID" - ) + msg = f"'id' is not set on '{self.shared_utils.hostname}' and is required to set ISIS NET address using the node ID" + raise AristaAvdMissingVariableError(msg) system_id = f"{isis_system_id_prefix}.{self.shared_utils.id:04d}" else: system_id = self.ipv4_to_isis_system_id(self.shared_utils.router_id) diff --git a/python-avd/pyavd/_eos_designs/structured_config/underlay/router_msdp.py b/python-avd/pyavd/_eos_designs/structured_config/underlay/router_msdp.py index c37cb8e91e1..74437b785b1 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/underlay/router_msdp.py +++ b/python-avd/pyavd/_eos_designs/structured_config/underlay/router_msdp.py @@ -6,8 +6,9 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import get, get_item -from ....j2filters import natural_sort +from pyavd._utils import get, get_item +from pyavd.j2filters import natural_sort + from .utils import UtilsMixin if TYPE_CHECKING: @@ -17,17 +18,17 @@ class RouterMsdpMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def router_msdp(self: AvdStructuredConfigUnderlay) -> dict | None: """ - return structured config for router_msdp + return structured config for router_msdp. Used for to configure multicast anycast RPs for the underlay """ - if self.shared_utils.underlay_multicast_rps is None: return None diff --git a/python-avd/pyavd/_eos_designs/structured_config/underlay/router_ospf.py b/python-avd/pyavd/_eos_designs/structured_config/underlay/router_ospf.py index 06ef57beda8..b143c800bda 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/underlay/router_ospf.py +++ b/python-avd/pyavd/_eos_designs/structured_config/underlay/router_ospf.py @@ -6,7 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import default, get +from pyavd._utils import default, get + from .utils import UtilsMixin if TYPE_CHECKING: @@ -16,14 +17,13 @@ class RouterOspfMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def router_ospf(self: AvdStructuredConfigUnderlay) -> dict | None: - """ - return structured config for router_ospf - """ + """Return structured config for router_ospf.""" if self.shared_utils.underlay_ospf is not True: return None diff --git a/python-avd/pyavd/_eos_designs/structured_config/underlay/router_pim_sparse_mode.py b/python-avd/pyavd/_eos_designs/structured_config/underlay/router_pim_sparse_mode.py index 95b01390f3d..deccb9d595a 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/underlay/router_pim_sparse_mode.py +++ b/python-avd/pyavd/_eos_designs/structured_config/underlay/router_pim_sparse_mode.py @@ -6,7 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import get, get_item +from pyavd._utils import get, get_item + from .utils import UtilsMixin if TYPE_CHECKING: @@ -16,17 +17,17 @@ class RouterPimSparseModeMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def router_pim_sparse_mode(self: AvdStructuredConfigUnderlay) -> dict | None: """ - return structured config for router_pim_sparse_mode + return structured config for router_pim_sparse_mode. Used for to configure multicast RPs for the underlay """ - if self.shared_utils.underlay_multicast_rps is None: return None @@ -59,14 +60,14 @@ def router_pim_sparse_mode(self: AvdStructuredConfigUnderlay) -> dict | None: } for node in nodes ], - } + }, ) if rp_addresses: router_pim_sparse_mode = { "ipv4": { "rp_addresses": rp_addresses, - } + }, } if anycast_rps: router_pim_sparse_mode["ipv4"]["anycast_rps"] = anycast_rps diff --git a/python-avd/pyavd/_eos_designs/structured_config/underlay/standard_access_lists.py b/python-avd/pyavd/_eos_designs/structured_config/underlay/standard_access_lists.py index d777d802052..94550007260 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/underlay/standard_access_lists.py +++ b/python-avd/pyavd/_eos_designs/structured_config/underlay/standard_access_lists.py @@ -15,17 +15,17 @@ class StandardAccessListsMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def standard_access_lists(self: AvdStructuredConfigUnderlay) -> list | None: """ - return structured config for standard_access_lists + return structured config for standard_access_lists. Used for to configure ACLs used by multicast RPs for the underlay """ - if self.shared_utils.underlay_multicast_rps is None: return None @@ -44,7 +44,7 @@ def standard_access_lists(self: AvdStructuredConfigUnderlay) -> list | None: } for index, group in enumerate(rp_entry["groups"]) ], - } + }, ) if standard_access_lists: diff --git a/python-avd/pyavd/_eos_designs/structured_config/underlay/static_routes.py b/python-avd/pyavd/_eos_designs/structured_config/underlay/static_routes.py index fa15767f536..4a20e1d1f55 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/underlay/static_routes.py +++ b/python-avd/pyavd/_eos_designs/structured_config/underlay/static_routes.py @@ -6,7 +6,8 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import get +from pyavd._utils import get + from .utils import UtilsMixin if TYPE_CHECKING: @@ -16,18 +17,18 @@ class StaticRoutesMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def static_routes(self: AvdStructuredConfigUnderlay) -> list[dict] | None: """ - Returns structured config for static_routes + Returns structured config for static_routes. Consist of - static_routes configured under node type l3 interfaces """ - static_routes = [] for l3_interface in self.shared_utils.l3_interfaces: diff --git a/python-avd/pyavd/_eos_designs/structured_config/underlay/utils.py b/python-avd/pyavd/_eos_designs/structured_config/underlay/utils.py index b4a6a6fac19..2d89f1adce3 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/underlay/utils.py +++ b/python-avd/pyavd/_eos_designs/structured_config/underlay/utils.py @@ -6,10 +6,10 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._errors import AristaAvdError -from ...._utils import default, get, get_item, strip_empties_from_dict -from ....j2filters import natural_sort, range_expand -from ...interface_descriptions import InterfaceDescriptionData +from pyavd._eos_designs.interface_descriptions.models import InterfaceDescriptionData +from pyavd._errors import AristaAvdError +from pyavd._utils import default, get, get_item, strip_empties_from_dict +from pyavd.j2filters import natural_sort, range_expand if TYPE_CHECKING: from . import AvdStructuredConfigUnderlay @@ -18,13 +18,14 @@ class UtilsMixin: """ Mixin Class with internal functions. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def _avd_peers(self: AvdStructuredConfigUnderlay) -> list: """ - Returns a list of peers + Returns a list of peers. This cannot be loaded in shared_utils since it will not be calculated until EosDesignsFacts has been rendered and shared_utils are shared between EosDesignsFacts and AvdStructuredConfig classes like this one. @@ -33,9 +34,7 @@ def _avd_peers(self: AvdStructuredConfigUnderlay) -> list: @cached_property def _underlay_filter_peer_as_route_maps_asns(self: AvdStructuredConfigUnderlay) -> list: - """ - Filtered ASNs - """ + """Filtered ASNs.""" if self.shared_utils.underlay_filter_peer_as is False: return [] @@ -44,9 +43,7 @@ def _underlay_filter_peer_as_route_maps_asns(self: AvdStructuredConfigUnderlay) @cached_property def _underlay_links(self: AvdStructuredConfigUnderlay) -> list: - """ - Returns the list of underlay links for this device - """ + """Returns the list of underlay links for this device.""" underlay_links = [] underlay_links.extend(self._uplinks) if self.shared_utils.fabric_sflow_uplinks is not None: @@ -118,9 +115,7 @@ def _underlay_links(self: AvdStructuredConfigUnderlay) -> list: @cached_property def _underlay_vlan_trunk_groups(self: AvdStructuredConfigUnderlay) -> list: - """ - Returns a list of trunk groups to configure on the underlay link - """ + """Returns a list of trunk groups to configure on the underlay link.""" if self.shared_utils.enable_trunk_groups is not True: return [] @@ -137,7 +132,7 @@ def _underlay_vlan_trunk_groups(self: AvdStructuredConfigUnderlay) -> list: { "vlan_list": uplink["vlans"], "trunk_groups": peer_trunk_groups, - } + }, ) if trunk_groups: @@ -150,14 +145,9 @@ def _uplinks(self: AvdStructuredConfigUnderlay) -> list: return get(self._hostvars, "switch.uplinks") def _get_l3_interface_cfg(self: AvdStructuredConfigUnderlay, l3_interface: dict) -> dict | None: - """ - Returns structured_configuration for one L3 interface - """ + """Returns structured_configuration for one L3 interface.""" interface_name = get(l3_interface, "name", required=True, org_key=f"...[node={self.shared_utils.hostname}].l3_interfaces[].name]") - if "." in interface_name: - iface_type = "l3dot1q" - else: - iface_type = "routed" + iface_type = "l3dot1q" if "." in interface_name else "routed" interface_description = l3_interface.get("description") if not interface_description: @@ -169,10 +159,10 @@ def _get_l3_interface_cfg(self: AvdStructuredConfigUnderlay, l3_interface: dict) peer_interface=l3_interface.get("peer_interface"), wan_carrier=l3_interface.get("wan_carrier"), wan_circuit_id=l3_interface.get("wan_circuit_id"), - ) + ), ) - # TODO catch if ip_address is not valid or not dhcp + # TODO: catch if ip_address is not valid or not dhcp ip_address = get( l3_interface, "ip_address", @@ -204,21 +194,22 @@ def _get_l3_interface_cfg(self: AvdStructuredConfigUnderlay, l3_interface: dict) if ip_address == "dhcp" and l3_interface.get("dhcp_accept_default_route", True): interface["dhcp_client_accept_default_route"] = True - if self.shared_utils.is_wan_router and (wan_carrier_name := l3_interface.get("wan_carrier")) is not None and interface["access_group_in"] is None: - if not get(get_item(self.shared_utils.wan_carriers, "name", wan_carrier_name, default={}), "trusted"): - raise AristaAvdError( - ( - "'ipv4_acl_in' must be set on WAN interfaces where 'wan_carrier' is set, unless the carrier is configured as 'trusted' " - f"under 'wan_carriers'. 'ipv4_acl_in' is missing on interface '{interface_name}'." - ) - ) + if ( + self.shared_utils.is_wan_router + and (wan_carrier_name := l3_interface.get("wan_carrier")) is not None + and interface["access_group_in"] is None + and not get(get_item(self.shared_utils.wan_carriers, "name", wan_carrier_name, default={}), "trusted") + ): + msg = ( + "'ipv4_acl_in' must be set on WAN interfaces where 'wan_carrier' is set, unless the carrier is configured as 'trusted' " + f"under 'wan_carriers'. 'ipv4_acl_in' is missing on interface '{interface_name}'." + ) + raise AristaAvdError(msg) return strip_empties_from_dict(interface) def _get_l3_uplink_with_l2_as_subint(self: AvdStructuredConfigUnderlay, link: dict) -> tuple[dict, list[dict]]: - """ - Return a tuple with main uplink interface, list of subinterfaces representing each SVI. - """ + """Return a tuple with main uplink interface, list of subinterfaces representing each SVI.""" vlans = [int(vlan) for vlan in range_expand(link["vlans"])] # Main interface @@ -241,16 +232,18 @@ def _get_l3_uplink_with_l2_as_subint(self: AvdStructuredConfigUnderlay, link: di main_interface.pop("description", None) if (mtu := main_interface.get("mtu", 1500)) != self.shared_utils.p2p_uplinks_mtu: - raise AristaAvdError( + msg = ( f"MTU '{self.shared_utils.p2p_uplinks_mtu}' set for 'p2p_uplinks_mtu' conflicts with MTU '{mtu}' " f"set on SVI for uplink_native_vlan '{link['native_vlan']}'." "Either adjust the MTU on the SVI or p2p_uplinks_mtu or change/remove the uplink_native_vlan setting." ) + raise AristaAvdError(msg) return main_interface, [interface for interface in interfaces if interface["name"] != link["interface"]] def _get_l2_as_subint(self: AvdStructuredConfigUnderlay, link: dict, svi: dict, vrf: dict) -> dict: """ Return structured config for one subinterface representing the given SVI. + Only supports static IPs or VRRP. """ svi_id = int(svi["id"]) @@ -275,20 +268,21 @@ def _get_l2_as_subint(self: AvdStructuredConfigUnderlay, link: dict, svi: dict, "flow_tracker": link.get("flow_tracker"), } if (mtu := subinterface["mtu"]) is not None and subinterface["mtu"] > self.shared_utils.p2p_uplinks_mtu: - raise AristaAvdError( + msg = ( f"MTU '{self.shared_utils.p2p_uplinks_mtu}' set for 'p2p_uplinks_mtu' must be larger or equal to MTU '{mtu}' " f"set on the SVI '{svi_id}'." "Either adjust the MTU on the SVI or p2p_uplinks_mtu." ) + raise AristaAvdError(msg) # Only set VRRPv4 if ip_address is set if subinterface["ip_address"] is not None: - # TODO in separate PR adding VRRP support for SVIs + # TODO: in separate PR adding VRRP support for SVIs pass # Only set VRRPv6 if ipv6_address is set if subinterface["ipv6_address"] is not None: - # TODO in separate PR adding VRRP support for SVIs + # TODO: in separate PR adding VRRP support for SVIs pass # Adding IP helpers and OSPF via a common function also used for SVIs on L3 switches. @@ -299,11 +293,13 @@ def _get_l2_as_subint(self: AvdStructuredConfigUnderlay, link: dict, svi: dict, @cached_property def _l3_interface_acls(self: AvdStructuredConfigUnderlay) -> dict[str, dict[str, dict]]: """ - Returns a dict of - : { - "ipv4_acl_in": , - "ipv4_acl_out": , - } + Return dict of l3 interface ACLs. + + : { + "ipv4_acl_in": , + "ipv4_acl_out": , + } + Only contains interfaces with ACLs and only the ACLs that are set, so use `get(self._l3_interface_acls, f"{interface_name}.ipv4_acl_in")` to get the value. """ diff --git a/python-avd/pyavd/_eos_designs/structured_config/underlay/vlans.py b/python-avd/pyavd/_eos_designs/structured_config/underlay/vlans.py index f4957538317..c7529d2d9bb 100644 --- a/python-avd/pyavd/_eos_designs/structured_config/underlay/vlans.py +++ b/python-avd/pyavd/_eos_designs/structured_config/underlay/vlans.py @@ -6,8 +6,9 @@ from functools import cached_property from typing import TYPE_CHECKING -from ...._utils import get, get_item -from ....j2filters import natural_sort, range_expand +from pyavd._utils import get, get_item +from pyavd.j2filters import natural_sort, range_expand + from .utils import UtilsMixin if TYPE_CHECKING: @@ -17,13 +18,14 @@ class VlansMixin(UtilsMixin): """ Mixin Class used to generate structured config for one key. - Class should only be used as Mixin to a AvdStructuredConfig class + + Class should only be used as Mixin to a AvdStructuredConfig class. """ @cached_property def vlans(self: AvdStructuredConfigUnderlay) -> list | None: """ - Return structured config for vlans + Return structured config for vlans. This function goes through all the underlay trunk groups and returns an inverted dict where the key is the vlan ID and the value is the list of @@ -31,9 +33,8 @@ def vlans(self: AvdStructuredConfigUnderlay) -> list | None: The function also creates uplink_native_vlan for this switch or downstream switches. """ - vlans = [] - # TODO - can probably do this with sets but need list in the end so not sure it is worth it + # TODO: - can probably do this with sets but need list in the end so not sure it is worth it for vlan_trunk_group in self._underlay_vlan_trunk_groups: for vlan in range_expand(vlan_trunk_group["vlan_list"]): if (found_vlan := get_item(vlans, "id", int(vlan))) is None: @@ -52,16 +53,16 @@ def vlans(self: AvdStructuredConfigUnderlay) -> list | None: # Add configuration for uplink or peer's uplink_native_vlan if it is not defined as part of network services switch_vlans = range_expand(get(self._hostvars, "switch.vlans")) uplink_native_vlans = natural_sort( - set(link["native_vlan"] for link in self._underlay_links if "native_vlan" in link and str(link["native_vlan"]) not in switch_vlans) + {link["native_vlan"] for link in self._underlay_links if "native_vlan" in link and str(link["native_vlan"]) not in switch_vlans}, + ) + vlans.extend( + { + "id": int(peer_uplink_native_vlan), + "name": "NATIVE", + "state": "suspend", + } + for peer_uplink_native_vlan in uplink_native_vlans ) - for peer_uplink_native_vlan in uplink_native_vlans: - vlans.append( - { - "id": int(peer_uplink_native_vlan), - "name": "NATIVE", - "state": "suspend", - } - ) if vlans: return vlans diff --git a/python-avd/pyavd/_errors/__init__.py b/python-avd/pyavd/_errors/__init__.py index 23b79312886..d75b1bdec6c 100644 --- a/python-avd/pyavd/_errors/__init__.py +++ b/python-avd/pyavd/_errors/__init__.py @@ -5,11 +5,11 @@ class AristaAvdError(Exception): - def __init__(self, message="An Error has occurred in an arista.avd plugin"): + def __init__(self, message: str = "An Error has occurred in an arista.avd plugin") -> None: self.message = message super().__init__(self.message) - def _json_path_to_string(self, json_path): + def _json_path_to_string(self, json_path: list) -> str: path = "" for index, elem in enumerate(json_path): if isinstance(elem, int): @@ -27,7 +27,7 @@ class AristaAvdMissingVariableError(AristaAvdError): class AvdSchemaError(AristaAvdError): - def __init__(self, message="Schema Error", error=None): + def __init__(self, message: str = "Schema Error", error: jsonschema.ValidationError | None = None) -> None: if isinstance(error, jsonschema.SchemaError): self.message = f"'Schema Error: {self._json_path_to_string(error.absolute_path)}': {error.message}" else: @@ -36,7 +36,7 @@ def __init__(self, message="Schema Error", error=None): class AvdValidationError(AristaAvdError): - def __init__(self, message: str = "Schema Error", error=None): + def __init__(self, message: str = "Schema Error", error: Exception | None = None) -> None: if isinstance(error, (jsonschema.ValidationError)): self.path = self._json_path_to_string(error.absolute_path) self.message = f"'Validation Error: {self.path}': {error.message}" @@ -45,8 +45,10 @@ def __init__(self, message: str = "Schema Error", error=None): super().__init__(self.message) -class AvdConversionWarning(AristaAvdError): - def __init__(self, message: str = "Data was converted to conform to schema", key=None, oldtype="unknown", newtype="unknown"): +class AvdConversionWarning(AristaAvdError): # noqa: N818 + def __init__( + self, message: str = "Data was converted to conform to schema", key: list | None = None, oldtype: str = "unknown", newtype: str = "unknown" + ) -> None: if key is not None: self.path = self._json_path_to_string(key) self.message = f"'Data Type Converted: {self.path} from '{oldtype}' to '{newtype}'" @@ -55,8 +57,17 @@ def __init__(self, message: str = "Data was converted to conform to schema", key super().__init__(self.message) -class AvdDeprecationWarning(AristaAvdError): - def __init__(self, key, new_key=None, remove_in_version=None, remove_after_date=None, url=None, removed=False): +class AvdDeprecationWarning(AristaAvdError): # noqa: N818 + def __init__( + self, + key: str, + new_key: str | None = None, + remove_in_version: str | None = None, + remove_after_date: str | None = None, + url: str | None = None, + *, + removed: bool = False, + ) -> None: messages = [] self.path = self._json_path_to_string(key) @@ -80,7 +91,7 @@ def __init__(self, key, new_key=None, remove_in_version=None, remove_after_date= class AristaAvdDuplicateDataError(AristaAvdError): - def __init__(self, context: str, context_item_a: str, context_item_b: str): + def __init__(self, context: str, context_item_a: str, context_item_b: str) -> None: self.message = ( f"Found duplicate objects with conflicting data while generating configuration for {context}. {context_item_a} conflicts with {context_item_b}." ) diff --git a/python-avd/pyavd/_schema/avddataconverter.py b/python-avd/pyavd/_schema/avddataconverter.py index 8cfe78a2a02..4d57097b5aa 100644 --- a/python-avd/pyavd/_schema/avddataconverter.py +++ b/python-avd/pyavd/_schema/avddataconverter.py @@ -3,11 +3,11 @@ # that can be found in the LICENSE file. from __future__ import annotations -from typing import TYPE_CHECKING, Generator +from typing import TYPE_CHECKING, Any -from .._errors import AvdConversionWarning, AvdDeprecationWarning -from .._utils import get_all -from ..j2filters import convert_dicts +from pyavd._errors import AvdConversionWarning, AvdDeprecationWarning +from pyavd._utils import get_all +from pyavd.j2filters import convert_dicts SCHEMA_TO_PY_TYPE_MAP = { "str": str, @@ -24,15 +24,15 @@ } if TYPE_CHECKING: + from collections.abc import Generator + from .avdschema import AvdSchema class AvdDataConverter: - """ - AvdDataConverter is used to convert AVD Data Types based on schema options. - """ + """AvdDataConverter is used to convert AVD Data Types based on schema options.""" - def __init__(self, avdschema: AvdSchema): + def __init__(self, avdschema: AvdSchema) -> None: self._avdschema = avdschema # We run through all the regular keys first, to ensure that all data has been converted @@ -44,9 +44,10 @@ def __init__(self, avdschema: AvdSchema): "deprecation": self.deprecation, } - def convert_data(self, data, schema: dict, path: list[str] = None) -> Generator: + def convert_data(self, data: Any, schema: dict, path: list[str] | None = None) -> Generator: """ Perform in-place conversion of data according to the provided schema. + Main entry function which is recursively called from the child functions performing the actual conversion of keys/items. """ if path is None: @@ -60,10 +61,8 @@ def convert_data(self, data, schema: dict, path: list[str] = None) -> Generator: # Converters will do inplace update of data. Any returns will be yielded conversion messages. yield from converter(schema[key], data, schema, path) - def convert_keys(self, keys: dict, data: dict, _, path: list[str]): - """ - This function performs conversion on each key with the relevant subschema - """ + def convert_keys(self, keys: dict, data: dict, _schema: dict, path: list[str]) -> Generator: + """This function performs conversion on each key with the relevant subschema.""" if not isinstance(data, dict): return @@ -74,18 +73,19 @@ def convert_keys(self, keys: dict, data: dict, _, path: list[str]): # Perform type conversion of the data for the child key if required based on "convert_types" if "convert_types" in childschema: - yield from self.convert_types(childschema["convert_types"], data, key, childschema, path + [key]) + yield from self.convert_types(childschema["convert_types"], data, key, childschema, [*path, key]) # Convert to lower case if set in schema and value is a string if childschema.get("convert_to_lower_case") and isinstance(data[key], str): data[key] = data[key].lower() - yield from self.convert_data(data[key], childschema, path + [key]) + yield from self.convert_data(data[key], childschema, [*path, key]) - def convert_dynamic_keys(self, dynamic_keys: dict, data: dict, schema: dict, path: list[str]): + def convert_dynamic_keys(self, dynamic_keys: dict, data: dict, schema: dict, path: list[str]) -> Generator: """ This function resolves "dynamic_keys" by looking in the actual data. - Then calls convert_keys to performs conversion on each resolved key with the relevant subschema + + Then calls convert_keys to performs conversion on each resolved key with the relevant subschema. """ if not isinstance(data, dict): return @@ -100,30 +100,29 @@ def convert_dynamic_keys(self, dynamic_keys: dict, data: dict, schema: dict, pat # Reuse convert_keys to perform the actual conversion on the resolved dynamic keys yield from self.convert_keys(keys, data, schema, path) - def convert_items(self, items: dict, data: list, _, path: list[str]): - """ - This function performs conversion on each item with the items subschema - """ + def convert_items(self, items: dict, data: list, _schema: dict, path: list[str]) -> Generator: + """This function performs conversion on each item with the items subschema.""" if not isinstance(data, list): return for index, item in enumerate(data): # Perform type conversion of the items data if required based on "convert_types" if "convert_types" in items: - yield from self.convert_types(items["convert_types"], data, index, items, path + [index]) + yield from self.convert_types(items["convert_types"], data, index, items, [*path, index]) # Convert to lower case if set in schema and item is a string if items.get("convert_to_lower_case") and isinstance(item, str): data[index] = item.lower() # Dive in to child items/schema - yield from self.convert_data(item, items, path + [index]) + yield from self.convert_data(item, items, [*path, index]) - def convert_types(self, convert_types: list, data: dict | list, index: str | int, schema: dict, path: list[str]): + def convert_types(self, convert_types: list, data: dict | list, index: str | int, schema: dict, path: list[str]) -> Generator: """ This function performs type conversion if necessary on a single data instance. + It is invoked for child keys during "keys" conversion and for child items during - "items" conversion + "items" conversion. "data" is either the parent dict or the parent list. "index" is either the key of the parent dict or the index of the parent list. @@ -140,10 +139,13 @@ def convert_types(self, convert_types: list, data: dict | list, index: str | int value = data[index] # For simple conversions, skip conversion if the value is of the correct type - if schema_type in SIMPLE_CONVERTERS and isinstance(value, SCHEMA_TO_PY_TYPE_MAP.get(schema_type)): - # Avoid corner case where we want to convert bool to int. Bool is a subclass of Int so it passes the check above. - if not (schema_type == "int" and isinstance(value, bool)): - return + # Avoid corner case where we want to convert bool to int. Bool is a subclass of Int so it passes the check above. + if ( + schema_type in SIMPLE_CONVERTERS + and isinstance(value, SCHEMA_TO_PY_TYPE_MAP.get(schema_type)) + and not (schema_type == "int" and isinstance(value, bool)) + ): + return for convert_type in convert_types: if isinstance(value, SCHEMA_TO_PY_TYPE_MAP.get(convert_type)): @@ -191,9 +193,10 @@ def convert_types(self, convert_types: list, data: dict | list, index: str | int yield AvdConversionWarning(key=path, oldtype=convert_type, newtype=schema_type) - def deprecation(self, deprecation: dict, _, __, path: list): + def deprecation(self, deprecation: dict, _data: Any, _schema: dict, path: list) -> Generator: """ - deprecation: + deprecation. + warning: bool, default = True new_key: str removed: bool @@ -201,7 +204,6 @@ def deprecation(self, deprecation: dict, _, __, path: list): remove_after_date: str url: str """ - if not deprecation.get("warning", True): return diff --git a/python-avd/pyavd/_schema/avdschema.py b/python-avd/pyavd/_schema/avdschema.py index c6294790513..b235dbfb910 100644 --- a/python-avd/pyavd/_schema/avdschema.py +++ b/python-avd/pyavd/_schema/avdschema.py @@ -3,14 +3,20 @@ # that can be found in the LICENSE file. from __future__ import annotations +from typing import TYPE_CHECKING, Any, NoReturn + import jsonschema from deepmerge import always_merger -from .._errors import AristaAvdError, AvdSchemaError, AvdValidationError +from pyavd._errors import AristaAvdError, AvdSchemaError, AvdValidationError + from .avddataconverter import AvdDataConverter from .avdvalidator import AvdValidator from .store import create_store +if TYPE_CHECKING: + from collections.abc import Generator + DEFAULT_SCHEMA = { "type": "dict", "allow_other_keys": True, @@ -20,8 +26,9 @@ class AvdSchema: """ AvdSchema takes either a schema as dict or the ID of a builtin schema. + If none of them are set, a default "dummy" schema will be loaded. - schema -> schema_id -> DEFAULT_SCHEMA + schema -> schema_id -> DEFAULT_SCHEMA. Parameters ---------- @@ -33,21 +40,22 @@ class AvdSchema: Force loading the YAML schema files into the store. By default schemas are loaded from pickled files. """ - def __init__(self, schema: dict = None, schema_id: str = None, load_store_from_yaml=False): + def __init__(self, schema: dict | None = None, schema_id: str | None = None, load_store_from_yaml: bool = False) -> None: self.store = create_store(load_from_yaml=load_store_from_yaml) self._schema_validator = jsonschema.Draft7Validator(self.store["avd_meta_schema"]) self.load_schema(schema, schema_id) - def validate_schema(self, schema: dict): + def validate_schema(self, schema: dict) -> Generator: validation_errors = self._schema_validator.iter_errors(schema) for validation_error in validation_errors: yield self._error_handler(validation_error) - def load_schema(self, schema: dict = None, schema_id: str = None): + def load_schema(self, schema: dict | None = None, schema_id: str | None = None) -> None: """ Load schema from dict or the ID of a builtin schema. + If none of them are set, a default "dummy" schema will be loaded. - schema -> schema_id -> DEFAULT_SCHEMA + schema -> schema_id -> DEFAULT_SCHEMA. Parameters ---------- @@ -63,7 +71,8 @@ def load_schema(self, schema: dict = None, schema_id: str = None): raise validation_error elif schema_id: if schema_id not in self.store: - raise AristaAvdError(f"Schema id {schema_id} not found in store. Must be one of {self.store.keys()}") + msg = f"Schema id {schema_id} not found in store. Must be one of {self.store.keys()}" + raise AristaAvdError(msg) schema = self.store[schema_id] else: @@ -74,16 +83,17 @@ def load_schema(self, schema: dict = None, schema_id: str = None): self._validator = AvdValidator(schema, self.store) self._dataconverter = AvdDataConverter(self) except Exception as e: - raise AristaAvdError("An error occurred during creation of the validator") from e + msg = "An error occurred during creation of the validator" + raise AristaAvdError(msg) from e - def extend_schema(self, schema: dict): + def extend_schema(self, schema: dict) -> NoReturn: for validation_error in self.validate_schema(schema): raise validation_error always_merger.merge(self._schema, schema) for validation_error in self.validate_schema(self._schema): raise validation_error - def validate(self, data): + def validate(self, data: Any) -> Generator: validation_errors = self._validator.iter_errors(data) try: @@ -92,7 +102,7 @@ def validate(self, data): except Exception as error: # pylint: disable=broad-exception-caught yield self._error_handler(error) - def convert(self, data): + def convert(self, data: Any) -> Generator: conversion_errors = self._dataconverter.convert_data(data, self._schema) try: @@ -101,7 +111,7 @@ def convert(self, data): except Exception as error: # pylint: disable=broad-exception-caught yield self._error_handler(error) - def _error_handler(self, error: Exception): + def _error_handler(self, error: Exception) -> Exception: if isinstance(error, AristaAvdError): return error if isinstance(error, jsonschema.ValidationError): @@ -110,11 +120,11 @@ def _error_handler(self, error: Exception): return AvdSchemaError(error=error) return error - def subschema(self, datapath: list): + def subschema(self, datapath: list) -> dict: """ Takes datapath elements as a list and returns the subschema for this datapath. - Example + Example: ------- Data model: a: @@ -156,23 +166,22 @@ def subschema(self, datapath: list): subschema(['a'], ) >> raises AvdSchemaError """ - if not isinstance(datapath, list): - raise AvdSchemaError(f"The datapath argument must be a list. Got {type(datapath)}") + msg = f"The datapath argument must be a list. Got {type(datapath)}" + raise AvdSchemaError(msg) schema = self._schema - def recursive_function(datapath, schema): - """ - Walk through schema following the datapath - """ + def recursive_function(datapath: list, schema: dict) -> dict: + """Walk through schema following the datapath.""" if len(datapath) == 0: return schema # More items in datapath, so we run recursively with recursive_function key = datapath[0] if not isinstance(key, str): - raise AvdSchemaError(f"All datapath items must be strings. Got {type(key)}") + msg = f"All datapath items must be strings. Got {type(key)}" + raise AvdSchemaError(msg) if schema["type"] == "dict": if key in schema.get("keys", []): @@ -184,6 +193,7 @@ def recursive_function(datapath, schema): return recursive_function(datapath[1:], schema["items"]["keys"][key]) # Falling through here in case the schema is not covering the requested datapath - raise AvdSchemaError(f"The datapath '{datapath}' could not be found in the schema") + msg = f"The datapath '{datapath}' could not be found in the schema" + raise AvdSchemaError(msg) return recursive_function(datapath, schema) diff --git a/python-avd/pyavd/_schema/avdvalidator.py b/python-avd/pyavd/_schema/avdvalidator.py index 0fd61a4de95..a75c1474055 100644 --- a/python-avd/pyavd/_schema/avdvalidator.py +++ b/python-avd/pyavd/_schema/avdvalidator.py @@ -2,21 +2,23 @@ # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. from collections import ChainMap +from collections.abc import Generator +from typing import Any, NoReturn import jsonschema import jsonschema._types import jsonschema.validators -from .._utils import get_all, get_all_with_path, get_indices_of_duplicate_items - # Special handling of jsonschema <4.18 vs. >=4.18 try: import jsonschema._validators as jsonschema_validators except ImportError: import jsonschema._keywords as jsonschema_validators +from pyavd._utils import get_all, get_all_with_path, get_indices_of_duplicate_items + -def _unique_keys_validator(validator, unique_keys: list[str], instance: list, _schema: dict): +def _unique_keys_validator(validator: object, unique_keys: list[str], instance: list, _schema: dict) -> Generator: if not validator.is_type(unique_keys, "list"): return @@ -32,7 +34,7 @@ def _unique_keys_validator(validator, unique_keys: list[str], instance: list, _s continue # Separate all paths and values - paths, values = zip(*paths_and_values) + paths, values = zip(*paths_and_values, strict=False) key = unique_key.split(".")[-1] is_nested_key = unique_key != key @@ -47,7 +49,7 @@ def _unique_keys_validator(validator, unique_keys: list[str], instance: list, _s ) -def _primary_key_validator(validator, primary_key: str, instance: list, schema: dict): +def _primary_key_validator(validator: object, primary_key: str, instance: list, schema: dict) -> Generator: if not validator.is_type(primary_key, "str"): return @@ -65,15 +67,16 @@ def _primary_key_validator(validator, primary_key: str, instance: list, schema: yield from _unique_keys_validator(validator, [primary_key], instance, schema) -def _keys_validator(validator, keys: dict, instance: dict, schema: dict): +def _keys_validator(validator: object, keys: dict, instance: dict, schema: dict) -> Generator: """ - This function validates each key with the relevant subschema + This function validates each key with the relevant subschema. + It also includes various child key validations, which can only be implemented with access to the parent "keys" instance. - Expand dynamic_keys - Validate "allow_other_keys" (default is false) - Validate "required" under child keys - - Expand "dynamic_valid_values" under child keys (don't perform validation) + - Expand "dynamic_valid_values" under child keys (don't perform validation). """ if not validator.is_type(instance, "object"): return @@ -119,40 +122,38 @@ def _keys_validator(validator, keys: dict, instance: dict, schema: dict): ) -def _dynamic_keys_validator(validator, _dynamic_keys: dict, instance: dict, schema: dict): - """ - This function triggers the regular "keys" validator in case only dynamic_keys is set. - """ +def _dynamic_keys_validator(validator: object, _dynamic_keys: dict, instance: dict, schema: dict) -> Generator: + """This function triggers the regular "keys" validator in case only dynamic_keys is set.""" if "keys" not in schema: yield from _keys_validator(validator, {}, instance, schema) -def _ref_validator(validator, ref, instance: dict, schema: dict): - raise NotImplementedError("$ref must be resolved before using AvdValidator") +def _ref_validator(_validator: object, _ref: str, _instance: dict, _schema: dict) -> NoReturn: + msg = "$ref must be resolved before using AvdValidator" + raise NotImplementedError(msg) -def _valid_values_validator(_validator, valid_values, instance, _schema: dict): - """ - This function validates if the instance conforms to the "valid_values" - """ +def _valid_values_validator(_validator: object, valid_values: list, instance: Any, _schema: dict) -> Generator: + """This function validates if the instance conforms to the "valid_values".""" if instance not in valid_values: yield jsonschema.ValidationError(f"'{instance}' is not one of {valid_values}") -def _is_dict(_validator, instance): +def _is_dict(_validator: object, instance: Any) -> bool: return isinstance(instance, (dict, ChainMap)) class AvdValidator: - def __new__(cls, schema: dict, store: dict): + def __new__(cls, schema: dict, store: dict) -> object: """ AvdSchemaValidator is used to validate AVD Data. + It uses a combination of our own validators and builtin jsonschema validators mapped to our own keywords. We have extra type checkers not covered by the AVD_META_SCHEMA (array, boolean etc) since the same TypeChecker is used by the validators themselves. """ - ValidatorClass = jsonschema.validators.create( + validator_cls = jsonschema.validators.create( meta_schema=store["avd_meta_schema"], validators={ "$ref": _ref_validator, @@ -186,8 +187,7 @@ def __new__(cls, schema: dict, store: dict): "bool": jsonschema._types.is_bool, "list": jsonschema._types.is_array, "int": jsonschema._types.is_integer, - } + }, ), - # version="0.1", ) - return ValidatorClass(schema) + return validator_cls(schema) diff --git a/python-avd/pyavd/_schema/store.py b/python-avd/pyavd/_schema/store.py index ca6a1f75c0d..e717ecd7bee 100644 --- a/python-avd/pyavd/_schema/store.py +++ b/python-avd/pyavd/_schema/store.py @@ -2,18 +2,20 @@ # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. from functools import lru_cache +from pathlib import Path from pickle import load from .constants import PICKLED_SCHEMAS @lru_cache -def create_store(load_from_yaml=False): +def create_store(*, load_from_yaml: bool = False) -> dict: if load_from_yaml: - raise NotImplementedError("'load_from_yaml' not supported for create_store under PyAVD") + msg = "'load_from_yaml' not supported for create_store under PyAVD" + raise NotImplementedError(msg) store = {} - for id, schema_file in PICKLED_SCHEMAS.items(): - with open(schema_file, "rb") as file: - store[id] = load(file) + for schema_id, schema_file in PICKLED_SCHEMAS.items(): + with Path(schema_file).open("rb") as file: + store[schema_id] = load(file) # noqa: S301 return store diff --git a/python-avd/pyavd/_utils/append_if_not_duplicate.py b/python-avd/pyavd/_utils/append_if_not_duplicate.py index de75dfd6552..c6619700bea 100644 --- a/python-avd/pyavd/_utils/append_if_not_duplicate.py +++ b/python-avd/pyavd/_utils/append_if_not_duplicate.py @@ -3,7 +3,8 @@ # that can be found in the LICENSE file. from __future__ import annotations -from .._errors import AristaAvdDuplicateDataError +from pyavd._errors import AristaAvdDuplicateDataError + from .compare_dicts import compare_dicts from .get import get from .get_item import get_item @@ -44,7 +45,7 @@ def append_if_not_duplicate( Often if is relevant to ignore the 'tenant' key so duplicate configs across multiple tenants can be ignored since tenant is not part of the output config. - Raises + Raises: ------ AristaAvdDuplicateDataError If a duplicate is found. diff --git a/python-avd/pyavd/_utils/batch.py b/python-avd/pyavd/_utils/batch.py index 6978cbdc4e0..6429ebc0a9c 100644 --- a/python-avd/pyavd/_utils/batch.py +++ b/python-avd/pyavd/_utils/batch.py @@ -7,13 +7,11 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Generator, Iterable + from collections.abc import Generator, Iterable def batch(iterable: Iterable, size: int) -> Generator[Iterable]: - """ - Returns a Generator of lists containing 'size' items. The final list may be shorter. - """ + """Returns a Generator of lists containing 'size' items. The final list may be shorter.""" iterator = iter(iterable) while batch := list(islice(iterator, size)): yield batch diff --git a/python-avd/pyavd/_utils/compare_dicts.py b/python-avd/pyavd/_utils/compare_dicts.py index a73a76b5242..037282befd4 100644 --- a/python-avd/pyavd/_utils/compare_dicts.py +++ b/python-avd/pyavd/_utils/compare_dicts.py @@ -8,7 +8,7 @@ def compare_dicts(dict1: dict, dict2: dict, ignore_keys: set[str] | None = None) """ Efficient comparison of dicts, where we can ignore certain keys. - Returns + Returns: ------- bool Do dict1 and dict2 match diff --git a/python-avd/pyavd/_utils/default.py b/python-avd/pyavd/_utils/default.py index b9c660d9b64..994c52989f3 100644 --- a/python-avd/pyavd/_utils/default.py +++ b/python-avd/pyavd/_utils/default.py @@ -1,9 +1,12 @@ # Copyright (c) 2023-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -def default(*values): +from typing import Any + + +def default(*values: Any) -> Any: """ - Accepts any number of arguments. Return the first value which is not None + Accepts any number of arguments. Return the first value which is not None. Last resort is to return None. @@ -12,12 +15,11 @@ def default(*values): *values : any One or more values to test - Returns + Returns: ------- any First value which is not None """ - for value in values: if value is not None: return value diff --git a/python-avd/pyavd/_utils/get.py b/python-avd/pyavd/_utils/get.py index 457e60bfa97..ba5fae75d55 100644 --- a/python-avd/pyavd/_utils/get.py +++ b/python-avd/pyavd/_utils/get.py @@ -1,10 +1,12 @@ # Copyright (c) 2023-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -from .._errors import AristaAvdMissingVariableError +from typing import Any +from pyavd._errors import AristaAvdMissingVariableError -def get(dictionary, key, default=None, required=False, org_key=None, separator="."): + +def get(dictionary: dict, key: str, default: Any = None, required: bool = False, org_key: str | None = None, separator: str = ".") -> Any: """ Get a value from a dictionary or nested dictionaries. @@ -27,17 +29,16 @@ def get(dictionary, key, default=None, required=False, org_key=None, separator=" String to use as the separator parameter in the split function. Useful in cases when the key can contain variables with "." inside (e.g. hostnames) - Returns + Returns: ------- any Value or default value - Raises + Raises: ------ AristaAvdMissingVariableError If the key is not found and required == True """ - if org_key is None: org_key = key keys = str(key).split(separator) @@ -53,7 +54,9 @@ def get(dictionary, key, default=None, required=False, org_key=None, separator=" return value -def get_v2(dict_or_object, key_or_attribute, default=None, required=False, org_key=None, separator="."): +def get_v2( + dict_or_object: dict | object, key_or_attribute: str, default: Any = None, required: bool = False, org_key: str | None = None, separator: str = "." +) -> Any: """ Get a value from a dictionary or object or nested dictionaries and objects. @@ -76,24 +79,20 @@ def get_v2(dict_or_object, key_or_attribute, default=None, required=False, org_k String to use as the separator parameter in the split function. Useful in cases when the key can contain variables with "." inside (e.g. hostnames) - Returns + Returns: ------- any Value or default value - Raises + Raises: ------ AristaAvdMissingVariableError If the key is not found and required == True """ - if org_key is None: org_key = key_or_attribute keys = str(key_or_attribute).split(separator) - if callable(getattr(dict_or_object, "get", None)): - value = dict_or_object.get(keys[0]) - else: - value = getattr(dict_or_object, keys[0], None) + value = dict_or_object.get(keys[0]) if callable(getattr(dict_or_object, "get", None)) else getattr(dict_or_object, keys[0], None) if value is None: if required is True: diff --git a/python-avd/pyavd/_utils/get_all.py b/python-avd/pyavd/_utils/get_all.py index d0e882b6a17..3cc85dd7d55 100644 --- a/python-avd/pyavd/_utils/get_all.py +++ b/python-avd/pyavd/_utils/get_all.py @@ -3,12 +3,15 @@ # that can be found in the LICENSE file. from __future__ import annotations -from typing import Any, Generator +from typing import TYPE_CHECKING, Any -from .._errors import AristaAvdMissingVariableError +from pyavd._errors import AristaAvdMissingVariableError +if TYPE_CHECKING: + from collections.abc import Generator -def get_all(data, path: str, required: bool = False, org_path=None): + +def get_all(data: Any, path: str, required: bool = False, org_path: str | None = None) -> list: """ Get all values from data matching a data path. @@ -26,17 +29,16 @@ def get_all(data, path: str, required: bool = False, org_path=None): org_path : str Internal variable used for raising exception with the full path even when called recursively - Returns + Returns: ------- list [ any ] List of values matching data path or empty list if no matches are found. - Raises + Raises: ------ AristaAvdMissingVariableError If the path is not found and required == True """ - if org_path is None: org_path = path @@ -65,7 +67,7 @@ def get_all(data, path: str, required: bool = False, org_path=None): return [] -def get_all_with_path(data, path: str, _current_path: list[str | int] | None = None) -> Generator[tuple[list[str | int], Any], None, None]: +def get_all_with_path(data: Any, path: str, _current_path: list[str | int] | None = None) -> Generator[tuple[list[str | int], Any], None, None]: """ Get all values from data matching a data path including the path they were found in. @@ -81,7 +83,7 @@ def get_all_with_path(data, path: str, _current_path: list[str | int] | None = N _current_path : list[str|int] Internal variable used for tracking the full path even when called recursively - Returns + Returns: ------- Generator yielding Tuples (, ) for all values from data matching a data path. diff --git a/python-avd/pyavd/_utils/get_indices_of_duplicate_items.py b/python-avd/pyavd/_utils/get_indices_of_duplicate_items.py index 15fcaa82cce..45f7285c0be 100644 --- a/python-avd/pyavd/_utils/get_indices_of_duplicate_items.py +++ b/python-avd/pyavd/_utils/get_indices_of_duplicate_items.py @@ -2,13 +2,12 @@ # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. from collections import defaultdict -from typing import Any, Generator +from collections.abc import Generator +from typing import Any def get_indices_of_duplicate_items(values: list) -> Generator[tuple[Any, list[int]], None, None]: - """ - Returns a Generator of Tuples with (, []) - """ + """Returns a Generator of Tuples with (, []).""" counters = defaultdict(list) for index, item in enumerate(values): counters[item].append(index) diff --git a/python-avd/pyavd/_utils/get_ip_from_pool.py b/python-avd/pyavd/_utils/get_ip_from_pool.py index 9f66413ad54..9f424f06398 100644 --- a/python-avd/pyavd/_utils/get_ip_from_pool.py +++ b/python-avd/pyavd/_utils/get_ip_from_pool.py @@ -3,7 +3,7 @@ # that can be found in the LICENSE file. import ipaddress -from .._errors import AristaAvdError +from pyavd._errors import AristaAvdError def get_ip_from_pool(pool: str, prefixlen: int, subnet_offset: int, ip_offset: int) -> str: @@ -19,17 +19,18 @@ def get_ip_from_pool(pool: str, prefixlen: int, subnet_offset: int, ip_offset: i Returns: IP address without mask """ - pool_network = ipaddress.ip_network(pool, strict=False) prefixlen_diff = prefixlen - pool_network.prefixlen try: subnet_size = (int(pool_network.hostmask) + 1) >> prefixlen_diff except ValueError as e: - raise AristaAvdError(f"Prefix length {prefixlen} is smaller than pool network prefix length {pool_network.prefixlen}") from e + msg = f"Prefix length {prefixlen} is smaller than pool network prefix length {pool_network.prefixlen}" + raise AristaAvdError(msg) from e if (subnet_offset + 1) * subnet_size > pool_network.num_addresses: - raise AristaAvdError(f"Unable to get {subnet_offset + 1} /{prefixlen} subnets from pool {pool}") + msg = f"Unable to get {subnet_offset + 1} /{prefixlen} subnets from pool {pool}" + raise AristaAvdError(msg) subnet = ipaddress.ip_network((int(pool_network.network_address) + subnet_offset * subnet_size, prefixlen)) @@ -38,13 +39,14 @@ def get_ip_from_pool(pool: str, prefixlen: int, subnet_offset: int, ip_offset: i # This is a regular subnet. Skip the network address and raise if we hit the broadcast address. # >= because ip_offset is 0-based. if ip_offset >= (subnet_size - 2): - raise IndexError + raise IndexError # noqa: TRY301 ip = subnet[ip_offset + 1] else: # This is a linknet (/31 or /127) or a single IP (/32 or /128) ip = subnet[ip_offset] except IndexError as e: - raise AristaAvdError(f"Unable to get {ip_offset + 1} hosts in subnet {subnet} taken from pool {pool}") from e + msg = f"Unable to get {ip_offset + 1} hosts in subnet {subnet} taken from pool {pool}" + raise AristaAvdError(msg) from e return str(ip) diff --git a/python-avd/pyavd/_utils/get_item.py b/python-avd/pyavd/_utils/get_item.py index b63ef2831c0..7ac26439cd6 100644 --- a/python-avd/pyavd/_utils/get_item.py +++ b/python-avd/pyavd/_utils/get_item.py @@ -1,21 +1,24 @@ # Copyright (c) 2023-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -from .._errors import AristaAvdMissingVariableError +from typing import Any + +from pyavd._errors import AristaAvdMissingVariableError def get_item( list_of_dicts: list, - key, - value, - default=None, - required=False, - case_sensitive=False, # pylint: disable=unused-argument - var_name=None, - custom_error_msg=None, -): + key: Any, + value: Any, + default: Any = None, + *, + required: bool = False, + _case_sensitive: bool = False, + var_name: str | None = None, + custom_error_msg: str | None = None, +) -> Any: """ - Get one dictionary from a list of dictionaries by matching the given key and value + Get one dictionary from a list of dictionaries by matching the given key and value. Returns the supplied default value or None if there is no match and "required" is False. @@ -40,17 +43,16 @@ def get_item( custom_error_msg : str Custom error message to raise when required is True and the value is not found - Returns + Returns: ------- any Dict or default value - Raises + Raises: ------ AristaAvdMissingVariableError If the key and value is not found and "required" == True """ - if var_name is None: var_name = key diff --git a/python-avd/pyavd/_utils/groupby.py b/python-avd/pyavd/_utils/groupby.py index 38e215a2761..e9c2e965eec 100644 --- a/python-avd/pyavd/_utils/groupby.py +++ b/python-avd/pyavd/_utils/groupby.py @@ -1,15 +1,15 @@ # Copyright (c) 2023-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. +from collections.abc import Iterator from itertools import groupby as itergroupby +from typing import Any -def groupby(list_of_dictionaries: list, key: str): - """ - Group list of dictionaries by key - """ +def groupby(list_of_dictionaries: list, key: str) -> Iterator: + """Group list of dictionaries by key.""" - def getkey(dictionary: dict): + def getkey(dictionary: dict) -> Any: return dictionary.get(key) sorted_list = sorted(list_of_dictionaries, key=getkey) diff --git a/python-avd/pyavd/_utils/load_python_class.py b/python-avd/pyavd/_utils/load_python_class.py index 0ee2e8c6469..4cf6a055e0b 100644 --- a/python-avd/pyavd/_utils/load_python_class.py +++ b/python-avd/pyavd/_utils/load_python_class.py @@ -5,13 +5,12 @@ import importlib -from .._errors import AristaAvdError, AristaAvdMissingVariableError +from pyavd._errors import AristaAvdError, AristaAvdMissingVariableError def load_python_class(module_path: str, class_name: str, parent_class: type | None = None) -> type: """ - Load Python Class via importlib - + Load Python Class via importlib. Parameters ---------- @@ -22,12 +21,12 @@ def load_python_class(module_path: str, class_name: str, parent_class: type | No parent_class : type Class from which the imported class must inherit if present - Returns + Returns: ------- type The loaded Class (and not an instance of the Class) - Raises + Raises: ------ AristaAvdMissingVariableError If module_path or class_name are not present @@ -37,9 +36,11 @@ def load_python_class(module_path: str, class_name: str, parent_class: type | No If the loaded Class is not inheriting from the optional parent_class """ if not module_path: - raise AristaAvdMissingVariableError("Cannot load a python class without the module_path set.") + msg = "Cannot load a python class without the module_path set." + raise AristaAvdMissingVariableError(msg) if not class_name: - raise AristaAvdMissingVariableError("Cannot load a python class without the class_name set.") + msg = "Cannot load a python class without the class_name set." + raise AristaAvdMissingVariableError(msg) try: cls = getattr(importlib.import_module(module_path), class_name) @@ -47,6 +48,7 @@ def load_python_class(module_path: str, class_name: str, parent_class: type | No raise AristaAvdError(imp_exc) from imp_exc if parent_class is not None and not issubclass(cls, parent_class): - raise AristaAvdError(f"{cls} is not a subclass of {parent_class} class") + msg = f"{cls} is not a subclass of {parent_class} class" + raise AristaAvdError(msg) return cls diff --git a/python-avd/pyavd/_utils/merge/__init__.py b/python-avd/pyavd/_utils/merge/__init__.py index 9dcf1467bb0..b49d5fac8ec 100644 --- a/python-avd/pyavd/_utils/merge/__init__.py +++ b/python-avd/pyavd/_utils/merge/__init__.py @@ -4,32 +4,33 @@ from __future__ import annotations from copy import deepcopy -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from deepmerge import Merger from .mergeonschema import MergeOnSchema if TYPE_CHECKING: - from ..._schema.avdschema import AvdSchema + from pyavd._schema.avdschema import AvdSchema -def _strategy_keep(_config, _path, base, nxt): - """prefer base, otherwise nxt""" +def _strategy_keep(_config: object, _path: list, base: Any, nxt: Any) -> Any: + """Prefer base, otherwise nxt.""" if base is not None: return base return nxt -def _strategy_prepend_unique(_config, _path, base, nxt): - """prepend nxt items without duplicates in base to base.""" +def _strategy_prepend_unique(_config: object, _path: list, base: list, nxt: list) -> list: + """Prepend nxt items without duplicates in base to base.""" nxt_as_set = set(nxt) return nxt + [n for n in base if n not in nxt_as_set] -def _strategy_must_match(_config, path, base, nxt): +def _strategy_must_match(_config: object, path: list, base: Any, nxt: Any) -> Any: if base != nxt: - raise ValueError(f"Values of {'.'.join(path)} do not match: {base} != {nxt}") + msg = f"Values of {'.'.join(path)} do not match: {base} != {nxt}" + raise ValueError(msg) return base @@ -43,9 +44,17 @@ def _strategy_must_match(_config, path, base, nxt): } -def merge(base, *nxt_list, recursive=True, list_merge="append", same_key_strategy="override", destructive_merge=True, schema: AvdSchema = None): +def merge( + base: Any, + *nxt_list: list[Any], + recursive: bool = True, + list_merge: str = "append", + same_key_strategy: str = "override", + destructive_merge: bool = True, + schema: AvdSchema = None, +) -> Any: """ - Merge two or more data sets using deepmerge + Merge two or more data sets using deepmerge. Parameters ---------- @@ -71,12 +80,12 @@ def merge(base, *nxt_list, recursive=True, list_merge="append", same_key_strateg schema : AvdSchema, optional An instance of AvdSchema can be passed to merge, to allow merging lists of dictionaries using the "primary_key" defined in the schema. """ - if not destructive_merge: base = deepcopy(base) if list_merge not in MAP_ANSIBLE_LIST_MERGE_TO_DEEPMERGE_LIST_STRATEGY: - raise ValueError(f"merge: 'list_merge' argument can only be equal to one of {list(MAP_ANSIBLE_LIST_MERGE_TO_DEEPMERGE_LIST_STRATEGY.keys())}") + msg = f"merge: 'list_merge' argument can only be equal to one of {list(MAP_ANSIBLE_LIST_MERGE_TO_DEEPMERGE_LIST_STRATEGY.keys())}" + raise ValueError(msg) list_strategies = [MAP_ANSIBLE_LIST_MERGE_TO_DEEPMERGE_LIST_STRATEGY.get(list_merge, "append")] @@ -104,11 +113,12 @@ def merge(base, *nxt_list, recursive=True, list_merge="append", same_key_strateg if isinstance(nxt, list): for nxt_item in nxt: if not destructive_merge: - nxt_item = deepcopy(nxt_item) - merger.merge(base, nxt_item) + merger.merge(base, deepcopy(nxt_item)) + else: + merger.merge(base, nxt_item) + elif not destructive_merge: + merger.merge(base, deepcopy(nxt)) else: - if not destructive_merge: - nxt = deepcopy(nxt) merger.merge(base, nxt) return base diff --git a/python-avd/pyavd/_utils/merge/mergeonschema.py b/python-avd/pyavd/_utils/merge/mergeonschema.py index c8c884f6af3..f46b8724814 100644 --- a/python-avd/pyavd/_utils/merge/mergeonschema.py +++ b/python-avd/pyavd/_utils/merge/mergeonschema.py @@ -13,19 +13,19 @@ class MergeOnSchema: """ - MergeOnSchema provides the method "strategy" to be used as - list merge strategy with the deepmerge library. + MergeOnSchema provides the method "strategy" to be used as list merge strategy with the deepmerge library. The class is needed to allow a schema to be passed along to the method. """ - def __init__(self, schema: AvdSchema = None): + def __init__(self, schema: AvdSchema = None) -> None: self.schema = schema - def strategy(self, config, path: list, base: list, nxt: list): + def strategy(self, config: object, path: list, base: list, nxt: list) -> list: """ - The argument "config" should be an instance of deepmerge.Merger, - but Ansible sanity test breaks type hinting with imported libs + Custom strategy to merge lists on schema primary key. + + The argument "config" should be an instance of deepmerge.Merger, but Ansible sanity test breaks type hinting with imported libs. """ # Skip if no schema is supplied if not self.schema: @@ -68,10 +68,8 @@ def strategy(self, config, path: list, base: list, nxt: list): base[base_index] = config.value_strategy(path, base_item, nxt_item) except Exception as e: - raise RuntimeError( - f"An issue occurred while trying to do schema-based deepmerge for the schema path {path} using primary key '{primary_key}'" - ) from e - + msg = f"An issue occurred while trying to do schema-based deepmerge for the schema path {path} using primary key '{primary_key}'" + raise RuntimeError(msg) from e # If all nxt items got merged, we can just return the updated base. if len(merged_nxt_indexes) == len(nxt): return base @@ -84,11 +82,9 @@ def strategy(self, config, path: list, base: list, nxt: list): del nxt[merged_nxt_index] except Exception as e: - raise RuntimeError( + msg = ( f"An issue occurred after schema-based deepmerge for the schema path {path} using primary key '{primary_key}', " f"while preparing remaining items with to be merged with regular strategies. Merged indexes were {merged_nxt_indexes}" - ) from e - - # Since we did inplace updates of both nxt and base, we return STRATEGY_END - # so deepmerge will run the next strategy on the remaining nxt items. + ) + raise RuntimeError(msg) from e return STRATEGY_END diff --git a/python-avd/pyavd/_utils/password_utils/password.py b/python-avd/pyavd/_utils/password_utils/password.py index 45f1e291a97..c2919ab4c5f 100644 --- a/python-avd/pyavd/_utils/password_utils/password.py +++ b/python-avd/pyavd/_utils/password_utils/password.py @@ -1,17 +1,17 @@ # Copyright (c) 2023-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -""" -Used by Encrypt / Decrypt filters -""" +"""Used by Encrypt / Decrypt filters.""" + from __future__ import annotations import random +from typing import Any from .password_utils import cbc_decrypt, cbc_encrypt -def _validate_password_and_key(password: str, key: str) -> None: +def _validate_password_and_key(password: Any, key: str) -> None: """ Validates the password and key values. @@ -24,13 +24,16 @@ def _validate_password_and_key(password: str, key: str) -> None: TypeError: If the password is not of type `str`. """ if not key: - raise ValueError("Key is required for encryption") + msg = "Key is required for encryption" + raise ValueError(msg) if not password: - raise ValueError("Password is required for encryption") + msg = "Password is required for encryption" + raise ValueError(msg) if not isinstance(password, str): - raise TypeError(f"Password MUST be of type 'str' but is of type {type(password)}") + msg = f"Password MUST be of type 'str' but is of type {type(password)}" + raise TypeError(msg) ############## @@ -83,13 +86,14 @@ def ospf_simple_decrypt(password: str, key: str) -> str: try: return cbc_decrypt(key_b, data).decode() except Exception as exc: - raise ValueError("OSPF password decryption failed - check the input parameters") from exc + msg = "OSPF password decryption failed - check the input parameters" + raise ValueError(msg) from exc OSPF_MESSAGE_DIGEST_HASH_ALGORITHMS = ["md5", "sha1", "sha256", "sha384", "sha512"] -def ospf_message_digest_encrypt(password: str, key: str, hash_algorithm: str = None, key_id: str = None) -> str: +def ospf_message_digest_encrypt(password: str, key: str, hash_algorithm: str | None = None, key_id: str | None = None) -> str: """ Encrypt a password for Message Digest Keys. @@ -110,9 +114,11 @@ def ospf_message_digest_encrypt(password: str, key: str, hash_algorithm: str = N """ _validate_password_and_key(password, key) if hash_algorithm is None or key_id is None: - raise ValueError("For OSPF message digest keys, both hash_algorithm and key_id are required") + msg = "For OSPF message digest keys, both hash_algorithm and key_id are required" + raise ValueError(msg) if hash_algorithm not in OSPF_MESSAGE_DIGEST_HASH_ALGORITHMS: - raise ValueError(f"For OSPF message digest keys, `hash_algorithm` must be in {OSPF_MESSAGE_DIGEST_HASH_ALGORITHMS}") + msg = f"For OSPF message digest keys, `hash_algorithm` must be in {OSPF_MESSAGE_DIGEST_HASH_ALGORITHMS}" + raise ValueError(msg) data = bytes(password, encoding="UTF-8") key_b = bytes(f"{key}_{hash_algorithm}Key_{key_id}", encoding="UTF-8") @@ -120,7 +126,7 @@ def ospf_message_digest_encrypt(password: str, key: str, hash_algorithm: str = N return cbc_encrypt(key_b, data).decode() -def ospf_message_digest_decrypt(password: str, key: str, hash_algorithm: str = None, key_id: str = None) -> str: +def ospf_message_digest_decrypt(password: str, key: str, hash_algorithm: str | None = None, key_id: str | None = None) -> str: """ Decrypt a password for Message Digest Keys. @@ -142,9 +148,11 @@ def ospf_message_digest_decrypt(password: str, key: str, hash_algorithm: str = N """ _validate_password_and_key(password, key) if hash_algorithm is None or key_id is None: - raise ValueError("For OSPF message digest keys, both hash_algorithm and key_id are required") + msg = "For OSPF message digest keys, both hash_algorithm and key_id are required" + raise ValueError(msg) if hash_algorithm not in OSPF_MESSAGE_DIGEST_HASH_ALGORITHMS: - raise ValueError(f"For OSPF message digest keys, `hash_algorithm` must be in {OSPF_MESSAGE_DIGEST_HASH_ALGORITHMS}") + msg = f"For OSPF message digest keys, `hash_algorithm` must be in {OSPF_MESSAGE_DIGEST_HASH_ALGORITHMS}" + raise ValueError(msg) data = bytes(password, encoding="UTF-8") key_b = bytes(f"{key}_{hash_algorithm}Key_{key_id}", encoding="UTF-8") @@ -152,13 +160,14 @@ def ospf_message_digest_decrypt(password: str, key: str, hash_algorithm: str = N try: return cbc_decrypt(key_b, data).decode() except Exception as exc: - raise ValueError("OSPF password decryption failed - check the input parameters") from exc + msg = "OSPF password decryption failed - check the input parameters" + raise ValueError(msg) from exc ############## # BGP ############## -def bgp_encrypt(password: str, key) -> str: +def bgp_encrypt(password: str, key: str) -> str: """ Encrypts a password for BGP (Border Gateway Protocol) authentication. @@ -181,7 +190,7 @@ def bgp_encrypt(password: str, key) -> str: return cbc_encrypt(key, data).decode() -def bgp_decrypt(password: str, key) -> str: +def bgp_decrypt(password: str, key: str) -> str: """ Decrypts a password for BGP (Border Gateway Protocol) authentication. @@ -205,7 +214,8 @@ def bgp_decrypt(password: str, key) -> str: try: return cbc_decrypt(key, data).decode() except Exception as exc: - raise ValueError("BGP password decryption failed - check the input parameters") from exc + msg = "BGP password decryption failed - check the input parameters" + raise ValueError(msg) from exc ############## @@ -224,7 +234,7 @@ def bgp_decrypt(password: str, key) -> str: } -def _validate_isis_args(password: str, key: str, mode: str): +def _validate_isis_args(password: str, key: str, mode: str) -> None: """ Validates the arguments for ISIS (Intermediate System to Intermediate System) encryption/decryption. @@ -241,19 +251,24 @@ def _validate_isis_args(password: str, key: str, mode: str): ValueError: If `mode` is empty or missing. """ if not password: - raise ValueError("Password is required for encryption/decryption") + msg = "Password is required for encryption/decryption" + raise ValueError(msg) if not isinstance(password, str): - raise TypeError(f"Password MUST be of type 'str' but is of type {type(password)}") + msg = f"Password MUST be of type 'str' but is of type {type(password)}" + raise TypeError(msg) if not isinstance(key, str): - raise TypeError(f"Key MUST be of type 'str' but is of type {type(key)}") + msg = f"Key MUST be of type 'str' but is of type {type(key)}" + raise TypeError(msg) if not isinstance(mode, str): - raise TypeError(f"Mode MUST be a string with one of the following options: {list(_ISIS_MODE_MAP)}. Got '{mode}'.") + msg = f"Mode MUST be a string with one of the following options: {list(_ISIS_MODE_MAP)}. Got '{mode}'." + raise TypeError(msg) if not mode: - raise ValueError("Mode is required for encryption/decryption") + msg = "Mode is required for encryption/decryption" + raise ValueError(msg) def _get_isis_key(key: str, mode: str) -> bytes: @@ -316,7 +331,8 @@ def isis_decrypt(password: str, key: str, mode: str) -> str: try: return cbc_decrypt(_get_isis_key(key, mode), data).decode() except Exception as exc: - raise ValueError("ISIS password decryption failed - check the input parameters") from exc + msg = "ISIS password decryption failed - check the input parameters" + raise ValueError(msg) from exc ############### @@ -353,6 +369,6 @@ def simple_7_encrypt(data: str, salt: int | None = None) -> str: """ if salt is None: # Accepting SonarLint issue: Pseudo random is ok since this is simply creating a visible salt - salt = random.randint(0, 15) # NOSONAR + salt = random.randint(0, 15) # NOSONAR # noqa: S311 cleartext = data.encode("UTF-8") return f"{salt:02}" + bytearray(char ^ (SIMPLE_7_SEED[(salt + i) % 53]) for i, char in enumerate(cleartext)).hex().upper() diff --git a/python-avd/pyavd/_utils/password_utils/password_utils.py b/python-avd/pyavd/_utils/password_utils/password_utils.py index b6c6e074b97..c7b6df1b413 100644 --- a/python-avd/pyavd/_utils/password_utils/password_utils.py +++ b/python-avd/pyavd/_utils/password_utils/password_utils.py @@ -156,7 +156,7 @@ ENC_SIG = b"\x4c\x88\xbb" -def des_setparity(key): +def des_setparity(key: bytes) -> bytes: res = b"" for b in key: pos = b & 0x7F @@ -164,7 +164,7 @@ def des_setparity(key): return res -def hashkey(pw) -> bytes: +def hashkey(pw: bytes) -> bytes: result = bytearray(SEED) for idx, b in enumerate(pw): @@ -186,7 +186,6 @@ def cbc_encrypt(key: bytes, data: bytes) -> bytes: Returns: bytes: The encrypted data, encoded in base64. """ - hashed_key = hashkey(key) padding = (8 - ((len(data) + 4) % 8)) % 8 ciphertext = ENC_SIG + bytes([padding * 16 + 0xE]) + data + bytes(padding) @@ -214,7 +213,6 @@ def cbc_decrypt(key: bytes, data: bytes) -> bytes: Raises: ValueError: If the decrypted data is invalid or the length of the provided data is not a multiple of the block length. """ - data = base64.b64decode(data) hashed_key = hashkey(key) @@ -227,7 +225,8 @@ def cbc_decrypt(key: bytes, data: bytes) -> bytes: # Checking the decrypted string pad = result[3] >> 4 if result[:3] != ENC_SIG or pad >= 8 or len(result[4:]) < pad: - raise ValueError("Invalid Encrypted String") + msg = "Invalid Encrypted String" + raise ValueError(msg) password_len = len(result) - pad return result[4:password_len] @@ -235,7 +234,8 @@ def cbc_decrypt(key: bytes, data: bytes) -> bytes: def cbc_check_password(key: bytes, data: bytes) -> bool: """ Verify if an encrypted password is decryptable. - It does not return the password but only raises an error if the password cannot be decrypted + + It does not return the password but only raises an error if the password cannot be decrypted. Args: key (bytes): The decryption key, which should be the peer group name or neighbor IP with '_passwd' suffix. @@ -248,6 +248,7 @@ def cbc_check_password(key: bytes, data: bytes) -> bool: try: cbc_decrypt(key, data) - return True except Exception: return False + + return True diff --git a/python-avd/pyavd/_utils/replace_or_append_item.py b/python-avd/pyavd/_utils/replace_or_append_item.py index 3dc52847252..6af313c8f9e 100644 --- a/python-avd/pyavd/_utils/replace_or_append_item.py +++ b/python-avd/pyavd/_utils/replace_or_append_item.py @@ -3,8 +3,7 @@ # that can be found in the LICENSE file. def replace_or_append_item(list_of_dicts: list, key: str, replacement_dict: dict) -> int: """ - In-place replace or append one dictionary to a list of dictionaries by matching the given key - with the value of replacement_dict[key] + In-place replace or append one dictionary to a list of dictionaries by matching the given key with the value of replacement_dict[key]. Parameters ---------- @@ -16,14 +15,14 @@ def replace_or_append_item(list_of_dicts: list, key: str, replacement_dict: dict Dictionary to replace / append. The value of 'key' in this dict is used to search for existing entries. - Returns + Returns: ------- int Index in list_of_dicts of replaced / appended entry """ - if key not in replacement_dict: - raise ValueError(f"The argument 'replacement_dict' does not contain the key {key}") + msg = f"The argument 'replacement_dict' does not contain the key {key}" + raise ValueError(msg) for index, list_item in enumerate(list_of_dicts): if not isinstance(list_item, dict): diff --git a/python-avd/pyavd/_utils/strip_empties.py b/python-avd/pyavd/_utils/strip_empties.py index 76ab0066ada..9affea880c1 100644 --- a/python-avd/pyavd/_utils/strip_empties.py +++ b/python-avd/pyavd/_utils/strip_empties.py @@ -3,8 +3,12 @@ # that can be found in the LICENSE file. from __future__ import annotations +from typing import TypeVar -def strip_null_from_data(data, strip_values_tuple=(None,)): +T = TypeVar("T") + + +def strip_null_from_data(data: T, strip_values_tuple: tuple = (None,)) -> T: """ strip_null_from_data Generic function to strip null entries regardless type of variable. @@ -13,7 +17,7 @@ def strip_null_from_data(data, strip_values_tuple=(None,)): data : Any Data to look for null content to strip out - Returns + Returns: ------- Any Cleaned data with no null. @@ -25,9 +29,9 @@ def strip_null_from_data(data, strip_values_tuple=(None,)): return data -def strip_empties_from_list(data, strip_values_tuple=(None, "", [], {})): +def strip_empties_from_list(data: list, strip_values_tuple: tuple = (None, "", [], {})) -> list: """ - strip_empties_from_list Remove entries with null value from a list + strip_empties_from_list Remove entries with null value from a list. Parameters ---------- @@ -36,7 +40,7 @@ def strip_empties_from_list(data, strip_values_tuple=(None, "", [], {})): strip_values_tuple : tuple, optional Value to remove from data, by default (None, "", [], {},) - Returns + Returns: ------- Any Cleaned list with no strip_values_tuple @@ -44,17 +48,20 @@ def strip_empties_from_list(data, strip_values_tuple=(None, "", [], {})): new_data = [] for v in data: if isinstance(v, dict): - v = strip_empties_from_dict(v, strip_values_tuple) + stripped_v = strip_empties_from_dict(v, strip_values_tuple) elif isinstance(v, list): - v = strip_empties_from_list(v, strip_values_tuple) - if v not in strip_values_tuple: - new_data.append(v) + stripped_v = strip_empties_from_list(v, strip_values_tuple) + else: + stripped_v = v + + if stripped_v not in strip_values_tuple: + new_data.append(stripped_v) return new_data -def strip_empties_from_dict(data, strip_values_tuple=(None, "", [], {})): +def strip_empties_from_dict(data: dict, strip_values_tuple: tuple = (None, "", [], {})) -> dict: """ - strip_empties_from_dict Remove entries with null value from a dict + strip_empties_from_dict Remove entries with null value from a dict. Parameters ---------- @@ -63,7 +70,7 @@ def strip_empties_from_dict(data, strip_values_tuple=(None, "", [], {})): strip_values_tuple : tuple, optional Value to remove from data, by default (None, "", [], {},) - Returns + Returns: ------- Any Cleaned dict with no strip_values_tuple @@ -71,9 +78,11 @@ def strip_empties_from_dict(data, strip_values_tuple=(None, "", [], {})): new_data = {} for k, v in data.items(): if isinstance(v, dict): - v = strip_empties_from_dict(v, strip_values_tuple) + stripped_v = strip_empties_from_dict(v, strip_values_tuple) elif isinstance(v, list): - v = strip_empties_from_list(v, strip_values_tuple) - if v not in strip_values_tuple: - new_data[k] = v + stripped_v = strip_empties_from_list(v, strip_values_tuple) + else: + stripped_v = v + if stripped_v not in strip_values_tuple: + new_data[k] = stripped_v return new_data diff --git a/python-avd/pyavd/_utils/template.py b/python-avd/pyavd/_utils/template.py index 814c02353c6..1046f037267 100644 --- a/python-avd/pyavd/_utils/template.py +++ b/python-avd/pyavd/_utils/template.py @@ -1,7 +1,7 @@ # Copyright (c) 2023-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -def template(template_file, template_vars, templar): +def template(template_file: str, template_vars: dict, templar: object) -> str: """ Run Ansible Templar with template file. @@ -22,13 +22,14 @@ def template(template_file, template_vars, templar): searchpath : list of str List of Paths - Returns + Returns: ------- str The rendered template """ if templar is None: - raise NotImplementedError("Jinja Templating is not implemented in pyavd") + msg = "Jinja Templating is not implemented in pyavd" + raise NotImplementedError(msg) # We only get here when running from Ansible, so it is safe to import from ansible. # pylint: disable=import-outside-toplevel @@ -43,6 +44,4 @@ def template(template_file, template_vars, templar): j2template = to_text(j2template) with templar.set_temporary_context(available_variables=template_vars): - result = templar.template(j2template, convert_data=False, escape_backslashes=False) - - return result + return templar.template(j2template, convert_data=False, escape_backslashes=False) diff --git a/python-avd/pyavd/_utils/template_var.py b/python-avd/pyavd/_utils/template_var.py index 387f46fc31b..2d10897b9ed 100644 --- a/python-avd/pyavd/_utils/template_var.py +++ b/python-avd/pyavd/_utils/template_var.py @@ -1,12 +1,14 @@ # Copyright (c) 2023-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. +from typing import Any + from .template import template -def template_var(template_file, template_vars, templar): +def template_var(template_file: str, template_vars: Any, templar: object) -> str: """ - Wrap "template" for single values like IP addresses + Wrap "template" for single values like IP addresses. The result is forced into a string and leading/trailing newlines and whitespaces are removed. @@ -21,7 +23,7 @@ def template_var(template_file, template_vars, templar): searchpath : list of str List of Paths - Returns + Returns: ------- str The rendered template diff --git a/python-avd/pyavd/_utils/unique.py b/python-avd/pyavd/_utils/unique.py index 0c94596ecc3..5a6b9e880f0 100644 --- a/python-avd/pyavd/_utils/unique.py +++ b/python-avd/pyavd/_utils/unique.py @@ -1,18 +1,17 @@ # Copyright (c) 2023-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -def unique(in_list): +def unique(in_list: list) -> list: """ - Return list of unique items from the in_list + Return list of unique items from the in_list. Parameters ---------- in_list : list - Returns + Returns: ------- list Unique list items """ - return list(set(in_list)) diff --git a/python-avd/pyavd/avd_schema_tools.py b/python-avd/pyavd/avd_schema_tools.py index 530a0a7bc6c..8d70e535f57 100644 --- a/python-avd/pyavd/avd_schema_tools.py +++ b/python-avd/pyavd/avd_schema_tools.py @@ -3,26 +3,29 @@ # that can be found in the LICENSE file. from __future__ import annotations -from typing import TYPE_CHECKING, Generator +from typing import TYPE_CHECKING if TYPE_CHECKING: + from collections.abc import Generator + from ._errors import AvdDeprecationWarning from .validation_result import ValidationResult class AvdSchemaTools: - """ - Tools that wrap the various schema components for easy use - """ + """Tools that wrap the various schema components for easy use.""" - def __init__(self, schema: dict = None, schema_id: str = None) -> None: + def __init__(self, schema: dict | None = None, schema_id: str | None = None) -> None: """ - Convert data according to the schema (convert_types) + Convert data according to the schema (convert_types). + The data conversion is done in-place (updating the original "data" dict). Args: + schema: + Optional AVD schema as dict schema_id: - Name of AVD Schema to use for conversion and validation. + Optional Name of AVD Schema to load from store """ # pylint: disable=import-outside-toplevel from ._schema.avdschema import AvdSchema @@ -33,7 +36,8 @@ def __init__(self, schema: dict = None, schema_id: str = None) -> None: def convert_data(self, data: dict) -> list[AvdDeprecationWarning]: """ - Convert data according to the schema (convert_types) + Convert data according to the schema (convert_types). + The data conversion is done in-place (updating the original "data" dict). Args: @@ -67,7 +71,7 @@ def convert_data(self, data: dict) -> list[AvdDeprecationWarning]: def validate_data(self, data: dict) -> ValidationResult: """ - Validate data according to the schema + Validate data according to the schema. Args: data: @@ -110,7 +114,7 @@ def validate_data(self, data: dict) -> ValidationResult: def convert_and_validate_data(self, data: dict) -> dict: """ - Convert and validate data according to the schema + Convert and validate data according to the schema. Returns dictionary to be compatible with Ansible plugin. Called from vendored "get_structured_config". @@ -118,7 +122,7 @@ def convert_and_validate_data(self, data: dict) -> dict: data: Input variables which are to be validated according to the schema. - Returns + Returns: dict : failed : bool True if data is invalid. Otherwise False. diff --git a/python-avd/pyavd/get_avd_facts.py b/python-avd/pyavd/get_avd_facts.py index 78b7093c524..7132141dc0c 100644 --- a/python-avd/pyavd/get_avd_facts.py +++ b/python-avd/pyavd/get_avd_facts.py @@ -27,14 +27,9 @@ def get_avd_facts(all_inputs: dict[str, dict]) -> dict[str, dict]: Returns: Nested dictionary with various internal "facts". The full dict must be given as argument to `pyavd.get_device_structured_config`: ```python - { - "avd_switch_facts": dict, - "avd_overlay_peers": dict, - "avd_topology_peers" : dict - } + {"avd_switch_facts": dict, "avd_overlay_peers": dict, "avd_topology_peers": dict} ``` """ - avd_switch_facts_instances = _create_avd_switch_facts_instances(all_inputs) avd_switch_facts = _render_avd_switch_facts(avd_switch_facts_instances) avd_overlay_peers, avd_topology_peers = _render_peer_facts(avd_switch_facts) @@ -94,9 +89,9 @@ def _create_avd_switch_facts_instances(all_inputs: dict[str, dict]) -> dict: return avd_switch_facts -def _render_avd_switch_facts(avd_switch_facts_instances: dict): +def _render_avd_switch_facts(avd_switch_facts_instances: dict) -> dict: """ - Run the render method on each EosDesignsFacts object + Run the render method on each EosDesignsFacts object. Args: avd_switch_facts_instances: Dictionary with instances of EosDesignsFacts per device. @@ -128,7 +123,7 @@ def _render_avd_switch_facts(avd_switch_facts_instances: dict): def _render_peer_facts(avd_switch_facts: dict) -> tuple[dict, dict]: """ - Build dicts of underlay and overlay peerings based on avd_switch_facts + Build dicts of underlay and overlay peerings based on avd_switch_facts. Args: avd_switch_facts: Nested Dictionaried with rendered "avd_switch_facts" per device. @@ -153,7 +148,6 @@ def _render_peer_facts(avd_switch_facts: dict) -> tuple[dict, dict]: List of switches having hostname2 as uplink_switch """ - avd_overlay_peers = {} avd_topology_peers = {} for hostname in avd_switch_facts: diff --git a/python-avd/pyavd/get_device_config.py b/python-avd/pyavd/get_device_config.py old mode 100755 new mode 100644 diff --git a/python-avd/pyavd/get_device_doc.py b/python-avd/pyavd/get_device_doc.py old mode 100755 new mode 100644 diff --git a/python-avd/pyavd/get_device_structured_config.py b/python-avd/pyavd/get_device_structured_config.py old mode 100755 new mode 100644 index ea79322152c..053ca166347 --- a/python-avd/pyavd/get_device_structured_config.py +++ b/python-avd/pyavd/get_device_structured_config.py @@ -49,6 +49,7 @@ def get_device_structured_config(hostname: str, inputs: dict, avd_facts: dict) - templar=None, ) if result.get("failed"): - raise AristaAvdError(f"{[str(error) for error in result['errors']]}") + msg = f"{[str(error) for error in result['errors']]}" + raise AristaAvdError(msg) return structured_config diff --git a/python-avd/pyavd/j2filters/add_md_toc.py b/python-avd/pyavd/j2filters/add_md_toc.py index 121eac3da29..49f332cf745 100644 --- a/python-avd/pyavd/j2filters/add_md_toc.py +++ b/python-avd/pyavd/j2filters/add_md_toc.py @@ -26,23 +26,26 @@ def add_md_toc(md_input: str, skip_lines: int = 0, toc_levels: int = 3, toc_mark TOC will be inserted or updated between two of these markers in the MD file default: '' - Returns + Returns: ------- str MD with added TOC """ - if not isinstance(skip_lines, int): - raise TypeError(f"add_md_toc 'skip_lines' argument must be an integer. Got '{skip_lines}'({type(skip_lines)}).") + msg = f"add_md_toc 'skip_lines' argument must be an integer. Got '{skip_lines}'({type(skip_lines)})." + raise TypeError(msg) if not isinstance(toc_levels, int) or toc_levels < 1: - raise TypeError(f"add_md_toc 'toc_levels' argument must be >0. Got '{toc_levels}'({type(skip_lines)}).") + msg = f"add_md_toc 'toc_levels' argument must be >0. Got '{toc_levels}'({type(skip_lines)})." + raise TypeError(msg) if not isinstance(toc_marker, str) or not toc_marker: - raise TypeError(f"add_md_toc 'toc_marker' argument must be a non-empty string. Got '{toc_marker}'({type(skip_lines)}).") + msg = f"add_md_toc 'toc_marker' argument must be a non-empty string. Got '{toc_marker}'({type(skip_lines)})." + raise TypeError(msg) if not isinstance(md_input, str): - raise TypeError(f"add_md_toc expects a string. Got {type(md_input)}.") + msg = f"add_md_toc expects a string. Got {type(md_input)}." + raise TypeError(msg) md_lines = md_input.split("\n") toc_marker_positions = [] @@ -75,14 +78,13 @@ def add_md_toc(md_input: str, skip_lines: int = 0, toc_levels: int = 3, toc_mark toc_lines.append(f"{prefix}[{text}](#{anchor_id})") if len(toc_marker_positions) != 2: - raise ValueError( - f"add_md_toc expects exactly two occurrences of the toc marker '{toc_marker}' on their own lines. Found {len(toc_marker_positions)} occurrences." - ) + msg = f"add_md_toc expects exactly two occurrences of the toc marker '{toc_marker}' on their own lines. Found {len(toc_marker_positions)} occurrences." + raise ValueError(msg) return "\n".join(md_lines[: toc_marker_positions[0]] + toc_lines + md_lines[toc_marker_positions[1] + 1 :]) -def _get_line_info(line: str, all_anchor_ids: list[str]) -> (int, str, str): +def _get_line_info(line: str, all_anchor_ids: list[str]) -> tuple[int, str, str]: """Split heading and return level, text and anchor_id. Since we know the line is already a heading, we can assume correct formatting. @@ -95,7 +97,7 @@ def _get_line_info(line: str, all_anchor_ids: list[str]) -> (int, str, str): all_anchor_ids: list List of existing anchor_ids - Returns + Returns: ------- int, str, str: The level of the heading, the text of the heading and the anchor_id for the heading. @@ -110,6 +112,7 @@ def _get_line_info(line: str, all_anchor_ids: list[str]) -> (int, str, str): def _get_anchor_id(text: str, all_anchor_ids: list[str]) -> str: """ Returns a unique anchor_id after adding it to 'all_anchor_ids'. + The logic here follow the auto-id generation algorithm of the MarkDown spec. Parameters @@ -119,7 +122,7 @@ def _get_anchor_id(text: str, all_anchor_ids: list[str]) -> str: all_anchor_ids: list List of existing anchor_ids - Returns + Returns: ------- str: The anchor ID for the text. diff --git a/python-avd/pyavd/j2filters/convert_dicts.py b/python-avd/pyavd/j2filters/convert_dicts.py index f78243d7d35..9fc267af957 100644 --- a/python-avd/pyavd/j2filters/convert_dicts.py +++ b/python-avd/pyavd/j2filters/convert_dicts.py @@ -8,6 +8,8 @@ def convert_dicts(dictionary: dict | list, primary_key: str = "name", secondary_key: str | None = None) -> list: """ + Convert dictionaries to lists. + The `arista.avd.convert_dicts` filter will convert a dictionary containing nested dictionaries to a list of dictionaries. It inserts the outer dictionary keys into each list item using the primary_key `name` (key name is configurable) and if there is a non-dictionary value,it inserts this value to @@ -46,7 +48,7 @@ def convert_dicts(dictionary: dict | list, primary_key: str = "name", secondary_ secondary_key : str, optional Name of secondary key used when inserting dictionary values which are list into items. - Returns + Returns: ------- any Returns list of dictionaries or input variable untouched if not a nested dictionary/list. @@ -61,13 +63,13 @@ def convert_dicts(dictionary: dict | list, primary_key: str = "name", secondary_ output.append({primary_key: element}) elif primary_key not in element and secondary_key is not None: # if element of nested dictionary is a dictionary but primary key is missing, insert primary and secondary keys. - for key in element: - output.append( - { - primary_key: key, - secondary_key: element[key], - } - ) + output.extend( + { + primary_key: key, + secondary_key: element[key], + } + for key in element + ) else: output.append(element) return output @@ -80,18 +82,17 @@ def convert_dicts(dictionary: dict | list, primary_key: str = "name", secondary_ { primary_key: key, secondary_key: dictionary[key], - } + }, ) + elif not isinstance(dictionary[key], dict): + # Not a nested dictionary + output.append({primary_key: key}) else: - if not isinstance(dictionary[key], dict): - # Not a nested dictionary - output.append({primary_key: key}) - else: - # Nested dictionary - output.append( - { - primary_key: key, - **dictionary[key], - } - ) + # Nested dictionary + output.append( + { + primary_key: key, + **dictionary[key], + }, + ) return output diff --git a/python-avd/pyavd/j2filters/decrypt.py b/python-avd/pyavd/j2filters/decrypt.py index 5894c7a4aa3..a9c62062b3a 100644 --- a/python-avd/pyavd/j2filters/decrypt.py +++ b/python-avd/pyavd/j2filters/decrypt.py @@ -3,10 +3,12 @@ # that can be found in the LICENSE file. from __future__ import annotations +from typing import Any + from pyavd._utils.password_utils import METHODS_DIR -def decrypt(value, passwd_type=None, key=None, **kwargs) -> str: +def decrypt(value: Any, passwd_type: str | None = None, key: str | None = None, **kwargs: dict[str, Any]) -> str: """ Umbrella function to execute the correct decrypt method based on the input type. @@ -24,9 +26,11 @@ def decrypt(value, passwd_type=None, key=None, **kwargs) -> str: KeyError: If `passwd_type` is not found in `METHODS_DIR`. """ if not passwd_type: - raise TypeError("type keyword must be present to use this test") + msg = "type keyword must be present to use this test" + raise TypeError(msg) try: decrypt_method = METHODS_DIR[passwd_type][1] except KeyError as exc: - raise KeyError(f"Type {passwd_type} is not supported for the decrypt filter") from exc + msg = f"Type {passwd_type} is not supported for the decrypt filter" + raise KeyError(msg) from exc return decrypt_method(str(value), key=key, **kwargs) diff --git a/python-avd/pyavd/j2filters/default.py b/python-avd/pyavd/j2filters/default.py index ecc328b2665..326a230c0c6 100644 --- a/python-avd/pyavd/j2filters/default.py +++ b/python-avd/pyavd/j2filters/default.py @@ -1,10 +1,14 @@ # Copyright (c) 2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. +from typing import TypeVar + from jinja2.runtime import Undefined +T = TypeVar("T") + -def default(*values): +def default(*values: list[T]) -> T | None: """ default will test value if defined and is not none. diff --git a/python-avd/pyavd/j2filters/encrypt.py b/python-avd/pyavd/j2filters/encrypt.py index 22077f908f5..e973bdcf368 100644 --- a/python-avd/pyavd/j2filters/encrypt.py +++ b/python-avd/pyavd/j2filters/encrypt.py @@ -3,10 +3,12 @@ # that can be found in the LICENSE file. from __future__ import annotations +from typing import Any + from pyavd._utils.password_utils import METHODS_DIR -def encrypt(value, passwd_type=None, key=None, **kwargs) -> str: +def encrypt(value: Any, passwd_type: str | None = None, key: str | None = None, **kwargs: Any) -> str: """ Umbrella function to execute the correct encrypt method based on the input type. @@ -24,9 +26,11 @@ def encrypt(value, passwd_type=None, key=None, **kwargs) -> str: KeyError: If `passwd_type` is not found in `METHODS_DIR`. """ if not passwd_type: - raise TypeError("type keyword must be present to use this test") + msg = "type keyword must be present to use this test" + raise TypeError(msg) try: encrypt_method = METHODS_DIR[passwd_type][0] except KeyError as exc: - raise KeyError(f"Type {passwd_type} is not supported for the encrypt filter") from exc + msg = f"Type {passwd_type} is not supported for the encrypt filter" + raise KeyError(msg) from exc return encrypt_method(str(value), key=key, **kwargs) diff --git a/python-avd/pyavd/j2filters/hide_passwords.py b/python-avd/pyavd/j2filters/hide_passwords.py index ca7bc09eaed..d13505b1c57 100644 --- a/python-avd/pyavd/j2filters/hide_passwords.py +++ b/python-avd/pyavd/j2filters/hide_passwords.py @@ -16,5 +16,6 @@ def hide_passwords(value: str, hide_passwords: bool = False) -> str: """ if not isinstance(hide_passwords, bool): - raise TypeError(f"{hide_passwords} in hide_passwords filter is not of type bool") + msg = f"{hide_passwords} in hide_passwords filter is not of type bool" + raise TypeError(msg) return "" if hide_passwords else value diff --git a/python-avd/pyavd/j2filters/is_in_filter.py b/python-avd/pyavd/j2filters/is_in_filter.py index fb92c5d4a73..22b3cff9c62 100644 --- a/python-avd/pyavd/j2filters/is_in_filter.py +++ b/python-avd/pyavd/j2filters/is_in_filter.py @@ -15,7 +15,7 @@ def is_in_filter(hostname: str, hostname_filter: list | None) -> bool: hostname_filter : list, optional Device filter, by default ['all'] - Returns + Returns: ------- boolean True if device hostname is part of filter. False if not. diff --git a/python-avd/pyavd/j2filters/list_compress.py b/python-avd/pyavd/j2filters/list_compress.py index 5ef482825b4..40af1555d13 100644 --- a/python-avd/pyavd/j2filters/list_compress.py +++ b/python-avd/pyavd/j2filters/list_compress.py @@ -9,8 +9,10 @@ def list_compress(list_to_compress: list[int]) -> str: """ Compresses a list of integers to a range string. + Args: list_to_compress (list): List of integers. + Returns: str: Compressed range string. @@ -20,10 +22,13 @@ def list_compress(list_to_compress: list[int]) -> str: list2: "{{ [1,2,3,7,8] | arista.avd.list_compress }}" -> "1-3,7-8" """ if not isinstance(list_to_compress, list): - raise TypeError(f"Value must be of type list, got {type(list_to_compress)}") + msg = f"Value must be of type list, got {type(list_to_compress)}" + raise TypeError(msg) if not all(isinstance(item, int) for item in list_to_compress): - raise TypeError(f"All elements of the list {list_to_compress} must be integers") + msg = f"All elements of the list {list_to_compress} must be integers" + raise TypeError(msg) - groups = (list(group) for key, group in groupby(sorted(list_to_compress), lambda element, iterator=count(): next(iterator) - element)) + counter = count() + groups = (list(group) for key, group in groupby(sorted(list_to_compress), lambda element, counter=counter: next(counter) - element)) return ",".join("-".join(map(str, (group[0], group[-1])[: len(group)])) for group in groups) diff --git a/python-avd/pyavd/j2filters/natural_sort.py b/python-avd/pyavd/j2filters/natural_sort.py index 1d331fbdb29..5c167c8e451 100644 --- a/python-avd/pyavd/j2filters/natural_sort.py +++ b/python-avd/pyavd/j2filters/natural_sort.py @@ -4,6 +4,7 @@ from __future__ import annotations import re +from typing import Any from jinja2.runtime import Undefined from jinja2.utils import Namespace @@ -15,12 +16,12 @@ def convert(text: str, ignore_case: bool) -> int | str: Converts the string to an integer if it is a digit, otherwise converts it to lower case if ignore_case is True. - Parameters - ---------- + Args: + ----- text (str): Input string. ignore_case (bool): If ignore_case is True, strings are applied lower() function. - Returns + Returns: ------- int | str: Converted string. """ @@ -32,25 +33,25 @@ def convert(text: str, ignore_case: bool) -> int | str: def natural_sort(iterable: list | dict | str | None, sort_key: str | None = None, *, strict: bool = True, ignore_case: bool = True) -> list: """Sorts an iterable in a natural (alphanumeric) order. - Parameters - ---------- + Args: + ----- iterable (list | dict | str | None): Input iterable. sort_key (str | None, optional): Key to sort by, defaults to None. strict (bool, optional): If strict is True, raise an error is the sort_key is missing. ignore_case (bool, optional): If ignore_case is True, strings are applied lower() function. - Returns + Returns: ------- list: Sorted iterable. - Raises + Raises: ------ KeyError, AttributeError: if strict=True and sort_key is not present in an item in the iterable. """ if isinstance(iterable, Undefined) or iterable is None: return [] - def alphanum_key(key): + def alphanum_key(key: Any) -> list: pattern = r"(\d+)" if sort_key is not None and isinstance(key, dict): if strict and sort_key not in key: diff --git a/python-avd/pyavd/j2filters/range_expand.py b/python-avd/pyavd/j2filters/range_expand.py index be86a9ecf81..d37a8afad6d 100644 --- a/python-avd/pyavd/j2filters/range_expand.py +++ b/python-avd/pyavd/j2filters/range_expand.py @@ -5,11 +5,87 @@ from __future__ import annotations import re +from dataclasses import dataclass +from typing import Any -def range_expand(range_to_expand): +@dataclass +class InterfaceData: + one_range: str + first_interface: int | None = None + last_interface: int | None = None + first_subinterface: int | None = None + last_subinterface: int | None = None + first_parent_interface: int | None = None + last_parent_interface: int | None = None + first_module: int | None = None + last_module: int | None = None + + +def expand_subinterfaces(interface_string: str, data: InterfaceData) -> list: + result = [] + if data.last_subinterface is not None: + if data.first_subinterface > data.last_subinterface: + msg = ( + f"Range {data.one_range} could not be expanded because the first subinterface {data.first_subinterface} is larger than last" + f" subinterface {data.last_subinterface} in the range." + ) + raise ValueError(msg) + result.extend(f"{interface_string}.{subinterface}" for subinterface in range(data.first_subinterface, data.last_subinterface + 1)) + else: + result.append(interface_string) + return result + + +def expand_interfaces(interface_string: str, data: InterfaceData) -> list: + result = [] + if data.first_interface > data.last_interface: + msg = ( + f"Range {data.one_range} could not be expanded because the first interface {data.first_interface} is larger than last interface" + f" {data.last_interface} in the range." + ) + raise ValueError(msg) + for interface in range(data.first_interface, data.last_interface + 1): + result.extend(expand_subinterfaces(f"{interface_string}{interface}", data)) + return result + + +def expand_parent_interfaces(interface_string: str, data: InterfaceData) -> list: + result = [] + if data.last_parent_interface: + if data.first_parent_interface > data.last_parent_interface: + msg = ( + f"Range {data.one_range} could not be expanded because the first interface {data.first_parent_interface} is larger than last" + f" interface {data.last_parent_interface} in the range." + ) + raise ValueError(msg) + for parent_interface in range(data.first_parent_interface, data.last_parent_interface + 1): + result.extend(expand_interfaces(f"{interface_string}{parent_interface}/", data)) + else: + result.extend(expand_interfaces(f"{interface_string}", data)) + return result + + +def expand_module(interface_string: str, data: InterfaceData) -> list: + result = [] + if data.last_module: + if data.first_module > data.last_module: + msg = ( + f"Range {data.one_range} could not be expanded because the first module {data.first_module} is larger than last module" + f" {data.last_module} in the range." + ) + raise ValueError(msg) + for module in range(data.first_module, data.last_module + 1): + result.extend(expand_parent_interfaces(f"{interface_string}{module}/", data)) + else: + result.extend(expand_parent_interfaces(f"{interface_string}", data)) + return result + + +def range_expand(range_to_expand: Any) -> list: if not isinstance(range_to_expand, (list, str)): - raise TypeError(f"value must be of type list or str, got {type(range_to_expand)}") + msg = f"value must be of type list or str, got {type(range_to_expand)}" + raise TypeError(msg) result = [] @@ -55,99 +131,27 @@ def range_expand(range_to_expand): if search_result: if len(search_result.groups()) == regex_groups: groups = search_result.groups() - first_module = last_module = None - first_parent_interface = last_parent_interface = None - first_interface = last_interface = None - first_subinterface = last_subinterface = None + data = InterfaceData(one_range=one_range) # Set prefix if found (otherwise use last set prefix) if groups[0]: prefix = groups[0] if groups[4]: - last_module = int(groups[4]) - if groups[3]: - first_module = int(groups[3]) - else: - first_module = last_module + data.last_module = int(groups[4]) + data.first_module = int(groups[3]) if groups[3] else data.last_module if groups[8]: - last_parent_interface = int(groups[8]) - if groups[7]: - first_parent_interface = int(groups[7]) - else: - first_parent_interface = last_parent_interface + data.last_parent_interface = int(groups[8]) + data.first_parent_interface = int(groups[7]) if groups[7] else data.last_parent_interface if groups[12]: - last_interface = int(groups[12]) - if groups[11]: - first_interface = int(groups[11]) - else: - first_interface = last_interface + data.last_interface = int(groups[12]) + data.first_interface = int(groups[11]) if groups[11] else data.last_interface if groups[16]: - last_subinterface = int(groups[16]) - if groups[15]: - first_subinterface = int(groups[15]) - else: - first_subinterface = last_subinterface - - def expand_subinterfaces(interface_string): - result = [] - if last_subinterface is not None: - if first_subinterface > last_subinterface: - raise ValueError( - f"Range {one_range} could not be expanded because the first subinterface {first_subinterface} is larger than last" - f" subinterface {last_subinterface} in the range." - ) - for subinterface in range(first_subinterface, last_subinterface + 1): - result.append(f"{interface_string}.{subinterface}") - else: - result.append(interface_string) - return result - - def expand_interfaces(interface_string): - result = [] - if first_interface > last_interface: - raise ValueError( - f"Range {one_range} could not be expanded because the first interface {first_interface} is larger than last interface" - f" {last_interface} in the range." - ) - for interface in range(first_interface, last_interface + 1): - for res in expand_subinterfaces(f"{interface_string}{interface}"): - result.append(res) - return result - - def expand_parent_interfaces(interface_string): - result = [] - if last_parent_interface: - if first_parent_interface > last_parent_interface: - raise ValueError( - f"Range {one_range} could not be expanded because the first interface {first_parent_interface} is larger than last" - f" interface {last_parent_interface} in the range." - ) - for parent_interface in range(first_parent_interface, last_parent_interface + 1): - for res in expand_interfaces(f"{interface_string}{parent_interface}/"): - result.append(res) - else: - for res in expand_interfaces(f"{interface_string}"): - result.append(res) - return result - - def expand_module(interface_string): - result = [] - if last_module: - if first_module > last_module: - raise ValueError( - f"Range {one_range} could not be expanded because the first module {first_module} is larger than last module" - f" {last_module} in the range." - ) - for module in range(first_module, last_module + 1): - for res in expand_parent_interfaces(f"{interface_string}{module}/"): - result.append(res) - else: - for res in expand_parent_interfaces(f"{interface_string}"): - result.append(res) - return result - - result.extend(expand_module(prefix)) + data.last_subinterface = int(groups[16]) + data.first_subinterface = int(groups[15]) if groups[15] else data.last_subinterface + + result.extend(expand_module(prefix, data)) else: - raise ValueError(f"Invalid range, got {one_range} and found {search_result.groups()}") + msg = f"Invalid range, got {one_range} and found {search_result.groups()}" + raise ValueError(msg) return result diff --git a/python-avd/pyavd/j2filters/snmp_hash.py b/python-avd/pyavd/j2filters/snmp_hash.py index 726c0615295..a757a245a0b 100644 --- a/python-avd/pyavd/j2filters/snmp_hash.py +++ b/python-avd/pyavd/j2filters/snmp_hash.py @@ -20,13 +20,15 @@ def _get_hash_object(auth_type: str) -> object: try: return hashlib.new(auth) except ValueError: - raise ValueError(f"{auth_type} is not a valid Auth algorithm for SNMPv3") from ValueError + msg = f"{auth_type} is not a valid Auth algorithm for SNMPv3" + raise ValueError(msg) from ValueError def _key_from_passphrase(passphrase: str, auth_type: str) -> str: """ - RFC 2574 section A.2 algorithm - https://www.rfc-editor.org/rfc/rfc2574.html#appendix-A2 + RFC 2574 section A.2 algorithm. + + https://www.rfc-editor.org/rfc/rfc2574.html#appendix-A2. :param passphrase: the passphrase to use to generate the key :param auth_type: a string in [md5|sha|sha224|sha256|sha384|sha512] @@ -45,7 +47,7 @@ def _key_from_passphrase(passphrase: str, auth_type: str) -> str: password_length = len(b_passphrase) while count < 1048576: cp = bytearray() - for _ in range(0, 64): + for _ in range(64): cp.append(b_passphrase[password_index % password_length]) password_index += 1 hash_object.update(cp) @@ -53,10 +55,11 @@ def _key_from_passphrase(passphrase: str, auth_type: str) -> str: return hash_object.hexdigest() -def _localize_passphrase(passphrase: str, auth_type: str, engine_id: str, priv_type: str = None) -> str: +def _localize_passphrase(passphrase: str, auth_type: str, engine_id: str, priv_type: str | None = None) -> str: """ - Key localization as described in RFC 2574, section 2.6 - https://www.rfc-editor.org/rfc/rfc2574.html#section-2.6 + Key localization as described in RFC 2574, section 2.6. + + https://www.rfc-editor.org/rfc/rfc2574.html#section-2.6. :param passphrase: the passphrase to localize, if priv_type is None it is the auth passphrase else it is the priv @@ -74,13 +77,13 @@ def _localize_passphrase(passphrase: str, auth_type: str, engine_id: str, priv_t :raises: AristaAvdError, when the auth_type or priv_type is not valid or if the engined_id is not a proper hexadecimal string """ - key = bytes.fromhex(_key_from_passphrase(passphrase, auth_type)) hash_object = _get_hash_object(auth_type) try: hash_object.update(key + bytes.fromhex(engine_id) + key) except ValueError as error: - raise ValueError(f"engine ID {engine_id} is not an hexadecimal string") from error + msg = f"engine ID {engine_id} is not an hexadecimal string" + raise ValueError(msg) from error localized_key = hash_object.hexdigest() if priv_type is not None: try: @@ -91,7 +94,8 @@ def _localize_passphrase(passphrase: str, auth_type: str, engine_id: str, priv_t # Truncate ithe key if required localized_key = localized_key[: _PRIV_KEY_LENGTH[priv_type] // 4] except KeyError as error: - raise ValueError(f"{priv_type} is not a valid Priv algorithm for SNMPv3") from error + msg = f"{priv_type} is not a valid Priv algorithm for SNMPv3" + raise ValueError(msg) from error return localized_key diff --git a/python-avd/pyavd/j2filters/status_render.py b/python-avd/pyavd/j2filters/status_render.py index 9583279afe6..51d481c9028 100644 --- a/python-avd/pyavd/j2filters/status_render.py +++ b/python-avd/pyavd/j2filters/status_render.py @@ -3,10 +3,15 @@ # that can be found in the LICENSE file. from __future__ import annotations +GH_CODE = { + "PASS": ":white_check_mark:", # Github MD code for Emoji checked box + "FAIL": ":x:", # GH MD code for Emoji Fail +} -def status_render(state_string, rendering): + +def status_render(state_string: str, rendering: str) -> str: """ - status_render Convert Text to EMOJI code + status_render Convert Text to EMOJI code. Parameters ---------- @@ -15,18 +20,11 @@ def status_render(state_string, rendering): rendering : string Markdown Flavor to use for Emoji rendering. - Returns + Returns: ------- str Value to render in markdown """ - # STATIC EMOJI CODE - GH_CODE = {} - # Github MD code for Emoji checked box - GH_CODE["PASS"] = ":white_check_mark:" - # GH MD code for Emoji Fail - GH_CODE["FAIL"] = ":x:" - if rendering == "github": return GH_CODE[state_string.upper()] return state_string diff --git a/python-avd/pyavd/j2tests/__init__.py b/python-avd/pyavd/j2tests/__init__.py new file mode 100644 index 00000000000..b17ca7c745d --- /dev/null +++ b/python-avd/pyavd/j2tests/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2024 Arista Networks, Inc. +# Use of this source code is governed by the Apache License 2.0 +# that can be found in the LICENSE file. diff --git a/python-avd/pyavd/j2tests/contains.py b/python-avd/pyavd/j2tests/contains.py index e7dcfbc697c..bf8ba0dca60 100644 --- a/python-avd/pyavd/j2tests/contains.py +++ b/python-avd/pyavd/j2tests/contains.py @@ -25,12 +25,12 @@ def contains(value: list[Any], test_value: Any | list[Any] = None) -> bool: test_value : single item or list of items Value(s) to test for in value - Returns + Returns: ------- boolean True if variable matches criteria, False in other cases. """ - # TODO - this will fail miserably if test_value is not hashable ! + # TODO: - this will fail miserably if test_value is not hashable ! if isinstance(value, Undefined) or value is None or not isinstance(value, list): # Invalid value - return false return False diff --git a/python-avd/pyavd/j2tests/defined.py b/python-avd/pyavd/j2tests/defined.py index f9a299e2eb6..c3e3e2984d5 100644 --- a/python-avd/pyavd/j2tests/defined.py +++ b/python-avd/pyavd/j2tests/defined.py @@ -9,13 +9,16 @@ from __future__ import annotations import warnings +from typing import Any from jinja2.runtime import Undefined -def defined(value, test_value=None, var_type=None, fail_action=None, var_name=None, run_tests=False): +def defined( + value: Any, test_value: Any = None, var_type: str | None = None, fail_action: str | None = None, var_name: str | None = None, *, run_tests: bool = False +) -> bool | tuple[bool, int]: """ - defined - Ansible test plugin to test if a variable is defined and not none + defined - Ansible test plugin to test if a variable is defined and not none. Arista.avd.defined will test value if defined and is not none and return true or false. If test_value is supplied, the value must also pass == test_value to return true. @@ -51,9 +54,9 @@ def defined(value, test_value=None, var_type=None, fail_action=None, var_name=No var_name : , optional Optional string to use as variable name in warning or error messages - Returns + Returns: ------- - boolean + bool True if variable matches criteria, False in other cases. """ if isinstance(value, Undefined) or value is None: @@ -62,16 +65,18 @@ def defined(value, test_value=None, var_type=None, fail_action=None, var_name=No warnings_count = {} if var_name is not None: warning_msg = f"{var_name} was expected but not set. Output may be incorrect or incomplete!" - warnings.warn(warning_msg) + warnings.warn(warning_msg) # noqa: B028 warnings_count["[WARNING]: " + warning_msg] = warnings_count.get("[WARNING]: " + warning_msg, 0) + 1 else: warning_msg = "A variable was expected but not set. Output may be incorrect or incomplete!" - warnings.warn(warning_msg) + warnings.warn(warning_msg) # noqa: B028 warnings_count["[WARNING]: " + warning_msg] = warnings_count.get("[WARNING]: " + warning_msg, 0) + 1 elif str(fail_action).lower() == "error": if var_name is not None: - raise ValueError(f"{var_name} was expected but not set!") - raise ValueError("A variable was expected but not set!") + msg = f"{var_name} was expected but not set!" + raise ValueError(msg) + msg = "A variable was expected but not set!" + raise ValueError(msg) if run_tests: return False, warnings_count return False @@ -82,16 +87,18 @@ def defined(value, test_value=None, var_type=None, fail_action=None, var_name=No warnings_count = {} if var_name is not None: warning_msg = f"{var_name} was set to {value} but we expected {test_value}. Output may be incorrect or incomplete!" - warnings.warn(warning_msg) + warnings.warn(warning_msg) # noqa: B028 warnings_count["[WARNING]: " + warning_msg] = warnings_count.get("[WARNING]: " + warning_msg, 0) + 1 else: warning_msg = f"A variable was set to {value} but we expected {test_value}. Output may be incorrect or incomplete!" - warnings.warn(warning_msg) + warnings.warn(warning_msg) # noqa: B028 warnings_count["[WARNING]: " + warning_msg] = warnings_count.get("[WARNING]: " + warning_msg, 0) + 1 elif str(fail_action).lower() == "error": if var_name is not None: - raise ValueError(f"{var_name} was set to {value} but we expected {test_value}!") - raise ValueError(f"A variable was set to {value} but we expected {test_value}!") + msg = f"{var_name} was set to {value} but we expected {test_value}!" + raise ValueError(msg) + msg = f"A variable was set to {value} but we expected {test_value}!" + raise ValueError(msg) if run_tests: return False, warnings_count return False @@ -101,16 +108,18 @@ def defined(value, test_value=None, var_type=None, fail_action=None, var_name=No warnings_count = {} if var_name is not None: warning_msg = f"{var_name} was a {type(value).__name__} but we expected a {str(var_type).lower()}. Output may be incorrect or incomplete!" - warnings.warn(warning_msg) + warnings.warn(warning_msg) # noqa: B028 warnings_count["[WARNING]: " + warning_msg] = warnings_count.get("[WARNING]: " + warning_msg, 0) + 1 else: warning_msg = f"A variable was a {type(value).__name__} but we expected a {str(var_type).lower()}. Output may be incorrect or incomplete!" - warnings.warn(warning_msg) + warnings.warn(warning_msg) # noqa: B028 warnings_count["[WARNING]: " + warning_msg] = warnings_count.get("[WARNING]: " + warning_msg, 0) + 1 elif str(fail_action).lower() == "error": if var_name is not None: - raise ValueError(f"{var_name} was a {type(value).__name__} but we expected a {str(var_type).lower()}!") - raise ValueError(f"A variable was a {type(value).__name__} but we expected a {str(var_type).lower()}!") + msg = f"{var_name} was a {type(value).__name__} but we expected a {str(var_type).lower()}!" + raise ValueError(msg) + msg = f"A variable was a {type(value).__name__} but we expected a {str(var_type).lower()}!" + raise ValueError(msg) if run_tests: return False, warnings_count return False diff --git a/python-avd/pyavd/templater.py b/python-avd/pyavd/templater.py index baa39a818b7..091d14c8c09 100644 --- a/python-avd/pyavd/templater.py +++ b/python-avd/pyavd/templater.py @@ -3,42 +3,46 @@ # that can be found in the LICENSE file. from __future__ import annotations -import os from pathlib import Path -from typing import Sequence +from typing import TYPE_CHECKING from jinja2 import ChoiceLoader, Environment, FileSystemLoader, ModuleLoader, StrictUndefined from .constants import JINJA2_EXTENSIONS, JINJA2_PRECOMPILED_TEMPLATE_PATH, JINJA2_TEMPLATE_PATHS, RUNNING_FROM_SRC +if TYPE_CHECKING: + import os + from collections.abc import Sequence + class Undefined(StrictUndefined): """ Allow nested checks for undefined instead of having to check on every level. + Example "{% if var.key.subkey is arista.avd.undefined %}" is ok. Without this it we would have to test every level, like "{% if var is arista.avd.undefined or var.key is arista.avd.undefined or var.key.subkey is arista.avd.undefined %}" """ - def __getattr__(self, _name): + def __getattr__(self, _name: str) -> Undefined: # Return original Undefined object to preserve the first failure context return self - def __getitem__(self, _key): + def __getitem__(self, _key: str) -> Undefined: # Return original Undefined object to preserve the first failure context return self - def __repr__(self): + def __repr__(self) -> str: return f"Undefined(hint={self._undefined_hint}, obj={self._undefined_obj}, name={self._undefined_name})" - def __contains__(self, _item): + def __contains__(self, _item: int) -> Undefined: # Return original Undefined object to preserve the first failure context return self class Templar: - def __init__(self, searchpaths: list[str] = None): + def __init__(self, searchpaths: list[str] | None = None) -> None: if not RUNNING_FROM_SRC: self.loader = ModuleLoader(JINJA2_PRECOMPILED_TEMPLATE_PATH) else: @@ -48,11 +52,11 @@ def __init__(self, searchpaths: list[str] = None): [ ModuleLoader(JINJA2_PRECOMPILED_TEMPLATE_PATH), FileSystemLoader(searchpaths), - ] + ], ) # Accepting SonarLint issue: No autoescaping is ok, since we are not using this for a website, so XSS is not applicable. - self.environment = Environment( # NOSONAR + self.environment = Environment( # NOSONAR # noqa: S701 extensions=JINJA2_EXTENSIONS, loader=self.loader, undefined=Undefined, @@ -99,20 +103,22 @@ def import_filters_and_tests(self) -> None: "arista.avd.range_expand": range_expand, "arista.avd.snmp_hash": snmp_hash, "arista.avd.status_render": status_render, - } + }, ) self.environment.tests.update( { "arista.avd.defined": defined, "arista.avd.contains": contains, - } + }, ) def render_template_from_file(self, template_file: str, template_vars: dict) -> str: return self.environment.get_template(template_file).render(template_vars) def compile_templates_in_paths(self, searchpaths: list[str]) -> None: - """Compile the Jinja2 templates in the path. + """ + Compile the Jinja2 templates in the path. + The FileSystemLoader tries to compile any file in the path no matter the extension so this uses a custom one. diff --git a/python-avd/pyavd/validate_inputs.py b/python-avd/pyavd/validate_inputs.py index 46385489e61..cb6652145e3 100644 --- a/python-avd/pyavd/validate_inputs.py +++ b/python-avd/pyavd/validate_inputs.py @@ -30,7 +30,7 @@ def validate_inputs(inputs: dict) -> ValidationResult: # pylint: enable=import-outside-toplevel # Initialize a global instance of eos_designs_schema_tools - global eos_designs_schema_tools + global eos_designs_schema_tools # noqa: PLW0603 TODO: improve code without global if eos_designs_schema_tools is None: eos_designs_schema_tools = AvdSchemaTools(schema_id=EOS_DESIGNS_SCHEMA_ID) diff --git a/python-avd/pyavd/validate_structured_config.py b/python-avd/pyavd/validate_structured_config.py index 71ae386b06c..102e159e3df 100644 --- a/python-avd/pyavd/validate_structured_config.py +++ b/python-avd/pyavd/validate_structured_config.py @@ -29,7 +29,7 @@ def validate_structured_config(structured_config: dict) -> ValidationResult: # pylint: enable=import-outside-toplevel # Initialize a global instance of eos_cli_config_gen_schema_tools - global eos_cli_config_gen_schema_tools + global eos_cli_config_gen_schema_tools # noqa: PLW0603 TODO: improve code to avoid globals if eos_cli_config_gen_schema_tools is None: eos_cli_config_gen_schema_tools = AvdSchemaTools(schema_id=EOS_CLI_CONFIG_GEN_SCHEMA_ID) diff --git a/python-avd/pyavd/validation_result.py b/python-avd/pyavd/validation_result.py index a429c53b654..c74f50188c0 100644 --- a/python-avd/pyavd/validation_result.py +++ b/python-avd/pyavd/validation_result.py @@ -11,7 +11,7 @@ class ValidationResult: """ - Object containing result of data validation + Object containing result of data validation. Attributes: failed: True if data is not valid according to the schema. Otherwise False. @@ -23,7 +23,7 @@ class ValidationResult: validation_errors: list[AvdValidationError] deprecation_warnings: list[AvdDeprecationWarning] - def __init__(self, failed: bool, validation_errors: list = None, deprecation_warnings: list = None): + def __init__(self, failed: bool, validation_errors: list | None = None, deprecation_warnings: list | None = None) -> None: self.failed = failed self.validation_errors = validation_errors or [] self.deprecation_warnings = deprecation_warnings or [] diff --git a/python-avd/pyproject.toml b/python-avd/pyproject.toml index 3ec659942fc..5b19429dc29 100644 --- a/python-avd/pyproject.toml +++ b/python-avd/pyproject.toml @@ -80,11 +80,17 @@ version = {attr = "pyavd.__version__"} [tool.black] line-length = 160 +force-exclude = '''pyavd/_cv/api/.*''' + [tool.isort] profile = "black" skip_gitignore = true line_length = 160 -known_third_party = ["pyavd"] +known_first_party = ["pyavd", "schema_tools"] + +extend_skip_glob = [ + "pyavd/_cv/api/**/*" +] [tool.coverage.run] branch = true @@ -123,3 +129,8 @@ exclude_also = [ "^blocks = {}", "^debug_info =.*", ] + +[tool.ruff] +# Extend the `pyproject.toml` file in the parent directory... +# Should not be needed, but the ruff vscode extension does not seem to respect the ruff behavior of ignoring files without tool.ruff +extend = "../pyproject.toml" diff --git a/python-avd/schema_tools/avdschemaresolver.py b/python-avd/schema_tools/avdschemaresolver.py index 6a0479761c8..1c7ab380434 100644 --- a/python-avd/schema_tools/avdschemaresolver.py +++ b/python-avd/schema_tools/avdschemaresolver.py @@ -4,18 +4,23 @@ from __future__ import annotations from copy import deepcopy +from typing import TYPE_CHECKING -from pyavd._utils import merge from referencing import Registry, Specification from referencing.exceptions import PointerToNowhere from referencing.jsonschema import DRAFT7, _legacy_anchor_in_dollar_id, _legacy_dollar_id, _maybe_in_subresource_crazy_items_dependencies +from pyavd._utils import merge + +if TYPE_CHECKING: + from collections.abc import Generator + class AvdSchemaResolver: - def __init__(self, base_schema_name: str, store: dict): + def __init__(self, base_schema_name: str, store: dict) -> None: self.resolver = self.create_resolver(store, base_uri=base_schema_name) - def resolve(self, resolved_schema: dict): + def resolve(self, resolved_schema: dict) -> dict: methods = { "items": self._items, "keys": self._keys, @@ -28,7 +33,7 @@ def resolve(self, resolved_schema: dict): return resolved_schema - def _keys(self, resolved_schema: dict): + def _keys(self, resolved_schema: dict) -> None: for key in resolved_schema["keys"]: # Resolve the child schema # Repeat in case new refs inherited from the first ref. @@ -37,7 +42,7 @@ def _keys(self, resolved_schema: dict): self.resolve(resolved_schema["keys"][key]) - def _dynamic_keys(self, resolved_schema: dict): + def _dynamic_keys(self, resolved_schema: dict) -> None: for key in resolved_schema["dynamic_keys"]: # Resolve the child schema # Repeat in case new refs inherited from the first ref. @@ -46,7 +51,7 @@ def _dynamic_keys(self, resolved_schema: dict): self.resolve(resolved_schema["dynamic_keys"][key]) - def _items(self, resolved_schema: dict): + def _items(self, resolved_schema: dict) -> None: # Resolve the child schema # Repeat in case new refs inherited from the first ref. while "$ref" in resolved_schema["items"]: @@ -54,29 +59,28 @@ def _items(self, resolved_schema: dict): self.resolve(resolved_schema["items"]) - def _ref_on_child(self, resolved_schema: dict): + def _ref_on_child(self, resolved_schema: dict) -> None: """ - This function resolves the $ref referenced schema, - then merges with any schema defined at the same level + This function resolves the $ref referenced schema, then merges with any schema defined at the same level. In place update of supplied resolved_schema """ try: resolved = self.resolver.lookup(resolved_schema["$ref"]) except PointerToNowhere: - raise RuntimeError( - ( - f"Unable to resolve $ref: '{resolved_schema['$ref']}'." - "Make sure to adhere to the strict format '^(eos_cli_config_gen|eos_designs)#(/[a-z$][a-z0-9_]*)*$'." - ) - ) from None + msg = ( + f"Unable to resolve $ref: '{resolved_schema['$ref']}'." + "Make sure to adhere to the strict format '^(eos_cli_config_gen|eos_designs)#(/[a-z$][a-z0-9_]*)*$'." + ) + raise RuntimeError(msg) from None ref_schema = deepcopy(resolved.contents) resolved_schema.pop("$ref") merge(resolved_schema, ref_schema, same_key_strategy="use_existing", list_merge="replace") - def create_resolver(self, store: dict, base_uri=""): + def create_resolver(self, store: dict, base_uri: str = "") -> object: """ Returns a resolver which can resolve "$ref" references across all AVD schemas. + The given "base_uri" can be used for relative references (currently not used in AVD). """ registry = self.create_registry(store) @@ -97,15 +101,13 @@ def create_registry(store: dict) -> Registry: from "referencing" which are also used for the builtin DRAFT7 specification. """ - def subresources(schema: dict): - """ - Generator of childschemas - """ - if "keys" in schema and schema["keys"]: + def subresources(schema: dict) -> Generator[dict, None, None]: + """Generator of childschemas.""" + if schema.get("keys"): yield from schema["keys"].values() - if "dynamic_keys" in schema and schema["dynamic_keys"]: + if schema.get("dynamic_keys"): yield from schema["dynamic_keys"].values() - if "$defs" in schema and schema["$defs"]: + if schema.get("$defs"): yield from schema["$defs"].values() if "items" in schema: yield schema["items"] diff --git a/python-avd/schema_tools/constants.py b/python-avd/schema_tools/constants.py index 0ace3ebfb7b..aaf8bf7b72e 100644 --- a/python-avd/schema_tools/constants.py +++ b/python-avd/schema_tools/constants.py @@ -2,6 +2,7 @@ # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. """AVD schematools constants.""" + from pathlib import Path REPO_ROOT = Path(__file__).parents[2] diff --git a/python-avd/schema_tools/generate_docs/mdtabsgen.py b/python-avd/schema_tools/generate_docs/mdtabsgen.py index 438f2b2d6ba..73224ccd4c9 100644 --- a/python-avd/schema_tools/generate_docs/mdtabsgen.py +++ b/python-avd/schema_tools/generate_docs/mdtabsgen.py @@ -4,17 +4,20 @@ from __future__ import annotations from textwrap import indent +from typing import TYPE_CHECKING + +from schema_tools.constants import LICENSE_HEADER -from ..constants import LICENSE_HEADER -from ..metaschema.meta_schema_model import AristaAvdSchema from .tablegen import get_table from .yamlgen import get_yaml +if TYPE_CHECKING: + from schema_tools.metaschema.meta_schema_model import AristaAvdSchema + def get_md_tabs(schema: AristaAvdSchema, target_table: str | None = None) -> str: """ - Generate the content of a markdown file with mkdocs tabs containing documentation - of of the schema optionally filtered using "target_table". + Generate the content of a markdown file with mkdocs tabs containing documentation of of the schema optionally filtered using "target_table". - Table tab contains a markdown table. - YAML tab contains a markdown code block with YAML. diff --git a/python-avd/schema_tools/generate_docs/tablegen.py b/python-avd/schema_tools/generate_docs/tablegen.py index 9a0ccbd0252..7e780359ea0 100644 --- a/python-avd/schema_tools/generate_docs/tablegen.py +++ b/python-avd/schema_tools/generate_docs/tablegen.py @@ -3,7 +3,10 @@ # that can be found in the LICENSE file. from __future__ import annotations -from ..metaschema.meta_schema_model import AristaAvdSchema +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from schema_tools.metaschema.meta_schema_model import AristaAvdSchema TABLE_HEADER = [ "| Variable | Type | Required | Default | Value Restrictions | Description |", @@ -12,9 +15,7 @@ def get_table(schema: AristaAvdSchema, target_table: str | None = None) -> str: - """ - Returns one markdown table either containing all keys of the given schema or only a subset if "target_table" is set. - """ + """Returns one markdown table either containing all keys of the given schema or only a subset if "target_table" is set.""" lines = [*TABLE_HEADER] lines.extend(str(row) for row in schema._generate_table_rows(target_table=target_table)) lines.append("") # Add final newline diff --git a/python-avd/schema_tools/generate_docs/tablerowgen.py b/python-avd/schema_tools/generate_docs/tablerowgen.py index 04e7af3a6c1..28383f8739f 100644 --- a/python-avd/schema_tools/generate_docs/tablerowgen.py +++ b/python-avd/schema_tools/generate_docs/tablerowgen.py @@ -3,15 +3,16 @@ # that can be found in the LICENSE file. from __future__ import annotations -from abc import ABC -from typing import TYPE_CHECKING, Generator +from typing import TYPE_CHECKING from pydantic import BaseModel from .utils import render_schema_field if TYPE_CHECKING: - from ..metaschema.meta_schema_model import AvdSchemaField + from collections.abc import Generator + + from schema_tools.metaschema.meta_schema_model import AvdSchemaField LEGACY_OUTPUT = False @@ -19,6 +20,7 @@ class TableRow(BaseModel): """ Dataclass for one table row. + Content is markdown formatted so it can be rendered directly. """ @@ -29,11 +31,11 @@ class TableRow(BaseModel): restrictions: str | None = None description: str | None = None - def __str__(self): + def __str__(self) -> str: return f"| {self.key} | {self.type} | {self.required or ''} | {self.default or ''} | {self.restrictions or ''} | {self.description or ''} |" -class TableRowGenBase(ABC): +class TableRowGenBase: """ Base class to be used with schema pydantic models. @@ -73,7 +75,7 @@ def generate_table_rows( def get_indentation(self) -> str: """ - Indentation is two spaces for dicts and 4 spaces for lists (so the hyphen will be indented 2) + Indentation is two spaces for dicts and 4 spaces for lists (so the hyphen will be indented 2). For the variable {"my":{"random":{"list":[]}}} the schema._path would be ["my", "random", "list", []]. The indentation would be 4*2-2+2 = 8 spaces. Since all items are simple values (not a dict with keys) @@ -99,9 +101,7 @@ def get_indentation(self) -> str: return i * indentation_count def get_deprecation_label(self) -> str | None: - """ - Returns None or a markdown formatted colored string with the deprecation status. - """ + """Returns None or a markdown formatted colored string with the deprecation status.""" if self.schema.deprecation is None: return "" @@ -110,9 +110,7 @@ def get_deprecation_label(self) -> str | None: return f' {label}' def get_deprecation_description(self) -> str | None: - """ - Returns None or a markdown formatted colored string with the deprecation description. - """ + """Returns None or a markdown formatted colored string with the deprecation description.""" if self.schema.deprecation is None: return None @@ -141,9 +139,7 @@ def get_deprecation_description(self) -> str | None: return f'{description}' def render_key(self) -> str: - """ - Renders markdown for "key" field including mouse-over and deprecation label with color. - """ + """Renders markdown for "key" field including mouse-over and deprecation label with color.""" path = ".".join(self.schema._path) if self.schema._key: @@ -161,9 +157,7 @@ def render_key(self) -> str: return f'[{self.get_indentation()}{key}](## "{path}"){self.get_deprecation_label()}' def render_type(self) -> str: - """ - Renders markdown for "type" field. - """ + """Renders markdown for "type" field.""" type_converters = { "str": "String", "int": "Integer", @@ -174,28 +168,24 @@ def render_type(self) -> str: return type_converters[self.schema.type] def render_required(self) -> str | None: - """ - Render markdown for "required" field. - """ + """Render markdown for "required" field.""" if self.schema._is_primary_key: return "Required, Unique" if self.schema._is_unique else "Required" if self.schema.required: return "Required" + return None def render_default(self) -> str | None: - """ - Should render markdown for "default" field. - """ + """Should render markdown for "default" field.""" if self.schema.default is not None: if isinstance(self.schema.default, (list, dict)) and (len(self.schema.default) > 1 or len(str(self.schema.default)) > 40): return "See (+) on YAML tab" return f"`{self.schema.default}`" + return None def render_description(self) -> str | None: - """ - Renders markdown for "description" field including deprecation text with color. - """ + """Renders markdown for "description" field including deprecation text with color.""" descriptions = [] if self.schema.description: descriptions.append(self.schema.description.replace("\n", "
")) @@ -210,14 +200,13 @@ def render_children(self) -> Generator[TableRow]: yield from [] def render_restrictions(self) -> str | None: - """ - Renders markdown for "restrictions" field as a multiline text compatible with a markdown table cell. - """ + """Renders markdown for "restrictions" field as a multiline text compatible with a markdown table cell.""" return "
".join(self.get_restrictions()) or None def get_restrictions(self) -> list: """ Returns a list of field restrictions to be rendered in the docs. + Only covers generic restrictions. Should be overridden in type specific subclasses. """ restrictions = [] @@ -247,6 +236,7 @@ class TableRowGenInt(TableRowGenBase): def get_restrictions(self) -> list: """ Returns a list of field restrictions to be rendered in the docs. + Leverages common restrictions from base class. """ restrictions = [] @@ -264,6 +254,7 @@ class TableRowGenStr(TableRowGenBase): def get_restrictions(self) -> list: """ Returns a list of field restrictions to be rendered in the docs. + Leverages common restrictions from base class. """ restrictions = [] @@ -290,9 +281,7 @@ def get_restrictions(self) -> list: class TableRowGenList(TableRowGenBase): def render_type(self) -> str: - """ - Renders markdown for "type" field. - """ + """Renders markdown for "type" field.""" type_converters = { "str": "String", "int": "Integer", @@ -309,6 +298,7 @@ def render_type(self) -> str: def get_restrictions(self) -> list: """ Returns a list of field restrictions to be rendered in the docs. + Leverages common restrictions from base class. """ restrictions = [] @@ -322,7 +312,7 @@ def get_restrictions(self) -> list: return restrictions def render_children(self) -> Generator[TableRow]: - """yields TableRow from each child class""" + """Yields TableRow from each child class.""" if not self.schema.items: return @@ -339,8 +329,7 @@ def render_children(self) -> Generator[TableRow]: class TableRowGenDict(TableRowGenBase): def render_children(self) -> Generator[TableRow]: - """yields TableRow from each child class""" - + """Yields TableRow from each child class.""" if self.schema.documentation_options and self.schema.documentation_options.hide_keys: # Skip generating table fields for children, if "hide_keys" is set. return diff --git a/python-avd/schema_tools/generate_docs/utils.py b/python-avd/schema_tools/generate_docs/utils.py index 60a10e6d451..7e23f94e027 100644 --- a/python-avd/schema_tools/generate_docs/utils.py +++ b/python-avd/schema_tools/generate_docs/utils.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from ..metaschema.meta_schema_model import AvdSchemaField + from schema_tools.metaschema.meta_schema_model import AvdSchemaField def render_schema_field(schema: AvdSchemaField, target_table: str | None) -> bool: diff --git a/python-avd/schema_tools/generate_docs/yamlgen.py b/python-avd/schema_tools/generate_docs/yamlgen.py index 9a3531dbc1e..30a2e82293d 100644 --- a/python-avd/schema_tools/generate_docs/yamlgen.py +++ b/python-avd/schema_tools/generate_docs/yamlgen.py @@ -3,12 +3,16 @@ # that can be found in the LICENSE file. from __future__ import annotations -from ..metaschema.meta_schema_model import AristaAvdSchema +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from schema_tools.metaschema.meta_schema_model import AristaAvdSchema def get_yaml(schema: AristaAvdSchema, target_table: str | None = None) -> str: """ Returns one markdown codeblock with YAML either containing all keys of the given schema or only a subset if "target_table" is set. + Also adds foot notes for use with mkdocs codeblock annotations as required. """ lines = [] diff --git a/python-avd/schema_tools/generate_docs/yamllinegen.py b/python-avd/schema_tools/generate_docs/yamllinegen.py index ba0d67744af..a74244de073 100644 --- a/python-avd/schema_tools/generate_docs/yamllinegen.py +++ b/python-avd/schema_tools/generate_docs/yamllinegen.py @@ -3,9 +3,8 @@ # that can be found in the LICENSE file. from __future__ import annotations -from abc import ABC from textwrap import indent -from typing import TYPE_CHECKING, Generator +from typing import TYPE_CHECKING import yaml from pydantic import BaseModel @@ -13,7 +12,9 @@ from .utils import render_schema_field if TYPE_CHECKING: - from ..metaschema.meta_schema_model import AvdSchemaField + from collections.abc import Generator + + from schema_tools.metaschema.meta_schema_model import AvdSchemaField LEGACY_OUTPUT = False @@ -21,6 +22,7 @@ class YamlLine(BaseModel): """ Dataclass for one YAML line (and any associated descriptions). + Content of line is yaml formatted so it should be rendered in a code block. The field may need an mkdocs code block annotation link to show large default values. If so the content of the annotation is stored in the 'annotation' attribute. @@ -35,12 +37,12 @@ class YamlLine(BaseModel): line: str annotation: str | None = None - def __str__(self): + def __str__(self) -> str: return self.line def render_annotation(self, annotation_number: int) -> str: """ - Returns markdown for annotation foot note, providing the contents of mkdocs code block annotation popup + Returns markdown for annotation foot note, providing the contents of mkdocs code block annotation popup. Like below (including the leading blank line): @@ -58,17 +60,18 @@ def render_annotation(self, annotation_number: int) -> str: def render_annotation_link(self, annotation_number: int) -> str: """ - Returns codeblock comment used for mkdocs codeblock annotations + Returns codeblock comment used for mkdocs codeblock annotations. Like: " # (123)!" """ return "" if self.annotation is None else f" # ({annotation_number})!" -class YamlLineGenBase(ABC): +class YamlLineGenBase: """ Base class to be used with schema pydantic models. - Provides the method "generate_yaml_lines" to build documentation tables + + Provides the method "generate_yaml_lines" to build documentation tables. """ def generate_yaml_lines( @@ -96,10 +99,8 @@ def generate_yaml_lines( yield from self.render_children() - def get_indentation(self, honor_first_list_key: bool = True) -> str: - """ - Indentation is two spaces for dicts and 4 spaces for lists (so the hyphen will be indented 2) - """ + def get_indentation(self, *, honor_first_list_key: bool = True) -> str: + """Indentation is two spaces for dicts and 4 spaces for lists (so the hyphen will be indented 2).""" indentation_count = len(self.schema._path) * 2 - 2 if not self.schema._key: # this is a flat list item so path is one shorter than for dict. So we add 2 to the indentation @@ -116,10 +117,7 @@ def is_removed(self) -> bool: return self.schema.deprecation and self.schema.deprecation.removed def render_field(self) -> Generator[YamlLine]: - """ - Renders YamlLines for this field including description. - """ - + """Renders YamlLines for this field including description.""" # Build semicolon separated list of field properties. value_fields = [ self.schema.type, @@ -128,10 +126,7 @@ def render_field(self) -> Generator[YamlLine]: self.get_required(), ] # TODO: Remove legacy output - if LEGACY_OUTPUT: - value = self.schema.type - else: - value = "; ".join(field for field in value_fields if field) + value = self.schema.type if LEGACY_OUTPUT else "; ".join(field for field in value_fields if field) key = f"{self.schema._key}: " if self.schema._key else "" @@ -141,9 +136,7 @@ def render_field(self) -> Generator[YamlLine]: ) def render_description(self) -> Generator[YamlLine]: - """ - Yields YamlLine with description for this field. - """ + """Yields YamlLine with description for this field.""" if self.schema.description: indentation = self.get_indentation(honor_first_list_key=False) description = indent(self.schema.description.strip(), f"{indentation}# ") @@ -152,9 +145,7 @@ def render_description(self) -> Generator[YamlLine]: yield YamlLine(line=f"\n{description}") def render_deprecation_description(self) -> Generator[YamlLine]: - """ - Yields YamlLine with deprecation description for this field. - """ + """Yields YamlLine with deprecation description for this field.""" if self.schema.deprecation is None: return @@ -176,18 +167,18 @@ def render_deprecation_description(self) -> Generator[YamlLine]: yield YamlLine(line=description) def get_required(self) -> str | None: - """ - Returns "required", "required; unique" or None depending on self.schema.required and self.is_primary_key - """ + """Returns "required", "required; unique" or None depending on self.schema.required and self.is_primary_key.""" if self.schema._is_primary_key: return "required; unique" if self.schema.required: return "required" + return None @property def needs_annotation_for_default_value(self) -> bool: """ Determines if this field should use a mkdocs codeblock annotation / popup to display the default value. + Is true for list or dict with length above 1. Otherwise false. """ return ( @@ -199,6 +190,7 @@ def needs_annotation_for_default_value(self) -> bool: def get_default(self) -> str | None: """ Returns default value or None. + For list or dict with len > 1 it will return none. See get_default_popup. """ @@ -207,15 +199,15 @@ def get_default(self) -> str | None: # Add quotes to string default value. return f'default="{self.schema.default}"' return f"default={self.schema.default}" + return None def get_annotation(self) -> str | None: if self.needs_annotation_for_default_value: return yaml.dump({self.schema._key: self.schema.default}, indent=2) + return None def render_restrictions(self) -> str | None: - """ - Returns restrictions as inline semicolon separated strings. - """ + """Returns restrictions as inline semicolon separated strings.""" return "; ".join(self.get_restrictions()) or None def get_restrictions(self) -> list: @@ -256,6 +248,7 @@ class YamlLineGenInt(YamlLineGenBase): def get_restrictions(self) -> list: """ Returns a list of restrictions. + Leverages common restrictions from base class. """ restrictions = [] @@ -276,6 +269,7 @@ class YamlLineGenStr(YamlLineGenBase): def get_restrictions(self) -> list: """ Returns a list of restrictions. + Leverages common restrictions from base class. """ restrictions = [] @@ -294,10 +288,7 @@ def get_restrictions(self) -> list: class YamlLineGenList(YamlLineGenBase): def render_field(self) -> Generator[YamlLine]: - """ - Renders YamlLine for this field. - """ - + """Renders YamlLine for this field.""" # Build semicolon separated list of field properties. properties_fields = [ self.render_restrictions(), @@ -324,6 +315,7 @@ def render_field(self) -> Generator[YamlLine]: def get_restrictions(self) -> list: """ Returns a list of restrictions. + Leverages common restrictions from base class. """ restrictions = [] @@ -340,7 +332,7 @@ def get_restrictions(self) -> list: return restrictions def render_children(self) -> Generator[YamlLine]: - """yields TableRow from each child class""" + """Yields TableRow from each child class.""" if not self.schema.items: return @@ -357,10 +349,7 @@ def render_children(self) -> Generator[YamlLine]: class YamlLineGenDict(YamlLineGenBase): def render_field(self) -> Generator[YamlLine]: - """ - Renders YamlLine for this field. - """ - + """Renders YamlLine for this field.""" # Build semicolon separated list of field properties. properties_fields = [ self.render_restrictions(), @@ -387,11 +376,9 @@ def render_field(self) -> Generator[YamlLine]: ) def render_children(self) -> Generator[YamlLine]: - """yields TableRow from each child class""" - + """Yields TableRow from each child class.""" if self.schema.documentation_options and self.schema.documentation_options.hide_keys: # Skip generating table fields for children, if "hide_keys" is set. - # print(f"Skipping path {self.path} since hide_keys is set") return if self.schema.dynamic_keys: diff --git a/python-avd/schema_tools/key_to_display_name.py b/python-avd/schema_tools/key_to_display_name.py index fc1d2e6c076..a0c2f7989f5 100644 --- a/python-avd/schema_tools/key_to_display_name.py +++ b/python-avd/schema_tools/key_to_display_name.py @@ -107,11 +107,10 @@ def key_to_display_name(key: str) -> str: if not isinstance(key, str): - raise ValueError(f"Invalid argument passed to 'key_to_display_name'. Must be a string. Got '{type(key)}'") + msg = f"Invalid argument passed to 'key_to_display_name'. Must be a string. Got '{type(key)}'" + raise TypeError(msg) words = key.split("_") - output = [] - for word in words: - output.append(WORDLIST.get(word.lower(), word.title())) + output = [WORDLIST.get(word.lower(), word.title()) for word in words] return " ".join(output) diff --git a/python-avd/schema_tools/metaschema/meta_schema_model.py b/python-avd/schema_tools/metaschema/meta_schema_model.py index bd672edd681..5617a2baf35 100644 --- a/python-avd/schema_tools/metaschema/meta_schema_model.py +++ b/python-avd/schema_tools/metaschema/meta_schema_model.py @@ -6,14 +6,18 @@ from abc import ABC from enum import Enum from functools import cached_property -from typing import Annotated, Any, ClassVar, Generator, List, Literal +from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal from pydantic import BaseModel, ConfigDict, Field, constr -from ..generate_docs.tablerowgen import TableRow, TableRowGenBase, TableRowGenBool, TableRowGenDict, TableRowGenInt, TableRowGenList, TableRowGenStr -from ..generate_docs.yamllinegen import YamlLine, YamlLineGenBase, YamlLineGenBool, YamlLineGenDict, YamlLineGenInt, YamlLineGenList, YamlLineGenStr +from schema_tools.generate_docs.tablerowgen import TableRow, TableRowGenBase, TableRowGenBool, TableRowGenDict, TableRowGenInt, TableRowGenList, TableRowGenStr +from schema_tools.generate_docs.yamllinegen import YamlLine, YamlLineGenBase, YamlLineGenBool, YamlLineGenDict, YamlLineGenInt, YamlLineGenList, YamlLineGenStr + from .resolvemodel import merge_schema_from_ref +if TYPE_CHECKING: + from collections.abc import Generator + """ This module provides Pydantic models (classes) representing the meta-schema of the AVD Schema. @@ -49,7 +53,7 @@ class AvdSchemaBaseModel(BaseModel, ABC): # Common nested models used by common fields class Deprecation(BaseModel): - """Deprecation settings""" + """Deprecation settings.""" warning: bool = True """Emit deprecation warning if key is set.""" @@ -65,7 +69,7 @@ class Deprecation(BaseModel): """URL detailing the deprecation and migration guidelines.""" class DocumentationOptions(BaseModel): - """Schema field options used for controlling documentation generation""" + """Schema field options used for controlling documentation generation.""" # Pydantic config option to forbid keys in the inputs that are not covered by the model model_config = ConfigDict(extra="forbid") @@ -120,15 +124,15 @@ class DocumentationOptions(BaseModel): # Signal to __init__ if the $ref in the schema should be resolved before initializing the pydantic model. _resolve_schema: ClassVar[bool] = True - def __init__(self, resolve_schema: bool | None = None, **data): + def __init__(self, resolve_schema: bool | None = None, **data: dict) -> None: """ - Overrides BaseModel.__init__(**data). + Overrides BaseModel.__init__. + Takes a kwarg "resolve_schema" which controls if all subclasses of AvdSchemaBaseModel should expand any $ref in the input schema. The $ref expansion _only_ covers this field. Any $ref on child fields are expanded as they are initialized by Pydantic since they are based on this base class. """ - # Setting the resolve_schema attribute on the class, so all sub-classes will inherit this automatically. if resolve_schema is not None: AvdSchemaBaseModel._resolve_schema = resolve_schema @@ -172,7 +176,8 @@ def _table(self) -> str | None: # This should never happen, since only the root key should be without a parent_schema. if len(self._path) != 1: - raise ValueError("Something went wrong in _table", self._path) + msg = "Something went wrong in _table" + raise ValueError(msg, self._path) # This is a root key the default table is the key with hyphens and removing <,> return self._key.replace("<", "").replace(">", "").replace("_", "-") @@ -181,19 +186,20 @@ def _table(self) -> str | None: def _path(self) -> list[str]: """ Returns the variable path for this field to be used in schema docs. + Like "rootkey.subkey.[].mykey". """ - # A list item has no key, so add "[]" to the parent schema for representing the list-item if not self._key: - return self._parent_schema._path + ["[]"] + return [*self._parent_schema._path, "[]"] # Add the key to the parent path - return self._parent_schema._path + [self._key] + return [*self._parent_schema._path, self._key] def _generate_table_rows(self, target_table: str | None = None) -> Generator[TableRow]: """ Yields "TableRow"s to be used in schema docs. + The function is called recursively inside the YamlLineGen classes for parsing children. """ # Using the Type of table row generator set in the subclass attribute _table_row_generator @@ -202,6 +208,7 @@ def _generate_table_rows(self, target_table: str | None = None) -> Generator[Tab def _generate_yaml_lines(self, target_table: str | None = None) -> Generator[YamlLine]: """ Yields "YamlLine"s to be used in schema docs. + The function is called recursively inside the YamlLineGen classes for parsing children. """ # Using the Type of yaml line generator set in the subclass attribute _yaml_line_generator @@ -225,7 +232,7 @@ class ConvertType(str, Enum): # AvdSchema field properties type: Literal["int"] - convert_types: List[ConvertType] | None = None + convert_types: list[ConvertType] | None = None """List of types to auto-convert from. For 'int' auto-conversion is supported from 'bool', 'str' and 'float'""" default: int | None = None """Default value""" @@ -233,7 +240,7 @@ class ConvertType(str, Enum): """Minimum value""" max: int | None = None """Maximum value""" - valid_values: List[int] | None = None + valid_values: list[int] | None = None """List of valid values""" dynamic_valid_values: str | None = None """ @@ -264,11 +271,11 @@ class ConvertType(str, Enum): # AvdSchema field properties type: Literal["bool"] - convert_types: List[ConvertType] | None = None + convert_types: list[ConvertType] | None = None """List of types to auto-convert from. For 'bool' auto-conversion is supported from 'int' and 'str'""" default: bool | None = None """Default value""" - valid_values: List[bool] | None = None + valid_values: list[bool] | None = None """List of valid values""" dynamic_valid_values: str | None = None """ @@ -307,14 +314,14 @@ class Format(str, Enum): cidr = "cidr" mac = "mac" - def __str__(self): + def __str__(self) -> str: return self.value # AvdSchema field properties type: Literal["str"] convert_to_lower_case: bool | None = False """Convert string value to lower case before performing validation""" - convert_types: List[ConvertType] | None = None + convert_types: list[ConvertType] | None = None """List of types to auto-convert from.\n\nFor 'str' auto-conversion is supported from 'bool' and 'int'""" default: str | None = None """Default value""" @@ -330,7 +337,7 @@ def __str__(self): The regular expression should be valid according to the ECMA 262 dialect. Remember to use double escapes. """ - valid_values: List[str] | None = None + valid_values: list[str] | None = None """List of valid values""" dynamic_valid_values: str | None = None """ @@ -362,13 +369,13 @@ class ConvertType(str, Enum): # AvdSchema field properties type: Literal["list"] - convert_types: List[ConvertType] | None = None + convert_types: list[ConvertType] | None = None """ List of types to auto-convert from. For 'list of dicts' auto-conversion is supported from 'dict' if 'primary_key' is set on the list schema. For other list item types conversion from dict will use the keys as list items. """ - default: List | None = None + default: list | None = None """Default value""" items: Annotated[AvdSchemaField, Field(discriminator="type")] | None = None """Schema for list items""" @@ -418,7 +425,8 @@ def _descendant_tables(self) -> set[str]: def model_post_init(self, __context: Any) -> None: """ - Overrides BaseModel.model_post_init(). + Overrides BaseModel.model_post_init. + Runs after this model including all child models have been initialized. Sets Internal attributes on child schema (if set): @@ -461,7 +469,7 @@ class AvdSchemaDict(AvdSchemaBaseModel): """ class DocumentationOptions(AvdSchemaBaseModel.DocumentationOptions): - """Extra schema field options used for controlling documentation generation for dicts""" + """Extra schema field options used for controlling documentation generation for dicts.""" hide_keys: bool | None = None # """ @@ -526,7 +534,8 @@ def _descendant_tables(self) -> set[str]: def model_post_init(self, __context: Any) -> None: """ - Overrides BaseModel.model_post_init(). + Overrides BaseModel.model_post_init. + Runs after this model including all child models have been initialized. Set Internal attributes on child schemas: diff --git a/python-avd/schema_tools/metaschema/resolvemodel.py b/python-avd/schema_tools/metaschema/resolvemodel.py index 21ce86e3185..1b721e0be86 100644 --- a/python-avd/schema_tools/metaschema/resolvemodel.py +++ b/python-avd/schema_tools/metaschema/resolvemodel.py @@ -8,7 +8,7 @@ from deepmerge import conservative_merger -from ..store import create_store +from schema_tools.store import create_store def merge_schema_from_ref(schema: dict) -> dict: @@ -27,13 +27,13 @@ def merge_schema_from_ref(schema: dict) -> dict: ref_schema = merge_schema_from_ref(get_schema_from_ref(ref)) if ref_schema["type"] != schema["type"]: # TODO: Consider if this should be a pyavd specific error - raise ValueError( + msg = ( f"Incompatible schema types from ref '{ref}' ref type '{ref_schema['type']}' schema type '{schema['type']}'\nschema: {schema}\nref_schema:" f" {ref_schema})" ) + raise ValueError(msg) - merged_schema = conservative_merger.merge(schema, ref_schema) - return merged_schema + return conservative_merger.merge(schema, ref_schema) @lru_cache @@ -46,17 +46,20 @@ def get_schema_from_ref(ref: str) -> dict: schema_store = create_store() if "#" not in ref: - raise ValueError("Missing # in ref") + msg = "Missing # in ref" + raise ValueError(msg) schema_name, ref = ref.split("#", maxsplit=1) if schema_name not in schema_store: - raise KeyError(f"Invalid schema name '{schema_name}'") + msg = f"Invalid schema name '{schema_name}'" + raise KeyError(msg) schema = schema_store[schema_name] path = ref.split("/") ref_schema = walk_schema(schema, path) if ref_schema is None: - raise KeyError(f"Unable to resolve schema ref '{ref}' for schema '{schema_name}'") + msg = f"Unable to resolve schema ref '{ref}' for schema '{schema_name}'" + raise KeyError(msg) return ref_schema diff --git a/python-avd/schema_tools/store.py b/python-avd/schema_tools/store.py index 412e5aef9fb..ca10c5257b0 100644 --- a/python-avd/schema_tools/store.py +++ b/python-avd/schema_tools/store.py @@ -4,6 +4,7 @@ from copy import deepcopy from functools import lru_cache from hashlib import sha1 +from pathlib import Path from pickle import HIGHEST_PROTOCOL from pickle import dump as pickle_dump from pickle import load as pickle_load @@ -15,9 +16,11 @@ @lru_cache -def create_store(load_from_yaml=False, force_rebuild=False) -> dict[str, dict]: +def create_store(*, load_from_yaml: bool = False, force_rebuild: bool = False) -> dict[str, dict]: """ - Create and return a schema store which is a dict of all our schemas like + Create and return a schema store. + + A schema store is a dict of all our schemas like { "avd_meta_schema": {...avd meta schema as dict...}, "eos_cli_config_gen": {...schema as dict...}, @@ -42,17 +45,15 @@ def create_store(load_from_yaml=False, force_rebuild=False) -> dict[str, dict]: return _compile_schemas() # Load from Pickle. - for id, schema_file in PICKLED_SCHEMAS.items(): - with open(schema_file, "rb") as file: - store[id] = pickle_load(file) + for schema_id, schema_file in PICKLED_SCHEMAS.items(): + with Path(schema_file).open("rb") as file: + store[schema_id] = pickle_load(file) # noqa: S301 return store def _should_recompile_schemas() -> bool: - """ - Returns true if pickled schemas should be recompiled - """ + """Returns true if pickled schemas should be recompiled.""" # Check if any pickled schema is missing for pickle_file in PICKLED_SCHEMAS.values(): if not pickle_file.exists(): @@ -73,21 +74,21 @@ def _should_recompile_schemas() -> bool: def _create_store_from_yaml() -> dict[str, dict]: - """ - Returns a schema store loaded from yaml files with $ref - """ + """Returns a schema store loaded from yaml files with $ref.""" store = {} - for id, schema_file in SCHEMA_PATHS.items(): - with open(schema_file, "r", encoding="UTF-8") as stream: - store[id] = safe_load(stream) + for schema_id, schema_file in SCHEMA_PATHS.items(): + with Path(schema_file).open(encoding="UTF-8") as stream: + store[schema_id] = safe_load(stream) return store def _compile_schemas() -> dict: """ - Load schemas from yaml files, + Resolve full schemas and save as pickle files. + + Load schemas from yaml files create a temporary "store", - resolve all $refs and save the resulting schemas as pickles + resolve all $refs and save the resulting schemas as pickles. """ schema_store = _create_store_from_yaml() @@ -123,7 +124,8 @@ def _compile_schemas() -> dict: def _resolve_schema(schema: dict, store: dict) -> dict: """ - Get fully resolved schema (where all $ref has been expanded recursively) + Get fully resolved schema (where all $ref has been expanded recursively). + .schemaresolver performs inplace update of the argument so we give it a copy of the existing schema. """ resolved_schema = deepcopy(schema) diff --git a/python-avd/scripts/build-schemas.py b/python-avd/scripts/build-schemas.py index a1e11d4438c..91119352b5b 100755 --- a/python-avd/scripts/build-schemas.py +++ b/python-avd/scripts/build-schemas.py @@ -27,28 +27,26 @@ def combine_schemas() -> None: for schema_name, fragments_path in SCHEMA_FRAGMENTS_PATHS.items(): print("Combining fragments", fragments_path) if schema_name not in SCHEMA_PATHS: - raise KeyError(f"Invalid schema name '{schema_name}'") + msg = f"Invalid schema name '{schema_name}'" + raise KeyError(msg) schema = {} for fragment_filename in sorted(fragments_path.glob(FRAGMENTS_PATTERN)): - # print("Combining fragment", fragment_filename) with fragment_filename.open(mode="r", encoding="UTF-8") as fragment_stream: schema = always_merger.merge(schema, yaml_load(fragment_stream, Loader=CSafeLoader)) with SCHEMA_PATHS[schema_name].open(mode="w", encoding="UTF-8") as schema_stream: schema_stream.write(indent(LICENSE_HEADER, prefix="# ") + "\n") schema_stream.write( - ( - "# yaml-language-server: $schema=../../../plugins/plugin_utils/schema/avd_meta_schema.json\n" - "# Line above is used by RedHat's YAML Schema vscode extension\n" - "# Use Ctrl + Space to get suggestions for every field. Autocomplete will pop up after typing 2 letters.\n" - ) + "# yaml-language-server: $schema=../../../plugins/plugin_utils/schema/avd_meta_schema.json\n" + "# Line above is used by RedHat's YAML Schema vscode extension\n" + "# Use Ctrl + Space to get suggestions for every field. Autocomplete will pop up after typing 2 letters.\n", ) schema_stream.write(yaml_dump(schema, Dumper=CSafeDumper, sort_keys=False)) -def build_schema_tables(schema_store) -> None: - """Build schema tables""" +def build_schema_tables(schema_store: dict) -> None: + """Build schema tables.""" for schema_name in SCHEMA_PATHS: if schema_name not in SCHEMA_FRAGMENTS_PATHS: continue @@ -59,7 +57,7 @@ def build_schema_tables(schema_store) -> None: for table_name in table_names: print(f"Building table: {table_name} from schema {schema_name}") table_file = output_dir.joinpath(f"{table_name}.md") - with open(table_file, mode="w", encoding="UTF-8") as file: + with Path(table_file).open(mode="w", encoding="UTF-8") as file: file.write(get_md_tabs(schema, table_name)) # Clean up other markdown files not covered by the tables. @@ -69,8 +67,9 @@ def build_schema_tables(schema_store) -> None: file.unlink() -def main(): - """Main entrypoint for the script. +def main() -> None: + """ + Main entrypoint for the script. It combines the schema fragments, and rebuild the pickled schemas. """ diff --git a/python-avd/scripts/custom_build_backend.py b/python-avd/scripts/custom_build_backend.py index c5a6c5e8b13..146c6c5d5f9 100644 --- a/python-avd/scripts/custom_build_backend.py +++ b/python-avd/scripts/custom_build_backend.py @@ -3,25 +3,20 @@ # that can be found in the LICENSE file. from pathlib import Path from subprocess import Popen -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from setuptools import build_meta as _orig from yaml import safe_load if not TYPE_CHECKING: - def __getattr__(name): - """ - Workaround to avoid 'from setuptools.build_meta import *' - """ + def __getattr__(name: str) -> Any: + """Workaround to avoid 'from setuptools.build_meta import *'.""" return locals().get(name, getattr(_orig, name)) def _translate_version(version: str, pyavd_prerelease: str) -> str: - """ - Translate an Ansible collection version to Python package version - """ - + """Translate an Ansible collection version to Python package version.""" avd_base_version = version.split("-", maxsplit=1)[0] if pyavd_prerelease: @@ -39,10 +34,10 @@ def _translate_version(version: str, pyavd_prerelease: str) -> str: def _insert_version() -> None: - with open(Path(__file__).parents[2].joinpath("ansible_collections/arista/avd/galaxy.yml"), encoding="UTF-8") as galaxy_file: + with Path(__file__).parents[2].joinpath("ansible_collections/arista/avd/galaxy.yml").open(encoding="UTF-8") as galaxy_file: ansible_version = dict(safe_load(galaxy_file)).get("version") - with open(Path(__file__).parents[1].joinpath("pyavd/__init__.py"), mode="r", encoding="UTF-8") as init_file: + with Path(__file__).parents[1].joinpath("pyavd/__init__.py").open(encoding="UTF-8") as init_file: init_lines = init_file.readlines() pyavd_prerelease = "" @@ -58,26 +53,28 @@ def _insert_version() -> None: init_lines[index] = f'__version__ = "{version}"\n' break - with open(Path(__file__).parents[1].joinpath("pyavd/__init__.py"), mode="w", encoding="UTF-8") as init_file: + with Path(__file__).parents[1].joinpath("pyavd/__init__.py").open(mode="w", encoding="UTF-8") as init_file: init_file.writelines(init_lines) -def get_requires_for_build_wheel(config_settings=None): +def get_requires_for_build_wheel(config_settings: dict | None = None) -> list[str]: print("Fetch version from ansible.avd ansible collection and insert into __init__.py") _insert_version() print("Running 'make dep' to generate compiled Jinja2 templates and schemas pickle files.") - with Popen("make dep", shell=True) as make_process: + with Popen("make dep", shell=True) as make_process: # noqa: S602,S607 if make_process.wait() != 0: - raise RuntimeError("Something went wrong during 'make dep'") + msg = "Something went wrong during 'make dep'" + raise RuntimeError(msg) return _orig.get_requires_for_build_wheel(config_settings) -def get_requires_for_build_editable(config_settings=None): +def get_requires_for_build_editable(config_settings: dict | None = None) -> list[str]: print("Running 'make dep' to generate compiled Jinja2 templates and schemas pickle files.") - with Popen("make dep", shell=True) as make_process: + with Popen("make dep", shell=True) as make_process: # noqa: S602,S607 if make_process.wait() != 0: - raise RuntimeError("Something went wrong during 'make dep'") + msg = "Something went wrong during 'make dep'" + raise RuntimeError(msg) return _orig.get_requires_for_build_editable(config_settings) diff --git a/python-avd/tests/pyavd/eos_designs/conftest.py b/python-avd/tests/pyavd/eos_designs/conftest.py index 7cbbe50738a..bc045a68180 100644 --- a/python-avd/tests/pyavd/eos_designs/conftest.py +++ b/python-avd/tests/pyavd/eos_designs/conftest.py @@ -5,6 +5,7 @@ from pathlib import Path import pytest + from pyavd import get_avd_facts from ...utils import read_file, read_vars diff --git a/python-avd/tests/pyavd/j2filters/test_add_md_toc.py b/python-avd/tests/pyavd/j2filters/test_add_md_toc.py index e4ebd5d05f8..2495efe55b5 100644 --- a/python-avd/tests/pyavd/j2filters/test_add_md_toc.py +++ b/python-avd/tests/pyavd/j2filters/test_add_md_toc.py @@ -8,6 +8,7 @@ from pathlib import Path import pytest + from pyavd.j2filters.add_md_toc import _get_anchor_id, add_md_toc DIR_PATH = Path(__file__).parent / "toc_files" @@ -26,7 +27,7 @@ class TestAddMdTocFilter: @pytest.mark.parametrize("skip_lines", SKIP_LINES_LIST) def test_add_md_toc(self, skip_lines): """Test add_md_toc success scenarii.""" - with open(MD_INPUT_VALID, "r", encoding="UTF-8") as input_file: + with Path(MD_INPUT_VALID).open("r", encoding="UTF-8") as input_file: resp = add_md_toc(input_file.read(), skip_lines=skip_lines, toc_levels=VALID_TOC_LEVEL, toc_marker=TOC_MARKER) with open(EXPECTED_TOC, "r", encoding="UTF-8") as input_file: @@ -36,19 +37,19 @@ def test_add_md_toc(self, skip_lines): def test_add_md_toc_invalid_skip_lines(self): """Test add_md_toc with invalid skip_lines.""" - with open(MD_INPUT_VALID, "r", encoding="UTF-8") as input_file: + with Path(MD_INPUT_VALID).open("r", encoding="UTF-8") as input_file: with pytest.raises(TypeError, match="add_md_toc 'skip_lines' argument must be an integer."): add_md_toc(input_file.read(), skip_lines="Not an int") def test_add_md_toc_invalid_toc_level(self): """Test add_md_toc with invalid toc level.""" - with open(MD_INPUT_VALID, "r", encoding="UTF-8") as input_file: + with Path(MD_INPUT_VALID).open("r", encoding="UTF-8") as input_file: with pytest.raises(TypeError): add_md_toc(input_file.read(), toc_levels=INVALID_TOC_LEVEL) def test_add_md_toc_invalid_toc_marker(self): """Test add_md_toc with invalid toc_marker.""" - with open(MD_INPUT_VALID, "r", encoding="UTF-8") as input_file: + with Path(MD_INPUT_VALID).open("r", encoding="UTF-8") as input_file: with pytest.raises(TypeError, match="add_md_toc 'toc_marker' argument must be a non-empty string."): add_md_toc(input_file.read(), toc_marker=["Not_as_string"]) @@ -59,16 +60,16 @@ def test_add_md_toc_invalid_md_input_type(self): def test_add_md_toc_invalid(self): """Test add_md_toc with invalid input file.""" - with open(MD_INPUT_INVALID, "r", encoding="UTF-8") as md_input_toc_invalid: + with Path(MD_INPUT_INVALID).open("r", encoding="UTF-8") as md_input_toc_invalid: with pytest.raises(ValueError, match="add_md_toc expects exactly two occurrences of the toc marker"): add_md_toc(md_input_toc_invalid.read()) def test_add_md_toc_btw_specific_markers(self): """Test to add the TOC at the end of the file using the specific markers features.""" - with open(DIR_PATH / "markers_at_bottom.md", "r", encoding="UTF-8") as input_file: + with DIR_PATH.joinpath("markers_at_bottom.md").open("r", encoding="UTF-8") as input_file: resp = add_md_toc(input_file.read(), skip_lines=0, toc_levels=2, toc_marker=TOC_MARKER) - with open(DIR_PATH / "expected_output_toc_at_bottom.md", "r", encoding="UTF-8") as expected_output: + with DIR_PATH.joinpath("expected_output_toc_at_bottom.md").open("r", encoding="UTF-8") as expected_output: assert resp == expected_output.read() def test__get_anchor_id_with_nonn_empty_anchor_id(self) -> None: diff --git a/python-avd/tests/pyavd/j2filters/test_convert_dict.py b/python-avd/tests/pyavd/j2filters/test_convert_dict.py index 81e0d95911c..e5b7a3e25ba 100644 --- a/python-avd/tests/pyavd/j2filters/test_convert_dict.py +++ b/python-avd/tests/pyavd/j2filters/test_convert_dict.py @@ -4,6 +4,7 @@ from __future__ import absolute_import, division, print_function import pytest + from pyavd.j2filters import convert_dicts DEFAULT_PRIMARY_KEY = "name" diff --git a/python-avd/tests/pyavd/j2filters/test_decrypt.py b/python-avd/tests/pyavd/j2filters/test_decrypt.py index 8c15192a7ab..0d4b0d735f6 100644 --- a/python-avd/tests/pyavd/j2filters/test_decrypt.py +++ b/python-avd/tests/pyavd/j2filters/test_decrypt.py @@ -6,6 +6,7 @@ from contextlib import nullcontext as does_not_raise import pytest + from pyavd.j2filters import decrypt @@ -27,8 +28,6 @@ ], ) def test_decrypt(password, passwd_type, key, kwargs, expected_raise): - """ - Test decrypt method for non existing and existing type - """ + """Test decrypt method for non existing and existing type.""" with expected_raise: decrypt(password, passwd_type=passwd_type, key=key, **kwargs) diff --git a/python-avd/tests/pyavd/j2filters/test_default.py b/python-avd/tests/pyavd/j2filters/test_default.py index 1751c684c20..984251aa2b4 100644 --- a/python-avd/tests/pyavd/j2filters/test_default.py +++ b/python-avd/tests/pyavd/j2filters/test_default.py @@ -3,6 +3,7 @@ # that can be found in the LICENSE file. import pytest from jinja2.runtime import Undefined + from pyavd.j2filters import default PRIMARY_VALUE_LIST = [1, "ABC", None, Undefined, {}, {"key": "value"}, [1, 2]] diff --git a/python-avd/tests/pyavd/j2filters/test_encrypt.py b/python-avd/tests/pyavd/j2filters/test_encrypt.py index f71c0bed3be..9a9ab91cab1 100644 --- a/python-avd/tests/pyavd/j2filters/test_encrypt.py +++ b/python-avd/tests/pyavd/j2filters/test_encrypt.py @@ -6,6 +6,7 @@ from contextlib import nullcontext as does_not_raise import pytest + from pyavd.j2filters import encrypt diff --git a/python-avd/tests/pyavd/j2filters/test_hide_passwords.py b/python-avd/tests/pyavd/j2filters/test_hide_passwords.py index 00ee0361650..531537c751c 100644 --- a/python-avd/tests/pyavd/j2filters/test_hide_passwords.py +++ b/python-avd/tests/pyavd/j2filters/test_hide_passwords.py @@ -6,6 +6,7 @@ __metaclass__ = type import pytest + from pyavd.j2filters import hide_passwords VALID_INPUT_HIDE_PASSWORDS = [ diff --git a/python-avd/tests/pyavd/j2filters/test_is_in_filter.py b/python-avd/tests/pyavd/j2filters/test_is_in_filter.py index 1787a1b8a62..bc3e6845fbd 100644 --- a/python-avd/tests/pyavd/j2filters/test_is_in_filter.py +++ b/python-avd/tests/pyavd/j2filters/test_is_in_filter.py @@ -4,6 +4,7 @@ from __future__ import absolute_import, division, print_function import pytest + from pyavd.j2filters import is_in_filter HOSTNAME_VALID = "test1.aristanetworks.com" diff --git a/python-avd/tests/pyavd/j2filters/test_list_compress.py b/python-avd/tests/pyavd/j2filters/test_list_compress.py index 68bb3048f02..1c9a6837ee3 100644 --- a/python-avd/tests/pyavd/j2filters/test_list_compress.py +++ b/python-avd/tests/pyavd/j2filters/test_list_compress.py @@ -6,6 +6,7 @@ __metaclass__ = type import pytest + from pyavd.j2filters import list_compress LIST_COMPRESS_INVALID_TESTS = [ diff --git a/python-avd/tests/pyavd/j2filters/test_natural_sort.py b/python-avd/tests/pyavd/j2filters/test_natural_sort.py index 52a3d947ffc..da9d05a6e94 100644 --- a/python-avd/tests/pyavd/j2filters/test_natural_sort.py +++ b/python-avd/tests/pyavd/j2filters/test_natural_sort.py @@ -5,6 +5,7 @@ from contextlib import nullcontext as does_not_raise import pytest + from pyavd.j2filters.natural_sort import convert, natural_sort diff --git a/python-avd/tests/pyavd/j2filters/test_range_expand.py b/python-avd/tests/pyavd/j2filters/test_range_expand.py index f151716f92c..405bf532c1b 100644 --- a/python-avd/tests/pyavd/j2filters/test_range_expand.py +++ b/python-avd/tests/pyavd/j2filters/test_range_expand.py @@ -5,6 +5,7 @@ from __future__ import annotations import pytest + from pyavd.j2filters import range_expand RANGE_TO_EXPAND_INVALID_VALUES = [ diff --git a/python-avd/tests/pyavd/j2filters/test_snmp_hash.py b/python-avd/tests/pyavd/j2filters/test_snmp_hash.py index b0faa4e08e6..6714e018030 100644 --- a/python-avd/tests/pyavd/j2filters/test_snmp_hash.py +++ b/python-avd/tests/pyavd/j2filters/test_snmp_hash.py @@ -8,6 +8,7 @@ from contextlib import nullcontext as does_not_raise import pytest + from pyavd.j2filters.snmp_hash import _PRIV_KEY_LENGTH, _get_hash_object, _key_from_passphrase, _localize_passphrase GET_HASH_OBJECT_TEST_CASES = [ diff --git a/python-avd/tests/pyavd/j2filters/test_status_render.py b/python-avd/tests/pyavd/j2filters/test_status_render.py index ed490788b20..66f30f03aa2 100644 --- a/python-avd/tests/pyavd/j2filters/test_status_render.py +++ b/python-avd/tests/pyavd/j2filters/test_status_render.py @@ -4,6 +4,7 @@ from __future__ import annotations import pytest + from pyavd.j2filters import status_render STATE_STRINGS = [("PASS", "github", ":white_check_mark:"), ("fail", "github", ":x:"), ("FAIL", "test", "FAIL")] diff --git a/python-avd/tests/pyavd/j2tests/test_contains.py b/python-avd/tests/pyavd/j2tests/test_contains.py index 2b05b05ec6d..6b68e60a433 100644 --- a/python-avd/tests/pyavd/j2tests/test_contains.py +++ b/python-avd/tests/pyavd/j2tests/test_contains.py @@ -7,6 +7,7 @@ import pytest from jinja2.runtime import Undefined + from pyavd.j2tests.contains import contains TEST_DATA = [ diff --git a/python-avd/tests/pyavd/j2tests/test_defined_plugin.py b/python-avd/tests/pyavd/j2tests/test_defined_plugin.py index 7fdc21471aa..aaac0eaf3d1 100644 --- a/python-avd/tests/pyavd/j2tests/test_defined_plugin.py +++ b/python-avd/tests/pyavd/j2tests/test_defined_plugin.py @@ -7,6 +7,7 @@ import pytest from jinja2.runtime import Undefined + from pyavd.j2tests.defined import defined VALUE_LIST = ["ab", None, 1, True, {"key": "value"}] diff --git a/python-avd/tests/pyavd/schema/test_avdschema.py b/python-avd/tests/pyavd/schema/test_avdschema.py index 1e5c972ab49..5806317ea64 100644 --- a/python-avd/tests/pyavd/schema/test_avdschema.py +++ b/python-avd/tests/pyavd/schema/test_avdschema.py @@ -6,6 +6,7 @@ import pytest import yaml from deepmerge import always_merger + from pyavd._errors import AvdValidationError from pyavd._schema.avdschema import DEFAULT_SCHEMA, AvdSchema diff --git a/python-avd/tests/pyavd/utils/merge/test_merge.py b/python-avd/tests/pyavd/utils/merge/test_merge.py index d7ae0898863..8ac457da051 100644 --- a/python-avd/tests/pyavd/utils/merge/test_merge.py +++ b/python-avd/tests/pyavd/utils/merge/test_merge.py @@ -9,6 +9,7 @@ import pytest import yaml + from pyavd._schema.avdschema import AvdSchema from pyavd._utils import merge diff --git a/python-avd/tests/pyavd/utils/password/test_password.py b/python-avd/tests/pyavd/utils/password/test_password.py index c7988b78d4f..de785336093 100644 --- a/python-avd/tests/pyavd/utils/password/test_password.py +++ b/python-avd/tests/pyavd/utils/password/test_password.py @@ -4,6 +4,7 @@ from __future__ import annotations import pytest + from pyavd._utils.password_utils import ( bgp_decrypt, bgp_encrypt, diff --git a/python-avd/tests/pyavd/utils/password/test_password_utils.py b/python-avd/tests/pyavd/utils/password/test_password_utils.py index bfe73fef76b..d891f06d2b4 100644 --- a/python-avd/tests/pyavd/utils/password/test_password_utils.py +++ b/python-avd/tests/pyavd/utils/password/test_password_utils.py @@ -4,6 +4,7 @@ from __future__ import annotations import pytest + from pyavd._utils.password_utils.password_utils import cbc_check_password, cbc_decrypt, cbc_encrypt # password used is "arista" diff --git a/python-avd/tests/pyavd/utils/test_get.py b/python-avd/tests/pyavd/utils/test_get.py index c64dd1ce020..75740f90008 100644 --- a/python-avd/tests/pyavd/utils/test_get.py +++ b/python-avd/tests/pyavd/utils/test_get.py @@ -9,6 +9,7 @@ from contextlib import contextmanager import pytest + from pyavd._errors import AristaAvdError from pyavd._utils import get diff --git a/python-avd/tests/pyavd/utils/test_get_ip_from_pool.py b/python-avd/tests/pyavd/utils/test_get_ip_from_pool.py index 472b6b64848..a2e477e3a3e 100644 --- a/python-avd/tests/pyavd/utils/test_get_ip_from_pool.py +++ b/python-avd/tests/pyavd/utils/test_get_ip_from_pool.py @@ -2,6 +2,7 @@ # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. import pytest + from pyavd._errors import AristaAvdError from pyavd._utils import get_ip_from_pool diff --git a/python-avd/tests/pyavd/utils/test_short_esi_to_route_target.py b/python-avd/tests/pyavd/utils/test_short_esi_to_route_target.py index 85f5a99d24a..1c1d6726270 100644 --- a/python-avd/tests/pyavd/utils/test_short_esi_to_route_target.py +++ b/python-avd/tests/pyavd/utils/test_short_esi_to_route_target.py @@ -6,6 +6,7 @@ __metaclass__ = type import pytest + from pyavd._utils import short_esi_to_route_target ESI_TO_RT_TEST_CASES = [ diff --git a/python-avd/tests/pyavd/utils/test_strip_empties.py b/python-avd/tests/pyavd/utils/test_strip_empties.py index 961e045d865..86716ce3407 100644 --- a/python-avd/tests/pyavd/utils/test_strip_empties.py +++ b/python-avd/tests/pyavd/utils/test_strip_empties.py @@ -6,6 +6,7 @@ __metaclass__ = type import pytest + from pyavd._utils import strip_null_from_data STRIP_EMPTIES_LIST = {