diff --git a/bin/meshroom_batch b/bin/meshroom_batch index 36b8fef689..6bee4c1f6d 100755 --- a/bin/meshroom_batch +++ b/bin/meshroom_batch @@ -154,10 +154,10 @@ with meshroom.core.graph.GraphModification(graph): # initialize template pipeline loweredPipelineTemplates = dict((k.lower(), v) for k, v in meshroom.core.pipelineTemplates.items()) if args.pipeline.lower() in loweredPipelineTemplates: - graph.load(loweredPipelineTemplates[args.pipeline.lower()], setupProjectFile=False, publishOutputs=True if args.output else False) + graph.initFromTemplate(loweredPipelineTemplates[args.pipeline.lower()], publishOutputs=True if args.output else False) else: # custom pipeline - graph.load(args.pipeline, setupProjectFile=False, publishOutputs=True if args.output else False) + graph.initFromTemplate(args.pipeline, publishOutputs=True if args.output else False) def parseInputs(inputs, uniqueInitNode): """Utility method for parsing the input and inputRecursive arguments.""" diff --git a/meshroom/core/graph.py b/meshroom/core/graph.py index e63aceca1a..f2f7d7dfb6 100644 --- a/meshroom/core/graph.py +++ b/meshroom/core/graph.py @@ -4,6 +4,7 @@ import logging import os import re +from typing import Any, Optional import weakref from collections import defaultdict, OrderedDict from contextlib import contextmanager @@ -16,7 +17,10 @@ from meshroom.core import Version from meshroom.core.attribute import Attribute, ListAttribute, GroupAttribute from meshroom.core.exception import GraphCompatibilityError, StopGraphVisit, StopBranchVisit -from meshroom.core.node import nodeFactory, Status, Node, CompatibilityNode +from meshroom.core.graphIO import GraphIO, GraphSerializer, TemplateGraphSerializer, PartialGraphSerializer +from meshroom.core.node import BaseNode, Status, Node, CompatibilityNode +from meshroom.core.nodeFactory import nodeFactory +from meshroom.core.typing import PathLike # Replace default encoder to support Enums @@ -148,6 +152,21 @@ def decorator(self, *args, **kwargs): return decorator +def blockNodeCallbacks(func): + """ + Graph methods loading serialized graph content must be decorated with 'blockNodeCallbacks', + to avoid attribute changed callbacks defined on node descriptions to be triggered during + this process. + """ + def inner(self, *args, **kwargs): + self._loading = True + try: + return func(self, *args, **kwargs) + finally: + self._loading = False + return inner + + class Graph(BaseObject): """ _________________ _________________ _________________ @@ -165,52 +184,6 @@ class Graph(BaseObject): """ _cacheDir = "" - class IO(object): - """ Centralize Graph file keys and IO version. """ - __version__ = "2.0" - - class Keys(object): - """ File Keys. """ - # Doesn't inherit enum to simplify usage (Graph.IO.Keys.XX, without .value) - Header = "header" - NodesVersions = "nodesVersions" - ReleaseVersion = "releaseVersion" - FileVersion = "fileVersion" - Graph = "graph" - - class Features(Enum): - """ File Features. """ - Graph = "graph" - Header = "header" - NodesVersions = "nodesVersions" - PrecomputedOutputs = "precomputedOutputs" - NodesPositions = "nodesPositions" - - @staticmethod - def getFeaturesForVersion(fileVersion): - """ Return the list of supported features based on a file version. - - Args: - fileVersion (str, Version): the file version - - Returns: - tuple of Graph.IO.Features: the list of supported features - """ - if isinstance(fileVersion, str): - fileVersion = Version(fileVersion) - - features = [Graph.IO.Features.Graph] - if fileVersion >= Version("1.0"): - features += [Graph.IO.Features.Header, - Graph.IO.Features.NodesVersions, - Graph.IO.Features.PrecomputedOutputs, - ] - - if fileVersion >= Version("1.1"): - features += [Graph.IO.Features.NodesPositions] - - return tuple(features) - def __init__(self, name, parent=None): super(Graph, self).__init__(parent) self.name = name @@ -224,7 +197,6 @@ def __init__(self, name, parent=None): self._nodes = DictModel(keyAttrName='name', parent=self) # Edges: use dst attribute as unique key since it can only have one input connection self._edges = DictModel(keyAttrName='dst', parent=self) - self._importedNodes = DictModel(keyAttrName='name', parent=self) self._compatibilityNodes = DictModel(keyAttrName='name', parent=self) self.cacheDir = meshroom.core.defaultCacheFolder self._filepath = '' @@ -232,20 +204,22 @@ def __init__(self, name, parent=None): self.header = {} def clear(self): + self._clearGraphContent() self.header.clear() - self._compatibilityNodes.clear() + self._unsetFilepath() + + def _clearGraphContent(self): self._edges.clear() # Tell QML nodes are going to be deleted for node in self._nodes: node.alive = False - self._importedNodes.clear() self._nodes.clear() - self._unsetFilepath() + self._compatibilityNodes.clear() @property def fileFeatures(self): """ Get loaded file supported features based on its version. """ - return Graph.IO.getFeaturesForVersion(self.header.get(Graph.IO.Keys.FileVersion, "0.0")) + return GraphIO.getFeaturesForVersion(self.header.get(GraphIO.Keys.FileVersion, "0.0")) @property def isLoading(self): @@ -253,37 +227,83 @@ def isLoading(self): return self._loading @Slot(str) - def load(self, filepath, setupProjectFile=True, importProject=False, publishOutputs=False): + def load(self, filepath: PathLike): """ - Load a Meshroom graph ".mg" file. + Load a Meshroom Graph ".mg" file in place. Args: - filepath: project filepath to load - setupProjectFile: Store the reference to the project file and setup the cache directory. - If false, it only loads the graph of the project file as a template. - importProject: True if the project that is loaded will be imported in the current graph, instead - of opened. - publishOutputs: True if "Publish" nodes from templates should not be ignored. + filepath: The path to the Meshroom Graph file to load. """ - self._loading = True - try: - return self._load(filepath, setupProjectFile, importProject, publishOutputs) - finally: - self._loading = False + self._deserialize(Graph._loadGraphData(filepath)) + self._setFilepath(filepath) + self._fileDateVersion = os.path.getmtime(filepath) + + def initFromTemplate(self, filepath: PathLike, publishOutputs: bool = False): + """ + Deserialize a template Meshroom Graph ".mg" file in place. + + When initializing from a template, the internal filepath of the graph instance is not set. + Saving the file on disk will require to specify a filepath. + + Args: + filepath: The path to the Meshroom Graph file to load. + publishOutputs: (optional) Whether to keep 'Publish' nodes. + """ + self._deserialize(Graph._loadGraphData(filepath)) + + if not publishOutputs: + for node in [node for node in self.nodes if node.nodeType == "Publish"]: + self.removeNode(node.name) + + @staticmethod + def _loadGraphData(filepath: PathLike) -> dict: + """Deserialize the content of the Meshroom Graph file at `filepath` to a dictionnary.""" + with open(filepath) as file: + graphData = json.load(file) + return graphData + + @blockNodeCallbacks + def _deserialize(self, graphData: dict): + """Deserialize `graphData` in the current Graph instance. + + Args: + graphData: The serialized Graph. + """ + self.clear() + self.header = graphData.get(GraphIO.Keys.Header, {}) + fileVersion = Version(self.header.get(GraphIO.Keys.FileVersion, "0.0")) + graphContent = self._normalizeGraphContent(graphData, fileVersion) + isTemplate = self.header.get(GraphIO.Keys.Template, False) + + with GraphModification(self): + # iterate over nodes sorted by suffix index in their names + for nodeName, nodeData in sorted( + graphContent.items(), key=lambda x: self.getNodeIndexFromName(x[0]) + ): + self._deserializeNode(nodeData, nodeName, self) - def _load(self, filepath, setupProjectFile, importProject, publishOutputs): - if not importProject: - self.clear() - with open(filepath) as jsonFile: - fileData = json.load(jsonFile) + # Create graph edges by resolving attributes expressions + self._applyExpr() + + # Templates are specific: they contain only the minimal amount of + # serialized data to describe the graph structure. + # They are not meant to be computed: therefore, we can early return here, + # as uid conflict evaluation is only meaningful for nodes with computed data. + if isTemplate: + return - self.header = fileData.get(Graph.IO.Keys.Header, {}) + # By this point, the graph has been fully loaded and an updateInternals has been triggered, so all the + # nodes' links have been resolved and their UID computations are all complete. + # It is now possible to check whether the UIDs stored in the graph file for each node correspond to the ones + # that were computed. + self._evaluateUidConflicts(graphContent) - fileVersion = self.header.get(Graph.IO.Keys.FileVersion, "0.0") - # Retro-compatibility for all project files with the previous UID format - if Version(fileVersion) < Version("2.0"): + def _normalizeGraphContent(self, graphData: dict, fileVersion: Version) -> dict: + graphContent = graphData.get(GraphIO.Keys.Graph, graphData) + + if fileVersion < Version("2.0"): # For internal folders, all "{uid0}" keys should be replaced with "{uid}" - updatedFileData = json.dumps(fileData).replace("{uid0}", "{uid}") + updatedFileData = json.dumps(graphContent).replace("{uid0}", "{uid}") # For fileVersion < 2.0, the nodes' UID is stored as: # "uids": {"0": "hashvalue"} @@ -295,239 +315,123 @@ def _load(self, filepath, setupProjectFile, importProject, publishOutputs): uid = occ.split("\"")[-2] # UID is second to last element newUidStr = r'"uid": "{}"'.format(uid) updatedFileData = updatedFileData.replace(occ, newUidStr) - fileData = json.loads(updatedFileData) - - # Older versions of Meshroom files only contained the serialized nodes - graphData = fileData.get(Graph.IO.Keys.Graph, fileData) - - if importProject: - self._importedNodes.clear() - graphData = self.updateImportedProject(graphData) - - if not isinstance(graphData, dict): - raise RuntimeError('loadGraph error: Graph is not a dict. File: {}'.format(filepath)) - - nodesVersions = self.header.get(Graph.IO.Keys.NodesVersions, {}) - - self._fileDateVersion = os.path.getmtime(filepath) - - # Check whether the file was saved as a template in minimal mode - isTemplate = self.header.get("template", False) - - with GraphModification(self): - # iterate over nodes sorted by suffix index in their names - for nodeName, nodeData in sorted(graphData.items(), key=lambda x: self.getNodeIndexFromName(x[0])): - if not isinstance(nodeData, dict): - raise RuntimeError('loadGraph error: Node is not a dict. File: {}'.format(filepath)) - - # retrieve version from - # 1. nodeData: node saved from a CompatibilityNode - # 2. nodesVersion in file header: node saved from a Node - # 3. fallback to no version "0.0": retro-compatibility - if "version" not in nodeData: - nodeData["version"] = nodesVersions.get(nodeData["nodeType"], "0.0") - - # if the node is a "Publish" node and comes from a template file, it should be ignored - # unless publishOutputs is True - if isTemplate and not publishOutputs and nodeData["nodeType"] == "Publish": - continue - - n = nodeFactory(nodeData, nodeName, template=isTemplate) - - # Add node to the graph with raw attributes values - self._addNode(n, nodeName) - - if importProject: - self._importedNodes.add(n) + graphContent = json.loads(updatedFileData) + + return graphContent + + def _deserializeNode(self, nodeData: dict, nodeName: str, fromGraph: "Graph"): + # Retrieve version from + # 1. nodeData: node saved from a CompatibilityNode + # 2. nodesVersion in file header: node saved from a Node + # 3. fallback behavior: default to "0.0" + if "version" not in nodeData: + nodeData["version"] = fromGraph._getNodeTypeVersionFromHeader(nodeData["nodeType"], "0.0") + inTemplate = fromGraph.header.get(GraphIO.Keys.Template, False) + node = nodeFactory(nodeData, nodeName, inTemplate=inTemplate) + self._addNode(node, nodeName) + return node - # Create graph edges by resolving attributes expressions - self._applyExpr() + def _getNodeTypeVersionFromHeader(self, nodeType: str, default: Optional[str] = None) -> Optional[str]: + nodeVersions = self.header.get(GraphIO.Keys.NodesVersions, {}) + return nodeVersions.get(nodeType, default) - if setupProjectFile: - # Update filepath related members - # Note: needs to be done at the end as it will trigger an updateInternals. - self._setFilepath(filepath) - elif not isTemplate: - # If no filepath is being set but the graph is not a template, trigger an updateInternals either way. - self.updateInternals() - - # By this point, the graph has been fully loaded and an updateInternals has been triggered, so all the - # nodes' links have been resolved and their UID computations are all complete. - # It is now possible to check whether the UIDs stored in the graph file for each node correspond to the ones - # that were computed. - if not isTemplate: # UIDs are not stored in templates - self._evaluateUidConflicts(graphData) - try: - self._applyExpr() - except Exception as e: - logging.warning(e) - - return True - - def _evaluateUidConflicts(self, data): + def _evaluateUidConflicts(self, graphContent: dict): """ - Compare the UIDs of all the nodes in the graph with the UID that is expected in the graph file. If there + Compare the computed UIDs of all the nodes in the graph with the UIDs serialized in `graphContent`. If there are mismatches, the nodes with the unexpected UID are replaced with "UidConflict" compatibility nodes. - Already existing nodes are removed and re-added to the graph identically to preserve all the edges, - which may otherwise be invalidated when a node with output edges but a UID conflict is re-generated as a - compatibility node. - - Args: - data (dict): the dictionary containing all the nodes to import and their data - """ - for nodeName, nodeData in sorted(data.items(), key=lambda x: self.getNodeIndexFromName(x[0])): - node = self.node(nodeName) - - savedUid = nodeData.get("uid", None) - graphUid = node._uid # Node's UID from the graph itself - - if savedUid != graphUid and graphUid is not None: - # Different UIDs, remove the existing node from the graph and replace it with a CompatibilityNode - logging.debug("UID conflict detected for {}".format(nodeName)) - self.removeNode(nodeName) - n = nodeFactory(nodeData, nodeName, template=False, uidConflict=True) - self._addNode(n, nodeName) - else: - # f connecting nodes have UID conflicts and are removed/re-added to the graph, some edges may be lost: - # the links will be erroneously updated, and any further resolution will fail. - # Recreating the entire graph as it was ensures that all edges will be correctly preserved. - self.removeNode(nodeName) - n = nodeFactory(nodeData, nodeName, template=False, uidConflict=False) - self._addNode(n, nodeName) - - def updateImportedProject(self, data): - """ - Update the names and links of the project to import so that it can fit - correctly in the existing graph. - - Parse all the nodes from the project that is going to be imported. - If their name already exists in the graph, replace them with new names, - then parse all the nodes' inputs/outputs to replace the old names with - the new ones in the links. - + Args: - data (dict): the dictionary containing all the nodes to import and their data - - Returns: - updatedData (dict): the dictionary containing all the nodes to import with their updated names and data + graphContent: The serialized Graph content. """ - nameCorrespondences = {} # maps the old node name to its updated one - updatedData = {} # input data with updated node names and links - - def createUniqueNodeName(nodeNames, inputName): - """ - Create a unique name that does not already exist in the current graph or in the list - of nodes that will be imported. - """ - i = 1 - while i: - newName = "{name}_{index}".format(name=inputName, index=i) - if newName not in nodeNames and newName not in updatedData.keys(): - return newName - i += 1 - - # First pass to get all the names that already exist in the graph, update them, and keep track of the changes - for nodeName, nodeData in sorted(data.items(), key=lambda x: self.getNodeIndexFromName(x[0])): - if not isinstance(nodeData, dict): - raise RuntimeError('updateImportedProject error: Node is not a dict.') - - if nodeName in self._nodes.keys() or nodeName in updatedData.keys(): - newName = createUniqueNodeName(self._nodes.keys(), nodeData["nodeType"]) - updatedData[newName] = nodeData - nameCorrespondences[nodeName] = newName - - else: - updatedData[nodeName] = nodeData - - newNames = [nodeName for nodeName in updatedData] # names of all the nodes that will be added - - # Second pass to update all the links in the input/output attributes for every node with the new names - for nodeName, nodeData in updatedData.items(): - nodeType = nodeData.get("nodeType", None) - nodeDesc = meshroom.core.nodesDesc[nodeType] - - inputs = nodeData.get("inputs", {}) - outputs = nodeData.get("outputs", {}) - - if inputs: - inputs = self.updateLinks(inputs, nameCorrespondences) - inputs = self.resetExternalLinks(inputs, nodeDesc.inputs, newNames) - updatedData[nodeName]["inputs"] = inputs - if outputs: - outputs = self.updateLinks(outputs, nameCorrespondences) - outputs = self.resetExternalLinks(outputs, nodeDesc.outputs, newNames) - updatedData[nodeName]["outputs"] = outputs - return updatedData + def _serializedNodeUidMatchesComputedUid(nodeData: dict, node: BaseNode) -> bool: + """Returns whether the serialized UID matches the one computed in the `node` instance.""" + if isinstance(node, CompatibilityNode): + return True + serializedUid = nodeData.get("uid", None) + computedUid = node._uid + return serializedUid is None or computedUid is None or serializedUid == computedUid + + uidConflictingNodes = [ + node + for node in self.nodes + if not _serializedNodeUidMatchesComputedUid(graphContent[node.name], node) + ] + + if not uidConflictingNodes: + return - @staticmethod - def updateLinks(attributes, nameCorrespondences): - """ - Update all the links that refer to nodes that are going to be imported and whose - names have to be updated. + logging.warning("UID Compatibility issues found: recreating conflicting nodes as CompatibilityNodes.") + + # A uid conflict is contagious: if a node has a uid conflict, all of its downstream nodes may be + # impacted as well, as the uid flows through connections. + # Therefore, we deal with conflicting uid nodes by depth: replacing a node with a CompatibilityNode restores + # the serialized uid, which might solve "false-positives" downstream conflicts as well. + nodesSortedByDepth = sorted(uidConflictingNodes, key=lambda node: node.minDepth) + for node in nodesSortedByDepth: + nodeData = graphContent[node.name] + # Evaluate if the node uid is still conflicting at this point, or if it has been resolved by an + # upstream node replacement. + if _serializedNodeUidMatchesComputedUid(nodeData, node): + continue + expectedUid = node._uid + compatibilityNode = nodeFactory(graphContent[node.name], node.name, expectedUid=expectedUid) + # This operation will trigger a graph update that will recompute the uids of all nodes, + # allowing the iterative resolution of uid conflicts. + self.replaceNode(node.name, compatibilityNode) + + + def importGraphContentFromFile(self, filepath: PathLike) -> list[Node]: + """Import the content (nodes and edges) of another Graph file into this Graph instance. Args: - attributes (dict): attributes whose links need to be updated - nameCorrespondences (dict): node names to replace in the links with the name to replace them with + filepath: The path to the Graph file to import. Returns: - attributes (dict): the attributes with all the updated links + The list of newly created Nodes. """ - for key, val in attributes.items(): - for corr in nameCorrespondences.keys(): - if isinstance(val, str) and corr in val: - attributes[key] = val.replace(corr, nameCorrespondences[corr]) - elif isinstance(val, list): - for v in val: - if isinstance(v, str): - if corr in v: - val[val.index(v)] = v.replace(corr, nameCorrespondences[corr]) - else: # the list does not contain strings, so there cannot be links to update - break - attributes[key] = val - - return attributes + graph = loadGraph(filepath) + return self.importGraphContent(graph) - @staticmethod - def resetExternalLinks(attributes, nodeDesc, newNames): + @blockNodeCallbacks + def importGraphContent(self, graph: "Graph") -> list[Node]: """ - Reset all links to nodes that are not part of the nodes which are going to be imported: - if there are links to nodes that are not in the list, then it means that the references - are made to external nodes, and we want to get rid of those. + Import the content (node and edges) of another `graph` into this Graph instance. + + Nodes are imported with their original names if possible, otherwise a new unique name is generated + from their node type. Args: - attributes (dict): attributes whose links might need to be reset - nodeDesc (list): list with all the attributes' description (including their default value) - newNames (list): names of the nodes that are going to be imported; no node name should be referenced - in the links except those contained in this list + graph: The graph to import. Returns: - attributes (dict): the attributes with all the links referencing nodes outside those which will be imported - reset to their default values + The list of newly created Nodes. """ - for key, val in attributes.items(): - defaultValue = None - for desc in nodeDesc: - if desc.name == key: - defaultValue = desc.value - break - - if isinstance(val, str): - if Attribute.isLinkExpression(val) and not any(name in val for name in newNames): - if defaultValue is not None: # prevents from not entering condition if defaultValue = '' - attributes[key] = defaultValue - - elif isinstance(val, list): - removedCnt = len(val) # counter to know whether all the list entries will be deemed invalid - tmpVal = list(val) # deep copy to ensure we iterate over the entire list (even if elements are removed) - for v in tmpVal: - if isinstance(v, str) and Attribute.isLinkExpression(v) and not any(name in v for name in newNames): - val.remove(v) - removedCnt -= 1 - if removedCnt == 0 and defaultValue is not None: # if all links were wrong, reset the attribute - attributes[key] = defaultValue - - return attributes + + def _renameClashingNodes(): + if not self.nodes: + return + unavailableNames = set(self.nodes.keys()) + for node in graph.nodes: + if node._name in unavailableNames: + node._name = self._createUniqueNodeName(node.nodeType, unavailableNames) + unavailableNames.add(node._name) + + def _importNodeAndEdges() -> list[Node]: + importedNodes = [] + # If we import the content of the graph within itself, + # iterate over a copy of the nodes as the graph is modified during the iteration. + nodes = graph.nodes if graph is not self else list(graph.nodes) + with GraphModification(self): + for srcNode in nodes: + node = self._deserializeNode(srcNode.toDict(), srcNode.name, graph) + importedNodes.append(node) + self._applyExpr() + return importedNodes + + _renameClashingNodes() + importedNodes = _importNodeAndEdges() + return importedNodes @property def updateEnabled(self): @@ -642,41 +546,6 @@ def duplicateNodes(self, srcNodes): return duplicates - def pasteNodes(self, data, position): - """ - Paste node(s) in the graph with their connections. The connections can only be between - the pasted nodes and not with the rest of the graph. - - Args: - data (dict): the dictionary containing the information about the nodes to paste, with their names and - links already updated to be added to the graph - position (list): the list of positions for each node to paste - - Returns: - list: the list of Node objects that were pasted and added to the graph - """ - nodes = [] - with GraphModification(self): - positionCnt = 0 # always valid because we know the data is sorted the same way as the position list - for key in sorted(data): - nodeType = data[key].get("nodeType", None) - if not nodeType: # this case should never occur, as the data should have been prefiltered first - pass - - attributes = {} - attributes.update(data[key].get("inputs", {})) - attributes.update(data[key].get("outputs", {})) - attributes.update(data[key].get("internalInputs", {})) - - node = Node(nodeType, position=position[positionCnt], **attributes) - self._addNode(node, key) - - nodes.append(node) - positionCnt += 1 - - self._applyExpr() - return nodes - def outEdges(self, attribute): """ Return the list of edges starting from the given attribute """ # type: (Attribute,) -> [Edge] @@ -740,8 +609,6 @@ def removeNode(self, nodeName): node.alive = False self._nodes.remove(node) - if node in self._importedNodes: - self._importedNodes.remove(node) self.update() return inEdges, outEdges, outListAttributes @@ -766,18 +633,26 @@ def addNewNode(self, nodeType, name=None, position=None, **kwargs): n.updateInternals() return n - def _createUniqueNodeName(self, inputName): - i = 1 - while i: - newName = "{name}_{index}".format(name=inputName, index=i) - if newName not in self._nodes.objects: + def _createUniqueNodeName(self, inputName: str, existingNames: Optional[set[str]] = None): + """Create a unique node name based on the input name. + + Args: + inputName: The desired node name. + existingNames: (optional) If specified, consider this set for uniqueness check, instead of the list of nodes. + """ + existingNodeNames = existingNames or set(self._nodes.objects.keys()) + + idx = 1 + while idx: + newName = f"{inputName}_{idx}" + if newName not in existingNodeNames: return newName - i += 1 + idx += 1 def node(self, nodeName): return self._nodes.get(nodeName) - def upgradeNode(self, nodeName): + def upgradeNode(self, nodeName) -> Node: """ Upgrade the CompatibilityNode identified as 'nodeName' Args: @@ -797,25 +672,49 @@ def upgradeNode(self, nodeName): if not isinstance(node, CompatibilityNode): raise ValueError("Upgrade is only available on CompatibilityNode instances.") upgradedNode = node.upgrade() + self.replaceNode(nodeName, upgradedNode) + return upgradedNode + + @changeTopology + def replaceNode(self, nodeName: str, newNode: BaseNode): + """Replace the node idenfitied by `nodeName` with `newNode`, while restoring compatible edges. + + Args: + nodeName: The name of the Node to replace. + newNode: The Node instance to replace it with. + """ with GraphModification(self): - inEdges, outEdges, outListAttributes = self.removeNode(nodeName) - self.addNode(upgradedNode, nodeName) - for dst, src in outEdges.items(): - # Re-create the entries in ListAttributes that were completely removed during the call to "removeNode" - # If they are not re-created first, adding their edges will lead to errors - # 0 = attribute name, 1 = attribute index, 2 = attribute value - if dst in outListAttributes.keys(): - listAttr = self.attribute(outListAttributes[dst][0]) - if isinstance(outListAttributes[dst][2], list): - listAttr[outListAttributes[dst][1]:outListAttributes[dst][1]] = outListAttributes[dst][2] - else: - listAttr.insert(outListAttributes[dst][1], outListAttributes[dst][2]) - try: - self.addEdge(self.attribute(src), self.attribute(dst)) - except (KeyError, ValueError) as e: - logging.warning("Failed to restore edge {} -> {}: {}".format(src, dst, str(e))) - - return upgradedNode, inEdges, outEdges, outListAttributes + _, outEdges, outListAttributes = self.removeNode(nodeName) + self.addNode(newNode, nodeName) + self._restoreOutEdges(outEdges, outListAttributes) + + def _restoreOutEdges(self, outEdges: dict[str, str], outListAttributes): + """Restore output edges that were removed during a call to "removeNode". + + Args: + outEdges: a dictionary containing the outgoing edges removed by a call to "removeNode". + {dstAttr.getFullNameToNode(), srcAttr.getFullNameToNode()} + outListAttributes: a dictionary containing the values, indices and keys of attributes that were connected + to a ListAttribute prior to the removal of all edges. + {dstAttr.getFullNameToNode(), (dstAttr.root.getFullNameToNode(), dstAttr.index, dstAttr.value)} + """ + def _recreateTargetListAttributeChildren(listAttrName: str, index: int, value: Any): + listAttr = self.attribute(listAttrName) + if not isinstance(listAttr, ListAttribute): + return + if isinstance(value, list): + listAttr[index:index] = value + else: + listAttr.insert(index, value) + + for dstName, srcName in outEdges.items(): + # Re-create the entries in ListAttributes that were completely removed during the call to "removeNode" + if dstName in outListAttributes: + _recreateTargetListAttributeChildren(*outListAttributes[dstName]) + try: + self.addEdge(self.attribute(srcName), self.attribute(dstName)) + except (KeyError, ValueError) as e: + logging.warning(f"Failed to restore edge {srcName} -> {dstName}: {str(e)}") def upgradeAllNodes(self): """ Upgrade all upgradable CompatibilityNode instances in the graph. """ @@ -1346,39 +1245,41 @@ def toDict(self): def asString(self): return str(self.toDict()) + def copy(self) -> "Graph": + """Create a copy of this Graph instance.""" + graph = Graph("") + graph._deserialize(self.serialize()) + return graph + + def serialize(self, asTemplate: bool = False) -> dict: + """Serialize this Graph instance. + + Args: + asTemplate: Whether to use the template serialization. + + Returns: + The serialized graph data. + """ + SerializerClass = TemplateGraphSerializer if asTemplate else GraphSerializer + return SerializerClass(self).serialize() + + def serializePartial(self, nodes: list[Node]) -> dict: + """Partially serialize this graph considering only the given list of `nodes`. + + Args: + nodes: The list of nodes to serialize. + + Returns: + The serialized graph data. + """ + return PartialGraphSerializer(self, nodes=nodes).serialize() + def save(self, filepath=None, setupProjectFile=True, template=False): path = filepath or self._filepath if not path: raise ValueError("filepath must be specified for unsaved files.") - self.header[Graph.IO.Keys.ReleaseVersion] = meshroom.__version__ - self.header[Graph.IO.Keys.FileVersion] = Graph.IO.__version__ - - # Store versions of node types present in the graph (excluding CompatibilityNode instances) - # and remove duplicates - usedNodeTypes = set([n.nodeDesc.__class__ for n in self._nodes if isinstance(n, Node)]) - # Convert to node types to "name: version" - nodesVersions = { - "{}".format(p.__name__): meshroom.core.nodeVersion(p, "0.0") - for p in usedNodeTypes - } - # Sort them by name (to avoid random order changing from one save to another) - nodesVersions = dict(sorted(nodesVersions.items())) - # Add it the header - self.header[Graph.IO.Keys.NodesVersions] = nodesVersions - self.header["template"] = template - - data = {} - if template: - data = { - Graph.IO.Keys.Header: self.header, - Graph.IO.Keys.Graph: self.getNonDefaultInputAttributes() - } - else: - data = { - Graph.IO.Keys.Header: self.header, - Graph.IO.Keys.Graph: self.toDict() - } + data = self.serialize(template) with open(path, 'w') as jsonFile: json.dump(data, jsonFile, indent=4) @@ -1389,51 +1290,6 @@ def save(self, filepath=None, setupProjectFile=True, template=False): # update the file date version self._fileDateVersion = os.path.getmtime(path) - def getNonDefaultInputAttributes(self): - """ - Instead of getting all the inputs and internal attribute keys, only get the keys of - the attributes whose value is not the default one. - The output attributes, UIDs, parallelization parameters and internal folder are - not relevant for templates, so they are explicitly removed from the returned dictionary. - - Returns: - dict: self.toDict() with the output attributes, UIDs, parallelization parameters, internal folder - and input/internal attributes with default values removed - """ - graph = self.toDict() - for nodeName in graph.keys(): - node = self.node(nodeName) - - inputKeys = list(graph[nodeName]["inputs"].keys()) - - internalInputKeys = [] - internalInputs = graph[nodeName].get("internalInputs", None) - if internalInputs: - internalInputKeys = list(internalInputs.keys()) - - for attrName in inputKeys: - attribute = node.attribute(attrName) - # check that attribute is not a link for choice attributes - if attribute.isDefault and not attribute.isLink: - del graph[nodeName]["inputs"][attrName] - - for attrName in internalInputKeys: - attribute = node.internalAttribute(attrName) - # check that internal attribute is not a link for choice attributes - if attribute.isDefault and not attribute.isLink: - del graph[nodeName]["internalInputs"][attrName] - - # If all the internal attributes are set to their default values, remove the entry - if len(graph[nodeName]["internalInputs"]) == 0: - del graph[nodeName]["internalInputs"] - - del graph[nodeName]["outputs"] - del graph[nodeName]["uid"] - del graph[nodeName]["internalFolder"] - del graph[nodeName]["parallelization"] - - return graph - def _setFilepath(self, filepath): """ Set the internal filepath of this Graph. @@ -1592,11 +1448,6 @@ def nodes(self): def edges(self): return self._edges - @property - def importedNodes(self): - """" Return the list of nodes that were added to the graph with the latest 'Import Project' action. """ - return self._importedNodes - @property def cacheDir(self): return self._cacheDir @@ -1636,7 +1487,7 @@ def setVerbose(self, v): edges = Property(BaseObject, edges.fget, constant=True) filepathChanged = Signal() filepath = Property(str, lambda self: self._filepath, notify=filepathChanged) - fileReleaseVersion = Property(str, lambda self: self.header.get(Graph.IO.Keys.ReleaseVersion, "0.0"), + fileReleaseVersion = Property(str, lambda self: self.header.get(GraphIO.Keys.ReleaseVersion, "0.0"), notify=filepathChanged) fileDateVersion = Property(float, fileDateVersion.fget, fileDateVersion.fset, notify=filepathChanged) cacheDirChanged = Signal() diff --git a/meshroom/core/graphIO.py b/meshroom/core/graphIO.py new file mode 100644 index 0000000000..196888036a --- /dev/null +++ b/meshroom/core/graphIO.py @@ -0,0 +1,229 @@ +from enum import Enum +from typing import Any, TYPE_CHECKING, Union + +import meshroom +from meshroom.core import Version +from meshroom.core.attribute import Attribute, GroupAttribute, ListAttribute +from meshroom.core.node import Node + +if TYPE_CHECKING: + from meshroom.core.graph import Graph + + +class GraphIO: + """Centralize Graph file keys and IO version.""" + + __version__ = "2.0" + + class Keys(object): + """File Keys.""" + + # Doesn't inherit enum to simplify usage (GraphIO.Keys.XX, without .value) + Header = "header" + NodesVersions = "nodesVersions" + ReleaseVersion = "releaseVersion" + FileVersion = "fileVersion" + Graph = "graph" + Template = "template" + + class Features(Enum): + """File Features.""" + + Graph = "graph" + Header = "header" + NodesVersions = "nodesVersions" + PrecomputedOutputs = "precomputedOutputs" + NodesPositions = "nodesPositions" + + @staticmethod + def getFeaturesForVersion(fileVersion: Union[str, Version]) -> tuple["GraphIO.Features", ...]: + """Return the list of supported features based on a file version. + + Args: + fileVersion (str, Version): the file version + + Returns: + tuple of GraphIO.Features: the list of supported features + """ + if isinstance(fileVersion, str): + fileVersion = Version(fileVersion) + + features = [GraphIO.Features.Graph] + if fileVersion >= Version("1.0"): + features += [ + GraphIO.Features.Header, + GraphIO.Features.NodesVersions, + GraphIO.Features.PrecomputedOutputs, + ] + + if fileVersion >= Version("1.1"): + features += [GraphIO.Features.NodesPositions] + + return tuple(features) + + +class GraphSerializer: + """Standard Graph serializer.""" + + def __init__(self, graph: "Graph") -> None: + self._graph = graph + + def serialize(self) -> dict: + """ + Serialize the Graph. + """ + return { + GraphIO.Keys.Header: self.serializeHeader(), + GraphIO.Keys.Graph: self.serializeContent(), + } + + @property + def nodes(self) -> list[Node]: + return self._graph.nodes + + def serializeHeader(self) -> dict: + """Build and return the graph serialization header. + + The header contains metadata about the graph, such as the: + - version of the software used to create it. + - version of the file format. + - version of the nodes types used in the graph. + - template flag. + + Args: + nodes: (optional) The list of nodes to consider for node types versions - use all nodes if not specified. + template: Whether the graph is going to be serialized as a template. + """ + header: dict[str, Any] = {} + header[GraphIO.Keys.ReleaseVersion] = meshroom.__version__ + header[GraphIO.Keys.FileVersion] = GraphIO.__version__ + header[GraphIO.Keys.NodesVersions] = self._getNodeTypesVersions() + return header + + def _getNodeTypesVersions(self) -> dict[str, str]: + """Get registered versions of each node types in `nodes`, excluding CompatibilityNode instances.""" + nodeTypes = set([node.nodeDesc.__class__ for node in self.nodes if isinstance(node, Node)]) + nodeTypesVersions = { + nodeType.__name__: meshroom.core.nodeVersion(nodeType, "0.0") for nodeType in nodeTypes + } + # Sort them by name (to avoid random order changing from one save to another). + return dict(sorted(nodeTypesVersions.items())) + + def serializeContent(self) -> dict: + """Graph content serialization logic.""" + return {node.name: self.serializeNode(node) for node in sorted(self.nodes, key=lambda n: n.name)} + + def serializeNode(self, node: Node) -> dict: + """Node serialization logic.""" + return node.toDict() + + +class TemplateGraphSerializer(GraphSerializer): + """Serializer for serializing a graph as a template.""" + + def serializeHeader(self) -> dict: + header = super().serializeHeader() + header[GraphIO.Keys.Template] = True + return header + + def serializeNode(self, node: Node) -> dict: + """Adapt node serialization to template graphs. + + Instead of getting all the inputs and internal attribute keys, only get the keys of + the attributes whose value is not the default one. + The output attributes, UIDs, parallelization parameters and internal folder are + not relevant for templates, so they are explicitly removed from the returned dictionary. + """ + # For now, implemented as a post-process to update the default serialization. + nodeData = super().serializeNode(node) + + inputKeys = list(nodeData["inputs"].keys()) + + internalInputKeys = [] + internalInputs = nodeData.get("internalInputs", None) + if internalInputs: + internalInputKeys = list(internalInputs.keys()) + + for attrName in inputKeys: + attribute = node.attribute(attrName) + # check that attribute is not a link for choice attributes + if attribute.isDefault and not attribute.isLink: + del nodeData["inputs"][attrName] + + for attrName in internalInputKeys: + attribute = node.internalAttribute(attrName) + # check that internal attribute is not a link for choice attributes + if attribute.isDefault and not attribute.isLink: + del nodeData["internalInputs"][attrName] + + # If all the internal attributes are set to their default values, remove the entry + if len(nodeData["internalInputs"]) == 0: + del nodeData["internalInputs"] + + del nodeData["outputs"] + del nodeData["uid"] + del nodeData["internalFolder"] + del nodeData["parallelization"] + + return nodeData + + +class PartialGraphSerializer(GraphSerializer): + """Serializer to serialize a partial graph (a subset of nodes).""" + + def __init__(self, graph: "Graph", nodes: list[Node]): + super().__init__(graph) + self._nodes = nodes + + @property + def nodes(self) -> list[Node]: + """Override to consider only the subset of nodes.""" + return self._nodes + + def serializeNode(self, node: Node) -> dict: + """Adapt node serialization to partial graph serialization.""" + # NOTE: For now, implemented as a post-process to the default serialization. + nodeData = super().serializeNode(node) + + # Override input attributes with custom serialization logic, to handle attributes + # connected to nodes that are not in the list of nodes to serialize. + for attributeName in nodeData["inputs"]: + nodeData["inputs"][attributeName] = self._serializeAttribute(node.attribute(attributeName)) + + # Clear UID for non-compatibility nodes, as the custom attribute serialization + # can be impacting the UID by removing connections to missing nodes. + if not node.isCompatibilityNode: + del nodeData["uid"] + + return nodeData + + def _serializeAttribute(self, attribute: Attribute) -> Any: + """ + Serialize `attribute` (recursively for list/groups) and deal with attributes being connected + to nodes that are not part of the partial list of nodes to serialize. + """ + # If the attribute is connected to a node that is not in the list of nodes to serialize, + # the link expression should not be serialized. + if attribute.isLink and attribute.getLinkParam().node not in self.nodes: + # If part of a list, this entry can be discarded. + if isinstance(attribute.root, ListAttribute): + return None + # Otherwise, return the default value for this attribute. + return attribute.defaultValue() + + if isinstance(attribute, ListAttribute): + # Recusively serialize each child of the ListAttribute, skipping those for which the attribute + # serialization logic above returns None. + return [ + exportValue + for child in attribute + if (exportValue := self._serializeAttribute(child)) is not None + ] + + if isinstance(attribute, GroupAttribute): + # Recursively serialize each child of the group attribute. + return {name: self._serializeAttribute(child) for name, child in attribute.value.items()} + + return attribute.getExportValue() + + diff --git a/meshroom/core/node.py b/meshroom/core/node.py index 1b8806e2e4..6f42272481 100644 --- a/meshroom/core/node.py +++ b/meshroom/core/node.py @@ -1668,7 +1668,17 @@ def attributeDescFromValue(attrName, value, isOutput): elif isinstance(value, float): return desc.FloatParam(range=None, **params) elif isinstance(value, str): - if isOutput or os.path.isabs(value) or Attribute.isLinkExpression(value): + if isOutput or os.path.isabs(value): + return desc.File(**params) + elif Attribute.isLinkExpression(value): + # Do not consider link expression as a valid default desc value. + # When the link expression is applied and transformed to an actual link, + # the systems resets the value using `Attribute.resetToDefaultValue` to indicate + # that this link expression has been handled. + # If the link expression is stored as the default value, it will never be cleared, + # leading to unexpected behavior where the link expression on a CompatibilityNode + # could be evaluated several times and/or incorrectly. + params["value"] = "" return desc.File(**params) else: return desc.StringParam(**params) @@ -1851,113 +1861,3 @@ def upgrade(self): canUpgrade = Property(bool, canUpgrade.fget, constant=True) issueDetails = Property(str, issueDetails.fget, constant=True) - -def nodeFactory(nodeDict, name=None, template=False, uidConflict=False): - """ - Create a node instance by deserializing the given node data. - If the serialized data matches the corresponding node type description, a Node instance is created. - If any compatibility issue occurs, a NodeCompatibility instance is created instead. - - Args: - nodeDict (dict): the serialization of the node - name (str): (optional) the node's name - template (bool): (optional) true if the node is part of a template, false otherwise - uidConflict (bool): (optional) true if a UID conflict has been detected externally on that node - - Returns: - BaseNode: the created node - """ - nodeType = nodeDict["nodeType"] - - # Retro-compatibility: inputs were previously saved as "attributes" - if "inputs" not in nodeDict and "attributes" in nodeDict: - nodeDict["inputs"] = nodeDict["attributes"] - del nodeDict["attributes"] - - # Get node inputs/outputs - inputs = nodeDict.get("inputs", {}) - internalInputs = nodeDict.get("internalInputs", {}) - outputs = nodeDict.get("outputs", {}) - version = nodeDict.get("version", None) - internalFolder = nodeDict.get("internalFolder", None) - position = Position(*nodeDict.get("position", [])) - uid = nodeDict.get("uid", None) - - compatibilityIssue = None - - nodeDesc = None - try: - nodeDesc = meshroom.core.nodesDesc[nodeType] - except KeyError: - # Unknown node type - compatibilityIssue = CompatibilityIssue.UnknownNodeType - - # Unknown node type should take precedence over UID conflict, as it cannot be resolved - if uidConflict and nodeDesc: - compatibilityIssue = CompatibilityIssue.UidConflict - - if nodeDesc and not uidConflict: # if uidConflict, there is no need to look for another compatibility issue - # Compare serialized node version with current node version - currentNodeVersion = meshroom.core.nodeVersion(nodeDesc) - # If both versions are available, check for incompatibility in major version - if version and currentNodeVersion and Version(version).major != Version(currentNodeVersion).major: - compatibilityIssue = CompatibilityIssue.VersionConflict - # In other cases, check attributes compatibility between serialized node and its description - else: - # Check that the node has the exact same set of inputs/outputs as its description, except - # if the node is described in a template file, in which only non-default parameters are saved; - # do not perform that check for internal attributes because there is no point in - # raising compatibility issues if their number differs: in that case, it is only useful - # if some internal attributes do not exist or are invalid - if not template and (sorted([attr.name for attr in nodeDesc.inputs - if not isinstance(attr, desc.PushButtonParam)]) != sorted(inputs.keys()) or - sorted([attr.name for attr in nodeDesc.outputs if not attr.isDynamicValue]) != - sorted(outputs.keys())): - compatibilityIssue = CompatibilityIssue.DescriptionConflict - - # Check whether there are any internal attributes that are invalidating in the node description: if there - # are, then check that these internal attributes are part of nodeDict; if they are not, a compatibility - # issue must be raised to warn the user, as this will automatically change the node's UID - if not template: - invalidatingIntInputs = [] - for attr in nodeDesc.internalInputs: - if attr.invalidate: - invalidatingIntInputs.append(attr.name) - for attr in invalidatingIntInputs: - if attr not in internalInputs.keys(): - compatibilityIssue = CompatibilityIssue.DescriptionConflict - break - - # Verify that all inputs match their descriptions - for attrName, value in inputs.items(): - if not CompatibilityNode.attributeDescFromName(nodeDesc.inputs, attrName, value): - compatibilityIssue = CompatibilityIssue.DescriptionConflict - break - # Verify that all internal inputs match their description - for attrName, value in internalInputs.items(): - if not CompatibilityNode.attributeDescFromName(nodeDesc.internalInputs, attrName, value): - compatibilityIssue = CompatibilityIssue.DescriptionConflict - break - # Verify that all outputs match their descriptions - for attrName, value in outputs.items(): - if not CompatibilityNode.attributeDescFromName(nodeDesc.outputs, attrName, value): - compatibilityIssue = CompatibilityIssue.DescriptionConflict - break - - if compatibilityIssue is None: - node = Node(nodeType, position, uid=uid, **inputs, **internalInputs, **outputs) - else: - logging.debug("Compatibility issue detected for node '{}': {}".format(name, compatibilityIssue.name)) - node = CompatibilityNode(nodeType, nodeDict, position, compatibilityIssue) - # Retro-compatibility: no internal folder saved - # can't spawn meaningful CompatibilityNode with precomputed outputs - # => automatically try to perform node upgrade - if not internalFolder and nodeDesc: - logging.warning("No serialized output data: performing automatic upgrade on '{}'".format(name)) - node = node.upgrade() - # If the node comes from a template file and there is a conflict, it should be upgraded anyway unless it is - # an "unknown node type" conflict (in which case the upgrade would fail) - elif template and compatibilityIssue is not CompatibilityIssue.UnknownNodeType: - node = node.upgrade() - - return node diff --git a/meshroom/core/nodeFactory.py b/meshroom/core/nodeFactory.py new file mode 100644 index 0000000000..f030b9c5b2 --- /dev/null +++ b/meshroom/core/nodeFactory.py @@ -0,0 +1,197 @@ +import logging +from typing import Any, Iterable, Optional, Union + +import meshroom.core +from meshroom.core import Version, desc +from meshroom.core.node import CompatibilityIssue, CompatibilityNode, Node, Position + + +def nodeFactory( + nodeData: dict, + name: Optional[str] = None, + inTemplate: bool = False, + expectedUid: Optional[str] = None, +) -> Union[Node, CompatibilityNode]: + """ + Create a node instance by deserializing the given node data. + If the serialized data matches the corresponding node type description, a Node instance is created. + If any compatibility issue occurs, a NodeCompatibility instance is created instead. + + Args: + nodeDict: The serialized Node data. + name: (optional) The node's name. + inTemplate: (optional) True if the node is created as part of a graph template. + expectedUid: (optional) The expected UID of the node within the context of a Graph. + + Returns: + The created Node instance. + """ + return _NodeCreator(nodeData, name, inTemplate, expectedUid).create() + + +class _NodeCreator: + + def __init__( + self, + nodeData: dict, + name: Optional[str] = None, + inTemplate: bool = False, + expectedUid: Optional[str] = None, + ): + self.nodeData = nodeData + self.name = name + self.inTemplate = inTemplate + self.expectedUid = expectedUid + + self._normalizeNodeData() + + self.nodeType = self.nodeData["nodeType"] + self.inputs = self.nodeData.get("inputs", {}) + self.internalInputs = self.nodeData.get("internalInputs", {}) + self.outputs = self.nodeData.get("outputs", {}) + self.version = self.nodeData.get("version", None) + self.internalFolder = self.nodeData.get("internalFolder") + self.position = Position(*self.nodeData.get("position", [])) + self.uid = self.nodeData.get("uid", None) + self.nodeDesc = meshroom.core.nodesDesc.get(self.nodeType, None) + + def create(self) -> Union[Node, CompatibilityNode]: + compatibilityIssue = self._checkCompatibilityIssues() + if compatibilityIssue: + node = self._createCompatibilityNode(compatibilityIssue) + node = self._tryUpgradeCompatibilityNode(node) + else: + node = self._createNode() + return node + + def _normalizeNodeData(self): + """Consistency fixes for backward compatibility with older serialized data.""" + # Inputs were previously saved as "attributes". + if "inputs" not in self.nodeData and "attributes" in self.nodeData: + self.nodeData["inputs"] = self.nodeData["attributes"] + del self.nodeData["attributes"] + + def _checkCompatibilityIssues(self) -> Optional[CompatibilityIssue]: + if self.nodeDesc is None: + return CompatibilityIssue.UnknownNodeType + + if not self._checkUidCompatibility(): + return CompatibilityIssue.UidConflict + + if not self._checkVersionCompatibility(): + return CompatibilityIssue.VersionConflict + + if not self._checkDescriptionCompatibility(): + return CompatibilityIssue.DescriptionConflict + + return None + + def _checkUidCompatibility(self) -> bool: + return self.expectedUid is None or self.expectedUid == self.uid + + def _checkVersionCompatibility(self) -> bool: + # Special case: a node with a version set to None indicates + # that it has been created from the current version of the node type. + nodeCreatedFromCurrentVersion = self.version is None + if nodeCreatedFromCurrentVersion: + return True + nodeTypeCurrentVersion = meshroom.core.nodeVersion(self.nodeDesc, "0.0") + return Version(self.version).major == Version(nodeTypeCurrentVersion).major + + def _checkDescriptionCompatibility(self) -> bool: + # Only perform strict attribute name matching for non-template graphs, + # since only non-default-value input attributes are serialized in templates. + if not self.inTemplate: + if not self._checkAttributesNamesMatchDescription(): + return False + + return self._checkAttributesAreCompatibleWithDescription() + + def _checkAttributesNamesMatchDescription(self) -> bool: + return ( + self._checkInputAttributesNames() + and self._checkOutputAttributesNames() + and self._checkInternalAttributesNames() + ) + + def _checkAttributesAreCompatibleWithDescription(self) -> bool: + return ( + self._checkAttributesCompatibility(self.nodeDesc.inputs, self.inputs) + and self._checkAttributesCompatibility(self.nodeDesc.internalInputs, self.internalInputs) + and self._checkAttributesCompatibility(self.nodeDesc.outputs, self.outputs) + ) + + def _checkInputAttributesNames(self) -> bool: + def serializedInput(attr: desc.Attribute) -> bool: + """Filter that excludes not-serialized desc input attributes.""" + if isinstance(attr, desc.PushButtonParam): + # PushButtonParam are not serialized has they do not hold a value. + return False + return True + + refAttributes = filter(serializedInput, self.nodeDesc.inputs) + return self._checkAttributesNamesStrictlyMatch(refAttributes, self.inputs) + + def _checkOutputAttributesNames(self) -> bool: + def serializedOutput(attr: desc.Attribute) -> bool: + """Filter that excludes not-serialized desc output attributes.""" + if attr.isDynamicValue: + # Dynamic outputs values are not serialized with the node, + # as their value is written in the computed output data. + return False + return True + + refAttributes = filter(serializedOutput, self.nodeDesc.outputs) + return self._checkAttributesNamesStrictlyMatch(refAttributes, self.outputs) + + def _checkInternalAttributesNames(self) -> bool: + invalidatingDescAttributes = [attr.name for attr in self.nodeDesc.internalInputs if attr.invalidate] + return all(attr in self.internalInputs.keys() for attr in invalidatingDescAttributes) + + def _checkAttributesNamesStrictlyMatch( + self, descAttributes: Iterable[desc.Attribute], attributesDict: dict[str, Any] + ) -> bool: + refNames = sorted([attr.name for attr in descAttributes]) + attrNames = sorted(attributesDict.keys()) + return refNames == attrNames + + def _checkAttributesCompatibility( + self, descAttributes: list[desc.Attribute], attributesDict: dict[str, Any] + ) -> bool: + return all( + CompatibilityNode.attributeDescFromName(descAttributes, attrName, value) is not None + for attrName, value in attributesDict.items() + ) + + def _createNode(self) -> Node: + logging.info(f"Creating node '{self.name}'") + return Node( + self.nodeType, + position=self.position, + uid=self.uid, + **self.inputs, + **self.internalInputs, + **self.outputs, + ) + + def _createCompatibilityNode(self, compatibilityIssue) -> CompatibilityNode: + logging.warning(f"Compatibility issue detected for node '{self.name}': {compatibilityIssue.name}") + return CompatibilityNode( + self.nodeType, self.nodeData, position=self.position, issue=compatibilityIssue + ) + + def _tryUpgradeCompatibilityNode(self, node: CompatibilityNode) -> Union[Node, CompatibilityNode]: + """Handle possible upgrades of CompatibilityNodes, when no computed data is associated to the Node.""" + if node.issue == CompatibilityIssue.UnknownNodeType: + return node + + # Nodes in templates are not meant to hold computation data. + if self.inTemplate: + logging.warning(f"Compatibility issue in template: performing automatic upgrade on '{self.name}'") + return node.upgrade() + + # Backward compatibility: "internalFolder" was not serialized. + if not self.internalFolder: + logging.warning(f"No serialized output data: performing automatic upgrade on '{self.name}'") + + return node diff --git a/meshroom/core/typing.py b/meshroom/core/typing.py new file mode 100644 index 0000000000..f526fb3e3d --- /dev/null +++ b/meshroom/core/typing.py @@ -0,0 +1,8 @@ +""" +Common typing aliases used in Meshroom. +""" + +from pathlib import Path +from typing import Union + +PathLike = Union[Path, str] diff --git a/meshroom/ui/commands.py b/meshroom/ui/commands.py index 47be0a3385..d1e8d8bf3e 100755 --- a/meshroom/ui/commands.py +++ b/meshroom/ui/commands.py @@ -6,8 +6,10 @@ from PySide6.QtCore import Property, Signal from meshroom.core.attribute import ListAttribute, Attribute -from meshroom.core.graph import GraphModification -from meshroom.core.node import nodeFactory, Position +from meshroom.core.graph import Graph, GraphModification +from meshroom.core.node import Position, CompatibilityIssue +from meshroom.core.nodeFactory import nodeFactory +from meshroom.core.typing import PathLike class UndoCommand(QUndoCommand): @@ -168,19 +170,7 @@ def undoImpl(self): node = nodeFactory(self.nodeDict, self.nodeName) self.graph.addNode(node, self.nodeName) assert (node.getName() == self.nodeName) - # recreate out edges deleted on node removal - for dstAttr, srcAttr in self.outEdges.items(): - # if edges were connected to ListAttributes, recreate their corresponding entry in said ListAttribute - # 0 = attribute name, 1 = attribute index, 2 = attribute value - if dstAttr in self.outListAttributes.keys(): - listAttr = self.graph.attribute(self.outListAttributes[dstAttr][0]) - if isinstance(self.outListAttributes[dstAttr][2], list): - listAttr[self.outListAttributes[dstAttr][1]:self.outListAttributes[dstAttr][1]] = self.outListAttributes[dstAttr][2] - else: - listAttr.insert(self.outListAttributes[dstAttr][1], self.outListAttributes[dstAttr][2]) - - self.graph.addEdge(self.graph.attribute(srcAttr), - self.graph.attribute(dstAttr)) + self.graph._restoreOutEdges(self.outEdges, self.outListAttributes) class DuplicateNodesCommand(GraphCommand): @@ -209,15 +199,27 @@ class PasteNodesCommand(GraphCommand): """ Handle node pasting in a Graph. """ - def __init__(self, graph, data, position=None, parent=None): + def __init__(self, graph: "Graph", data: dict, position: Position, parent=None): super(PasteNodesCommand, self).__init__(graph, parent) self.data = data self.position = position - self.nodeNames = [] + self.nodeNames: list[str] = [] def redoImpl(self): - data = self.graph.updateImportedProject(self.data) - nodes = self.graph.pasteNodes(data, self.position) + graph = Graph("") + try: + graph._deserialize(self.data) + except: + return False + + boundingBoxCenter = self._boundingBoxCenter(graph.nodes) + offset = Position(self.position.x - boundingBoxCenter.x, self.position.y - boundingBoxCenter.y) + + for node in graph.nodes: + node.position = Position(node.position.x + offset.x, node.position.y + offset.y) + + nodes = self.graph.importGraphContent(graph) + self.nodeNames = [node.name for node in nodes] self.setText("Paste Node{} ({})".format("s" if len(self.nodeNames) > 1 else "", ", ".join(self.nodeNames))) return nodes @@ -226,12 +228,31 @@ def undoImpl(self): for name in self.nodeNames: self.graph.removeNode(name) + def _boundingBox(self, nodes) -> tuple[int, int, int, int]: + if not nodes: + return (0, 0, 0 , 0) + + minX = maxX = nodes[0].x + minY = maxY = nodes[0].y + + for node in nodes[1:]: + minX = min(minX, node.x) + minY = min(minY, node.y) + maxX = max(maxX, node.x) + maxY = max(maxY, node.y) + + return (minX, minY, maxX, maxY) + + def _boundingBoxCenter(self, nodes): + minX, minY, maxX, maxY = self._boundingBox(nodes) + return Position((minX + maxX) / 2, (minY + maxY) / 2) class ImportProjectCommand(GraphCommand): """ Handle the import of a project into a Graph. """ - def __init__(self, graph, filepath=None, position=None, yOffset=0, parent=None): + + def __init__(self, graph: Graph, filepath: PathLike, position=None, yOffset=0, parent=None): super(ImportProjectCommand, self).__init__(graph, parent) self.filepath = filepath self.importedNames = [] @@ -239,9 +260,8 @@ def __init__(self, graph, filepath=None, position=None, yOffset=0, parent=None): self.yOffset = yOffset def redoImpl(self): - status = self.graph.load(self.filepath, setupProjectFile=False, importProject=True) - importedNodes = self.graph.importedNodes - self.setText("Import Project ({} nodes)".format(importedNodes.count)) + importedNodes = self.graph.importGraphContentFromFile(self.filepath) + self.setText(f"Import Project ({len(importedNodes)} nodes)") lowestY = 0 for node in self.graph.nodes: @@ -419,37 +439,24 @@ def __init__(self, graph, node, parent=None): super(UpgradeNodeCommand, self).__init__(graph, parent) self.nodeDict = node.toDict() self.nodeName = node.getName() - self.outEdges = {} - self.outListAttributes = {} + self.compatibilityIssue = None self.setText("Upgrade Node {}".format(self.nodeName)) def redoImpl(self): - if not self.graph.node(self.nodeName).canUpgrade: + if not (node := self.graph.node(self.nodeName)).canUpgrade: return False - upgradedNode, _, self.outEdges, self.outListAttributes = self.graph.upgradeNode(self.nodeName) - return upgradedNode + self.compatibilityIssue = node.issue + return self.graph.upgradeNode(self.nodeName) def undoImpl(self): - # delete upgraded node - self.graph.removeNode(self.nodeName) + expectedUid = None + if self.compatibilityIssue == CompatibilityIssue.UidConflict: + expectedUid = self.graph.node(self.nodeName)._uid + # recreate compatibility node with GraphModification(self.graph): - # We come back from an upgrade, so we enforce uidConflict=True as there was a uid conflict before - node = nodeFactory(self.nodeDict, name=self.nodeName, uidConflict=True) - self.graph.addNode(node, self.nodeName) - # recreate out edges - for dstAttr, srcAttr in self.outEdges.items(): - # if edges were connected to ListAttributes, recreate their corresponding entry in said ListAttribute - # 0 = attribute name, 1 = attribute index, 2 = attribute value - if dstAttr in self.outListAttributes.keys(): - listAttr = self.graph.attribute(self.outListAttributes[dstAttr][0]) - if isinstance(self.outListAttributes[dstAttr][2], list): - listAttr[self.outListAttributes[dstAttr][1]:self.outListAttributes[dstAttr][1]] = self.outListAttributes[dstAttr][2] - else: - listAttr.insert(self.outListAttributes[dstAttr][1], self.outListAttributes[dstAttr][2]) - - self.graph.addEdge(self.graph.attribute(srcAttr), - self.graph.attribute(dstAttr)) + node = nodeFactory(self.nodeDict, name=self.nodeName, expectedUid=expectedUid) + self.graph.replaceNode(self.nodeName, node) class EnableGraphUpdateCommand(GraphCommand): diff --git a/meshroom/ui/graph.py b/meshroom/ui/graph.py index 0cde6b2ddd..c7dabc14bc 100644 --- a/meshroom/ui/graph.py +++ b/meshroom/ui/graph.py @@ -25,6 +25,7 @@ from meshroom.common.qt import QObjectListModel from meshroom.core.attribute import Attribute, ListAttribute from meshroom.core.graph import Graph, Edge +from meshroom.core.graphIO import GraphIO from meshroom.core.taskManager import TaskManager @@ -396,7 +397,7 @@ def setGraph(self, g): self.updateChunks() # perform auto-layout if graph does not provide nodes positions - if Graph.IO.Features.NodesPositions not in self._graph.fileFeatures: + if GraphIO.Features.NodesPositions not in self._graph.fileFeatures: self._layout.reset() # clear undo-stack after layout self._undoStack.clear() @@ -451,17 +452,21 @@ def stopChildThreads(self): self.stopExecution() self._chunksMonitor.stop() - @Slot(str, result=bool) - def loadGraph(self, filepath, setupProjectFile=True, publishOutputs=False): - g = Graph('') - status = True + @Slot(str) + def loadGraph(self, filepath): + g = Graph("") if filepath: - status = g.load(filepath, setupProjectFile, importProject=False, publishOutputs=publishOutputs) + g.load(filepath) if not os.path.exists(g.cacheDir): os.mkdir(g.cacheDir) - g.fileDateVersion = os.path.getmtime(filepath) self.setGraph(g) - return status + + @Slot(str, bool, result=bool) + def initFromTemplate(self, filepath, publishOutputs=False): + graph = Graph("") + if filepath: + graph.initFromTemplate(filepath, publishOutputs=publishOutputs) + self.setGraph(graph) @Slot(QUrl, result="QVariantList") @Slot(QUrl, QPoint, result="QVariantList") @@ -1045,126 +1050,43 @@ def getSelectedNodesContent(self) -> str: """ if not self._nodeSelection.hasSelection(): return "" - serializedSelection = {node.name: node.toDict() for node in self.iterSelectedNodes()} - return json.dumps(serializedSelection, indent=4) + graphData = self._graph.serializePartial(self.getSelectedNodes()) + return json.dumps(graphData, indent=4) - @Slot(str, QPoint, bool, result=list) - def pasteNodes(self, clipboardContent, position=None, centerPosition=False) -> list[Node]: + @Slot(str, QPoint, result=list) + def pasteNodes(self, serializedData: str, position: Optional[QPoint]=None) -> list[Node]: """ - Parse the content of the clipboard to see whether it contains - valid node descriptions. If that is the case, the nodes described - in the clipboard are built with the available information. - Otherwise, nothing is done. + Import string-serialized graph content `serializedData` in the current graph, optionally at the given + `position`. + If the `serializedData` does not contain valid serialized graph data, nothing is done. - This function does not need to be preceded by a call to "getSelectedNodesContent". - Any clipboard content that contains at least a node type with a valid JSON - formatting (dictionary form with double quotes around the keys and values) - can be used to generate a node. + This method can be used with the result of "getSelectedNodesContent". + But it also accepts any serialized content that matches the graph data or graph content format. For example, it is enough to have: {"nodeName_1": {"nodeType":"CameraInit"}, "nodeName_2": {"nodeType":"FeatureMatching"}} - in the clipboard to create a default CameraInit and a default FeatureMatching nodes. + in `serializedData` to create a default CameraInit and a default FeatureMatching nodes. Args: - clipboardContent (str): the string contained in the clipboard, that may or may not contain valid - node information - position (QPoint): the position of the mouse in the Graph Editor when the function was called - centerPosition (bool): whether the provided position is not the top-left corner of the pasting - zone, but its center + serializedData: The string-serialized graph data. + position: The position where to paste the nodes. If None, the nodes are pasted at (0, 0). Returns: list: the list of Node objects that were pasted and added to the graph """ - if not clipboardContent: - return - try: - d = json.loads(clipboardContent) - except ValueError as e: - raise ValueError(e) - - if not isinstance(d, dict): - raise ValueError("The clipboard does not contain a valid node. Cannot paste it.") - - # If the clipboard contains a header, then a whole file is contained in the clipboard - # Extract the "graph" part and paste it all, ignore the rest - if d.get("header", None): - d = d.get("graph", None) - if not d: - return - - if isinstance(position, QPoint): - position = Position(position.x(), position.y()) - if self.hoveredNode: - # If a node is hovered, add an offset to prevent complete occlusion - position = Position(position.x + self.layout.gridSpacing, position.y + self.layout.gridSpacing) - - # Get the position of the first node in a zone whose top-left corner is the mouse and the bottom-right - # corner the (x, y) coordinates, with x the maximum of all the nodes' position along the x-axis, and y the - # maximum of all the nodes' position along the y-axis. All nodes with a position will be placed relatively - # to the first node within that zone. - firstNodePos = None - minX = 0 - maxX = 0 - minY = 0 - maxY = 0 - for key in sorted(d): - nodeType = d[key].get("nodeType", None) - if not nodeType: - raise ValueError("Invalid node description: no provided node type for '{}'".format(key)) - - pos = d[key].get("position", None) - if pos: - if not firstNodePos: - firstNodePos = pos - minX = pos[0] - maxX = pos[0] - minY = pos[1] - maxY = pos[1] - else: - if minX > pos[0]: - minX = pos[0] - if maxX < pos[0]: - maxX = pos[0] - if minY > pos[1]: - minY = pos[1] - if maxY < pos[1]: - maxY = pos[1] - - # Ensure there will not be an error if no node has a specified position - if not firstNodePos: - firstNodePos = [0, 0] - - # Position of the first node within the zone - position = Position(position.x + firstNodePos[0] - minX, position.y + firstNodePos[1] - minY) - - if centerPosition: # Center the zone around the mouse's position (mouse's position might be artificial) - maxX = maxX + self.layout.nodeWidth # maxX and maxY are the position of the furthest node's top-left corner - maxY = maxY + self.layout.nodeHeight # We want the position of the furthest node's bottom-right corner - position = Position(position.x - ((maxX - minX) / 2), position.y - ((maxY - minY) / 2)) - - finalPosition = None - prevPosition = None - positions = [] - - for key in sorted(d): - currentPosition = d[key].get("position", None) - if not finalPosition: - finalPosition = position - else: - if prevPosition and currentPosition: - # If the nodes both have a position, recreate the distance between them with a different - # starting point - x = finalPosition.x + (currentPosition[0] - prevPosition[0]) - y = finalPosition.y + (currentPosition[1] - prevPosition[1]) - finalPosition = Position(x, y) - else: - # If either the current node or previous one lacks a position, use a custom one - finalPosition = Position(finalPosition.x + self.layout.gridSpacing + self.layout.nodeWidth, finalPosition.y) - prevPosition = currentPosition - positions.append(finalPosition) + graphData = json.loads(serializedData) + except json.JSONDecodeError: + logging.warning("Content is not a valid JSON string.") + return [] + + pos = Position(position.x(), position.y()) if position else Position(0, 0) + result = self.push(commands.PasteNodesCommand(self._graph, graphData, pos)) + if result is False: + logging.warning("Content is not a valid graph data.") + return [] + return result - return self.push(commands.PasteNodesCommand(self.graph, d, position=positions)) undoStack = Property(QObject, lambda self: self._undoStack, constant=True) graphChanged = Signal() diff --git a/meshroom/ui/qml/Application.qml b/meshroom/ui/qml/Application.qml index 48884e2f33..fb5fc67c67 100644 --- a/meshroom/ui/qml/Application.qml +++ b/meshroom/ui/qml/Application.qml @@ -141,7 +141,7 @@ Page { nameFilters: ["Meshroom Graphs (*.mg)"] onAccepted: { // Open the template as a regular file - if (_reconstruction.loadUrl(currentFile, true, true)) { + if (_reconstruction.load(currentFile)) { MeshroomApp.addRecentProjectFile(currentFile.toString()) } } @@ -356,7 +356,7 @@ Page { text: "Reload File" onClicked: { - _reconstruction.loadUrl(_reconstruction.graph.filepath) + _reconstruction.load(_reconstruction.graph.filepath) fileModifiedDialog.close() } } @@ -661,7 +661,7 @@ Page { MenuItem { onTriggered: ensureSaved(function() { openRecentMenu.dismiss() - if (_reconstruction.loadUrl(modelData["path"])) { + if (_reconstruction.load(modelData["path"])) { MeshroomApp.addRecentProjectFile(modelData["path"]) } else { MeshroomApp.removeRecentProjectFile(modelData["path"]) diff --git a/meshroom/ui/qml/GraphEditor/GraphEditor.qml b/meshroom/ui/qml/GraphEditor/GraphEditor.qml index c74acbc7d1..1a7813ac5b 100755 --- a/meshroom/ui/qml/GraphEditor/GraphEditor.qml +++ b/meshroom/ui/qml/GraphEditor/GraphEditor.qml @@ -82,25 +82,18 @@ Item { /// Paste content of clipboard to graph editor and create new node if valid function pasteNodes() { - var finalPosition = undefined - var centerPosition = false + let finalPosition = undefined; if (mouseArea.containsMouse) { - if (uigraph.hoveredNode !== null) { - var node = nodeDelegate(uigraph.hoveredNode) - finalPosition = Qt.point(node.mousePosition.x + node.x, node.mousePosition.y + node.y) - } else { - finalPosition = mapToItem(draggable, mouseArea.mouseX, mouseArea.mouseY) - } + finalPosition = mapToItem(draggable, mouseArea.mouseX, mouseArea.mouseY); } else { - finalPosition = getCenterPosition() - centerPosition = true + finalPosition = getCenterPosition(); } - var copiedContent = Clipboard.getText() - var nodes = uigraph.pasteNodes(copiedContent, finalPosition, centerPosition) + const copiedContent = Clipboard.getText(); + const nodes = uigraph.pasteNodes(copiedContent, finalPosition); if (nodes.length > 0) { - uigraph.selectedNode = nodes[0] - uigraph.selectNodes(nodes) + uigraph.selectedNode = nodes[0]; + uigraph.selectNodes(nodes); } } diff --git a/meshroom/ui/qml/Homepage.qml b/meshroom/ui/qml/Homepage.qml index ef27a22dac..fe7f9ff4a8 100644 --- a/meshroom/ui/qml/Homepage.qml +++ b/meshroom/ui/qml/Homepage.qml @@ -384,7 +384,7 @@ Page { } else { // Open project mainStack.push("Application.qml") - if (_reconstruction.loadUrl(modelData["path"])) { + if (_reconstruction.load(modelData["path"])) { MeshroomApp.addRecentProjectFile(modelData["path"]) } else { MeshroomApp.removeRecentProjectFile(modelData["path"]) diff --git a/meshroom/ui/qml/main.qml b/meshroom/ui/qml/main.qml index 16940a74c5..20c2f81fa1 100644 --- a/meshroom/ui/qml/main.qml +++ b/meshroom/ui/qml/main.qml @@ -128,7 +128,7 @@ ApplicationWindow { if (mainStack.currentItem instanceof Homepage) { mainStack.push("Application.qml") } - if (_reconstruction.loadUrl(currentFile)) { + if (_reconstruction.load(currentFile)) { MeshroomApp.addRecentProjectFile(currentFile.toString()) } } diff --git a/meshroom/ui/reconstruction.py b/meshroom/ui/reconstruction.py index c774527f34..94d926a0c1 100755 --- a/meshroom/ui/reconstruction.py +++ b/meshroom/ui/reconstruction.py @@ -5,6 +5,7 @@ from collections.abc import Iterable from multiprocessing.pool import ThreadPool from threading import Thread +from typing import Callable from PySide6.QtCore import QObject, Slot, Property, Signal, QUrl, QSizeF, QPoint from PySide6.QtGui import QMatrix4x4, QMatrix3x3, QQuaternion, QVector3D, QVector2D @@ -534,17 +535,24 @@ def new(self, pipeline=None): # - correct pipeline name but the case does not match (e.g. panoramaHDR instead of panoramaHdr) # - lowercase pipeline name given through the "New Pipeline" menu loweredPipelineTemplates = dict((k.lower(), v) for k, v in meshroom.core.pipelineTemplates.items()) - if p.lower() in loweredPipelineTemplates: - self.load(loweredPipelineTemplates[p.lower()], setupProjectFile=False) - else: - # use the user-provided default project file - self.load(p, setupProjectFile=False) + filepath = loweredPipelineTemplates.get(p.lower(), p) + return self._loadWithErrorReport(self.initFromTemplate, filepath) @Slot(str, result=bool) - def load(self, filepath, setupProjectFile=True, publishOutputs=False): + @Slot(QUrl, result=bool) + def load(self, url): + if isinstance(url, QUrl): + # depending how the QUrl has been initialized, + # toLocalFile() may return the local path or an empty string + localFile = url.toLocalFile() or url.toString() + else: + localFile = url + return self._loadWithErrorReport(self.loadGraph, localFile) + + def _loadWithErrorReport(self, loadFunction: Callable[[str], None], filepath: str): logging.info(f"Load project file: '{filepath}'") try: - status = super(Reconstruction, self).loadGraph(filepath, setupProjectFile, publishOutputs) + loadFunction(filepath) # warn about pre-release projects being automatically upgraded if Version(self._graph.fileReleaseVersion).major == "0": self.warning.emit(Message( @@ -554,8 +562,8 @@ def load(self, filepath, setupProjectFile=True, publishOutputs=False): "Open it with the corresponding version of Meshroom to recover your data." )) self.setActive(True) - return status - except FileNotFoundError as e: + return True + except FileNotFoundError: self.error.emit( Message( "No Such File", @@ -564,8 +572,7 @@ def load(self, filepath, setupProjectFile=True, publishOutputs=False): ) ) logging.error("Error while loading '{}': No Such File.".format(filepath)) - return False - except Exception as e: + except Exception: import traceback trace = traceback.format_exc() self.error.emit( @@ -577,20 +584,8 @@ def load(self, filepath, setupProjectFile=True, publishOutputs=False): ) logging.error("Error while loading '{}'.".format(filepath)) logging.error(trace) - return False - @Slot(QUrl, result=bool) - @Slot(QUrl, bool, bool, result=bool) - def loadUrl(self, url, setupProjectFile=True, publishOutputs=False): - if isinstance(url, (QUrl)): - # depending how the QUrl has been initialized, - # toLocalFile() may return the local path or an empty string - localFile = url.toLocalFile() - if not localFile: - localFile = url.toString() - else: - localFile = url - return self.load(localFile, setupProjectFile, publishOutputs) + return False def onGraphChanged(self): """ React to the change of the internal graph. """ @@ -860,7 +855,7 @@ def handleFilesUrl(self, filesByType, cameraInit=None, position=None): ) ) else: - return self.loadUrl(filesByType["meshroomScenes"][0]) + return self.load(filesByType["meshroomScenes"][0]) diff --git a/tests/test_compatibility.py b/tests/test_compatibility.py index 00505352e7..07ba2526ce 100644 --- a/tests/test_compatibility.py +++ b/tests/test_compatibility.py @@ -4,6 +4,7 @@ import os import copy +from typing import Type import pytest import meshroom.core @@ -12,6 +13,8 @@ from meshroom.core.graph import Graph, loadGraph from meshroom.core.node import CompatibilityNode, CompatibilityIssue, Node +from .utils import registeredNodeTypes + SampleGroupV1 = [ desc.IntParam(name="a", label="a", description="", value=0, range=None), @@ -156,6 +159,12 @@ class SampleInputNodeV2(desc.InputNode): ] + +def replaceNodeTypeDesc(nodeType: str, nodeDesc: Type[desc.Node]): + """Change the `nodeDesc` associated to `nodeType`.""" + meshroom.core.nodesDesc[nodeType] = nodeDesc + + def test_unknown_node_type(): """ Test compatibility behavior for unknown node type. @@ -218,8 +227,7 @@ def test_description_conflict(): g.save(graphFile) # reload file as-is, ensure no compatibility issue is detected (no CompatibilityNode instances) - g = loadGraph(graphFile) - assert all(isinstance(n, Node) for n in g.nodes) + loadGraph(graphFile, strictCompatibility=True) # offset node types register to create description conflicts # each node type name now reference the next one's implementation @@ -247,7 +255,7 @@ def test_description_conflict(): assert not hasattr(compatNode, "in") # perform upgrade - upgradedNode = g.upgradeNode(nodeName)[0] + upgradedNode = g.upgradeNode(nodeName) assert isinstance(upgradedNode, Node) and isinstance(upgradedNode.nodeDesc, SampleNodeV2) assert list(upgradedNode.attributes.keys()) == ["in", "paramA", "output"] @@ -262,7 +270,7 @@ def test_description_conflict(): assert hasattr(compatNode, "paramA") # perform upgrade - upgradedNode = g.upgradeNode(nodeName)[0] + upgradedNode = g.upgradeNode(nodeName) assert isinstance(upgradedNode, Node) and isinstance(upgradedNode.nodeDesc, SampleNodeV3) assert not hasattr(upgradedNode, "paramA") @@ -275,7 +283,7 @@ def test_description_conflict(): assert not hasattr(compatNode, "paramA") # perform upgrade - upgradedNode = g.upgradeNode(nodeName)[0] + upgradedNode = g.upgradeNode(nodeName) assert isinstance(upgradedNode, Node) and isinstance(upgradedNode.nodeDesc, SampleNodeV4) assert hasattr(upgradedNode, "paramA") @@ -295,7 +303,7 @@ def test_description_conflict(): assert isinstance(elt, next(a for a in SampleGroupV1 if a.name == elt.name).__class__) # perform upgrade - upgradedNode = g.upgradeNode(nodeName)[0] + upgradedNode = g.upgradeNode(nodeName) assert isinstance(upgradedNode, Node) and isinstance(upgradedNode.nodeDesc, SampleNodeV5) assert hasattr(upgradedNode, "paramA") @@ -399,20 +407,220 @@ def test_conformUpgrade(): class TestGraphLoadingWithStrictCompatibility: + def test_failsOnUnknownNodeType(self, graphSavedOnDisk): + with registeredNodeTypes([SampleNodeV1]): + graph: Graph = graphSavedOnDisk + graph.addNewNode(SampleNodeV1.__name__) + graph.save() + + with pytest.raises(GraphCompatibilityError): + loadGraph(graph.filepath, strictCompatibility=True) + + def test_failsOnNodeDescriptionCompatibilityIssue(self, graphSavedOnDisk): - registerNodeType(SampleNodeV1) - registerNodeType(SampleNodeV2) - graph: Graph = graphSavedOnDisk - graph.addNewNode(SampleNodeV1.__name__) - graph.save() + with registeredNodeTypes([SampleNodeV1, SampleNodeV2]): + graph: Graph = graphSavedOnDisk + graph.addNewNode(SampleNodeV1.__name__) + graph.save() + + replaceNodeTypeDesc(SampleNodeV1.__name__, SampleNodeV2) + + with pytest.raises(GraphCompatibilityError): + loadGraph(graph.filepath, strictCompatibility=True) + - # Replace saved node description by V2 - meshroom.core.nodesDesc[SampleNodeV1.__name__] = SampleNodeV2 +class TestGraphTemplateLoading: + + def test_failsOnUnknownNodeTypeError(self, graphSavedOnDisk): + + with registeredNodeTypes([SampleNodeV1, SampleNodeV2]): + graph: Graph = graphSavedOnDisk + graph.addNewNode(SampleNodeV1.__name__) + graph.save(template=True) with pytest.raises(GraphCompatibilityError): loadGraph(graph.filepath, strictCompatibility=True) - unregisterNodeType(SampleNodeV1) - unregisterNodeType(SampleNodeV2) + def test_loadsIfIncompatibleNodeHasDefaultAttributeValues(self, graphSavedOnDisk): + with registeredNodeTypes([SampleNodeV1, SampleNodeV2]): + graph: Graph = graphSavedOnDisk + graph.addNewNode(SampleNodeV1.__name__) + graph.save(template=True) + + replaceNodeTypeDesc(SampleNodeV1.__name__, SampleNodeV2) + + loadGraph(graph.filepath, strictCompatibility=True) + + def test_loadsIfValueSetOnCompatibleAttribute(self, graphSavedOnDisk): + with registeredNodeTypes([SampleNodeV1, SampleNodeV2]): + graph: Graph = graphSavedOnDisk + node = graph.addNewNode(SampleNodeV1.__name__, paramA="foo") + graph.save(template=True) + + replaceNodeTypeDesc(SampleNodeV1.__name__, SampleNodeV2) + + loadedGraph = loadGraph(graph.filepath, strictCompatibility=True) + assert loadedGraph.nodes.get(node.name).paramA.value == "foo" + + def test_loadsIfValueSetOnIncompatibleAttribute(self, graphSavedOnDisk): + with registeredNodeTypes([SampleNodeV1, SampleNodeV2]): + graph: Graph = graphSavedOnDisk + graph.addNewNode(SampleNodeV1.__name__, input="foo") + graph.save(template=True) + + replaceNodeTypeDesc(SampleNodeV1.__name__, SampleNodeV2) + + loadGraph(graph.filepath, strictCompatibility=True) + + +class UidTestingNodeV1(desc.Node): + inputs = [ + desc.File(name="input", label="Input", description="", value="", invalidate=True), + ] + outputs = [desc.File(name="output", label="Output", description="", value=desc.Node.internalFolder)] + + +class UidTestingNodeV2(desc.Node): + """ + Changes from SampleNodeBV1: + * 'param' has been added + """ + + inputs = [ + desc.File(name="input", label="Input", description="", value="", invalidate=True), + desc.ListAttribute( + name="param", + label="Param", + elementDesc=desc.File( + name="file", + label="File", + description="", + value="", + ), + description="", + ), + ] + outputs = [desc.File(name="output", label="Output", description="", value=desc.Node.internalFolder)] + + +class UidTestingNodeV3(desc.Node): + """ + Changes from SampleNodeBV2: + * 'input' is not invalidating the UID. + """ + + inputs = [ + desc.File(name="input", label="Input", description="", value="", invalidate=False), + desc.ListAttribute( + name="param", + label="Param", + elementDesc=desc.File( + name="file", + label="File", + description="", + value="", + ), + description="", + ), + ] + outputs = [desc.File(name="output", label="Output", description="", value=desc.Node.internalFolder)] + + +class TestUidConflict: + def test_changingInvalidateOnAttributeDescCreatesUidConflict(self, graphSavedOnDisk): + with registeredNodeTypes([UidTestingNodeV2]): + graph: Graph = graphSavedOnDisk + node = graph.addNewNode(UidTestingNodeV2.__name__) + + graph.save() + replaceNodeTypeDesc(UidTestingNodeV2.__name__, UidTestingNodeV3) + + with pytest.raises(GraphCompatibilityError): + loadGraph(graph.filepath, strictCompatibility=True) + + loadedGraph = loadGraph(graph.filepath) + loadedNode = loadedGraph.node(node.name) + assert isinstance(loadedNode, CompatibilityNode) + assert loadedNode.issue == CompatibilityIssue.UidConflict + + def test_uidConflictingNodesPreserveConnectionsOnGraphLoad(self, graphSavedOnDisk): + with registeredNodeTypes([UidTestingNodeV2]): + graph: Graph = graphSavedOnDisk + nodeA = graph.addNewNode(UidTestingNodeV2.__name__) + nodeB = graph.addNewNode(UidTestingNodeV2.__name__) + + nodeB.param.append("") + graph.addEdge(nodeA.output, nodeB.param.at(0)) + + graph.save() + replaceNodeTypeDesc(UidTestingNodeV2.__name__, UidTestingNodeV3) + + loadedGraph = loadGraph(graph.filepath) + assert len(loadedGraph.compatibilityNodes) == 2 + + loadedNodeA = loadedGraph.node(nodeA.name) + loadedNodeB = loadedGraph.node(nodeB.name) + + assert loadedNodeB.param.at(0).linkParam == loadedNodeA.output + + def test_upgradingConflictingNodesPreserveConnections(self, graphSavedOnDisk): + with registeredNodeTypes([UidTestingNodeV2]): + graph: Graph = graphSavedOnDisk + nodeA = graph.addNewNode(UidTestingNodeV2.__name__) + nodeB = graph.addNewNode(UidTestingNodeV2.__name__) + + # Double-connect nodeA.output to nodeB, on both a single attribute and a list attribute + nodeB.param.append("") + graph.addEdge(nodeA.output, nodeB.param.at(0)) + graph.addEdge(nodeA.output, nodeB.input) + + graph.save() + replaceNodeTypeDesc(UidTestingNodeV2.__name__, UidTestingNodeV3) + + def checkNodeAConnectionsToNodeB(): + loadedNodeA = loadedGraph.node(nodeA.name) + loadedNodeB = loadedGraph.node(nodeB.name) + return ( + loadedNodeB.param.at(0).linkParam == loadedNodeA.output + and loadedNodeB.input.linkParam == loadedNodeA.output + ) + + loadedGraph = loadGraph(graph.filepath) + loadedGraph.upgradeNode(nodeA.name) + + assert checkNodeAConnectionsToNodeB() + loadedGraph.upgradeNode(nodeB.name) + + assert checkNodeAConnectionsToNodeB() + assert len(loadedGraph.compatibilityNodes) == 0 + + + def test_uidConflictDoesNotPropagateToValidDownstreamNodeThroughConnection(self, graphSavedOnDisk): + with registeredNodeTypes([UidTestingNodeV1, UidTestingNodeV2]): + graph: Graph = graphSavedOnDisk + nodeA = graph.addNewNode(UidTestingNodeV2.__name__) + nodeB = graph.addNewNode(UidTestingNodeV1.__name__) + + graph.addEdge(nodeA.output, nodeB.input) + + graph.save() + replaceNodeTypeDesc(UidTestingNodeV2.__name__, UidTestingNodeV3) + + loadedGraph = loadGraph(graph.filepath) + assert len(loadedGraph.compatibilityNodes) == 1 + + def test_uidConflictDoesNotPropagateToValidDownstreamNodeThroughListConnection(self, graphSavedOnDisk): + with registeredNodeTypes([UidTestingNodeV2, UidTestingNodeV3]): + graph: Graph = graphSavedOnDisk + nodeA = graph.addNewNode(UidTestingNodeV2.__name__) + nodeB = graph.addNewNode(UidTestingNodeV3.__name__) + + nodeB.param.append("") + graph.addEdge(nodeA.output, nodeB.param.at(0)) + + graph.save() + replaceNodeTypeDesc(UidTestingNodeV2.__name__, UidTestingNodeV3) + loadedGraph = loadGraph(graph.filepath) + assert len(loadedGraph.compatibilityNodes) == 1 diff --git a/tests/test_graphIO.py b/tests/test_graphIO.py new file mode 100644 index 0000000000..65835a5a69 --- /dev/null +++ b/tests/test_graphIO.py @@ -0,0 +1,299 @@ +from meshroom.core import desc +from meshroom.core.graph import Graph + +from .utils import registeredNodeTypes + + +class SimpleNode(desc.Node): + inputs = [ + desc.File(name="input", label="Input", description="", value=""), + ] + outputs = [ + desc.File(name="output", label="Output", description="", value=""), + ] + + +class NodeWithListAttributes(desc.Node): + inputs = [ + desc.ListAttribute( + name="listInput", + label="List Input", + description="", + elementDesc=desc.File(name="file", label="File", description="", value=""), + exposed=True, + ), + desc.GroupAttribute( + name="group", + label="Group", + description="", + groupDesc=[ + desc.ListAttribute( + name="listInput", + label="List Input", + description="", + elementDesc=desc.File(name="file", label="File", description="", value=""), + exposed=True, + ), + ], + ), + ] + + +def compareGraphsContent(graphA: Graph, graphB: Graph) -> bool: + """Returns whether the content (node and deges) of two graphs are considered identical. + + Similar nodes: nodes with the same name, type and compatibility status. + Similar edges: edges with the same source and destination attribute names. + """ + + def _buildNodesSet(graph: Graph): + return set([(node.name, node.nodeType, node.isCompatibilityNode) for node in graph.nodes]) + + def _buildEdgesSet(graph: Graph): + return set([(edge.src.fullName, edge.dst.fullName) for edge in graph.edges]) + + nodesSetA, edgesSetA = _buildNodesSet(graphA), _buildEdgesSet(graphA) + nodesSetB, edgesSetB = _buildNodesSet(graphB), _buildEdgesSet(graphB) + + return nodesSetA == nodesSetB and edgesSetA == edgesSetB + + +class TestImportGraphContent: + def test_importEmptyGraph(self): + graph = Graph("") + + otherGraph = Graph("") + nodes = otherGraph.importGraphContent(graph) + + assert len(nodes) == 0 + assert len(graph.nodes) == 0 + + def test_importGraphWithSingleNode(self): + graph = Graph("") + + with registeredNodeTypes([SimpleNode]): + graph.addNewNode(SimpleNode.__name__) + + otherGraph = Graph("") + otherGraph.importGraphContent(graph) + + assert compareGraphsContent(graph, otherGraph) + + def test_importGraphWithSeveralNodes(self): + graph = Graph("") + + with registeredNodeTypes([SimpleNode]): + graph.addNewNode(SimpleNode.__name__) + graph.addNewNode(SimpleNode.__name__) + + otherGraph = Graph("") + otherGraph.importGraphContent(graph) + + assert compareGraphsContent(graph, otherGraph) + + def test_importingGraphWithNodesAndEdges(self): + graph = Graph("") + + with registeredNodeTypes([SimpleNode]): + nodeA_1 = graph.addNewNode(SimpleNode.__name__) + nodeA_2 = graph.addNewNode(SimpleNode.__name__) + + graph.addEdge(nodeA_1.output, nodeA_2.input) + + otherGraph = Graph("") + otherGraph.importGraphContent(graph) + assert compareGraphsContent(graph, otherGraph) + + def test_edgeRemappingOnImportingGraphSeveralTimes(self): + graph = Graph("") + + with registeredNodeTypes([SimpleNode]): + nodeA_1 = graph.addNewNode(SimpleNode.__name__) + nodeA_2 = graph.addNewNode(SimpleNode.__name__) + + graph.addEdge(nodeA_1.output, nodeA_2.input) + + otherGraph = Graph("") + otherGraph.importGraphContent(graph) + otherGraph.importGraphContent(graph) + + def test_edgeRemappingOnImportingGraphWithUnkownNodeTypesSeveralTimes(self): + graph = Graph("") + + with registeredNodeTypes([SimpleNode]): + nodeA_1 = graph.addNewNode(SimpleNode.__name__) + nodeA_2 = graph.addNewNode(SimpleNode.__name__) + + graph.addEdge(nodeA_1.output, nodeA_2.input) + + otherGraph = Graph("") + otherGraph.importGraphContent(graph) + otherGraph.importGraphContent(graph) + + assert len(otherGraph.nodes) == 4 + assert len(otherGraph.compatibilityNodes) == 4 + assert len(otherGraph.edges) == 2 + + def test_importGraphWithUnknownNodeTypesCreatesCompatibilityNodes(self): + graph = Graph("") + + with registeredNodeTypes([SimpleNode]): + graph.addNewNode(SimpleNode.__name__) + + otherGraph = Graph("") + importedNode = otherGraph.importGraphContent(graph) + + assert len(importedNode) == 1 + assert importedNode[0].isCompatibilityNode + + def test_importGraphContentInPlace(self): + graph = Graph("") + + with registeredNodeTypes([SimpleNode]): + nodeA_1 = graph.addNewNode(SimpleNode.__name__) + nodeA_2 = graph.addNewNode(SimpleNode.__name__) + + graph.addEdge(nodeA_1.output, nodeA_2.input) + + graph.importGraphContent(graph) + + assert len(graph.nodes) == 4 + + def test_importGraphContentFromFile(self, graphSavedOnDisk): + graph: Graph = graphSavedOnDisk + + with registeredNodeTypes([SimpleNode]): + nodeA_1 = graph.addNewNode(SimpleNode.__name__) + nodeA_2 = graph.addNewNode(SimpleNode.__name__) + + graph.addEdge(nodeA_1.output, nodeA_2.input) + graph.save() + + otherGraph = Graph("") + nodes = otherGraph.importGraphContentFromFile(graph.filepath) + + assert len(nodes) == 2 + + assert compareGraphsContent(graph, otherGraph) + + def test_importGraphContentFromFileWithCompatibilityNodes(self, graphSavedOnDisk): + graph: Graph = graphSavedOnDisk + + with registeredNodeTypes([SimpleNode]): + nodeA_1 = graph.addNewNode(SimpleNode.__name__) + nodeA_2 = graph.addNewNode(SimpleNode.__name__) + + graph.addEdge(nodeA_1.output, nodeA_2.input) + graph.save() + + otherGraph = Graph("") + nodes = otherGraph.importGraphContentFromFile(graph.filepath) + + assert len(nodes) == 2 + assert len(otherGraph.compatibilityNodes) == 2 + assert not compareGraphsContent(graph, otherGraph) + + +class TestGraphPartialSerialization: + def test_emptyGraph(self): + graph = Graph("") + serializedGraph = graph.serializePartial([]) + + otherGraph = Graph("") + otherGraph._deserialize(serializedGraph) + assert compareGraphsContent(graph, otherGraph) + + def test_serializeAllNodesIsSimilarToStandardSerialization(self): + graph = Graph("") + + with registeredNodeTypes([SimpleNode]): + nodeA = graph.addNewNode(SimpleNode.__name__) + nodeB = graph.addNewNode(SimpleNode.__name__) + + graph.addEdge(nodeA.output, nodeB.input) + + partialSerializedGraph = graph.serializePartial([nodeA, nodeB]) + standardSerializedGraph = graph.serialize() + + graphA = Graph("") + graphA._deserialize(partialSerializedGraph) + + graphB = Graph("") + graphB._deserialize(standardSerializedGraph) + + assert compareGraphsContent(graph, graphA) + assert compareGraphsContent(graphA, graphB) + + def test_singleNodeWithInputConnectionFromNonSerializedNodeRemovesEdge(self): + graph = Graph("") + + with registeredNodeTypes([SimpleNode]): + nodeA = graph.addNewNode(SimpleNode.__name__) + nodeB = graph.addNewNode(SimpleNode.__name__) + + graph.addEdge(nodeA.output, nodeB.input) + + serializedGraph = graph.serializePartial([nodeB]) + + otherGraph = Graph("") + otherGraph._deserialize(serializedGraph) + + assert len(otherGraph.compatibilityNodes) == 0 + assert len(otherGraph.nodes) == 1 + assert len(otherGraph.edges) == 0 + + def test_serializeSingleNodeWithInputConnectionToListAttributeRemovesListEntry(self): + graph = Graph("") + + with registeredNodeTypes([SimpleNode, NodeWithListAttributes]): + nodeA = graph.addNewNode(SimpleNode.__name__) + nodeB = graph.addNewNode(NodeWithListAttributes.__name__) + + nodeB.listInput.append("") + graph.addEdge(nodeA.output, nodeB.listInput.at(0)) + + otherGraph = Graph("") + otherGraph._deserialize(graph.serializePartial([nodeB])) + + assert len(otherGraph.node(nodeB.name).listInput) == 0 + + def test_serializeSingleNodeWithInputConnectionToNestedListAttributeRemovesListEntry(self): + graph = Graph("") + + with registeredNodeTypes([SimpleNode, NodeWithListAttributes]): + nodeA = graph.addNewNode(SimpleNode.__name__) + nodeB = graph.addNewNode(NodeWithListAttributes.__name__) + + nodeB.group.listInput.append("") + graph.addEdge(nodeA.output, nodeB.group.listInput.at(0)) + + otherGraph = Graph("") + otherGraph._deserialize(graph.serializePartial([nodeB])) + + assert len(otherGraph.node(nodeB.name).group.listInput) == 0 + + +class TestGraphCopy: + def test_graphCopyIsIdenticalToOriginalGraph(self): + graph = Graph("") + + with registeredNodeTypes([SimpleNode]): + nodeA = graph.addNewNode(SimpleNode.__name__) + nodeB = graph.addNewNode(SimpleNode.__name__) + + graph.addEdge(nodeA.output, nodeB.input) + + graphCopy = graph.copy() + assert compareGraphsContent(graph, graphCopy) + + def test_graphCopyWithUnknownNodeTypesDiffersFromOriginalGraph(self): + graph = Graph("") + + with registeredNodeTypes([SimpleNode]): + nodeA = graph.addNewNode(SimpleNode.__name__) + nodeB = graph.addNewNode(SimpleNode.__name__) + + graph.addEdge(nodeA.output, nodeB.input) + + graphCopy = graph.copy() + assert not compareGraphsContent(graph, graphCopy) diff --git a/tests/test_nodeAttributeChangedCallback.py b/tests/test_nodeAttributeChangedCallback.py index edd14bc8dc..faee0e00ba 100644 --- a/tests/test_nodeAttributeChangedCallback.py +++ b/tests/test_nodeAttributeChangedCallback.py @@ -431,3 +431,28 @@ def test_loadingGraphWithComputedDynamicOutputValueDoesNotTriggerDownstreamAttri assert nodeB.affectedInput.value == 0 +class TestAttributeCallbackBehaviorOnGraphImport: + @classmethod + def setup_class(cls): + registerNodeType(NodeWithAttributeChangedCallback) + + @classmethod + def teardown_class(cls): + unregisterNodeType(NodeWithAttributeChangedCallback) + + def test_importingGraphDoesNotTriggerAttributeChangedCallbacks(self): + graph = Graph("") + + nodeA = graph.addNewNode(NodeWithAttributeChangedCallback.__name__) + nodeB = graph.addNewNode(NodeWithAttributeChangedCallback.__name__) + + graph.addEdge(nodeA.affectedInput, nodeB.input) + + nodeA.input.value = 5 + nodeB.affectedInput.value = 2 + + otherGraph = Graph("") + otherGraph.importGraphContent(graph) + + assert otherGraph.node(nodeB.name).affectedInput.value == 2 + diff --git a/tests/test_templatesVersion.py b/tests/test_templatesVersion.py index 402a228ac4..eb23628f72 100644 --- a/tests/test_templatesVersion.py +++ b/tests/test_templatesVersion.py @@ -4,6 +4,7 @@ from meshroom.core.graph import Graph from meshroom.core import pipelineTemplates, Version from meshroom.core.node import CompatibilityIssue, CompatibilityNode +from meshroom.core.graphIO import GraphIO import json import meshroom @@ -24,13 +25,13 @@ def test_templateVersions(): with open(path) as jsonFile: fileData = json.load(jsonFile) - graphData = fileData.get(Graph.IO.Keys.Graph, fileData) + graphData = fileData.get(GraphIO.Keys.Graph, fileData) assert isinstance(graphData, dict) - header = fileData.get(Graph.IO.Keys.Header, {}) + header = fileData.get(GraphIO.Keys.Header, {}) assert header.get("template", False) - nodesVersions = header.get(Graph.IO.Keys.NodesVersions, {}) + nodesVersions = header.get(GraphIO.Keys.NodesVersions, {}) for _, nodeData in graphData.items(): nodeType = nodeData["nodeType"] diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000000..30745c5f43 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,15 @@ +from contextlib import contextmanager +from typing import Type +from meshroom.core import registerNodeType, unregisterNodeType + +from meshroom.core import desc + +@contextmanager +def registeredNodeTypes(nodeTypes: list[Type[desc.Node]]): + for nodeType in nodeTypes: + registerNodeType(nodeType) + + yield + + for nodeType in nodeTypes: + unregisterNodeType(nodeType)