diff --git a/mars/services/task/analyzer/analyzer.py b/mars/services/task/analyzer/analyzer.py index e70390e157..95b694ae86 100644 --- a/mars/services/task/analyzer/analyzer.py +++ b/mars/services/task/analyzer/analyzer.py @@ -128,6 +128,7 @@ def __init__( if graph_assigner_cls is None: graph_assigner_cls = GraphAssigner self._graph_assigner_cls = graph_assigner_cls + self._chunk_to_copied = dict() self._logic_key_generator = LogicKeyGenerator() @classmethod @@ -226,6 +227,7 @@ def _gen_subtask_info( result_chunks_set = set() chunk_graph = ChunkGraph(result_chunks) out_of_scope_chunks = [] + chunk_to_copied = self._chunk_to_copied update_meta_chunks = [] # subtask properties band = None @@ -271,11 +273,13 @@ def _gen_subtask_info( chunk_priority = chunk.op.priority # process input chunks inp_chunks = [] + input_changed = False build_fetch_index_to_chunks = dict() for i, inp_chunk in enumerate(chunk.inputs): if inp_chunk in chunks_set: - inp_chunks.append(inp_chunk) + inp_chunks.append(chunk_to_copied[inp_chunk]) else: + input_changed = True build_fetch_index_to_chunks[i] = inp_chunk inp_chunks.append(None) if not isinstance(inp_chunk.op, Fetch): @@ -285,14 +289,31 @@ def _gen_subtask_info( ) for i, fetch_chunk in zip(build_fetch_index_to_chunks, fetch_chunks): inp_chunks[i] = fetch_chunk - for out_chunk in chunk.op.outputs: + + if input_changed: + copied_op = chunk.op.copy() + copied_op._key = chunk.op.key + out_chunks = [ + c.data + for c in copied_op.new_chunks( + inp_chunks, kws=[c.params.copy() for c in chunk.op.outputs] + ) + ] + else: + out_chunks = chunk.op.outputs # Note: `dtypes`, `index_value`, and `columns_value` are lazily # initialized, so we should call property `params` to initialize # these fields. - out_chunk.params - processed.add(out_chunk) + [o.params for o in out_chunks] + + for src_chunk, out_chunk in zip(chunk.op.outputs, out_chunks): + processed.add(src_chunk) + out_chunk._key = src_chunk.key chunk_graph.add_node(out_chunk) - if out_chunk in self._final_result_chunks_set: + # cannot be copied twice + assert src_chunk not in chunk_to_copied + chunk_to_copied[src_chunk] = out_chunk + if src_chunk in self._final_result_chunks_set: if out_chunk not in result_chunks_set: # add to result chunks result_chunks.append(out_chunk) @@ -320,12 +341,18 @@ def _gen_subtask_info( if out_of_scope_chunks: inp_subtasks = [] for out_of_scope_chunk in out_of_scope_chunks: + copied_out_of_scope_chunk = chunk_to_copied[out_of_scope_chunk] inp_subtask = chunk_to_subtask[out_of_scope_chunk] - if out_of_scope_chunk not in inp_subtask.chunk_graph.result_chunks: + if ( + copied_out_of_scope_chunk + not in inp_subtask.chunk_graph.result_chunks + ): # make sure the chunk that out of scope # is in the input subtask's results, # or the meta may be lost - inp_subtask.chunk_graph.result_chunks.append(out_of_scope_chunk) + inp_subtask.chunk_graph.result_chunks.append( + copied_out_of_scope_chunk + ) inp_subtasks.append(inp_subtask) depth = max(st.priority[0] for st in inp_subtasks) + 1 else: @@ -383,9 +410,10 @@ def _gen_map_reduce_info( # record analyzer map reduce id for mapper op # copied chunk exists because map chunk must have # been processed before shuffle proxy - if not hasattr(map_chunk, "extra_params"): # pragma: no cover - map_chunk.extra_params = dict() - map_chunk.extra_params["analyzer_map_reduce_id"] = map_reduce_id + copied_map_chunk = self._chunk_to_copied[map_chunk] + if not hasattr(copied_map_chunk, "extra_params"): # pragma: no cover + copied_map_chunk.extra_params = dict() + copied_map_chunk.extra_params["analyzer_map_reduce_id"] = map_reduce_id reducer_bands = [assign_results[r.outputs[0]] for r in reducer_ops] map_reduce_info = MapReduceInfo( map_reduce_id=map_reduce_id, diff --git a/mars/services/task/supervisor/tests/task_preprocessor.py b/mars/services/task/supervisor/tests/task_preprocessor.py index 496c0037e4..89ea51faa3 100644 --- a/mars/services/task/supervisor/tests/task_preprocessor.py +++ b/mars/services/task/supervisor/tests/task_preprocessor.py @@ -180,7 +180,11 @@ def analyze( map_reduce_id_to_infos=self.map_reduce_id_to_infos, ) subtask_graph = analyzer.gen_subtask_graph() - results = set(c for c in chunk_graph.results if not isinstance(c.op, Fetch)) + results = set( + analyzer._chunk_to_copied[c] + for c in chunk_graph.results + if not isinstance(c.op, Fetch) + ) for subtask in subtask_graph: if subtask.extra_config is None: subtask.extra_config = dict()