Skip to content

Commit

Permalink
🔧 Add strict typing for sphinx_needs.utils
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisjsewell committed Aug 29, 2023
1 parent ab5ea94 commit ce9ac85
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 38 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ module = [
'sphinx_needs.directives.needuml',
'sphinx_needs.functions.functions',
'sphinx_needs.layout',
'sphinx_needs.utils',
]
ignore_errors = true

Expand Down
2 changes: 1 addition & 1 deletion sphinx_needs/functions/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def resolve_variants_options(env: BuildEnvironment):
needs = data.get_or_create_needs()
for need in needs.values():
# Data to use as filter context.
need_context: Dict = {**need}
need_context: Dict[str, Any] = {**need}
need_context.update(**needs_config.filter_data) # Add needs_filter_data to filter context
_sphinx_tags = env.app.builder.tags.tags # Get sphinx tags
need_context.update(**_sphinx_tags) # Add sphinx tags to filter context
Expand Down
98 changes: 62 additions & 36 deletions sphinx_needs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,22 @@
import re
from functools import reduce, wraps
from re import Pattern
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from urllib.parse import urlparse

from docutils import nodes
from jinja2 import BaseLoader, Environment, Template
from jinja2 import Environment, Template
from matplotlib.figure import FigureBase
from sphinx.application import BuildEnvironment, Sphinx

Expand All @@ -18,9 +29,23 @@
from sphinx_needs.defaults import NEEDS_PROFILING
from sphinx_needs.logging import get_logger

try:
from typing import TypedDict
except ImportError:
from typing_extensions import TypedDict

if TYPE_CHECKING:
from sphinx_needs.functions.functions import DynamicFunction

logger = get_logger(__name__)

NEEDS_FUNCTIONS = {}

class NeedFunctionsType(TypedDict):
name: str
function: "DynamicFunction"


NEEDS_FUNCTIONS: Dict[str, NeedFunctionsType] = {}

