From bc4ac8a01163fb8518061b175679c9501caf160b Mon Sep 17 00:00:00 2001 From: aldbr Date: Tue, 3 Dec 2024 08:49:11 +0100 Subject: [PATCH] feat(consistency): make SQL Alchemy interfaces consistent --- diracx-db/src/diracx/db/sql/auth/db.py | 18 +-- diracx-db/src/diracx/db/sql/auth/schema.py | 46 ++++---- diracx-db/src/diracx/db/sql/dummy/schema.py | 12 +- diracx-db/src/diracx/db/sql/job/db.py | 22 ++-- diracx-db/src/diracx/db/sql/job/schema.py | 108 +++++++++--------- diracx-db/src/diracx/db/sql/job_logging/db.py | 44 +++---- .../src/diracx/db/sql/job_logging/schema.py | 18 +-- .../src/diracx/db/sql/pilot_agents/schema.py | 44 +++---- diracx-db/src/diracx/db/sql/utils/__init__.py | 4 +- .../tests/auth/test_authorization_flow.py | 2 +- diracx-db/tests/auth/test_device_flow.py | 6 +- diracx-db/tests/auth/test_refresh_token.py | 25 ++-- diracx-db/tests/test_dummy_db.py | 30 ++--- .../src/diracx/routers/auth/management.py | 2 +- .../src/diracx/routers/auth/token.py | 28 ++--- docs/CODING_CONVENTION.md | 22 ++++ 16 files changed, 230 insertions(+), 201 deletions(-) diff --git a/diracx-db/src/diracx/db/sql/auth/db.py b/diracx-db/src/diracx/db/sql/auth/db.py index cd8d1d03..b587f869 100644 --- a/diracx-db/src/diracx/db/sql/auth/db.py +++ b/diracx-db/src/diracx/db/sql/auth/db.py @@ -58,7 +58,7 @@ async def get_device_flow(self, device_code: str, max_validity: int): stmt = select( DeviceFlows, (DeviceFlows.creation_time < substract_date(seconds=max_validity)).label( - "is_expired" + "IsExpired" ), ).with_for_update() stmt = stmt.where( @@ -66,10 +66,10 @@ async def get_device_flow(self, device_code: str, max_validity: int): ) res = dict((await self.conn.execute(stmt)).one()._mapping) - if res["is_expired"]: + if res["IsExpired"]: raise ExpiredFlowError() - if res["status"] == FlowStatus.READY: + if res["Status"] == FlowStatus.READY: # Update the status to Done before returning await self.conn.execute( update(DeviceFlows) @@ -81,10 +81,10 @@ async def get_device_flow(self, device_code: str, max_validity: int): ) return res - if res["status"] == FlowStatus.DONE: + if res["Status"] == FlowStatus.DONE: raise AuthorizationError("Code was already used") - if res["status"] == FlowStatus.PENDING: + if res["Status"] == FlowStatus.PENDING: raise PendingAuthorizationError() raise AuthorizationError("Bad state in device flow") @@ -190,7 +190,7 @@ async def authorization_flow_insert_id_token( stmt = select(AuthorizationFlows.code, AuthorizationFlows.redirect_uri) stmt = stmt.where(AuthorizationFlows.uuid == uuid) row = (await self.conn.execute(stmt)).one() - return code, row.redirect_uri + return code, row.RedirectURI async def get_authorization_flow(self, code: str, max_validity: int): hashed_code = hashlib.sha256(code.encode()).hexdigest() @@ -205,7 +205,7 @@ async def get_authorization_flow(self, code: str, max_validity: int): res = dict((await self.conn.execute(stmt)).one()._mapping) - if res["status"] == FlowStatus.READY: + if res["Status"] == FlowStatus.READY: # Update the status to Done before returning await self.conn.execute( update(AuthorizationFlows) @@ -215,7 +215,7 @@ async def get_authorization_flow(self, code: str, max_validity: int): return res - if res["status"] == FlowStatus.DONE: + if res["Status"] == FlowStatus.DONE: raise AuthorizationError("Code was already used") raise AuthorizationError("Bad state in authorization flow") @@ -247,7 +247,7 @@ async def insert_refresh_token( row = (await self.conn.execute(stmt)).one() # Return the JWT ID and the creation time - return jti, row.creation_time + return jti, row.CreationTime async def get_refresh_token(self, jti: str) -> dict: """Get refresh token details bound to a given JWT ID.""" diff --git a/diracx-db/src/diracx/db/sql/auth/schema.py b/diracx-db/src/diracx/db/sql/auth/schema.py index b6efbede..8d7dddc7 100644 --- a/diracx-db/src/diracx/db/sql/auth/schema.py +++ b/diracx-db/src/diracx/db/sql/auth/schema.py @@ -39,27 +39,27 @@ class FlowStatus(Enum): class DeviceFlows(Base): __tablename__ = "DeviceFlows" - user_code = Column(String(USER_CODE_LENGTH), primary_key=True) - status = EnumColumn(FlowStatus, server_default=FlowStatus.PENDING.name) - creation_time = DateNowColumn() - client_id = Column(String(255)) - scope = Column(String(1024)) - device_code = Column(String(128), unique=True) # Should be a hash - id_token = NullColumn(JSON()) + user_code = Column("UserCode", String(USER_CODE_LENGTH), primary_key=True) + status = EnumColumn("Status", FlowStatus, server_default=FlowStatus.PENDING.name) + creation_time = DateNowColumn("CreationTime") + client_id = Column("ClientID", String(255)) + scope = Column("Scope", String(1024)) + device_code = Column("DeviceCode", String(128), unique=True) # Should be a hash + id_token = NullColumn("IDToken", JSON()) class AuthorizationFlows(Base): __tablename__ = "AuthorizationFlows" - uuid = Column(Uuid(as_uuid=False), primary_key=True) - status = EnumColumn(FlowStatus, server_default=FlowStatus.PENDING.name) - client_id = Column(String(255)) - creation_time = DateNowColumn() - scope = Column(String(1024)) - code_challenge = Column(String(255)) - code_challenge_method = Column(String(8)) - redirect_uri = Column(String(255)) - code = NullColumn(String(255)) # Should be a hash - id_token = NullColumn(JSON()) + uuid = Column("UUID", Uuid(as_uuid=False), primary_key=True) + status = EnumColumn("Status", FlowStatus, server_default=FlowStatus.PENDING.name) + client_id = Column("ClientID", String(255)) + creation_time = DateNowColumn("CretionTime") + scope = Column("Scope", String(1024)) + code_challenge = Column("CodeChallenge", String(255)) + code_challenge_method = Column("CodeChallengeMethod", String(8)) + redirect_uri = Column("RedirectURI", String(255)) + code = NullColumn("Code", String(255)) # Should be a hash + id_token = NullColumn("IDToken", JSON()) class RefreshTokenStatus(Enum): @@ -85,13 +85,13 @@ class RefreshTokens(Base): __tablename__ = "RefreshTokens" # Refresh token attributes - jti = Column(Uuid(as_uuid=False), primary_key=True) + jti = Column("JTI", Uuid(as_uuid=False), primary_key=True) status = EnumColumn( - RefreshTokenStatus, server_default=RefreshTokenStatus.CREATED.name + "Status", RefreshTokenStatus, server_default=RefreshTokenStatus.CREATED.name ) - creation_time = DateNowColumn() - scope = Column(String(1024)) + creation_time = DateNowColumn("CreationTime") + scope = Column("Scope", String(1024)) # User attributes bound to the refresh token - sub = Column(String(1024)) - preferred_username = Column(String(255)) + sub = Column("Sub", String(1024)) + preferred_username = Column("PreferredUsername", String(255)) diff --git a/diracx-db/src/diracx/db/sql/dummy/schema.py b/diracx-db/src/diracx/db/sql/dummy/schema.py index b6ddde79..a0c11c09 100644 --- a/diracx-db/src/diracx/db/sql/dummy/schema.py +++ b/diracx-db/src/diracx/db/sql/dummy/schema.py @@ -10,13 +10,13 @@ class Owners(Base): __tablename__ = "Owners" - owner_id = Column(Integer, primary_key=True, autoincrement=True) - creation_time = DateNowColumn() - name = Column(String(255)) + owner_id = Column("OwnerID", Integer, primary_key=True, autoincrement=True) + creation_time = DateNowColumn("CreationTime") + name = Column("Name", String(255)) class Cars(Base): __tablename__ = "Cars" - license_plate = Column(Uuid(), primary_key=True) - model = Column(String(255)) - owner_id = Column(Integer, ForeignKey(Owners.owner_id)) + license_plate = Column("LicensePlate", Uuid(), primary_key=True) + model = Column("Model", String(255)) + owner_id = Column("OwnerID", Integer, ForeignKey(Owners.owner_id)) diff --git a/diracx-db/src/diracx/db/sql/job/db.py b/diracx-db/src/diracx/db/sql/job/db.py index 9542e57a..3597079b 100644 --- a/diracx-db/src/diracx/db/sql/job/db.py +++ b/diracx-db/src/diracx/db/sql/job/db.py @@ -58,7 +58,7 @@ class JobDB(BaseSQLDB): async def summary(self, group_by, search) -> list[dict[str, str | int]]: columns = _get_columns(Jobs.__table__, group_by) - stmt = select(*columns, func.count(Jobs.JobID).label("count")) + stmt = select(*columns, func.count(Jobs.job_id).label("count")) stmt = apply_search_filters(Jobs.__table__.columns.__getitem__, stmt, search) stmt = stmt.group_by(*columns) @@ -111,7 +111,7 @@ async def _insert_new_jdl(self, jdl) -> int: from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL stmt = insert(JobJDLs).values( - JDL="", JobRequirements="", OriginalJDL=compressJDL(jdl) + jdl="", job_requirements="", original_jdl=compressJDL(jdl) ) result = await self.conn.execute(stmt) # await self.engine.commit() @@ -129,7 +129,7 @@ async def set_job_attributes(self, job_id, job_data): """TODO: add myDate and force parameters.""" if "Status" in job_data: job_data = job_data | {"LastUpdateTime": datetime.now(tz=timezone.utc)} - stmt = update(Jobs).where(Jobs.JobID == job_id).values(job_data) + stmt = update(Jobs).where(Jobs.job_id == job_id).values(job_data) await self.conn.execute(stmt) async def _check_and_prepare_job( @@ -171,7 +171,7 @@ async def set_job_jdl(self, job_id, jdl): from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL stmt = ( - update(JobJDLs).where(JobJDLs.JobID == job_id).values(JDL=compressJDL(jdl)) + update(JobJDLs).where(JobJDLs.job_id == job_id).values(JDL=compressJDL(jdl)) ) await self.conn.execute(stmt) @@ -179,9 +179,9 @@ async def get_job_jdl(self, job_id: int, original: bool = False) -> str: from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import extractJDL if original: - stmt = select(JobJDLs.OriginalJDL).where(JobJDLs.JobID == job_id) + stmt = select(JobJDLs.original_jdl).where(JobJDLs.job_id == job_id) else: - stmt = select(JobJDLs.JDL).where(JobJDLs.JobID == job_id) + stmt = select(JobJDLs.jdl).where(JobJDLs.job_id == job_id) jdl = (await self.conn.execute(stmt)).scalar_one() if jdl: @@ -431,9 +431,9 @@ async def reschedule_job(self, job_id) -> dict[str, Any]: async def get_job_status(self, job_id: int) -> LimitedJobStatusReturn: try: - stmt = select(Jobs.Status, Jobs.MinorStatus, Jobs.ApplicationStatus).where( - Jobs.JobID == job_id - ) + stmt = select( + Jobs.status, Jobs.minor_status, Jobs.application_status + ).where(Jobs.job_id == job_id) return LimitedJobStatusReturn( **dict((await self.conn.execute(stmt)).one()._mapping) ) @@ -455,7 +455,7 @@ async def set_job_command(self, job_id: int, command: str, arguments: str = ""): async def delete_jobs(self, job_ids: list[int]): """Delete jobs from the database.""" - stmt = delete(JobJDLs).where(JobJDLs.JobID.in_(job_ids)) + stmt = delete(JobJDLs).where(JobJDLs.job_id.in_(job_ids)) await self.conn.execute(stmt) async def set_properties( @@ -488,7 +488,7 @@ async def set_properties( if update_timestamp: values["LastUpdateTime"] = datetime.now(tz=timezone.utc) - stmt = update(Jobs).where(Jobs.JobID == bindparam("job_id")).values(**values) + stmt = update(Jobs).where(Jobs.job_id == bindparam("job_id")).values(**values) rows = await self.conn.execute(stmt, update_parameters) return rows.rowcount diff --git a/diracx-db/src/diracx/db/sql/job/schema.py b/diracx-db/src/diracx/db/sql/job/schema.py index d17edf2d..eea1e3a1 100644 --- a/diracx-db/src/diracx/db/sql/job/schema.py +++ b/diracx-db/src/diracx/db/sql/job/schema.py @@ -17,34 +17,34 @@ class Jobs(JobDBBase): __tablename__ = "Jobs" - JobID = Column( + job_id = Column( "JobID", Integer, ForeignKey("JobJDLs.JobID", ondelete="CASCADE"), primary_key=True, default=0, ) - JobType = Column("JobType", String(32), default="user") - JobGroup = Column("JobGroup", String(32), default="00000000") - Site = Column("Site", String(100), default="ANY") - JobName = Column("JobName", String(128), default="Unknown") - Owner = Column("Owner", String(64), default="Unknown") - OwnerGroup = Column("OwnerGroup", String(128), default="Unknown") - VO = Column("VO", String(32)) - SubmissionTime = NullColumn("SubmissionTime", DateTime) - RescheduleTime = NullColumn("RescheduleTime", DateTime) - LastUpdateTime = NullColumn("LastUpdateTime", DateTime) - StartExecTime = NullColumn("StartExecTime", DateTime) - HeartBeatTime = NullColumn("HeartBeatTime", DateTime) - EndExecTime = NullColumn("EndExecTime", DateTime) - Status = Column("Status", String(32), default="Received") - MinorStatus = Column("MinorStatus", String(128), default="Unknown") - ApplicationStatus = Column("ApplicationStatus", String(255), default="Unknown") - UserPriority = Column("UserPriority", Integer, default=0) - RescheduleCounter = Column("RescheduleCounter", Integer, default=0) - VerifiedFlag = Column("VerifiedFlag", EnumBackedBool(), default=False) + job_type = Column("JobType", String(32), default="user") + job_group = Column("JobGroup", String(32), default="00000000") + site = Column("Site", String(100), default="ANY") + job_name = Column("JobName", String(128), default="Unknown") + owner = Column("Owner", String(64), default="Unknown") + owner_group = Column("OwnerGroup", String(128), default="Unknown") + vo = Column("VO", String(32)) + submission_time = NullColumn("SubmissionTime", DateTime) + reschedule_time = NullColumn("RescheduleTime", DateTime) + last_update_time = NullColumn("LastUpdateTime", DateTime) + start_exec_time = NullColumn("StartExecTime", DateTime) + heart_beat_time = NullColumn("HeartBeatTime", DateTime) + end_exec_time = NullColumn("EndExecTime", DateTime) + status = Column("Status", String(32), default="Received") + minor_status = Column("MinorStatus", String(128), default="Unknown") + application_status = Column("ApplicationStatus", String(255), default="Unknown") + user_priority = Column("UserPriority", Integer, default=0) + reschedule_counter = Column("RescheduleCounter", Integer, default=0) + verified_flag = Column("VerifiedFlag", EnumBackedBool(), default=False) # TODO: Should this be True/False/"Failed"? Or True/False/Null? - AccountedFlag = Column( + accounted_flag = Column( "AccountedFlag", Enum("True", "False", "Failed"), default="False" ) @@ -64,66 +64,66 @@ class Jobs(JobDBBase): class JobJDLs(JobDBBase): __tablename__ = "JobJDLs" - JobID = Column(Integer, autoincrement=True, primary_key=True) - JDL = Column(Text) - JobRequirements = Column(Text) - OriginalJDL = Column(Text) + job_id = Column("JobID", Integer, autoincrement=True, primary_key=True) + jdl = Column("JDL", Text) + job_requirements = Column("JobRequirements", Text) + original_jdl = Column("OriginalJDL", Text) class InputData(JobDBBase): __tablename__ = "InputData" - JobID = Column( - Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + job_id = Column( + "JobID", Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True ) - LFN = Column(String(255), default="", primary_key=True) - Status = Column(String(32), default="AprioriGood") + lfn = Column("LFN", String(255), default="", primary_key=True) + status = Column("Status", String(32), default="AprioriGood") class JobParameters(JobDBBase): __tablename__ = "JobParameters" - JobID = Column( - Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + job_id = Column( + "JobID", Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True ) - Name = Column(String(100), primary_key=True) - Value = Column(Text) + name = Column("Name", String(100), primary_key=True) + value = Column("Value", Text) class OptimizerParameters(JobDBBase): __tablename__ = "OptimizerParameters" - JobID = Column( - Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + job_id = Column( + "JobID", Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True ) - Name = Column(String(100), primary_key=True) - Value = Column(Text) + name = Column("Name", String(100), primary_key=True) + value = Column("Value", Text) class AtticJobParameters(JobDBBase): __tablename__ = "AtticJobParameters" - JobID = Column( - Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + job_id = Column( + "JobID", Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True ) - Name = Column(String(100), primary_key=True) - Value = Column(Text) - RescheduleCycle = Column(Integer) + name = Column("Name", String(100), primary_key=True) + value = Column("Value", Text) + reschedule_cycle = Column("RescheduleCycle", Integer) class HeartBeatLoggingInfo(JobDBBase): __tablename__ = "HeartBeatLoggingInfo" - JobID = Column( - Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + job_id = Column( + "JobID", Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True ) - Name = Column(String(100), primary_key=True) - Value = Column(Text) - HeartBeatTime = Column(DateTime, primary_key=True) + name = Column("Name", String(100), primary_key=True) + value = Column("Value", Text) + heart_beat_time = Column("HeartBeatTime", DateTime, primary_key=True) class JobCommands(JobDBBase): __tablename__ = "JobCommands" - JobID = Column( - Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + job_id = Column( + "JobID", Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True ) - Command = Column(String(100)) - Arguments = Column(String(100)) - Status = Column(String(64), default="Received") - ReceptionTime = Column(DateTime, primary_key=True) - ExecutionTime = NullColumn(DateTime) + command = Column("Command", String(100)) + arguments = Column("Arguments", String(100)) + status = Column("Status", String(64), default="Received") + reception_time = Column("ReceptionTime", DateTime, primary_key=True) + execution_time = NullColumn("ExecutionTime", DateTime) diff --git a/diracx-db/src/diracx/db/sql/job_logging/db.py b/diracx-db/src/diracx/db/sql/job_logging/db.py index 2edf12f2..e1b136a0 100644 --- a/diracx-db/src/diracx/db/sql/job_logging/db.py +++ b/diracx-db/src/diracx/db/sql/job_logging/db.py @@ -45,9 +45,9 @@ async def insert_record( as datetime.datetime object. If the time stamp is not provided the current UTC time is used. """ - # First, fetch the maximum SeqNum for the given job_id - seqnum_stmt = select(func.coalesce(func.max(LoggingInfo.SeqNum) + 1, 1)).where( - LoggingInfo.JobID == job_id + # First, fetch the maximum seq_num for the given job_id + seqnum_stmt = select(func.coalesce(func.max(LoggingInfo.seq_num) + 1, 1)).where( + LoggingInfo.job_id == job_id ) seqnum = await self.conn.scalar(seqnum_stmt) @@ -58,14 +58,14 @@ async def insert_record( ) stmt = insert(LoggingInfo).values( - JobID=int(job_id), - SeqNum=seqnum, - Status=status, - MinorStatus=minor_status, - ApplicationStatus=application_status[:255], - StatusTime=date, - StatusTimeOrder=epoc, - Source=source[:32], + job_id=int(job_id), + seq_num=seqnum, + status=status, + minor_status=minor_status, + application_status=application_status[:255], + status_time=date, + status_time_order=epoc, + source=source[:32], ) await self.conn.execute(stmt) @@ -75,14 +75,14 @@ async def get_records(self, job_id: int) -> list[JobStatusReturn]: """ stmt = ( select( - LoggingInfo.Status, - LoggingInfo.MinorStatus, - LoggingInfo.ApplicationStatus, - LoggingInfo.StatusTime, - LoggingInfo.Source, + LoggingInfo.status, + LoggingInfo.minor_status, + LoggingInfo.application_status, + LoggingInfo.status_time, + LoggingInfo.source, ) - .where(LoggingInfo.JobID == int(job_id)) - .order_by(LoggingInfo.StatusTimeOrder, LoggingInfo.StatusTime) + .where(LoggingInfo.job_id == int(job_id)) + .order_by(LoggingInfo.status_time_order, LoggingInfo.status_time) ) rows = await self.conn.execute(stmt) @@ -139,7 +139,7 @@ async def get_records(self, job_id: int) -> list[JobStatusReturn]: async def delete_records(self, job_ids: list[int]): """Delete logging records for given jobs.""" - stmt = delete(LoggingInfo).where(LoggingInfo.JobID.in_(job_ids)) + stmt = delete(LoggingInfo).where(LoggingInfo.job_id.in_(job_ids)) await self.conn.execute(stmt) async def get_wms_time_stamps(self, job_id): @@ -148,9 +148,9 @@ async def get_wms_time_stamps(self, job_id): """ result = {} stmt = select( - LoggingInfo.Status, - LoggingInfo.StatusTimeOrder, - ).where(LoggingInfo.JobID == job_id) + LoggingInfo.status, + LoggingInfo.status_time_order, + ).where(LoggingInfo.job_id == job_id) rows = await self.conn.execute(stmt) if not rows.rowcount: raise JobNotFoundError(job_id) from None diff --git a/diracx-db/src/diracx/db/sql/job_logging/schema.py b/diracx-db/src/diracx/db/sql/job_logging/schema.py index 6f459c48..1c229bb7 100644 --- a/diracx-db/src/diracx/db/sql/job_logging/schema.py +++ b/diracx-db/src/diracx/db/sql/job_logging/schema.py @@ -13,13 +13,15 @@ class LoggingInfo(JobLoggingDBBase): __tablename__ = "LoggingInfo" - JobID = Column(Integer) - SeqNum = Column(Integer) - Status = Column(String(32), default="") - MinorStatus = Column(String(128), default="") - ApplicationStatus = Column(String(255), default="") - StatusTime = DateNowColumn() + job_id = Column("JobID", Integer) + seq_num = Column("SeqNum", Integer) + status = Column("Status", String(32), default="") + minor_status = Column("MinorStatus", String(128), default="") + application_status = Column("ApplicationStatus", String(255), default="") + status_time = DateNowColumn("StatusTime") # TODO: Check that this corresponds to the DOUBLE(12,3) type in MySQL - StatusTimeOrder = Column(Numeric(precision=12, scale=3), default=0) - Source = Column(String(32), default="Unknown", name="StatusSource") + status_time_order = Column( + "StatusTimeOrder", Numeric(precision=12, scale=3), default=0 + ) + source = Column("StatusSource", String(32), default="Unknown") __table_args__ = (PrimaryKeyConstraint("JobID", "SeqNum"),) diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/schema.py b/diracx-db/src/diracx/db/sql/pilot_agents/schema.py index 7a2a0c5e..76cd5c89 100644 --- a/diracx-db/src/diracx/db/sql/pilot_agents/schema.py +++ b/diracx-db/src/diracx/db/sql/pilot_agents/schema.py @@ -16,22 +16,22 @@ class PilotAgents(PilotAgentsDBBase): __tablename__ = "PilotAgents" - PilotID = Column("PilotID", Integer, autoincrement=True, primary_key=True) - InitialJobID = Column("InitialJobID", Integer, default=0) - CurrentJobID = Column("CurrentJobID", Integer, default=0) - PilotJobReference = Column("PilotJobReference", String(255), default="Unknown") - PilotStamp = Column("PilotStamp", String(32), default="") - DestinationSite = Column("DestinationSite", String(128), default="NotAssigned") - Queue = Column("Queue", String(128), default="Unknown") - GridSite = Column("GridSite", String(128), default="Unknown") - VO = Column("VO", String(128)) - GridType = Column("GridType", String(32), default="LCG") - BenchMark = Column("BenchMark", Double, default=0.0) - SubmissionTime = NullColumn("SubmissionTime", DateTime) - LastUpdateTime = NullColumn("LastUpdateTime", DateTime) - Status = Column("Status", String(32), default="Unknown") - StatusReason = Column("StatusReason", String(255), default="Unknown") - AccountingSent = Column("AccountingSent", EnumBackedBool(), default=False) + pilot_id = Column("PilotID", Integer, autoincrement=True, primary_key=True) + initial_job_id = Column("InitialJobID", Integer, default=0) + current_job_id = Column("CurrentJobID", Integer, default=0) + pilot_job_reference = Column("PilotJobReference", String(255), default="Unknown") + pilot_stamp = Column("PilotStamp", String(32), default="") + destination_site = Column("DestinationSite", String(128), default="NotAssigned") + queue = Column("Queue", String(128), default="Unknown") + grid_site = Column("GridSite", String(128), default="Unknown") + vo = Column("VO", String(128)) + grid_type = Column("GridType", String(32), default="LCG") + benchmark = Column("BenchMark", Double, default=0.0) + submission_time = NullColumn("SubmissionTime", DateTime) + last_update_time = NullColumn("LastUpdateTime", DateTime) + status = Column("Status", String(32), default="Unknown") + status_reason = Column("StatusReason", String(255), default="Unknown") + accounting_sent = Column("AccountingSent", EnumBackedBool(), default=False) __table_args__ = ( Index("PilotJobReference", "PilotJobReference"), @@ -43,9 +43,9 @@ class PilotAgents(PilotAgentsDBBase): class JobToPilotMapping(PilotAgentsDBBase): __tablename__ = "JobToPilotMapping" - PilotID = Column("PilotID", Integer, primary_key=True) - JobID = Column("JobID", Integer, primary_key=True) - StartTime = Column("StartTime", DateTime) + pilot_id = Column("PilotID", Integer, primary_key=True) + job_id = Column("JobID", Integer, primary_key=True) + start_time = Column("StartTime", DateTime) __table_args__ = (Index("JobID", "JobID"), Index("PilotID", "PilotID")) @@ -53,6 +53,6 @@ class JobToPilotMapping(PilotAgentsDBBase): class PilotOutput(PilotAgentsDBBase): __tablename__ = "PilotOutput" - PilotID = Column("PilotID", Integer, primary_key=True) - StdOutput = Column("StdOutput", Text) - StdError = Column("StdError", Text) + pilot_id = Column("PilotID", Integer, primary_key=True) + std_output = Column("StdOutput", Text) + std_error = Column("StdError", Text) diff --git a/diracx-db/src/diracx/db/sql/utils/__init__.py b/diracx-db/src/diracx/db/sql/utils/__init__.py index cc2c5e8b..d79aca44 100644 --- a/diracx-db/src/diracx/db/sql/utils/__init__.py +++ b/diracx-db/src/diracx/db/sql/utils/__init__.py @@ -125,8 +125,8 @@ def substract_date(**kwargs: float) -> datetime: DateNowColumn = partial(Column, type_=DateTime(timezone=True), server_default=UTCNow()) -def EnumColumn(enum_type, **kwargs): # noqa: N802 - return Column(Enum(enum_type, native_enum=False, length=16), **kwargs) +def EnumColumn(name, enum_type, **kwargs): # noqa: N802 + return Column(name, Enum(enum_type, native_enum=False, length=16), **kwargs) class EnumBackedBool(types.TypeDecorator): diff --git a/diracx-db/tests/auth/test_authorization_flow.py b/diracx-db/tests/auth/test_authorization_flow.py index 153896a9..240cd55e 100644 --- a/diracx-db/tests/auth/test_authorization_flow.py +++ b/diracx-db/tests/auth/test_authorization_flow.py @@ -49,7 +49,7 @@ async def test_insert_id_token(auth_db: AuthDB): with pytest.raises(NoResultFound): await auth_db.get_authorization_flow(code, EXPIRED) res = await auth_db.get_authorization_flow(code, MAX_VALIDITY) - assert res["id_token"] == id_token + assert res["IDToken"] == id_token # Cannot add a id_token after finishing the flow async with auth_db as auth_db: diff --git a/diracx-db/tests/auth/test_device_flow.py b/diracx-db/tests/auth/test_device_flow.py index e1cb0e6b..45093d2e 100644 --- a/diracx-db/tests/auth/test_device_flow.py +++ b/diracx-db/tests/auth/test_device_flow.py @@ -107,8 +107,8 @@ async def test_device_flow_lookup(auth_db: AuthDB, monkeypatch): await auth_db.get_device_flow(device_code1, EXPIRED) res = await auth_db.get_device_flow(device_code1, MAX_VALIDITY) - assert res["user_code"] == user_code1 - assert res["id_token"] == {"token": "mytoken"} + assert res["UserCode"] == user_code1 + assert res["IDToken"] == {"token": "mytoken"} # cannot get it a second time async with auth_db as auth_db: @@ -147,4 +147,4 @@ async def test_device_flow_insert_id_token(auth_db: AuthDB): async with auth_db as auth_db: res = await auth_db.get_device_flow(device_code, MAX_VALIDITY) - assert res["id_token"] == id_token + assert res["IDToken"] == id_token diff --git a/diracx-db/tests/auth/test_refresh_token.py b/diracx-db/tests/auth/test_refresh_token.py index 2b0cb4f0..2d72cef0 100644 --- a/diracx-db/tests/auth/test_refresh_token.py +++ b/diracx-db/tests/auth/test_refresh_token.py @@ -55,16 +55,21 @@ async def test_get(auth_db: AuthDB): ) # Enrich the dict with the generated refresh token attributes - refresh_token_details["jti"] = jti - refresh_token_details["status"] = RefreshTokenStatus.CREATED - refresh_token_details["creation_time"] = creation_time + expected_refresh_token = { + "Sub": refresh_token_details["sub"], + "PreferredUsername": refresh_token_details["preferred_username"], + "Scope": refresh_token_details["scope"], + "JTI": jti, + "Status": RefreshTokenStatus.CREATED, + "CreationTime": creation_time, + } # Get refresh token details async with auth_db as auth_db: result = await auth_db.get_refresh_token(jti) # Make sure they are identical - assert result == refresh_token_details + assert result == expected_refresh_token async def test_get_user_refresh_tokens(auth_db: AuthDB): @@ -96,11 +101,11 @@ async def test_get_user_refresh_tokens(auth_db: AuthDB): # And check that the subject value corresponds to the user's subject assert len(refresh_tokens_user1) == 2 for refresh_token in refresh_tokens_user1: - assert refresh_token["sub"] == sub1 + assert refresh_token["Sub"] == sub1 assert len(refresh_tokens_user2) == 1 for refresh_token in refresh_tokens_user2: - assert refresh_token["sub"] == sub2 + assert refresh_token["Sub"] == sub2 async def test_revoke(auth_db: AuthDB): @@ -121,7 +126,7 @@ async def test_revoke(auth_db: AuthDB): async with auth_db as auth_db: refresh_token_details = await auth_db.get_refresh_token(jti) - assert refresh_token_details["status"] == RefreshTokenStatus.REVOKED + assert refresh_token_details["Status"] == RefreshTokenStatus.REVOKED async def test_revoke_user_refresh_tokens(auth_db: AuthDB): @@ -194,7 +199,7 @@ async def test_revoke_and_get_user_refresh_tokens(auth_db: AuthDB): # And check that the subject value corresponds to the user's subject assert len(refresh_tokens_user) == nb_tokens for refresh_token in refresh_tokens_user: - assert refresh_token["sub"] == sub + assert refresh_token["Sub"] == sub # Revoke one of the tokens async with auth_db as auth_db: @@ -208,8 +213,8 @@ async def test_revoke_and_get_user_refresh_tokens(auth_db: AuthDB): # And check that the subject value corresponds to the user's subject assert len(refresh_tokens_user) == nb_tokens - 1 for refresh_token in refresh_tokens_user: - assert refresh_token["sub"] == sub - assert refresh_token["jti"] != jtis[0] + assert refresh_token["Sub"] == sub + assert refresh_token["JTI"] != jtis[0] async def test_get_refresh_tokens(auth_db: AuthDB): diff --git a/diracx-db/tests/test_dummy_db.py b/diracx-db/tests/test_dummy_db.py index e7011539..023899ce 100644 --- a/diracx-db/tests/test_dummy_db.py +++ b/diracx-db/tests/test_dummy_db.py @@ -27,7 +27,7 @@ async def test_insert_and_summary(dummy_db: DummyDB): # So it is important to write test this way async with dummy_db as dummy_db: # First we check that the DB is empty - result = await dummy_db.summary(["model"], []) + result = await dummy_db.summary(["Model"], []) assert not result # Now we add some data in the DB @@ -44,14 +44,14 @@ async def test_insert_and_summary(dummy_db: DummyDB): # Check that there are now 10 cars assigned to a single driver async with dummy_db as dummy_db: - result = await dummy_db.summary(["owner_id"], []) + result = await dummy_db.summary(["OwnerID"], []) assert result[0]["count"] == 10 # Test the selection async with dummy_db as dummy_db: result = await dummy_db.summary( - ["owner_id"], [{"parameter": "model", "operator": "eq", "value": "model_1"}] + ["OwnerID"], [{"parameter": "Model", "operator": "eq", "value": "model_1"}] ) assert result[0]["count"] == 1 @@ -59,7 +59,7 @@ async def test_insert_and_summary(dummy_db: DummyDB): async with dummy_db as dummy_db: with pytest.raises(InvalidQueryError): result = await dummy_db.summary( - ["owner_id"], + ["OwnerID"], [ { "parameter": "model", @@ -93,7 +93,7 @@ async def test_successful_transaction(dummy_db): assert dummy_db.conn # First we check that the DB is empty - result = await dummy_db.summary(["owner_id"], []) + result = await dummy_db.summary(["OwnerID"], []) assert not result # Add data @@ -104,7 +104,7 @@ async def test_successful_transaction(dummy_db): ) assert result - result = await dummy_db.summary(["owner_id"], []) + result = await dummy_db.summary(["OwnerID"], []) assert result[0]["count"] == 10 # The connection is closed when the context manager is exited @@ -114,7 +114,7 @@ async def test_successful_transaction(dummy_db): # Start a new transaction # The previous data should still be there because the transaction was committed (successful) async with dummy_db as dummy_db: - result = await dummy_db.summary(["owner_id"], []) + result = await dummy_db.summary(["OwnerID"], []) assert result[0]["count"] == 10 @@ -134,7 +134,7 @@ async def test_failed_transaction(dummy_db): assert dummy_db.conn # First we check that the DB is empty - result = await dummy_db.summary(["owner_id"], []) + result = await dummy_db.summary(["OwnerID"], []) assert not result # Add data @@ -159,7 +159,7 @@ async def test_failed_transaction(dummy_db): # Start a new transaction # The previous data should not be there because the transaction was rolled back (failed) async with dummy_db as dummy_db: - result = await dummy_db.summary(["owner_id"], []) + result = await dummy_db.summary(["OwnerID"], []) assert not result @@ -203,7 +203,7 @@ async def test_successful_with_exception_transaction(dummy_db): assert dummy_db.conn # First we check that the DB is empty - result = await dummy_db.summary(["owner_id"], []) + result = await dummy_db.summary(["OwnerID"], []) assert not result # Add data @@ -217,7 +217,7 @@ async def test_successful_with_exception_transaction(dummy_db): ) assert result - result = await dummy_db.summary(["owner_id"], []) + result = await dummy_db.summary(["OwnerID"], []) assert result[0]["count"] == 10 # This will raise an exception but the transaction will be rolled back @@ -231,7 +231,7 @@ async def test_successful_with_exception_transaction(dummy_db): # Start a new transaction # The previous data should not be there because the transaction was rolled back (failed) async with dummy_db as dummy_db: - result = await dummy_db.summary(["owner_id"], []) + result = await dummy_db.summary(["OwnerID"], []) assert not result # Start a new transaction, this time we commit it manually @@ -240,7 +240,7 @@ async def test_successful_with_exception_transaction(dummy_db): assert dummy_db.conn # First we check that the DB is empty - result = await dummy_db.summary(["owner_id"], []) + result = await dummy_db.summary(["OwnerID"], []) assert not result # Add data @@ -254,7 +254,7 @@ async def test_successful_with_exception_transaction(dummy_db): ) assert result - result = await dummy_db.summary(["owner_id"], []) + result = await dummy_db.summary(["OwnerID"], []) assert result[0]["count"] == 10 # Manually commit the transaction, and then raise an exception @@ -271,5 +271,5 @@ async def test_successful_with_exception_transaction(dummy_db): # Start a new transaction # The previous data should be there because the transaction was committed before the exception async with dummy_db as dummy_db: - result = await dummy_db.summary(["owner_id"], []) + result = await dummy_db.summary(["OwnerID"], []) assert result[0]["count"] == 10 diff --git a/diracx-routers/src/diracx/routers/auth/management.py b/diracx-routers/src/diracx/routers/auth/management.py index e8b59356..7bd7c1b9 100644 --- a/diracx-routers/src/diracx/routers/auth/management.py +++ b/diracx-routers/src/diracx/routers/auth/management.py @@ -66,7 +66,7 @@ async def revoke_refresh_token( detail="JTI provided does not exist", ) - if PROXY_MANAGEMENT not in user_info.properties and user_info.sub != res["sub"]: + if PROXY_MANAGEMENT not in user_info.properties and user_info.sub != res["Sub"]: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Cannot revoke a refresh token owned by someone else", diff --git a/diracx-routers/src/diracx/routers/auth/token.py b/diracx-routers/src/diracx/routers/auth/token.py index 8346e2b9..d21416ea 100644 --- a/diracx-routers/src/diracx/routers/auth/token.py +++ b/diracx-routers/src/diracx/routers/auth/token.py @@ -130,17 +130,17 @@ async def get_oidc_token_info_from_device_flow( # raise DiracHttpResponseError(status.HTTP_400_BAD_REQUEST, {"error": "slow_down"}) # raise DiracHttpResponseError(status.HTTP_400_BAD_REQUEST, {"error": "expired_token"}) - if info["client_id"] != client_id: + if info["ClientID"] != client_id: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Bad client_id", ) - oidc_token_info = info["id_token"] - scope = info["scope"] + oidc_token_info = info["IDToken"] + scope = info["Scope"] # TODO: use HTTPException while still respecting the standard format # required by the RFC - if info["status"] != FlowStatus.READY: + if info["Status"] != FlowStatus.READY: # That should never ever happen raise NotImplementedError(f"Unexpected flow status {info['status']!r}") return (oidc_token_info, scope) @@ -159,12 +159,12 @@ async def get_oidc_token_info_from_authorization_flow( info = await auth_db.get_authorization_flow( code, settings.authorization_flow_expiration_seconds ) - if redirect_uri != info["redirect_uri"]: + if redirect_uri != info["RedirectURI"]: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid redirect_uri", ) - if client_id != info["client_id"]: + if client_id != info["ClientID"]: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Bad client_id", @@ -184,18 +184,18 @@ async def get_oidc_token_info_from_authorization_flow( detail="Malformed code_verifier", ) from e - if code_challenge != info["code_challenge"]: + if code_challenge != info["CodeChallenge"]: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid code_challenge", ) - oidc_token_info = info["id_token"] - scope = info["scope"] + oidc_token_info = info["IDToken"] + scope = info["Scope"] # TODO: use HTTPException while still respecting the standard format # required by the RFC - if info["status"] != FlowStatus.READY: + if info["Status"] != FlowStatus.READY: # That should never ever happen raise NotImplementedError(f"Unexpected flow status {info['status']!r}") @@ -214,7 +214,7 @@ async def get_oidc_token_info_from_refresh_flow( # Get some useful user information from the refresh token entry in the DB refresh_token_attributes = await auth_db.get_refresh_token(jti) - sub = refresh_token_attributes["sub"] + sub = refresh_token_attributes["Sub"] # Check if the refresh token was obtained from the legacy_exchange endpoint # If it is the case, we bypass the refresh token rotation mechanism @@ -224,7 +224,7 @@ async def get_oidc_token_info_from_refresh_flow( # This might indicate that a potential attacker try to impersonate someone # In such case, all the refresh tokens bound to a given user (subject) should be revoked # Forcing the user to reauthenticate interactively through an authorization/device flow (recommended practice) - if refresh_token_attributes["status"] == RefreshTokenStatus.REVOKED: + if refresh_token_attributes["Status"] == RefreshTokenStatus.REVOKED: # Revoke all the user tokens from the subject await auth_db.revoke_user_refresh_tokens(sub) @@ -246,9 +246,9 @@ async def get_oidc_token_info_from_refresh_flow( # The sub attribute coming from the DB contains the VO name # We need to remove it as if it were coming from an ID token from an external IdP "sub": sub.split(":", 1)[1], - "preferred_username": refresh_token_attributes["preferred_username"], + "preferred_username": refresh_token_attributes["PreferredUsername"], } - scope = refresh_token_attributes["scope"] + scope = refresh_token_attributes["Scope"] return (oidc_token_info, scope, legacy_exchange) diff --git a/docs/CODING_CONVENTION.md b/docs/CODING_CONVENTION.md index 9ea27510..50363485 100644 --- a/docs/CODING_CONVENTION.md +++ b/docs/CODING_CONVENTION.md @@ -47,6 +47,28 @@ ALWAYS DO from __future__ import annotations ``` +# SQL Alchemy + +DO + +```python +class Owners(Base): + __tablename__ = "Owners" + owner_id = Column("OwnerID", Integer, primary_key=True, autoincrement=True) + creation_time = DateNowColumn("CreationTime") + name = Column("Name", String(255)) +``` + +DONT + +```python +class Owners(Base): + __tablename__ = "Owners" + OwnerID = Column(Integer, primary_key=True, autoincrement=True) + CreationTime = DateNowColumn() + Name = Column(String(255)) +``` + # Structure