diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index c7cc0de6b..affc853ff 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -569,6 +569,8 @@ async def execute_query( will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions from any retries that failed google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error + google.cloud.bigtable.data.exceptions.ParameterTypeInferenceFailed: Raised if + a parameter is passed without an explicit type, and the type cannot be infered """ warnings.warn( "ExecuteQuery is in preview and may change in the future.", diff --git a/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py b/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py index 66f264610..a8f60be36 100644 --- a/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py +++ b/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py @@ -22,7 +22,6 @@ Tuple, TYPE_CHECKING, ) - from google.api_core import retry as retries from google.cloud.bigtable.data.execute_query._byte_cursor import _ByteCursor @@ -116,7 +115,6 @@ def __init__( exception_factory=_retry_exception_factory, ) self._req_metadata = req_metadata - try: self._register_instance_task = CrossSync.create_task( self._client._register_instance, diff --git a/google/cloud/bigtable/data/execute_query/_parameters_formatting.py b/google/cloud/bigtable/data/execute_query/_parameters_formatting.py index edb7a6380..eadda21f4 100644 --- a/google/cloud/bigtable/data/execute_query/_parameters_formatting.py +++ b/google/cloud/bigtable/data/execute_query/_parameters_formatting.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional import datetime +from typing import Any, Dict, Optional + from google.api_core.datetime_helpers import DatetimeWithNanoseconds + from google.cloud.bigtable.data.exceptions import ParameterTypeInferenceFailed -from google.cloud.bigtable.data.execute_query.values import ExecuteQueryValueType from google.cloud.bigtable.data.execute_query.metadata import SqlType +from google.cloud.bigtable.data.execute_query.values import ExecuteQueryValueType def _format_execute_query_params( @@ -48,7 +50,6 @@ def _format_execute_query_params( parameter_types = parameter_types or {} result_values = {} - for key, value in params.items(): user_provided_type = parameter_types.get(key) try: @@ -109,6 +110,16 @@ def _detect_type(value: ExecuteQueryValueType) -> SqlType.Type: "Cannot infer type of None, please provide the type manually." ) + if isinstance(value, list): + raise ParameterTypeInferenceFailed( + "Cannot infer type of ARRAY parameters, please provide the type manually." + ) + + if isinstance(value, float): + raise ParameterTypeInferenceFailed( + "Cannot infer type of float, must specify either FLOAT32 or FLOAT64 type manually." + ) + for field_type, type_dict in _TYPES_TO_TYPE_DICTS: if isinstance(value, field_type): return type_dict diff --git a/google/cloud/bigtable/data/execute_query/_query_result_parsing_utils.py b/google/cloud/bigtable/data/execute_query/_query_result_parsing_utils.py index b65dce27b..4cb5db291 100644 --- a/google/cloud/bigtable/data/execute_query/_query_result_parsing_utils.py +++ b/google/cloud/bigtable/data/execute_query/_query_result_parsing_utils.py @@ -22,6 +22,7 @@ SqlType.Bytes: "bytes_value", SqlType.String: "string_value", SqlType.Int64: "int_value", + SqlType.Float32: "float_value", SqlType.Float64: "float_value", SqlType.Bool: "bool_value", SqlType.Timestamp: "timestamp_value", diff --git a/google/cloud/bigtable/data/execute_query/metadata.py b/google/cloud/bigtable/data/execute_query/metadata.py index 0c9cf9697..bb29588d0 100644 --- a/google/cloud/bigtable/data/execute_query/metadata.py +++ b/google/cloud/bigtable/data/execute_query/metadata.py @@ -21,23 +21,16 @@ """ from collections import defaultdict -from typing import ( - Optional, - List, - Dict, - Set, - Type, - Union, - Tuple, - Any, -) +import datetime +from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union + +from google.api_core.datetime_helpers import DatetimeWithNanoseconds +from google.protobuf import timestamp_pb2 # type: ignore +from google.type import date_pb2 # type: ignore + from google.cloud.bigtable.data.execute_query.values import _NamedList from google.cloud.bigtable_v2 import ResultSetMetadata from google.cloud.bigtable_v2 import Type as PBType -from google.type import date_pb2 # type: ignore -from google.protobuf import timestamp_pb2 # type: ignore -from google.api_core.datetime_helpers import DatetimeWithNanoseconds -import datetime class SqlType: @@ -127,6 +120,8 @@ class Array(Type): def __init__(self, element_type: "SqlType.Type"): if isinstance(element_type, SqlType.Array): raise ValueError("Arrays of arrays are not supported.") + if isinstance(element_type, SqlType.Map): + raise ValueError("Arrays of Maps are not supported.") self._element_type = element_type @property @@ -140,10 +135,21 @@ def from_pb_type(cls, type_pb: Optional[PBType] = None) -> "SqlType.Array": return cls(_pb_type_to_metadata_type(type_pb.array_type.element_type)) def _to_value_pb_dict(self, value: Any): - raise NotImplementedError("Array is not supported as a query parameter") + if value is None: + return {} + + return { + "array_value": { + "values": [ + self.element_type._to_value_pb_dict(entry) for entry in value + ] + } + } def _to_type_pb_dict(self) -> Dict[str, Any]: - raise NotImplementedError("Array is not supported as a query parameter") + return { + "array_type": {"element_type": self.element_type._to_type_pb_dict()} + } def __eq__(self, other): return super().__eq__(other) and self.element_type == other.element_type @@ -222,6 +228,13 @@ class Float64(Type): value_pb_dict_field_name = "float_value" type_field_name = "float64_type" + class Float32(Type): + """Float32 SQL type.""" + + expected_type = float + value_pb_dict_field_name = "float_value" + type_field_name = "float32_type" + class Bool(Type): """Bool SQL type.""" @@ -376,6 +389,7 @@ def _pb_metadata_to_metadata_types( "bytes_type": SqlType.Bytes, "string_type": SqlType.String, "int64_type": SqlType.Int64, + "float32_type": SqlType.Float32, "float64_type": SqlType.Float64, "bool_type": SqlType.Bool, "timestamp_type": SqlType.Timestamp, diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index 74f318d39..5c11c1990 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -14,13 +14,16 @@ import pytest import asyncio +import datetime import uuid import os from google.api_core import retry from google.api_core.exceptions import ClientError +from google.cloud.bigtable.data.execute_query.metadata import SqlType from google.cloud.bigtable.data.read_modify_write_rules import _MAX_INCREMENT_VALUE from google.cloud.environment_vars import BIGTABLE_EMULATOR +from google.type import date_pb2 from google.cloud.bigtable.data._cross_sync import CrossSync @@ -1027,3 +1030,83 @@ async def test_execute_query_simple(self, client, table_id, instance_id): row = rows[0] assert row["a"] == 1 assert row["b"] == "foo" + + @CrossSync.pytest + @pytest.mark.usefixtures("client") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + async def test_execute_query_params(self, client, table_id, instance_id): + query = ( + "SELECT @stringParam AS strCol, @bytesParam as bytesCol, @int64Param AS intCol, " + "@float32Param AS float32Col, @float64Param AS float64Col, @boolParam AS boolCol, " + "@tsParam AS tsCol, @dateParam AS dateCol, @byteArrayParam AS byteArrayCol, " + "@stringArrayParam AS stringArrayCol, @intArrayParam AS intArrayCol, " + "@float32ArrayParam AS float32ArrayCol, @float64ArrayParam AS float64ArrayCol, " + "@boolArrayParam AS boolArrayCol, @tsArrayParam AS tsArrayCol, " + "@dateArrayParam AS dateArrayCol" + ) + parameters = { + "stringParam": "foo", + "bytesParam": b"bar", + "int64Param": 12, + "float32Param": 1.1, + "float64Param": 1.2, + "boolParam": True, + "tsParam": datetime.datetime.fromtimestamp(1000, tz=datetime.timezone.utc), + "dateParam": datetime.date(2025, 1, 16), + "byteArrayParam": [b"foo", b"bar", None], + "stringArrayParam": ["foo", "bar", None], + "intArrayParam": [1, None, 2], + "float32ArrayParam": [1.2, None, 1.3], + "float64ArrayParam": [1.4, None, 1.5], + "boolArrayParam": [None, False, True], + "tsArrayParam": [ + datetime.datetime.fromtimestamp(1000, tz=datetime.timezone.utc), + datetime.datetime.fromtimestamp(2000, tz=datetime.timezone.utc), + None, + ], + "dateArrayParam": [ + datetime.date(2025, 1, 16), + datetime.date(2025, 1, 17), + None, + ], + } + param_types = { + "float32Param": SqlType.Float32(), + "float64Param": SqlType.Float64(), + "byteArrayParam": SqlType.Array(SqlType.Bytes()), + "stringArrayParam": SqlType.Array(SqlType.String()), + "intArrayParam": SqlType.Array(SqlType.Int64()), + "float32ArrayParam": SqlType.Array(SqlType.Float32()), + "float64ArrayParam": SqlType.Array(SqlType.Float64()), + "boolArrayParam": SqlType.Array(SqlType.Bool()), + "tsArrayParam": SqlType.Array(SqlType.Timestamp()), + "dateArrayParam": SqlType.Array(SqlType.Date()), + } + result = await client.execute_query( + query, instance_id, parameters=parameters, parameter_types=param_types + ) + rows = [r async for r in result] + assert len(rows) == 1 + row = rows[0] + assert row["strCol"] == parameters["stringParam"] + assert row["bytesCol"] == parameters["bytesParam"] + assert row["intCol"] == parameters["int64Param"] + assert row["float32Col"] == pytest.approx(parameters["float32Param"]) + assert row["float64Col"] == pytest.approx(parameters["float64Param"]) + assert row["boolCol"] == parameters["boolParam"] + assert row["tsCol"] == parameters["tsParam"] + assert row["dateCol"] == date_pb2.Date(year=2025, month=1, day=16) + assert row["stringArrayCol"] == parameters["stringArrayParam"] + assert row["byteArrayCol"] == parameters["byteArrayParam"] + assert row["intArrayCol"] == parameters["intArrayParam"] + assert row["float32ArrayCol"] == pytest.approx(parameters["float32ArrayParam"]) + assert row["float64ArrayCol"] == pytest.approx(parameters["float64ArrayParam"]) + assert row["boolArrayCol"] == parameters["boolArrayParam"] + assert row["tsArrayCol"] == parameters["tsArrayParam"] + assert row["dateArrayCol"] == [ + date_pb2.Date(year=2025, month=1, day=16), + date_pb2.Date(year=2025, month=1, day=17), + None, + ] diff --git a/tests/system/data/test_system_autogen.py b/tests/system/data/test_system_autogen.py index c96cfdb50..cbaaf5a8f 100644 --- a/tests/system/data/test_system_autogen.py +++ b/tests/system/data/test_system_autogen.py @@ -16,12 +16,15 @@ # This file is automatically generated by CrossSync. Do not edit manually. import pytest +import datetime import uuid import os from google.api_core import retry from google.api_core.exceptions import ClientError +from google.cloud.bigtable.data.execute_query.metadata import SqlType from google.cloud.bigtable.data.read_modify_write_rules import _MAX_INCREMENT_VALUE from google.cloud.environment_vars import BIGTABLE_EMULATOR +from google.type import date_pb2 from google.cloud.bigtable.data._cross_sync import CrossSync from . import TEST_FAMILY, TEST_FAMILY_2 @@ -838,3 +841,74 @@ def test_execute_query_simple(self, client, table_id, instance_id): row = rows[0] assert row["a"] == 1 assert row["b"] == "foo" + + @pytest.mark.usefixtures("client") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_execute_query_params(self, client, table_id, instance_id): + query = "SELECT @stringParam AS strCol, @bytesParam as bytesCol, @int64Param AS intCol, @float32Param AS float32Col, @float64Param AS float64Col, @boolParam AS boolCol, @tsParam AS tsCol, @dateParam AS dateCol, @byteArrayParam AS byteArrayCol, @stringArrayParam AS stringArrayCol, @intArrayParam AS intArrayCol, @float32ArrayParam AS float32ArrayCol, @float64ArrayParam AS float64ArrayCol, @boolArrayParam AS boolArrayCol, @tsArrayParam AS tsArrayCol, @dateArrayParam AS dateArrayCol" + parameters = { + "stringParam": "foo", + "bytesParam": b"bar", + "int64Param": 12, + "float32Param": 1.1, + "float64Param": 1.2, + "boolParam": True, + "tsParam": datetime.datetime.fromtimestamp(1000, tz=datetime.timezone.utc), + "dateParam": datetime.date(2025, 1, 16), + "byteArrayParam": [b"foo", b"bar", None], + "stringArrayParam": ["foo", "bar", None], + "intArrayParam": [1, None, 2], + "float32ArrayParam": [1.2, None, 1.3], + "float64ArrayParam": [1.4, None, 1.5], + "boolArrayParam": [None, False, True], + "tsArrayParam": [ + datetime.datetime.fromtimestamp(1000, tz=datetime.timezone.utc), + datetime.datetime.fromtimestamp(2000, tz=datetime.timezone.utc), + None, + ], + "dateArrayParam": [ + datetime.date(2025, 1, 16), + datetime.date(2025, 1, 17), + None, + ], + } + param_types = { + "float32Param": SqlType.Float32(), + "float64Param": SqlType.Float64(), + "byteArrayParam": SqlType.Array(SqlType.Bytes()), + "stringArrayParam": SqlType.Array(SqlType.String()), + "intArrayParam": SqlType.Array(SqlType.Int64()), + "float32ArrayParam": SqlType.Array(SqlType.Float32()), + "float64ArrayParam": SqlType.Array(SqlType.Float64()), + "boolArrayParam": SqlType.Array(SqlType.Bool()), + "tsArrayParam": SqlType.Array(SqlType.Timestamp()), + "dateArrayParam": SqlType.Array(SqlType.Date()), + } + result = client.execute_query( + query, instance_id, parameters=parameters, parameter_types=param_types + ) + rows = [r for r in result] + assert len(rows) == 1 + row = rows[0] + assert row["strCol"] == parameters["stringParam"] + assert row["bytesCol"] == parameters["bytesParam"] + assert row["intCol"] == parameters["int64Param"] + assert row["float32Col"] == pytest.approx(parameters["float32Param"]) + assert row["float64Col"] == pytest.approx(parameters["float64Param"]) + assert row["boolCol"] == parameters["boolParam"] + assert row["tsCol"] == parameters["tsParam"] + assert row["dateCol"] == date_pb2.Date(year=2025, month=1, day=16) + assert row["stringArrayCol"] == parameters["stringArrayParam"] + assert row["byteArrayCol"] == parameters["byteArrayParam"] + assert row["intArrayCol"] == parameters["intArrayParam"] + assert row["float32ArrayCol"] == pytest.approx(parameters["float32ArrayParam"]) + assert row["float64ArrayCol"] == pytest.approx(parameters["float64ArrayParam"]) + assert row["boolArrayCol"] == parameters["boolArrayParam"] + assert row["tsArrayCol"] == parameters["tsArrayParam"] + assert row["dateArrayCol"] == [ + date_pb2.Date(year=2025, month=1, day=16), + date_pb2.Date(year=2025, month=1, day=17), + None, + ] diff --git a/tests/unit/data/execute_query/test_execute_query_parameters_parsing.py b/tests/unit/data/execute_query/test_execute_query_parameters_parsing.py index f7159fb71..bebbd8d45 100644 --- a/tests/unit/data/execute_query/test_execute_query_parameters_parsing.py +++ b/tests/unit/data/execute_query/test_execute_query_parameters_parsing.py @@ -12,17 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime + +from google.api_core.datetime_helpers import DatetimeWithNanoseconds +from google.type import date_pb2 import pytest + from google.cloud.bigtable.data.execute_query._parameters_formatting import ( _format_execute_query_params, ) from google.cloud.bigtable.data.execute_query.metadata import SqlType from google.cloud.bigtable.data.execute_query.values import Struct -import datetime - -from google.type import date_pb2 -from google.api_core.datetime_helpers import DatetimeWithNanoseconds - +from google.protobuf import timestamp_pb2 timestamp = int( datetime.datetime(2024, 5, 12, 17, 44, 12, tzinfo=datetime.timezone.utc).timestamp() @@ -71,7 +72,7 @@ ), ], ) -def test_instance_execute_query_parameters_simple_types_parsing( +def test_execute_query_parameters_inferred_types_parsing( input_value, value_field, type_field, expected_value ): result = _format_execute_query_params( @@ -84,7 +85,161 @@ def test_instance_execute_query_parameters_simple_types_parsing( assert type_field in result["test"]["type_"] -def test_instance_execute_query_parameters_not_supported_types(): +@pytest.mark.parametrize( + "value, sql_type, proto_result", + [ + (1.3, SqlType.Float32(), {"type_": {"float32_type": {}}, "float_value": 1.3}), + (1.3, SqlType.Float64(), {"type_": {"float64_type": {}}, "float_value": 1.3}), + ( + [1, 2, 3, 4], + SqlType.Array(SqlType.Int64()), + { + "type_": {"array_type": {"element_type": {"int64_type": {}}}}, + "array_value": { + "values": [ + {"int_value": 1}, + {"int_value": 2}, + {"int_value": 3}, + {"int_value": 4}, + ] + }, + }, + ), + ( + [1, None, 2, None], + SqlType.Array(SqlType.Int64()), + { + "type_": {"array_type": {"element_type": {"int64_type": {}}}}, + "array_value": { + "values": [ + {"int_value": 1}, + {}, + {"int_value": 2}, + {}, + ] + }, + }, + ), + ( + None, + SqlType.Array(SqlType.Int64()), + { + "type_": {"array_type": {"element_type": {"int64_type": {}}}}, + }, + ), + ( + ["foo", "bar", None], + SqlType.Array(SqlType.String()), + { + "type_": {"array_type": {"element_type": {"string_type": {}}}}, + "array_value": { + "values": [ + {"string_value": "foo"}, + {"string_value": "bar"}, + {}, + ] + }, + }, + ), + ( + [b"foo", b"bar", None], + SqlType.Array(SqlType.Bytes()), + { + "type_": {"array_type": {"element_type": {"bytes_type": {}}}}, + "array_value": { + "values": [ + {"bytes_value": b"foo"}, + {"bytes_value": b"bar"}, + {}, + ] + }, + }, + ), + ( + [ + datetime.datetime.fromtimestamp(1000, tz=datetime.timezone.utc), + datetime.datetime.fromtimestamp(2000, tz=datetime.timezone.utc), + None, + ], + SqlType.Array(SqlType.Timestamp()), + { + "type_": {"array_type": {"element_type": {"timestamp_type": {}}}}, + "array_value": { + "values": [ + {"timestamp_value": timestamp_pb2.Timestamp(seconds=1000)}, + {"timestamp_value": timestamp_pb2.Timestamp(seconds=2000)}, + {}, + ], + }, + }, + ), + ( + [True, False, None], + SqlType.Array(SqlType.Bool()), + { + "type_": {"array_type": {"element_type": {"bool_type": {}}}}, + "array_value": { + "values": [ + {"bool_value": True}, + {"bool_value": False}, + {}, + ], + }, + }, + ), + ( + [datetime.date(2025, 1, 16), datetime.date(2025, 1, 17), None], + SqlType.Array(SqlType.Date()), + { + "type_": {"array_type": {"element_type": {"date_type": {}}}}, + "array_value": { + "values": [ + {"date_value": date_pb2.Date(year=2025, month=1, day=16)}, + {"date_value": date_pb2.Date(year=2025, month=1, day=17)}, + {}, + ], + }, + }, + ), + ( + [1.1, 1.2, None], + SqlType.Array(SqlType.Float32()), + { + "type_": {"array_type": {"element_type": {"float32_type": {}}}}, + "array_value": { + "values": [ + {"float_value": 1.1}, + {"float_value": 1.2}, + {}, + ] + }, + }, + ), + ( + [1.1, 1.2, None], + SqlType.Array(SqlType.Float64()), + { + "type_": {"array_type": {"element_type": {"float64_type": {}}}}, + "array_value": { + "values": [ + {"float_value": 1.1}, + {"float_value": 1.2}, + {}, + ] + }, + }, + ), + ], +) +def test_execute_query_explicit_parameter_parsing(value, sql_type, proto_result): + result = _format_execute_query_params( + {"param_name": value}, {"param_name": sql_type} + ) + print(result) + assert result["param_name"] == proto_result + + +def test_execute_query_parameters_not_supported_types(): with pytest.raises(ValueError): _format_execute_query_params({"test1": 1.1}, None) @@ -105,14 +260,6 @@ def test_instance_execute_query_parameters_not_supported_types(): }, ) - with pytest.raises(NotImplementedError, match="not supported"): - _format_execute_query_params( - {"test1": [1]}, - { - "test1": SqlType.Array(SqlType.Int64()), - }, - ) - with pytest.raises(NotImplementedError, match="not supported"): _format_execute_query_params( {"test1": Struct([("field1", 1)])}, @@ -132,3 +279,16 @@ def test_instance_execute_query_parameters_not_match(): "test2": SqlType.String(), }, ) + + +def test_array_params_enforce_element_type(): + with pytest.raises(ValueError, match="Error when parsing parameter p") as e1: + _format_execute_query_params( + {"p": ["a", 1, None]}, {"p": SqlType.Array(SqlType.String())} + ) + with pytest.raises(ValueError, match="Error when parsing parameter p") as e2: + _format_execute_query_params( + {"p": ["a", 1, None]}, {"p": SqlType.Array(SqlType.Int64())} + ) + assert "Expected query parameter of type str, got int" in str(e1.value.__cause__) + assert "Expected query parameter of type int, got str" in str(e2.value.__cause__)