# List of internal need option names. They should not be used by or presented to user.
INTERNALS = [
Expand Down Expand Up @@ -116,14 +141,15 @@ def row_col_maker(
for v in needs_config.string_links.values():
needs_string_links_option.extend(v["options"])

if need_key in need_info and need_info[need_key] is not None:
if isinstance(need_info[need_key], (list, set)):
data = need_info[need_key]
elif isinstance(need_info[need_key], str) and need_key in needs_string_links_option:
data = re.split(r",|;", need_info[need_key])
if need_key in need_info and need_info[need_key] is not None: # type: ignore[literal-required]
value = need_info[need_key] # type: ignore[literal-required]
if isinstance(value, (list, set)):
data = value
elif isinstance(value, str) and need_key in needs_string_links_option:
data = re.split(r",|;", value)
data = [i.strip() for i in data if len(i) != 0]
else:
data = [need_info[need_key]]
data = [value]

for index, datum in enumerate(data):
link_id = datum
Expand All @@ -138,10 +164,8 @@ def row_col_maker(
link_string_list = {}
for link_name, link_conf in needs_config.string_links.items():
link_string_list[link_name] = {
"url_template": Environment(loader=BaseLoader, autoescape=True).from_string(link_conf["link_url"]),
"name_template": Environment(loader=BaseLoader, autoescape=True).from_string(
link_conf["link_name"]
),
"url_template": Environment(autoescape=True).from_string(link_conf["link_url"]),
"name_template": Environment(autoescape=True).from_string(link_conf["link_name"]),
"regex_compiled": re.compile(link_conf["regex"]),
"options": link_conf["options"],
"name": link_name,
Expand Down Expand Up @@ -170,6 +194,7 @@ def row_col_maker(

if make_ref:
if need_info["is_external"]:
assert need_info["external_url"] is not None, "external_url must be set for external needs"
ref_col["refuri"] = check_and_calc_base_url_rel_path(need_info["external_url"], fromdocname)
ref_col["classes"].append(need_info["external_css"])
row_col["classes"].append(need_info["external_css"])
Expand All @@ -179,6 +204,7 @@ def row_col_maker(
elif ref_lookup:
temp_need = all_needs[link_id]
if temp_need["is_external"]:
assert temp_need["external_url"] is not None, "external_url must be set for external needs"
ref_col["refuri"] = check_and_calc_base_url_rel_path(temp_need["external_url"], fromdocname)
ref_col["classes"].append(temp_need["external_css"])
row_col["classes"].append(temp_need["external_css"])
Expand Down Expand Up @@ -263,9 +289,9 @@ def profile(keyword: str) -> Callable[[FuncT], FuncT]:
Activation only happens, if given keyword is part of ``needs_profiling``.
"""

def inner(func):
def inner(func): # type: ignore
@wraps(func)
def wrapper(*args, **kwargs):
def wrapper(*args, **kwargs): # type: ignore
with cProfile.Profile() as pr:
result = func(*args, **kwargs)

Expand Down Expand Up @@ -304,11 +330,11 @@ def check_and_calc_base_url_rel_path(external_url: str, fromdocname: str) -> str
return ref_uri


def check_and_get_external_filter_func(filter_func_ref: Optional[str]):
def check_and_get_external_filter_func(filter_func_ref: Optional[str]) -> Tuple[Any, str]:
"""Check and import filter function from external python file."""
# Check if external filter code is defined
filter_func = None
filter_args = []
filter_args = ""

if filter_func_ref:
try:
Expand All @@ -317,23 +343,25 @@ def check_and_get_external_filter_func(filter_func_ref: Optional[str]):
logger.warning(
f'Filter function not valid "{filter_func_ref}". Example: my_module:my_func [needs]', type="needs"
)
return [] # No needs were found because of invalid filter function
return filter_func, filter_args

result = re.search(r"^(\w+)(?:\((.*)\))*$", filter_function)
if not result:
return filter_func, filter_args
filter_function = result.group(1)
filter_args = result.group(2) or []
filter_args = result.group(2) or ""

try:
final_module = importlib.import_module(filter_module)
filter_func = getattr(final_module, filter_function)
except Exception:
logger.warning(f"Could not import filter function: {filter_func_ref} [needs]", type="needs")
return []
return filter_func, filter_args

return filter_func, filter_args


def jinja_parse(context: Dict, jinja_string: str) -> str:
def jinja_parse(context: Dict[str, Any], jinja_string: str) -> str:
"""
Function to parse mapping options set to a string containing jinja template format.
Expand Down Expand Up @@ -393,7 +421,7 @@ def save_matplotlib_figure(app: Sphinx, figure: FigureBase, basename: str, fromd
return image_node


def dict_get(root, items, default=None) -> Any:
def dict_get(root: Dict[str, Any], items: Any, default: Any = None) -> Any:
"""
Access a nested object in root by item sequence.
Expand All @@ -411,7 +439,7 @@ def dict_get(root, items, default=None) -> Any:


def match_string_link(
text_item: str, data: str, need_key: str, matching_link_confs: List[Dict], render_context: Dict[str, Any]
text_item: str, data: str, need_key: str, matching_link_confs: List[Dict[str, Any]], render_context: Dict[str, Any]
) -> Any:
try:
link_name = None
Expand All @@ -438,23 +466,22 @@ def match_string_link(
return ref_item


def match_variants(option_value: Union[str, List], keywords: Dict, needs_variants: Dict) -> Union[str, List, None]:
def match_variants(
option_value: Union[str, List[str]], keywords: Dict[str, Any], needs_variants: Dict[str, str]
) -> Union[None, str, List[str]]:
"""
Function to handle variant option management.
:param option_value: Value assigned to an option
:type option_value: Union[str, List]
:param keywords: Data to use as filtering context
:type keywords: Dict
:param needs_variants: Needs variants data set in users conf.py
:type needs_variants: Dict
:return: A string, list, or None to be used as value for option.
:rtype: Union[str, List, None]
"""

def variant_handling(
variant_definitions: List, variant_data: Dict, variant_pattern: Pattern
) -> Union[str, List, None]:
variant_definitions: List[str], variant_data: Dict[str, Any], variant_pattern: Pattern # type: ignore[type-arg]
) -> Optional[str]:
filter_context = variant_data
# filter_result = []
no_variants_in_option = False
Expand Down Expand Up @@ -505,8 +532,8 @@ def variant_handling(

# Handling multiple variant definitions
if isinstance(option_value, str):
multiple_variants: List = variant_splitting.split(rf"""{option_value}""")
multiple_variants: List = [
multiple_variants: List[str] = variant_splitting.split(rf"""{option_value}""")
multiple_variants = [
re.sub(r"^([;, ]+)|([;, ]+$)", "", i) for i in multiple_variants if i not in (None, ";", "", " ")
]
if len(multiple_variants) == 1 and not variant_rule_matching.search(multiple_variants[0]):
Expand All @@ -516,7 +543,7 @@ def variant_handling(
return option_value
return new_option_value
elif isinstance(option_value, (list, set, tuple)):
multiple_variants: List = list(option_value)
multiple_variants = list(option_value)
# In case an option value is a list (:tags: open; close), and does not contain any variant definition,
# then return the unmodified value
# options = all([bool(not variant_rule_matching.search(i)) for i in multiple_variants])
Expand Down Expand Up @@ -574,10 +601,9 @@ def node_match(node_types: Union[Type[nodes.Element], List[Type[nodes.Element]]]
:param node_types: List of docutils node types
:return: function, which can be used as constraint-function for docutils findall()
"""
if not isinstance(node_types, list):
node_types = [node_types]
node_types_list = node_types if isinstance(node_types, list) else [node_types]

def condition(node, node_types=node_types):
def condition(node: nodes.Node, node_types: List[Type[nodes.Element]] = node_types_list) -> bool:
return any(isinstance(node, x) for x in node_types)

return condition
Expand Down Expand Up @@ -619,7 +645,7 @@ def _is_valid(link_type: str) -> bool:

def get_scale(options: Dict[str, Any], location: Any) -> str:
"""Get scale for diagram, from directive option."""
scale = options.get("scale", "100").replace("%", "")
scale: str = options.get("scale", "100").replace("%", "")
if not scale.isdigit():
logger.warning(
f'scale value must be a number. "{scale}" found [needs]',
Expand Down

0 comments on commit ce9ac85

Please sign in to comment.