From 78ce9911fa4911d094056363cdfe634abcc1b093 Mon Sep 17 00:00:00 2001 From: joseph-sentry <136376984+joseph-sentry@users.noreply.github.com> Date: Thu, 9 Jan 2025 11:59:05 -0500 Subject: [PATCH] feat: add capability for writing to bigquery (#993) --- generated_proto/testrun/__init__.py | 0 generated_proto/testrun/ta_testrun_pb2.py | 35 +++ generated_proto/testrun/ta_testrun_pb2.pyi | 83 ++++++ protobuf/ta_testrun.proto | 30 ++ requirements.in | 3 + requirements.txt | 58 +++- services/bigquery.py | 175 +++++++++++ services/tests/test_bigquery.py | 148 +++++++++ ta_storage/__init__.py | 0 ta_storage/base.py | 20 ++ ta_storage/bq.py | 85 ++++++ ta_storage/pg.py | 331 +++++++++++++++++++++ ta_storage/tests/test_bq.py | 117 ++++++++ ta_storage/tests/test_pg.py | 63 ++++ 14 files changed, 1139 insertions(+), 9 deletions(-) create mode 100644 generated_proto/testrun/__init__.py create mode 100644 generated_proto/testrun/ta_testrun_pb2.py create mode 100644 generated_proto/testrun/ta_testrun_pb2.pyi create mode 100644 protobuf/ta_testrun.proto create mode 100644 services/bigquery.py create mode 100644 services/tests/test_bigquery.py create mode 100644 ta_storage/__init__.py create mode 100644 ta_storage/base.py create mode 100644 ta_storage/bq.py create mode 100644 ta_storage/pg.py create mode 100644 ta_storage/tests/test_bq.py create mode 100644 ta_storage/tests/test_pg.py diff --git a/generated_proto/testrun/__init__.py b/generated_proto/testrun/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/generated_proto/testrun/ta_testrun_pb2.py b/generated_proto/testrun/ta_testrun_pb2.py new file mode 100644 index 000000000..ba133248f --- /dev/null +++ b/generated_proto/testrun/ta_testrun_pb2.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: ta_testrun.proto +# Protobuf Python Version: 5.29.2 +"""Generated protocol buffer code.""" + +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, 5, 29, 2, "", "ta_testrun.proto" +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x10ta_testrun.proto"\xda\x02\n\x07TestRun\x12\x11\n\ttimestamp\x18\x01 \x01(\x03\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x11\n\tclassname\x18\x03 \x01(\t\x12\x11\n\ttestsuite\x18\x04 \x01(\t\x12\x15\n\rcomputed_name\x18\x05 \x01(\t\x12!\n\x07outcome\x18\x06 \x01(\x0e\x32\x10.TestRun.Outcome\x12\x17\n\x0f\x66\x61ilure_message\x18\x07 \x01(\t\x12\x18\n\x10\x64uration_seconds\x18\x08 \x01(\x02\x12\x0e\n\x06repoid\x18\n \x01(\x03\x12\x12\n\ncommit_sha\x18\x0b \x01(\t\x12\x13\n\x0b\x62ranch_name\x18\x0c \x01(\t\x12\r\n\x05\x66lags\x18\r \x03(\t\x12\x10\n\x08\x66ilename\x18\x0e \x01(\t\x12\x11\n\tframework\x18\x0f \x01(\t".\n\x07Outcome\x12\n\n\x06PASSED\x10\x00\x12\n\n\x06\x46\x41ILED\x10\x01\x12\x0b\n\x07SKIPPED\x10\x02' +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "ta_testrun_pb2", _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals["_TESTRUN"]._serialized_start = 21 + _globals["_TESTRUN"]._serialized_end = 367 + _globals["_TESTRUN_OUTCOME"]._serialized_start = 321 + _globals["_TESTRUN_OUTCOME"]._serialized_end = 367 +# @@protoc_insertion_point(module_scope) diff --git a/generated_proto/testrun/ta_testrun_pb2.pyi b/generated_proto/testrun/ta_testrun_pb2.pyi new file mode 100644 index 000000000..bd9eb0d61 --- /dev/null +++ b/generated_proto/testrun/ta_testrun_pb2.pyi @@ -0,0 +1,83 @@ +from typing import ClassVar as _ClassVar +from typing import Iterable as _Iterable +from typing import Optional as _Optional +from typing import Union as _Union + +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf.internal import containers as _containers +from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper + +DESCRIPTOR: _descriptor.FileDescriptor + +class TestRun(_message.Message): + __slots__ = ( + "timestamp", + "name", + "classname", + "testsuite", + "computed_name", + "outcome", + "failure_message", + "duration_seconds", + "repoid", + "commit_sha", + "branch_name", + "flags", + "filename", + "framework", + ) + class Outcome(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + PASSED: _ClassVar[TestRun.Outcome] + FAILED: _ClassVar[TestRun.Outcome] + SKIPPED: _ClassVar[TestRun.Outcome] + + PASSED: TestRun.Outcome + FAILED: TestRun.Outcome + SKIPPED: TestRun.Outcome + TIMESTAMP_FIELD_NUMBER: _ClassVar[int] + NAME_FIELD_NUMBER: _ClassVar[int] + CLASSNAME_FIELD_NUMBER: _ClassVar[int] + TESTSUITE_FIELD_NUMBER: _ClassVar[int] + COMPUTED_NAME_FIELD_NUMBER: _ClassVar[int] + OUTCOME_FIELD_NUMBER: _ClassVar[int] + FAILURE_MESSAGE_FIELD_NUMBER: _ClassVar[int] + DURATION_SECONDS_FIELD_NUMBER: _ClassVar[int] + REPOID_FIELD_NUMBER: _ClassVar[int] + COMMIT_SHA_FIELD_NUMBER: _ClassVar[int] + BRANCH_NAME_FIELD_NUMBER: _ClassVar[int] + FLAGS_FIELD_NUMBER: _ClassVar[int] + FILENAME_FIELD_NUMBER: _ClassVar[int] + FRAMEWORK_FIELD_NUMBER: _ClassVar[int] + timestamp: int + name: str + classname: str + testsuite: str + computed_name: str + outcome: TestRun.Outcome + failure_message: str + duration_seconds: float + repoid: int + commit_sha: str + branch_name: str + flags: _containers.RepeatedScalarFieldContainer[str] + filename: str + framework: str + def __init__( + self, + timestamp: _Optional[int] = ..., + name: _Optional[str] = ..., + classname: _Optional[str] = ..., + testsuite: _Optional[str] = ..., + computed_name: _Optional[str] = ..., + outcome: _Optional[_Union[TestRun.Outcome, str]] = ..., + failure_message: _Optional[str] = ..., + duration_seconds: _Optional[float] = ..., + repoid: _Optional[int] = ..., + commit_sha: _Optional[str] = ..., + branch_name: _Optional[str] = ..., + flags: _Optional[_Iterable[str]] = ..., + filename: _Optional[str] = ..., + framework: _Optional[str] = ..., + ) -> None: ... diff --git a/protobuf/ta_testrun.proto b/protobuf/ta_testrun.proto new file mode 100644 index 000000000..c0530b875 --- /dev/null +++ b/protobuf/ta_testrun.proto @@ -0,0 +1,30 @@ +syntax = "proto2"; + +message TestRun { + optional int64 timestamp = 1; + optional string name = 2; + optional string classname = 3; + optional string testsuite = 4; + optional string computed_name = 5; + + enum Outcome { + PASSED = 0; + FAILED = 1; + SKIPPED = 2; + } + + optional Outcome outcome = 6; + + optional string failure_message = 7; + optional float duration_seconds = 8; + + optional int64 repoid = 10; + optional string commit_sha = 11; + + optional string branch_name = 12; + + repeated string flags = 13; + + optional string filename = 14; + optional string framework = 15; +} diff --git a/requirements.in b/requirements.in index 878718cbd..7782f3afa 100644 --- a/requirements.in +++ b/requirements.in @@ -12,6 +12,8 @@ coverage Django>=4.2.16 django-postgres-extra>=2.0.8 factory-boy +google-cloud-bigquery>=3.27.0 +google-cloud-bigquery-storage>=2.27.0 google-cloud-pubsub google-cloud-storage>=2.10.0 grpcio>=1.66.2 @@ -26,6 +28,7 @@ pre-commit polars==1.12.0 proto-plus>=1.25.0 psycopg2>=2.9.10 +protobuf>=5.29.2 pydantic>=2.9.0 PyJWT>=2.4.0 pytest diff --git a/requirements.txt b/requirements.txt index 4c1a38016..a8c9acd69 100644 --- a/requirements.txt +++ b/requirements.txt @@ -79,6 +79,10 @@ coverage==7.5.0 # pytest-cov cryptography==43.0.1 # via shared +deprecated==1.2.15 + # via + # opentelemetry-api + # opentelemetry-semantic-conventions distlib==0.3.7 # via virtualenv distro==1.8.0 @@ -109,18 +113,29 @@ freezegun==1.5.0 # via pytest-freezegun google-api-core==2.23.0 # via + # google-cloud-bigquery + # google-cloud-bigquery-storage # google-cloud-core # google-cloud-pubsub # google-cloud-storage google-auth==2.36.0 # via # google-api-core + # google-cloud-bigquery + # google-cloud-bigquery-storage # google-cloud-core + # google-cloud-pubsub # google-cloud-storage # shared -google-cloud-core==2.3.3 - # via google-cloud-storage -google-cloud-pubsub==2.18.4 +google-cloud-bigquery==3.27.0 + # via -r requirements.in +google-cloud-bigquery-storage==2.27.0 + # via -r requirements.in +google-cloud-core==2.4.1 + # via + # google-cloud-bigquery + # google-cloud-storage +google-cloud-pubsub==2.27.1 # via # -r requirements.in # shared @@ -133,13 +148,15 @@ google-crc32c==1.1.2 # google-cloud-storage # google-resumable-media google-resumable-media==2.7.2 - # via google-cloud-storage -googleapis-common-protos==1.59.1 + # via + # google-cloud-bigquery + # google-cloud-storage +googleapis-common-protos==1.66.0 # via # google-api-core # grpc-google-iam-v1 # grpcio-status -grpc-google-iam-v1==0.12.6 +grpc-google-iam-v1==0.14.0 # via google-cloud-pubsub grpcio==1.68.1 # via @@ -173,6 +190,8 @@ idna==3.7 # yarl ijson==3.2.3 # via shared +importlib-metadata==8.5.0 + # via opentelemetry-api iniconfig==1.1.1 # via pytest jinja2==3.1.4 @@ -205,12 +224,23 @@ oauthlib==3.1.0 # via shared openai==1.2.4 # via -r requirements.in +opentelemetry-api==1.29.0 + # via + # google-cloud-pubsub + # opentelemetry-sdk + # opentelemetry-semantic-conventions +opentelemetry-sdk==1.29.0 + # via google-cloud-pubsub +opentelemetry-semantic-conventions==0.50b0 + # via opentelemetry-sdk orjson==3.10.11 # via # -r requirements.in # shared packaging==24.1 - # via pytest + # via + # google-cloud-bigquery + # pytest platformdirs==3.11.0 # via virtualenv pluggy==1.5.0 @@ -229,10 +259,13 @@ proto-plus==1.25.0 # via # -r requirements.in # google-api-core + # google-cloud-bigquery-storage # google-cloud-pubsub -protobuf==4.24.3 +protobuf==5.29.2 # via + # -r requirements.in # google-api-core + # google-cloud-bigquery-storage # google-cloud-pubsub # googleapis-common-protos # grpc-google-iam-v1 @@ -293,6 +326,7 @@ python-dateutil==2.9.0.post0 # django-postgres-extra # faker # freezegun + # google-cloud-bigquery # time-machine python-json-logger==0.1.11 # via -r requirements.in @@ -320,6 +354,7 @@ requests==2.32.3 # -r requirements.in # analytics-python # google-api-core + # google-cloud-bigquery # google-cloud-storage # shared # stripe @@ -380,6 +415,7 @@ tqdm==4.66.1 typing-extensions==4.12.2 # via # openai + # opentelemetry-sdk # pydantic # pydantic-core # stripe @@ -404,9 +440,13 @@ virtualenv==20.24.5 wcwidth==0.2.5 # via prompt-toolkit wrapt==1.16.0 - # via vcrpy + # via + # deprecated + # vcrpy yarl==1.9.4 # via vcrpy +zipp==3.21.0 + # via importlib-metadata zstandard==0.23.0 # via # -r requirements.in diff --git a/services/bigquery.py b/services/bigquery.py new file mode 100644 index 000000000..4da0de60a --- /dev/null +++ b/services/bigquery.py @@ -0,0 +1,175 @@ +from types import ModuleType +from typing import Dict, List, cast + +import polars as pl +from google.api_core import retry +from google.cloud import bigquery +from google.cloud.bigquery_storage_v1 import BigQueryWriteClient, types +from google.cloud.bigquery_storage_v1.writer import AppendRowsStream +from google.oauth2.service_account import Credentials +from google.protobuf import descriptor_pb2 +from shared.config import get_config + + +class BigQueryService: + """ + Requires a table to be created with a time partitioning schema. + """ + + def __init__( + self, + gcp_config: dict[str, str], + ) -> None: + """Initialize BigQuery client with GCP credentials. + + Args: + gcp_config: Dictionary containing Google Cloud service account credentials + including project_id, private_key and other required fields + Raises: + google.api_core.exceptions.GoogleAPIError: If client initialization fails + ValueError: If required credentials are missing from gcp_config + google.auth.exceptions.DefaultCredentialsError: If credentials are not found + """ + self.credentials = Credentials.from_service_account_info(gcp_config) + + if not self.credentials.project_id: + raise ValueError("Project ID is not set") + + self.project_id = cast(str, self.credentials.project_id) + + self.client = bigquery.Client( + project=self.project_id, credentials=self.credentials + ) + + def write( + self, + dataset_id: str, + table_id: str, + proto_module: ModuleType, + data: list[bytes], + ) -> None: + """Write records to the BigQuery table using the Storage Write API. + Uses protobuf encoded models defined in the protobuf directory. + + Args: + table_id: Full table ID in format 'project.dataset.table' + data: List of already encoded proto2 bytes + + Raises: + google.api_core.exceptions.GoogleAPIError: If the API request fails + """ + + self.write_client = BigQueryWriteClient(credentials=self.credentials) + + parent = self.write_client.table_path(self.project_id, dataset_id, table_id) + + write_stream = types.WriteStream() + + write_stream.type_ = types.WriteStream.Type.PENDING + write_stream = self.write_client.create_write_stream( + parent=parent, write_stream=write_stream + ) + stream_name = write_stream.name + + request_template = types.AppendRowsRequest() + request_template.write_stream = stream_name + + proto_descriptor = descriptor_pb2.DescriptorProto() + proto_module.DESCRIPTOR.message_types_by_name.values()[0].CopyToProto( + proto_descriptor + ) + + proto_schema = types.ProtoSchema() + proto_schema.proto_descriptor = proto_descriptor + + proto_data = types.AppendRowsRequest.ProtoData() + proto_data.writer_schema = proto_schema + + request_template.proto_rows = proto_data + + append_rows_stream = AppendRowsStream(self.write_client, request_template) + + proto_rows = types.ProtoRows() + proto_rows.serialized_rows = data + + request = types.AppendRowsRequest() + request.offset = 0 + proto_data = types.AppendRowsRequest.ProtoData() + proto_data.rows = proto_rows + request.proto_rows = proto_data + + _ = append_rows_stream.send(request) + + self.write_client.finalize_write_stream(name=write_stream.name) + + batch_commit_write_streams_request = types.BatchCommitWriteStreamsRequest() + batch_commit_write_streams_request.parent = parent + batch_commit_write_streams_request.write_streams = [write_stream.name] + + self.write_client.batch_commit_write_streams(batch_commit_write_streams_request) + + def query(self, query: str, params: dict | None = None) -> List[Dict]: + """Execute a BigQuery SQL query and return results. + Try not to write INSERT statements and use the write method instead. + + Args: + query: SQL query string + params: Optional dict of query parameters + + Returns: + List of dictionaries containing the query results + + Raises: + google.api_core.exceptions.GoogleAPIError: If the query fails + """ + job_config = bigquery.QueryJobConfig() + + if params: + job_config.query_parameters = [ + bigquery.ScalarQueryParameter(k, "STRING", v) for k, v in params.items() + ] + + row_iterator = self.client.query_and_wait( + query, job_config=job_config, retry=retry.Retry(deadline=30) + ) + + return [dict(row.items()) for row in row_iterator] + + def query_polars( + self, + query: str, + params: dict | None = None, + schema: list[str | tuple[str, pl.DataType]] | None = None, + ) -> pl.DataFrame: + """Execute a BigQuery SQL query and return results. + Try not to write INSERT statements and use the write method instead. + + Args: + query: SQL query string + params: Optional dict of query parameters + + Returns: + List of dictionaries containing the query results + + Raises: + google.api_core.exceptions.GoogleAPIError: If the query fails + """ + job_config = bigquery.QueryJobConfig() + + if params: + job_config.query_parameters = [ + bigquery.ScalarQueryParameter(k, "STRING", v) for k, v in params.items() + ] + + row_iterator = self.client.query_and_wait( + query, job_config=job_config, retry=retry.Retry(deadline=30) + ) + + return pl.DataFrame( + (dict(row.items()) for row in row_iterator), schema=schema, orient="row" + ) + + +def get_bigquery_service(): + gcp_config: dict[str, str] = get_config("services", "gcp", default={}) + return BigQueryService(gcp_config) diff --git a/services/tests/test_bigquery.py b/services/tests/test_bigquery.py new file mode 100644 index 000000000..9a9198d69 --- /dev/null +++ b/services/tests/test_bigquery.py @@ -0,0 +1,148 @@ +import datetime as dt +from datetime import datetime + +import polars as pl +import pytest + +import generated_proto.testrun.ta_testrun_pb2 as ta_testrun_pb2 +from services.bigquery import BigQueryService + +fake_private_key = """-----BEGIN RSA PRIVATE KEY----- +MIICXAIBAAKBgQDCFqq2ygFh9UQU/6PoDJ6L9e4ovLPCHtlBt7vzDwyfwr3XGxln +0VbfycVLc6unJDVEGZ/PsFEuS9j1QmBTTEgvCLR6RGpfzmVuMO8wGVEO52pH73h9 +rviojaheX/u3ZqaA0di9RKy8e3L+T0ka3QYgDx5wiOIUu1wGXCs6PhrtEwICBAEC +gYBu9jsi0eVROozSz5dmcZxUAzv7USiUcYrxX007SUpm0zzUY+kPpWLeWWEPaddF +VONCp//0XU8hNhoh0gedw7ZgUTG6jYVOdGlaV95LhgY6yXaQGoKSQNNTY+ZZVT61 +zvHOlPynt3GZcaRJOlgf+3hBF5MCRoWKf+lDA5KiWkqOYQJBAMQp0HNVeTqz+E0O +6E0neqQDQb95thFmmCI7Kgg4PvkS5mz7iAbZa5pab3VuyfmvnVvYLWejOwuYSp0U +9N8QvUsCQQD9StWHaVNM4Lf5zJnB1+lJPTXQsmsuzWvF3HmBkMHYWdy84N/TdCZX +Cxve1LR37lM/Vijer0K77wAx2RAN/ppZAkB8+GwSh5+mxZKydyPaPN29p6nC6aLx +3DV2dpzmhD0ZDwmuk8GN+qc0YRNOzzJ/2UbHH9L/lvGqui8I6WLOi8nDAkEA9CYq +ewfdZ9LcytGz7QwPEeWVhvpm0HQV9moetFWVolYecqBP4QzNyokVnpeUOqhIQAwe +Z0FJEQ9VWsG+Df0noQJBALFjUUZEtv4x31gMlV24oiSWHxIRX4fEND/6LpjleDZ5 +C/tY+lZIEO1Gg/FxSMB+hwwhwfSuE3WohZfEcSy+R48= +-----END RSA PRIVATE KEY-----""" + +gcp_config = { + "type": "service_account", + "project_id": "codecov-dev", + "private_key_id": "testu7gvpfyaasze2lboblawjb3032mbfisy9gpg", + "private_key": fake_private_key, + "client_email": "localstoragetester@genuine-polymer-165712.iam.gserviceaccount.com", + "client_id": "110927033630051704865", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/localstoragetester%40genuine-polymer-165712.iam.gserviceaccount.com", +} + + +sql = """ +WITH sample_data AS ( + SELECT * FROM UNNEST([ + STRUCT(TIMESTAMP '2025-01-01T00:00:00Z' AS timestamp, 1 AS id, 'name' AS name), + STRUCT(TIMESTAMP '2024-12-30T00:00:00Z' AS timestamp, 2 AS id, 'name2' AS name) + ]) +) +SELECT * FROM sample_data +""" + + +@pytest.mark.skip(reason="This test requires being run using actual working creds") +def test_bigquery_service(): + bigquery_service = BigQueryService(gcp_config) + + results = bigquery_service.query(sql) + + assert len(results) == 2 + assert {row["timestamp"] for row in results} == { + datetime.fromisoformat("2025-01-01T00:00:00Z"), + datetime.fromisoformat("2024-12-30T00:00:00Z"), + } + assert {row["name"] for row in results} == {"name", "name2"} + assert {row["id"] for row in results} == {1, 2} + + +@pytest.mark.skip(reason="This test requires being run using actual working creds") +def test_bigquery_service_polars(): + bigquery_service = BigQueryService(gcp_config) + + results = bigquery_service.query_polars( + sql, + None, + [ + ("timestamp", pl.Datetime(time_zone=dt.UTC)), + "id", + "name", + ], + ) + + assert len(results) == 2 + assert {x for x in results["timestamp"].to_list()} == { + datetime.fromisoformat("2025-01-01T00:00:00Z"), + datetime.fromisoformat("2024-12-30T00:00:00Z"), + } + assert {x for x in results["name"].to_list()} == {"name", "name2"} + assert {x for x in results["id"].to_list()} == {1, 2} + + +# this test should only be run manually when making changes to the way we write to bigquery +# the reason it's not automated is because vcrpy does not seem to work with the gRPC requests +@pytest.mark.skip(reason="This test requires being run using actual working creds") +def test_bigquery_service_write(): + table_name = "codecov-dev.test_dataset.testruns" + bigquery_service = BigQueryService(gcp_config) + + bigquery_service.query(f"TRUNCATE TABLE `{table_name}`") + + data = [ + ta_testrun_pb2.TestRun( + timestamp=int( + datetime.fromisoformat("2025-01-01T00:00:00.000000Z").timestamp() + * 1000000 + ), + name="name", + classname="classname", + testsuite="testsuite", + computed_name="computed_name", + outcome=ta_testrun_pb2.TestRun.Outcome.PASSED, + failure_message="failure_message", + duration_seconds=1.0, + filename="filename", + ), + ta_testrun_pb2.TestRun( + timestamp=int( + datetime.fromisoformat("2024-12-30T00:00:00.000000Z").timestamp() + * 1000000 + ), + name="name2", + classname="classname2", + testsuite="testsuite2", + computed_name="computed_name2", + outcome=ta_testrun_pb2.TestRun.Outcome.FAILED, + failure_message="failure_message2", + duration_seconds=2.0, + filename="filename2", + ), + ] + + serialized_data = [row.SerializeToString() for row in data] + + bigquery_service.write( + "test_dataset", + "testruns", + ta_testrun_pb2, + serialized_data, + ) + + results = bigquery_service.query(f"SELECT * FROM `{table_name}`") + + assert len(results) == 2 + + assert {row["timestamp"] for row in results} == set( + [ + datetime.fromisoformat("2025-01-01T00:00:00Z"), + datetime.fromisoformat("2024-12-30T00:00:00Z"), + ] + ) + assert {row["name"] for row in results} == set(["name", "name2"]) diff --git a/ta_storage/__init__.py b/ta_storage/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ta_storage/base.py b/ta_storage/base.py new file mode 100644 index 000000000..f8b3a4629 --- /dev/null +++ b/ta_storage/base.py @@ -0,0 +1,20 @@ +from abc import ABC, abstractmethod + +from test_results_parser import Testrun + +from database.models.reports import Upload + + +class TADriver(ABC): + @abstractmethod + def write_testruns( + self, + timestamp: int, + repo_id: int, + commit_sha: str, + branch_name: str, + upload: Upload, + framework: str | None, + testruns: list[Testrun], + ): + pass diff --git a/ta_storage/bq.py b/ta_storage/bq.py new file mode 100644 index 000000000..6617e31f4 --- /dev/null +++ b/ta_storage/bq.py @@ -0,0 +1,85 @@ +from datetime import datetime +from typing import Literal, TypedDict, cast + +from shared.config import get_config +from test_results_parser import Testrun + +import generated_proto.testrun.ta_testrun_pb2 as ta_testrun_pb2 +from database.models.reports import Upload +from services.bigquery import get_bigquery_service +from ta_storage.base import TADriver + +DATASET_NAME: str = cast( + str, get_config("services", "bigquery", "dataset_name", default="codecov_prod") +) + +TESTRUN_TABLE_NAME: str = cast( + str, get_config("services", "bigquery", "testrun_table_name", default="testruns") +) + + +def outcome_to_int( + outcome: Literal["pass", "skip", "failure", "error"], +) -> ta_testrun_pb2.TestRun.Outcome: + match outcome: + case "pass": + return ta_testrun_pb2.TestRun.Outcome.PASSED + case "skip": + return ta_testrun_pb2.TestRun.Outcome.SKIPPED + case "failure" | "error": + return ta_testrun_pb2.TestRun.Outcome.FAILED + case _: + raise ValueError(f"Invalid outcome: {outcome}") + + +class TransformedTestrun(TypedDict): + name: str + classname: str + testsuite: str + computed_name: str + outcome: int + failure_message: str + duration: float + filename: str + + +class BQDriver(TADriver): + def write_testruns( + self, + timestamp: int | None, + repo_id: int, + commit_sha: str, + branch_name: str, + upload: Upload, + framework: str | None, + testruns: list[Testrun], + ): + bq_service = get_bigquery_service() + + if timestamp is None: + timestamp = int(datetime.now().timestamp() * 1000000) + + flag_names = upload.flag_names + testruns_pb: list[bytes] = [] + + for t in testruns: + test_run = ta_testrun_pb2.TestRun( + timestamp=timestamp, + repoid=repo_id, + commit_sha=commit_sha, + framework=framework, + branch_name=branch_name, + flags=list(flag_names), + classname=t["classname"], + name=t["name"], + testsuite=t["testsuite"], + computed_name=t["computed_name"], + outcome=outcome_to_int(t["outcome"]), + failure_message=t["failure_message"], + duration_seconds=t["duration"], + filename=t["filename"], + ) + testruns_pb.append(test_run.SerializeToString()) + flag_names = upload.flag_names + + bq_service.write(DATASET_NAME, TESTRUN_TABLE_NAME, ta_testrun_pb2, testruns_pb) diff --git a/ta_storage/pg.py b/ta_storage/pg.py new file mode 100644 index 000000000..becf4fd7b --- /dev/null +++ b/ta_storage/pg.py @@ -0,0 +1,331 @@ +from datetime import date, datetime +from typing import Any, Literal, TypedDict + +from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.orm import Session +from test_results_parser import Testrun + +from database.models import ( + DailyTestRollup, + RepositoryFlag, + Test, + TestFlagBridge, + TestInstance, + Upload, +) +from services.test_results import generate_flags_hash, generate_test_id +from ta_storage.base import TADriver + + +class DailyTotals(TypedDict): + test_id: str + repoid: int + pass_count: int + fail_count: int + skip_count: int + flaky_fail_count: int + branch: str + date: date + latest_run: datetime + commits_where_fail: list[str] + last_duration_seconds: float + avg_duration_seconds: float + + +def get_repo_flag_ids(db_session: Session, repoid: int, flags: list[str]) -> set[int]: + if not flags: + return set() + + return set( + db_session.query(RepositoryFlag.id_) + .filter( + RepositoryFlag.repository_id == repoid, + RepositoryFlag.flag_name.in_(flags), + ) + .all() + ) + + +def modify_structures( + tests_to_write: dict[str, dict[str, Any]], + test_instances_to_write: list[dict[str, Any]], + test_flag_bridge_data: list[dict], + daily_totals: dict[str, DailyTotals], + testrun: Testrun, + upload: Upload, + repoid: int, + branch: str | None, + commit_sha: str, + repo_flag_ids: set[int], + flaky_test_set: set[str], + framework: str | None, +): + flags_hash = generate_flags_hash(upload.flag_names) + test_id = generate_test_id( + repoid, + testrun["testsuite"], + testrun["name"], + flags_hash, + ) + + test = generate_test_dict(test_id, repoid, testrun, flags_hash, framework) + tests_to_write[test_id] = test + + test_instance = generate_test_instance_dict( + test_id, upload, testrun, commit_sha, branch, repoid + ) + test_instances_to_write.append(test_instance) + + if repo_flag_ids: + test_flag_bridge_data.extend( + {"test_id": test_id, "flag_id": flag_id} for flag_id in repo_flag_ids + ) + + if test["id"] in daily_totals: + update_daily_totals( + daily_totals, + test["id"], + testrun["duration"], + testrun["outcome"], + ) + else: + create_daily_totals( + daily_totals, + test_id, + repoid, + testrun["duration"], + testrun["outcome"], + branch, + commit_sha, + flaky_test_set, + ) + + +def generate_test_dict( + test_id: str, + repoid: int, + testrun: Testrun, + flags_hash: str, + framework: str | None, +) -> dict[str, Any]: + return { + "id": test_id, + "repoid": repoid, + "name": f"{testrun['classname']}\x1f{testrun['name']}", + "testsuite": testrun["testsuite"], + "flags_hash": flags_hash, + "framework": framework, + "filename": testrun["filename"], + "computed_name": testrun["computed_name"], + } + + +def generate_test_instance_dict( + test_id: str, + upload: Upload, + testrun: Testrun, + commit_sha: str, + branch: str | None, + repoid: int, +) -> dict[str, Any]: + return { + "test_id": test_id, + "upload_id": upload.id, + "duration_seconds": testrun["duration"], + "outcome": testrun["outcome"], + "failure_message": testrun["failure_message"], + "commitid": commit_sha, + "branch": branch, + "reduced_error_id": None, + "repoid": repoid, + } + + +def update_daily_totals( + daily_totals: dict, + test_id: str, + duration_seconds: float | None, + outcome: Literal["pass", "failure", "error", "skip"], +): + daily_totals[test_id]["last_duration_seconds"] = duration_seconds + + # logic below is a little complicated but we're basically doing: + + # (old_avg * num of values used to compute old avg) + new value + # ------------------------------------------------------------- + # num of values used to compute old avg + 1 + if ( + duration_seconds is not None + and daily_totals[test_id]["avg_duration_seconds"] is not None + ): + daily_totals[test_id]["avg_duration_seconds"] = ( + daily_totals[test_id]["avg_duration_seconds"] + * ( + daily_totals[test_id]["pass_count"] + + daily_totals[test_id]["fail_count"] + ) + + duration_seconds + ) / ( + daily_totals[test_id]["pass_count"] + + daily_totals[test_id]["fail_count"] + + 1 + ) + + if outcome == "pass": + daily_totals[test_id]["pass_count"] += 1 + elif outcome == "failure" or outcome == "error": + daily_totals[test_id]["fail_count"] += 1 + elif outcome == "skip": + daily_totals[test_id]["skip_count"] += 1 + + +def create_daily_totals( + daily_totals: dict, + test_id: str, + repoid: int, + duration_seconds: float | None, + outcome: Literal["pass", "failure", "error", "skip"], + branch: str | None, + commit_sha: str, + flaky_test_set: set[str], +): + daily_totals[test_id] = { + "test_id": test_id, + "repoid": repoid, + "last_duration_seconds": duration_seconds, + "avg_duration_seconds": duration_seconds, + "pass_count": 1 if outcome == "pass" else 0, + "fail_count": 1 if outcome == "failure" or outcome == "error" else 0, + "skip_count": 1 if outcome == "skip" else 0, + "flaky_fail_count": 1 + if test_id in flaky_test_set and (outcome == "failure" or outcome == "error") + else 0, + "branch": branch, + "date": date.today(), + "latest_run": datetime.now(), + "commits_where_fail": [commit_sha] + if (outcome == "failure" or outcome == "error") + else [], + } + + +def save_tests(db_session: Session, tests_to_write: dict[str, dict[str, Any]]): + test_data = sorted( + tests_to_write.values(), + key=lambda x: str(x["id"]), + ) + + test_insert = insert(Test.__table__).values(test_data) + insert_on_conflict_do_update = test_insert.on_conflict_do_update( + index_elements=["repoid", "name", "testsuite", "flags_hash"], + set_={ + "framework": test_insert.excluded.framework, + "computed_name": test_insert.excluded.computed_name, + "filename": test_insert.excluded.filename, + }, + ) + db_session.execute(insert_on_conflict_do_update) + db_session.commit() + + +def save_test_flag_bridges(db_session: Session, test_flag_bridge_data: list[dict]): + insert_on_conflict_do_nothing_flags = ( + insert(TestFlagBridge.__table__) + .values(test_flag_bridge_data) + .on_conflict_do_nothing(index_elements=["test_id", "flag_id"]) + ) + db_session.execute(insert_on_conflict_do_nothing_flags) + db_session.commit() + + +def save_daily_test_rollups(db_session: Session, daily_rollups: dict[str, DailyTotals]): + sorted_rollups = sorted(daily_rollups.values(), key=lambda x: str(x["test_id"])) + rollup_table = DailyTestRollup.__table__ + stmt = insert(rollup_table).values(sorted_rollups) + stmt = stmt.on_conflict_do_update( + index_elements=[ + "repoid", + "branch", + "test_id", + "date", + ], + set_={ + "last_duration_seconds": stmt.excluded.last_duration_seconds, + "avg_duration_seconds": ( + rollup_table.c.avg_duration_seconds + * (rollup_table.c.pass_count + rollup_table.c.fail_count) + + stmt.excluded.avg_duration_seconds + ) + / (rollup_table.c.pass_count + rollup_table.c.fail_count + 1), + "latest_run": stmt.excluded.latest_run, + "pass_count": rollup_table.c.pass_count + stmt.excluded.pass_count, + "skip_count": rollup_table.c.skip_count + stmt.excluded.skip_count, + "fail_count": rollup_table.c.fail_count + stmt.excluded.fail_count, + "flaky_fail_count": rollup_table.c.flaky_fail_count + + stmt.excluded.flaky_fail_count, + "commits_where_fail": rollup_table.c.commits_where_fail + + stmt.excluded.commits_where_fail, + }, + ) + db_session.execute(stmt) + db_session.commit() + + +def save_test_instances(db_session: Session, test_instance_data: list[dict]): + insert_test_instances = insert(TestInstance.__table__).values(test_instance_data) + db_session.execute(insert_test_instances) + db_session.commit() + + +class PGDriver(TADriver): + def __init__(self, db_session: Session, flaky_test_set: set): + self.db_session = db_session + self.flaky_test_set = flaky_test_set + + def write_testruns( + self, + timestamp: int | None, + repo_id: int, + commit_sha: str, + branch_name: str, + upload: Upload, + framework: str | None, + testruns: list[Testrun], + ): + tests_to_write: dict[str, dict[str, Any]] = {} + test_instances_to_write: list[dict[str, Any]] = [] + daily_totals: dict[str, DailyTotals] = dict() + test_flag_bridge_data: list[dict] = [] + + repo_flag_ids = get_repo_flag_ids(self.db_session, repo_id, upload.flag_names) + + for testrun in testruns: + modify_structures( + tests_to_write, + test_instances_to_write, + test_flag_bridge_data, + daily_totals, + testrun, + upload, + repo_id, + branch_name, + commit_sha, + repo_flag_ids, + self.flaky_test_set, + framework, + ) + + if len(tests_to_write) > 0: + save_tests(self.db_session, tests_to_write) + + if len(test_flag_bridge_data) > 0: + save_test_flag_bridges(self.db_session, test_flag_bridge_data) + + if len(daily_totals) > 0: + save_daily_test_rollups(self.db_session, daily_totals) + + if len(test_instances_to_write) > 0: + save_test_instances(self.db_session, test_instances_to_write) + + upload.state = "v2_persisted" + self.db_session.commit() diff --git a/ta_storage/tests/test_bq.py b/ta_storage/tests/test_bq.py new file mode 100644 index 000000000..43f20e83a --- /dev/null +++ b/ta_storage/tests/test_bq.py @@ -0,0 +1,117 @@ +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest +from test_results_parser import Testrun + +import generated_proto.testrun.ta_testrun_pb2 as ta_testrun_pb2 +from database.tests.factories import RepositoryFlagFactory, UploadFactory +from ta_storage.bq import DATASET_NAME, TESTRUN_TABLE_NAME, BQDriver + + +@pytest.fixture +def mock_bigquery_service(): + with patch("ta_storage.bq.get_bigquery_service") as mock: + service = MagicMock() + mock.return_value = service + yield service + + +def test_bigquery_driver(dbsession, mock_bigquery_service): + bq = BQDriver() + + upload = UploadFactory() + dbsession.add(upload) + dbsession.flush() + + repo_flag_1 = RepositoryFlagFactory( + repository=upload.report.commit.repository, flag_name="flag1" + ) + repo_flag_2 = RepositoryFlagFactory( + repository=upload.report.commit.repository, flag_name="flag2" + ) + dbsession.add(repo_flag_1) + dbsession.add(repo_flag_2) + dbsession.flush() + + upload.flags.append(repo_flag_1) + upload.flags.append(repo_flag_2) + dbsession.flush() + + test_data: list[Testrun] = [ + { + "name": "test_name", + "classname": "test_class", + "testsuite": "test_suite", + "duration": 100.0, + "outcome": "pass", + "build_url": "https://example.com/build/123", + "filename": "test_file", + "computed_name": "test_computed_name", + "failure_message": None, + }, + { + "name": "test_name2", + "classname": "test_class2", + "testsuite": "test_suite2", + "duration": 100.0, + "outcome": "failure", + "build_url": "https://example.com/build/123", + "filename": "test_file2", + "computed_name": "test_computed_name2", + "failure_message": "test_failure_message", + }, + ] + + timestamp = int(datetime.now().timestamp() * 1000000) + + bq.write_testruns( + timestamp, + upload.report.commit.repoid, + upload.report.commit.commitid, + upload.report.commit.branch, + upload, + "pytest", + test_data, + ) + + # Verify the BigQuery service was called correctly + mock_bigquery_service.write.assert_called_once_with( + DATASET_NAME, + TESTRUN_TABLE_NAME, + ta_testrun_pb2, + [ + ta_testrun_pb2.TestRun( + timestamp=timestamp, + name="test_name", + classname="test_class", + testsuite="test_suite", + duration_seconds=100.0, + outcome=ta_testrun_pb2.TestRun.Outcome.PASSED, + filename="test_file", + computed_name="test_computed_name", + failure_message=None, + repoid=upload.report.commit.repoid, + commit_sha=upload.report.commit.commitid, + framework="pytest", + branch_name=upload.report.commit.branch, + flags=["flag1", "flag2"], + ).SerializeToString(), + ta_testrun_pb2.TestRun( + timestamp=timestamp, + name="test_name2", + classname="test_class2", + testsuite="test_suite2", + duration_seconds=100.0, + outcome=ta_testrun_pb2.TestRun.Outcome.FAILED, + filename="test_file2", + computed_name="test_computed_name2", + failure_message="test_failure_message", + repoid=upload.report.commit.repoid, + commit_sha=upload.report.commit.commitid, + framework="pytest", + branch_name=upload.report.commit.branch, + flags=["flag1", "flag2"], + ).SerializeToString(), + ], + ) diff --git a/ta_storage/tests/test_pg.py b/ta_storage/tests/test_pg.py new file mode 100644 index 000000000..370cd1b42 --- /dev/null +++ b/ta_storage/tests/test_pg.py @@ -0,0 +1,63 @@ +from database.models import DailyTestRollup, Test, TestFlagBridge, TestInstance +from database.tests.factories import RepositoryFlagFactory, UploadFactory +from ta_storage.pg import PGDriver + + +def test_pg_driver(dbsession): + pg = PGDriver(dbsession, set()) + + upload = UploadFactory() + dbsession.add(upload) + dbsession.flush() + + repo_flag_1 = RepositoryFlagFactory( + repository=upload.report.commit.repository, flag_name="flag1" + ) + repo_flag_2 = RepositoryFlagFactory( + repository=upload.report.commit.repository, flag_name="flag2" + ) + dbsession.add(repo_flag_1) + dbsession.add(repo_flag_2) + dbsession.flush() + + upload.flags.append(repo_flag_1) + upload.flags.append(repo_flag_2) + dbsession.flush() + + pg.write_testruns( + None, + upload.report.commit.repoid, + upload.report.commit.id, + upload.report.commit.branch, + upload, + "pytest", + [ + { + "name": "test_name", + "classname": "test_class", + "testsuite": "test_suite", + "duration": 100.0, + "outcome": "pass", + "build_url": "https://example.com/build/123", + "filename": "test_file", + "computed_name": "test_computed_name", + "failure_message": None, + }, + { + "name": "test_name2", + "classname": "test_class2", + "testsuite": "test_suite2", + "duration": 100.0, + "outcome": "failure", + "build_url": "https://example.com/build/123", + "filename": "test_file2", + "computed_name": "test_computed_name2", + "failure_message": "test_failure_message", + }, + ], + ) + + assert dbsession.query(Test).count() == 2 + assert dbsession.query(TestInstance).count() == 2 + assert dbsession.query(TestFlagBridge).count() == 4 + assert dbsession.query(DailyTestRollup).count() == 2