From a5aff6dd6ee02920c38f59ef7e3339c352dd6b1b Mon Sep 17 00:00:00 2001 From: zazulam Date: Tue, 24 Sep 2024 13:24:26 -0400 Subject: [PATCH] refactor: add conditional for primitives Signed-off-by: zazulam --- .../kfp/compiler/pipeline_spec_builder.py | 30 ++++++++++++++----- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/sdk/python/kfp/compiler/pipeline_spec_builder.py b/sdk/python/kfp/compiler/pipeline_spec_builder.py index ac11e9582d0..0024f6214ad 100644 --- a/sdk/python/kfp/compiler/pipeline_spec_builder.py +++ b/sdk/python/kfp/compiler/pipeline_spec_builder.py @@ -193,15 +193,31 @@ def check_task_input_types(input_value, input_name, pipeline_task_spec, task, elif isinstance(input_value, list): for item in input_value: - check_task_input_types(item, input_name, pipeline_task_spec, task, - parent_component_inputs, - tasks_in_current_dag) + if isinstance(item, (pipeline_channel.PipelineArtifactChannel, + pipeline_channel.PipelineParameterChannel)): + check_task_input_types(item, input_name, pipeline_task_spec, + task, parent_component_inputs, + tasks_in_current_dag) + elif isinstance(item, (str, int, float, bool)): + pipeline_task_spec.inputs.parameters[ + input_name].runtime_value.constant.CopyFrom( + to_protobuf_value(input_value)) elif isinstance(input_value, dict): - for _, value in input_value.items(): - check_task_input_types(value, input_name, pipeline_task_spec, task, - parent_component_inputs, - tasks_in_current_dag) + for key, value in input_value.items(): + if isinstance(value, (pipeline_channel.PipelineArtifactChannel, + pipeline_channel.PipelineParameterChannel)): + check_task_input_types(value, input_name, pipeline_task_spec, + task, parent_component_inputs, + tasks_in_current_dag) + + elif isinstance(value, (str, int, float, bool)): + pipeline_task_spec.inputs.parameters[ + input_name].runtime_value.constant.CopyFrom( + to_protobuf_value(input_value)) + # check_task_input_types(value, input_name, pipeline_task_spec, task, + # parent_component_inputs, + # tasks_in_current_dag) elif isinstance(input_value, (str, int, float, bool)): pipeline_channels = (