From 6240cdaef25d860b9d9bd6f81fe1d63decbdcd85 Mon Sep 17 00:00:00 2001 From: Ryunosuke O'Neil Date: Mon, 16 Sep 2024 17:07:50 +0200 Subject: [PATCH] job_manager --> jobs; split jobs router into multiple jobs router modules; simplify status routes for jobs router --- .../src/diracx/db/sql/utils/job_status.py | 191 ++-- diracx-routers/pyproject.toml | 6 +- .../src/diracx/routers/jobs/__init__.py | 830 +----------------- .../src/diracx/routers/jobs/legacy.py | 0 .../src/diracx/routers/jobs/query.py | 306 +++++++ .../src/diracx/routers/jobs/status.py | 266 ++++++ .../src/diracx/routers/jobs/submission.py | 204 +++++ .../tests/jobs/test_wms_access_policy.py | 2 +- diracx-testing/src/diracx/testing/__init__.py | 4 +- docs/SERVICES.md | 4 +- 10 files changed, 865 insertions(+), 948 deletions(-) create mode 100644 diracx-routers/src/diracx/routers/jobs/legacy.py create mode 100644 diracx-routers/src/diracx/routers/jobs/query.py create mode 100644 diracx-routers/src/diracx/routers/jobs/status.py create mode 100644 diracx-routers/src/diracx/routers/jobs/submission.py diff --git a/diracx-db/src/diracx/db/sql/utils/job_status.py b/diracx-db/src/diracx/db/sql/utils/job_status.py index d7b7b7287..5807e2b65 100644 --- a/diracx-db/src/diracx/db/sql/utils/job_status.py +++ b/diracx-db/src/diracx/db/sql/utils/job_status.py @@ -16,11 +16,47 @@ from .. import JobDB, JobLoggingDB, SandboxMetadataDB, TaskQueueDB +async def set_job_statuses( + job_update: dict[int, dict[datetime, JobStatusUpdate]], + config: Config, + job_db: JobDB, + job_logging_db: JobLoggingDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, + force: bool = False, +): + """Bulk operation setting status on multiple job IDs, returning a dictionary of job ID to result. + This is done by calling set_job_status for each ID and status dictionary provided within a ForgivingTaskGroup. + + """ + async with ForgivingTaskGroup() as tg: + results = [ + tg.create_task( + set_job_status( + job_id, + status_dict, + config, + job_db, + job_logging_db, + task_queue_db, + background_task, + force=force, + ) + ) + for job_id, status_dict in job_update.items() + ] + + return {job_id: status for job_id, status in zip(job_update.keys(), results)} + + async def set_job_status( job_id: int, status: dict[datetime, JobStatusUpdate], + config: Config, job_db: JobDB, job_logging_db: JobLoggingDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, force: bool = False, ) -> SetJobStatusReturn: """Set various status fields for job specified by its jobId. @@ -118,133 +154,56 @@ async def set_job_status( if not endTime and newEndTime: job_data["EndExecTime"] = newEndTime - if job_data: - await job_db.setJobAttributes(job_id, job_data) - - for updTime in updateTimes: - sDict = statusDict[updTime] - if not sDict.get("Status"): - sDict["Status"] = "idem" - if not sDict.get("MinorStatus"): - sDict["MinorStatus"] = "idem" - if not sDict.get("ApplicationStatus"): - sDict["ApplicationStatus"] = "idem" - if not sDict.get("Source"): - sDict["Source"] = "Unknown" - - await job_logging_db.insert_record( - job_id, - sDict["Status"], - sDict["MinorStatus"], - sDict["ApplicationStatus"], - updTime, - sDict["Source"], - ) - - return SetJobStatusReturn(**job_data) - - -class ForgivingTaskGroup(asyncio.TaskGroup): - # Hacky way, check https://stackoverflow.com/questions/75250788/how-to-prevent-python3-11-taskgroup-from-canceling-all-the-tasks - # Basically e're using this because we want to wait for all tasks to finish, even if one of them raises an exception - def _abort(self): - return None + ##################################################################################################### + async with asyncio.TaskGroup() as tg: + # delete or kill job, if we transition to DELETED or KILLED state + # TODO + if new_status in [JobStatus.DELETED, JobStatus.KILLED]: + tg.create_task( + _remove_jobs_from_task_queue( + [job_id], config, task_queue_db, background_task + ) + ) + # TODO: implement StorageManagerClient + # returnValueOrRaise(StorageManagerClient().killTasksBySourceTaskID(job_ids)) -async def delete_jobs( - job_ids: list[int], - config: Config, - job_db: JobDB, - job_logging_db: JobLoggingDB, - task_queue_db: TaskQueueDB, - background_task: BackgroundTasks, -): - """Removing jobs from task queues, send a kill command and set status to DELETED. + tg.create_task(job_db.set_job_command(job_id, "Kill")) - :raises: BaseExceptionGroup[JobNotFound] for every job that was not found. - """ - await _remove_jobs_from_task_queue(job_ids, config, task_queue_db, background_task) - # TODO: implement StorageManagerClient - # returnValueOrRaise(StorageManagerClient().killTasksBySourceTaskID(job_ids)) + # Update database tables + if job_data: + tg.create_task(job_db.setJobAttributes(job_id, job_data)) - async with ForgivingTaskGroup() as task_group: - for job_id in job_ids: - task_group.create_task(job_db.set_job_command(job_id, "Kill")) + for updTime in updateTimes: + sDict = statusDict[updTime] + if not sDict.get("Status"): + sDict["Status"] = "idem" + if not sDict.get("MinorStatus"): + sDict["MinorStatus"] = "idem" + if not sDict.get("ApplicationStatus"): + sDict["ApplicationStatus"] = "idem" + if not sDict.get("Source"): + sDict["Source"] = "Unknown" - task_group.create_task( - set_job_status( + tg.create_task( + job_logging_db.insert_record( job_id, - { - datetime.now(timezone.utc): JobStatusUpdate( - Status=JobStatus.DELETED, - MinorStatus="Checking accounting", - Source="job_manager", - ) - }, - job_db, - job_logging_db, - force=True, + sDict["Status"], + sDict["MinorStatus"], + sDict["ApplicationStatus"], + updTime, + sDict["Source"], ) ) + return SetJobStatusReturn(**job_data) -async def kill_jobs( - job_ids: list[int], - config: Config, - job_db: JobDB, - job_logging_db: JobLoggingDB, - task_queue_db: TaskQueueDB, - background_task: BackgroundTasks, -): - """Kill jobs by removing them from the task queues, set kill as a job command and setting the job status to KILLED. - :raises: BaseExceptionGroup[JobNotFound] for every job that was not found. - """ - await _remove_jobs_from_task_queue(job_ids, config, task_queue_db, background_task) - # TODO: implement StorageManagerClient - # returnValueOrRaise(StorageManagerClient().killTasksBySourceTaskID(job_ids)) - - async with ForgivingTaskGroup() as task_group: - for job_id in job_ids: - task_group.create_task(job_db.set_job_command(job_id, "Kill")) - task_group.create_task( - set_job_status( - job_id, - { - datetime.now(timezone.utc): JobStatusUpdate( - Status=JobStatus.KILLED, - MinorStatus="Marked for termination", - Source="job_manager", - ) - }, - job_db, - job_logging_db, - force=True, - ) - ) - # TODO: Consider using the code below instead, probably more stable but less performant - # errors = [] - # for job_id in job_ids: - # try: - # await job_db.set_job_command(job_id, "Kill") - # await set_job_status( - # job_id, - # { - # datetime.now(timezone.utc): JobStatusUpdate( - # Status=JobStatus.KILLED, - # MinorStatus="Marked for termination", - # Source="job_manager", - # ) - # }, - # job_db, - # job_logging_db, - # force=True, - # ) - # except JobNotFound as e: - # errors.append(e) - - # if errors: - # raise BaseExceptionGroup("Some job ids were not found", errors) +class ForgivingTaskGroup(asyncio.TaskGroup): + # Hacky way, check https://stackoverflow.com/questions/75250788/how-to-prevent-python3-11-taskgroup-from-canceling-all-the-tasks + # Basically e're using this because we want to wait for all tasks to finish, even if one of them raises an exception + def _abort(self): + return None async def remove_jobs( diff --git a/diracx-routers/pyproject.toml b/diracx-routers/pyproject.toml index 6316fce00..6c188994c 100644 --- a/diracx-routers/pyproject.toml +++ b/diracx-routers/pyproject.toml @@ -48,14 +48,14 @@ types = [ ] [project.entry-points."diracx.services"] -jobs = "diracx.routers.job_manager:router" +jobs = "diracx.routers.jobs:router" config = "diracx.routers.configuration:router" auth = "diracx.routers.auth:router" ".well-known" = "diracx.routers.auth.well_known:router" [project.entry-points."diracx.access_policies"] -WMSAccessPolicy = "diracx.routers.job_manager.access_policies:WMSAccessPolicy" -SandboxAccessPolicy = "diracx.routers.job_manager.access_policies:SandboxAccessPolicy" +WMSAccessPolicy = "diracx.routers.jobs.access_policies:WMSAccessPolicy" +SandboxAccessPolicy = "diracx.routers.jobs.access_policies:SandboxAccessPolicy" [tool.setuptools.packages.find] diff --git a/diracx-routers/src/diracx/routers/jobs/__init__.py b/diracx-routers/src/diracx/routers/jobs/__init__.py index 5dd9f7d46..8dd828748 100644 --- a/diracx-routers/src/diracx/routers/jobs/__init__.py +++ b/diracx-routers/src/diracx/routers/jobs/__init__.py @@ -1,835 +1,17 @@ from __future__ import annotations -import asyncio import logging -from datetime import datetime, timezone -from http import HTTPStatus -from typing import Annotated, Any -from fastapi import BackgroundTasks, Body, Depends, HTTPException, Query, Response -from pydantic import BaseModel -from sqlalchemy.exc import NoResultFound -from typing_extensions import TypedDict - -from diracx.core.exceptions import JobNotFound -from diracx.core.models import ( - JobStatus, - JobStatusReturn, - JobStatusUpdate, - LimitedJobStatusReturn, - ScalarSearchOperator, - SearchSpec, - SetJobStatusReturn, - SortSpec, -) -from diracx.db.sql.utils.job_status import ( - delete_jobs, - kill_jobs, - remove_jobs, - set_job_status, -) - -from ..dependencies import ( - Config, - JobDB, - JobLoggingDB, - JobParametersDB, - SandboxMetadataDB, - TaskQueueDB, -) from ..fastapi_classes import DiracxRouter -from ..utils.users import AuthorizedUserInfo, verify_dirac_access_token -from .access_policies import ActionType, CheckWMSPolicyCallable +from .query import router as query_router from .sandboxes import router as sandboxes_router - -MAX_PARAMETRIC_JOBS = 20 +from .status import router as status_router +from .submission import router as submission_router logger = logging.getLogger(__name__) router = DiracxRouter() router.include_router(sandboxes_router) - - -class JobSummaryParams(BaseModel): - grouping: list[str] - search: list[SearchSpec] = [] - # TODO: Add more validation - - -class JobSearchParams(BaseModel): - parameters: list[str] | None = None - search: list[SearchSpec] = [] - sort: list[SortSpec] = [] - distinct: bool = False - # TODO: Add more validation - - -class InsertedJob(TypedDict): - JobID: int - Status: str - MinorStatus: str - TimeStamp: datetime - - -class JobID(BaseModel): - job_id: int - - -EXAMPLE_JDLS = { - "Simple JDL": { - "value": [ - """Arguments = "jobDescription.xml -o LogLevel=INFO"; -Executable = "dirac-jobexec"; -JobGroup = jobGroup; -JobName = jobName; -JobType = User; -LogLevel = INFO; -OutputSandbox = - { - Script1_CodeOutput.log, - std.err, - std.out - }; -Priority = 1; -Site = ANY; -StdError = std.err; -StdOutput = std.out;""" - ] - }, - "Parametric JDL": { - "value": ["""Arguments = "jobDescription.xml -o LogLevel=INFO"""] - }, -} - - -@router.post("/") -async def submit_bulk_jobs( - job_definitions: Annotated[list[str], Body(openapi_examples=EXAMPLE_JDLS)], - job_db: JobDB, - job_logging_db: JobLoggingDB, - user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], - check_permissions: CheckWMSPolicyCallable, -) -> list[InsertedJob]: - await check_permissions(action=ActionType.CREATE, job_db=job_db) - - from DIRAC.Core.Utilities.ClassAd.ClassAdLight import ClassAd - from DIRAC.Core.Utilities.ReturnValues import returnValueOrRaise - from DIRAC.WorkloadManagementSystem.Service.JobPolicy import RIGHT_SUBMIT, JobPolicy - from DIRAC.WorkloadManagementSystem.Utilities.ParametricJob import ( - generateParametricJobs, - getParameterVectorLength, - ) - - class DiracxJobPolicy(JobPolicy): - def __init__(self, user_info: AuthorizedUserInfo, allInfo: bool = True): - self.userName = user_info.preferred_username - self.userGroup = user_info.dirac_group - self.userProperties = user_info.properties - self.jobDB = None - self.allInfo = allInfo - self._permissions: dict[str, bool] = {} - self._getUserJobPolicy() - - # Check job submission permission - policyDict = returnValueOrRaise(DiracxJobPolicy(user_info).getJobPolicy()) - if not policyDict[RIGHT_SUBMIT]: - raise HTTPException(HTTPStatus.FORBIDDEN, "You are not allowed to submit jobs") - - # TODO: that needs to go in the legacy adapter (Does it ? Because bulk submission is not supported there) - for i in range(len(job_definitions)): - job_definition = job_definitions[i].strip() - if not (job_definition.startswith("[") and job_definition.endswith("]")): - job_definition = f"[{job_definition}]" - job_definitions[i] = job_definition - - if len(job_definitions) == 1: - # Check if the job is a parametric one - jobClassAd = ClassAd(job_definitions[0]) - result = getParameterVectorLength(jobClassAd) - if not result["OK"]: - print("Issue with getParameterVectorLength", result["Message"]) - return result - nJobs = result["Value"] - parametricJob = False - if nJobs is not None and nJobs > 0: - # if we are here, then jobDesc was the description of a parametric job. So we start unpacking - parametricJob = True - result = generateParametricJobs(jobClassAd) - if not result["OK"]: - return result - jobDescList = result["Value"] - else: - # if we are here, then jobDesc was the description of a single job. - jobDescList = job_definitions - else: - # if we are here, then jobDesc is a list of JDLs - # we need to check that none of them is a parametric - for job_definition in job_definitions: - res = getParameterVectorLength(ClassAd(job_definition)) - if not res["OK"]: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, detail=res["Message"] - ) - if res["Value"]: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, - detail="You cannot submit parametric jobs in a bulk fashion", - ) - - jobDescList = job_definitions - parametricJob = True - - # TODO: make the max number of jobs configurable in the CS - if len(jobDescList) > MAX_PARAMETRIC_JOBS: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, - detail=f"Normal user cannot submit more than {MAX_PARAMETRIC_JOBS} jobs at once", - ) - - result = [] - - if parametricJob: - initialStatus = JobStatus.SUBMITTING - initialMinorStatus = "Bulk transaction confirmation" - else: - initialStatus = JobStatus.RECEIVED - initialMinorStatus = "Job accepted" - - for ( - jobDescription - ) in ( - jobDescList - ): # jobDescList because there might be a list generated by a parametric job - res = await job_db.insert( - jobDescription, - user_info.preferred_username, - user_info.dirac_group, - initialStatus, - initialMinorStatus, - user_info.vo, - ) - - job_id = res["JobID"] - logging.debug( - f'Job added to the JobDB", "{job_id} for {user_info.preferred_username}/{user_info.dirac_group}' - ) - - await job_logging_db.insert_record( - int(job_id), - initialStatus, - initialMinorStatus, - "Unknown", - datetime.now(timezone.utc), - "JobManager", - ) - - result.append(res) - - return result - - # TODO: is this needed ? - # if not parametricJob: - # self.__sendJobsToOptimizationMind(jobIDList) - # return result - - return await asyncio.gather( - *(job_db.insert(j.owner, j.group, j.vo) for j in job_definitions) - ) - - -@router.delete("/") -async def delete_bulk_jobs( - job_ids: Annotated[list[int], Query()], - config: Config, - job_db: JobDB, - job_logging_db: JobLoggingDB, - task_queue_db: TaskQueueDB, - background_task: BackgroundTasks, - check_permissions: CheckWMSPolicyCallable, -): - - await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=job_ids) - # TODO: implement job policy - - try: - await delete_jobs( - job_ids, - config, - job_db, - job_logging_db, - task_queue_db, - background_task, - ) - except* JobNotFound as group_exc: - failed_job_ids: list[int] = list({e.job_id for e in group_exc.exceptions}) # type: ignore - - raise HTTPException( - status_code=HTTPStatus.NOT_FOUND, - detail={ - "message": f"Failed to delete {len(failed_job_ids)} jobs out of {len(job_ids)}", - "valid_job_ids": list(set(job_ids) - set(failed_job_ids)), - "failed_job_ids": failed_job_ids, - }, - ) from group_exc - - return job_ids - - -@router.post("/kill") -async def kill_bulk_jobs( - job_ids: Annotated[list[int], Query()], - config: Config, - job_db: JobDB, - job_logging_db: JobLoggingDB, - task_queue_db: TaskQueueDB, - background_task: BackgroundTasks, - check_permissions: CheckWMSPolicyCallable, -): - await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=job_ids) - # TODO: implement job policy - try: - await kill_jobs( - job_ids, - config, - job_db, - job_logging_db, - task_queue_db, - background_task, - ) - except* JobNotFound as group_exc: - failed_job_ids: list[int] = list({e.job_id for e in group_exc.exceptions}) # type: ignore - - raise HTTPException( - status_code=HTTPStatus.NOT_FOUND, - detail={ - "message": f"Failed to kill {len(failed_job_ids)} jobs out of {len(job_ids)}", - "valid_job_ids": list(set(job_ids) - set(failed_job_ids)), - "failed_job_ids": failed_job_ids, - }, - ) from group_exc - - return job_ids - - -@router.post("/remove") -async def remove_bulk_jobs( - job_ids: Annotated[list[int], Query()], - config: Config, - job_db: JobDB, - job_logging_db: JobLoggingDB, - sandbox_metadata_db: SandboxMetadataDB, - task_queue_db: TaskQueueDB, - background_task: BackgroundTasks, - check_permissions: CheckWMSPolicyCallable, -): - """Fully remove a list of jobs from the WMS databases. - - WARNING: This endpoint has been implemented for the compatibility with the legacy DIRAC WMS - and the JobCleaningAgent. However, once this agent is ported to diracx, this endpoint should - be removed, and the delete endpoint should be used instead for any other purpose. - """ - await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=job_ids) - # TODO: Remove once legacy DIRAC no longer needs this - - # TODO: implement job policy - # Some tests have already been written in the test_job_manager, - # but they need to be uncommented and are not complete - - await remove_jobs( - job_ids, - config, - job_db, - job_logging_db, - sandbox_metadata_db, - task_queue_db, - background_task, - ) - - return job_ids - - -@router.get("/status") -async def get_job_status_bulk( - job_ids: Annotated[list[int], Query()], - job_db: JobDB, - check_permissions: CheckWMSPolicyCallable, -) -> dict[int, LimitedJobStatusReturn]: - await check_permissions(action=ActionType.READ, job_db=job_db, job_ids=job_ids) - try: - result = await asyncio.gather( - *(job_db.get_job_status(job_id) for job_id in job_ids) - ) - return {job_id: status for job_id, status in zip(job_ids, result)} - except JobNotFound as e: - raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(e)) from e - - -@router.patch("/status") -async def set_job_status_bulk( - job_update: dict[int, dict[datetime, JobStatusUpdate]], - job_db: JobDB, - job_logging_db: JobLoggingDB, - check_permissions: CheckWMSPolicyCallable, - force: bool = False, -) -> dict[int, SetJobStatusReturn]: - await check_permissions( - action=ActionType.MANAGE, job_db=job_db, job_ids=list(job_update) - ) - # check that the datetime contains timezone info - for job_id, status in job_update.items(): - for dt in status: - if dt.tzinfo is None: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, - detail=f"Timestamp {dt} is not timezone aware for job {job_id}", - ) - - res = await asyncio.gather( - *( - set_job_status(job_id, status, job_db, job_logging_db, force) - for job_id, status in job_update.items() - ) - ) - return {job_id: status for job_id, status in zip(job_update.keys(), res)} - - -@router.get("/status/history") -async def get_job_status_history_bulk( - job_ids: Annotated[list[int], Query()], - job_logging_db: JobLoggingDB, - job_db: JobDB, - check_permissions: CheckWMSPolicyCallable, -) -> dict[int, list[JobStatusReturn]]: - await check_permissions(action=ActionType.READ, job_db=job_db, job_ids=job_ids) - result = await asyncio.gather( - *(job_logging_db.get_records(job_id) for job_id in job_ids) - ) - return {job_id: status for job_id, status in zip(job_ids, result)} - - -@router.post("/reschedule") -async def reschedule_bulk_jobs( - job_ids: Annotated[list[int], Query()], - job_db: JobDB, - job_logging_db: JobLoggingDB, - check_permissions: CheckWMSPolicyCallable, -): - await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=job_ids) - rescheduled_jobs = [] - # TODO: Joblist Policy: - # validJobList, invalidJobList, nonauthJobList, ownerJobList = self.jobPolicy.evaluateJobRights( - # jobList, RIGHT_RESCHEDULE - # ) - # For the moment all jobs are valid: - valid_job_list = job_ids - for job_id in valid_job_list: - # TODO: delete job in TaskQueueDB - # self.taskQueueDB.deleteJob(jobID) - result = await job_db.rescheduleJob(job_id) - try: - res_status = await job_db.get_job_status(job_id) - except NoResultFound as e: - raise HTTPException( - status_code=HTTPStatus.NOT_FOUND, detail=f"Job {job_id} not found" - ) from e - - initial_status = res_status.Status - initial_minor_status = res_status.MinorStatus - - await job_logging_db.insert_record( - int(job_id), - initial_status, - initial_minor_status, - "Unknown", - datetime.now(timezone.utc), - "JobManager", - ) - if result: - rescheduled_jobs.append(job_id) - # To uncomment when jobPolicy is setup: - # if invalid_job_list or non_auth_job_list: - # logging.error("Some jobs failed to reschedule") - # if invalid_job_list: - # logging.info(f"Invalid jobs: {invalid_job_list}") - # if non_auth_job_list: - # logging.info(f"Non authorized jobs: {nonauthJobList}") - - # TODO: send jobs to OtimizationMind - # self.__sendJobsToOptimizationMind(validJobList) - return rescheduled_jobs - - -@router.post("/{job_id}/reschedule") -async def reschedule_single_job( - job_id: int, - job_db: JobDB, - check_permissions: CheckWMSPolicyCallable, -): - await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=[job_id]) - try: - result = await job_db.rescheduleJob(job_id) - except ValueError as e: - raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(e)) from e - return result - - -EXAMPLE_SEARCHES = { - "Show all": { - "summary": "Show all", - "description": "Shows all jobs the current user has access to.", - "value": {}, - }, - "A specific job": { - "summary": "A specific job", - "description": "Search for a specific job by ID", - "value": {"search": [{"parameter": "JobID", "operator": "eq", "value": "5"}]}, - }, - "Get ordered job statuses": { - "summary": "Get ordered job statuses", - "description": "Get only job statuses for specific jobs, ordered by status", - "value": { - "parameters": ["JobID", "Status"], - "search": [ - {"parameter": "JobID", "operator": "in", "values": ["6", "2", "3"]} - ], - "sort": [{"parameter": "JobID", "direction": "asc"}], - }, - }, -} - - -EXAMPLE_RESPONSES: dict[int | str, dict[str, Any]] = { - 200: { - "description": "List of matching results", - "content": { - "application/json": { - "example": [ - { - "JobID": 1, - "JobGroup": "jobGroup", - "Owner": "myvo:my_nickname", - "SubmissionTime": "2023-05-25T07:03:35.602654", - "LastUpdateTime": "2023-05-25T07:03:35.602652", - "Status": "RECEIVED", - "MinorStatus": "Job accepted", - "ApplicationStatus": "Unknown", - }, - { - "JobID": 2, - "JobGroup": "my_nickname", - "Owner": "myvo:cburr", - "SubmissionTime": "2023-05-25T07:03:36.256378", - "LastUpdateTime": "2023-05-25T07:10:11.974324", - "Status": "Done", - "MinorStatus": "Application Exited Successfully", - "ApplicationStatus": "All events processed", - }, - ] - } - }, - }, - 206: { - "description": "Partial Content. Only a part of the requested range could be served.", - "headers": { - "Content-Range": { - "description": "The range of jobs returned in this response", - "schema": {"type": "string", "example": "jobs 0-1/4"}, - } - }, - "model": list[dict[str, Any]], - "content": { - "application/json": { - "example": [ - { - "JobID": 1, - "JobGroup": "jobGroup", - "Owner": "myvo:my_nickname", - "SubmissionTime": "2023-05-25T07:03:35.602654", - "LastUpdateTime": "2023-05-25T07:03:35.602652", - "Status": "RECEIVED", - "MinorStatus": "Job accepted", - "ApplicationStatus": "Unknown", - }, - { - "JobID": 2, - "JobGroup": "my_nickname", - "Owner": "myvo:cburr", - "SubmissionTime": "2023-05-25T07:03:36.256378", - "LastUpdateTime": "2023-05-25T07:10:11.974324", - "Status": "Done", - "MinorStatus": "Application Exited Successfully", - "ApplicationStatus": "All events processed", - }, - ] - } - }, - }, -} - -MAX_PER_PAGE = 10000 - - -@router.post("/search", responses=EXAMPLE_RESPONSES) -async def search( - config: Config, - job_db: JobDB, - job_parameters_db: JobParametersDB, - user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], - check_permissions: CheckWMSPolicyCallable, - response: Response, - page: int = 1, - per_page: int = 100, - body: Annotated[ - JobSearchParams | None, Body(openapi_examples=EXAMPLE_SEARCHES) - ] = None, -) -> list[dict[str, Any]]: - """Retrieve information about jobs. - - **TODO: Add more docs** - """ - await check_permissions(action=ActionType.QUERY, job_db=job_db) - - # Apply a limit to per_page to prevent abuse of the API - if per_page > MAX_PER_PAGE: - per_page = MAX_PER_PAGE - - if body is None: - body = JobSearchParams() - # TODO: Apply all the job policy stuff properly using user_info - if not config.Operations["Defaults"].Services.JobMonitoring.GlobalJobsInfo: - body.search.append( - { - "parameter": "Owner", - "operator": ScalarSearchOperator.EQUAL, - "value": user_info.sub, - } - ) - - total, jobs = await job_db.search( - body.parameters, - body.search, - body.sort, - distinct=body.distinct, - page=page, - per_page=per_page, - ) - # Set the Content-Range header if needed - # https://datatracker.ietf.org/doc/html/rfc7233#section-4 - - # No jobs found but there are jobs for the requested search - # https://datatracker.ietf.org/doc/html/rfc7233#section-4.4 - if len(jobs) == 0 and total > 0: - response.headers["Content-Range"] = f"jobs */{total}" - response.status_code = HTTPStatus.REQUESTED_RANGE_NOT_SATISFIABLE - - # The total number of jobs is greater than the number of jobs returned - # https://datatracker.ietf.org/doc/html/rfc7233#section-4.2 - elif len(jobs) < total: - first_idx = per_page * (page - 1) - last_idx = min(first_idx + len(jobs), total) - 1 if total > 0 else 0 - response.headers["Content-Range"] = f"jobs {first_idx}-{last_idx}/{total}" - response.status_code = HTTPStatus.PARTIAL_CONTENT - return jobs - - -@router.post("/summary") -async def summary( - config: Config, - job_db: JobDB, - user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], - body: JobSummaryParams, - check_permissions: CheckWMSPolicyCallable, -): - """Show information suitable for plotting.""" - await check_permissions(action=ActionType.QUERY, job_db=job_db) - # TODO: Apply all the job policy stuff properly using user_info - if not config.Operations["Defaults"].Services.JobMonitoring.GlobalJobsInfo: - body.search.append( - { - "parameter": "Owner", - "operator": ScalarSearchOperator.EQUAL, - "value": user_info.sub, - } - ) - return await job_db.summary(body.grouping, body.search) - - -@router.get("/{job_id}") -async def get_single_job( - job_id: int, - job_db: JobDB, - check_permissions: CheckWMSPolicyCallable, -): - await check_permissions(action=ActionType.READ, job_db=job_db, job_ids=[job_id]) - return f"This job {job_id}" - - -@router.delete("/{job_id}") -async def delete_single_job( - job_id: int, - config: Config, - job_db: JobDB, - job_logging_db: JobLoggingDB, - task_queue_db: TaskQueueDB, - background_task: BackgroundTasks, - check_permissions: CheckWMSPolicyCallable, -): - """Delete a job by killing and setting the job status to DELETED.""" - await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=[job_id]) - - # TODO: implement job policy - try: - await delete_jobs( - [job_id], - config, - job_db, - job_logging_db, - task_queue_db, - background_task, - ) - except* JobNotFound as e: - raise HTTPException( - status_code=HTTPStatus.NOT_FOUND.value, detail=str(e.exceptions[0]) - ) from e - - return f"Job {job_id} has been successfully deleted" - - -@router.post("/{job_id}/kill") -async def kill_single_job( - job_id: int, - config: Config, - job_db: JobDB, - job_logging_db: JobLoggingDB, - task_queue_db: TaskQueueDB, - background_task: BackgroundTasks, - check_permissions: CheckWMSPolicyCallable, -): - """Kill a job.""" - await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=[job_id]) - - # TODO: implement job policy - - try: - await kill_jobs( - [job_id], config, job_db, job_logging_db, task_queue_db, background_task - ) - except* JobNotFound as e: - raise HTTPException( - status_code=HTTPStatus.NOT_FOUND, detail=str(e.exceptions[0]) - ) from e - - return f"Job {job_id} has been successfully killed" - - -@router.post("/{job_id}/remove") -async def remove_single_job( - job_id: int, - config: Config, - job_db: JobDB, - job_logging_db: JobLoggingDB, - sandbox_metadata_db: SandboxMetadataDB, - task_queue_db: TaskQueueDB, - background_task: BackgroundTasks, - check_permissions: CheckWMSPolicyCallable, -): - """Fully remove a job from the WMS databases. - - WARNING: This endpoint has been implemented for the compatibility with the legacy DIRAC WMS - and the JobCleaningAgent. However, once this agent is ported to diracx, this endpoint should - be removed, and the delete endpoint should be used instead. - """ - await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=[job_id]) - # TODO: Remove once legacy DIRAC no longer needs this - - # TODO: implement job policy - - await remove_jobs( - [job_id], - config, - job_db, - job_logging_db, - sandbox_metadata_db, - task_queue_db, - background_task, - ) - - return f"Job {job_id} has been successfully removed" - - -@router.get("/{job_id}/status") -async def get_single_job_status( - job_id: int, - job_db: JobDB, - check_permissions: CheckWMSPolicyCallable, -) -> dict[int, LimitedJobStatusReturn]: - await check_permissions(action=ActionType.READ, job_db=job_db, job_ids=[job_id]) - try: - status = await job_db.get_job_status(job_id) - except JobNotFound as e: - raise HTTPException( - status_code=HTTPStatus.NOT_FOUND, detail=f"Job {job_id} not found" - ) from e - return {job_id: status} - - -@router.patch("/{job_id}/status") -async def set_single_job_status( - job_id: int, - status: Annotated[dict[datetime, JobStatusUpdate], Body()], - job_db: JobDB, - job_logging_db: JobLoggingDB, - check_permissions: CheckWMSPolicyCallable, - force: bool = False, -) -> dict[int, SetJobStatusReturn]: - await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=[job_id]) - # check that the datetime contains timezone info - for dt in status: - if dt.tzinfo is None: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, - detail=f"Timestamp {dt} is not timezone aware", - ) - - try: - latest_status = await set_job_status( - job_id, status, job_db, job_logging_db, force - ) - except JobNotFound as e: - raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(e)) from e - return {job_id: latest_status} - - -@router.get("/{job_id}/status/history") -async def get_single_job_status_history( - job_id: int, - job_db: JobDB, - job_logging_db: JobLoggingDB, - check_permissions: CheckWMSPolicyCallable, -) -> dict[int, list[JobStatusReturn]]: - await check_permissions(action=ActionType.READ, job_db=job_db, job_ids=[job_id]) - try: - status = await job_logging_db.get_records(job_id) - except JobNotFound as e: - raise HTTPException( - status_code=HTTPStatus.NOT_FOUND, detail="Job not found" - ) from e - return {job_id: status} - - -@router.patch("/{job_id}") -async def set_single_job_properties( - job_id: int, - job_properties: Annotated[dict[str, Any], Body()], - job_db: JobDB, - check_permissions: CheckWMSPolicyCallable, - update_timestamp: bool = False, -): - """Update the given job properties (MinorStatus, ApplicationStatus, etc).""" - await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=[job_id]) - - rowcount = await job_db.set_properties( - {job_id: job_properties}, update_timestamp=update_timestamp - ) - if not rowcount: - raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Job not found") +router.include_router(status_router) +router.include_router(query_router) +router.include_router(submission_router) diff --git a/diracx-routers/src/diracx/routers/jobs/legacy.py b/diracx-routers/src/diracx/routers/jobs/legacy.py new file mode 100644 index 000000000..e69de29bb diff --git a/diracx-routers/src/diracx/routers/jobs/query.py b/diracx-routers/src/diracx/routers/jobs/query.py new file mode 100644 index 000000000..97687e742 --- /dev/null +++ b/diracx-routers/src/diracx/routers/jobs/query.py @@ -0,0 +1,306 @@ +from __future__ import annotations + +import asyncio +import logging +from http import HTTPStatus +from typing import Annotated, Any + +from fastapi import Body, Depends, HTTPException, Query, Response +from pydantic import BaseModel + +from diracx.core.exceptions import JobNotFound +from diracx.core.models import ( + JobStatusReturn, + LimitedJobStatusReturn, + ScalarSearchOperator, + SearchSpec, + SortSpec, +) +from diracx.core.properties import JOB_ADMINISTRATOR, NORMAL_USER + +from ..auth import has_properties +from ..dependencies import ( + Config, + JobDB, + JobLoggingDB, + JobParametersDB, +) +from ..fastapi_classes import DiracxRouter +from ..utils.users import AuthorizedUserInfo, verify_dirac_access_token +from .access_policies import ActionType, CheckWMSPolicyCallable + +logger = logging.getLogger(__name__) + +router = DiracxRouter(dependencies=[has_properties(NORMAL_USER | JOB_ADMINISTRATOR)]) + + +class JobSummaryParams(BaseModel): + grouping: list[str] + search: list[SearchSpec] = [] + # TODO: Add more validation + + +class JobSearchParams(BaseModel): + parameters: list[str] | None = None + search: list[SearchSpec] = [] + sort: list[SortSpec] = [] + distinct: bool = False + # TODO: Add more validation + + +MAX_PER_PAGE = 10000 + + +EXAMPLE_SEARCHES = { + "Show all": { + "summary": "Show all", + "description": "Shows all jobs the current user has access to.", + "value": {}, + }, + "A specific job": { + "summary": "A specific job", + "description": "Search for a specific job by ID", + "value": {"search": [{"parameter": "JobID", "operator": "eq", "value": "5"}]}, + }, + "Get ordered job statuses": { + "summary": "Get ordered job statuses", + "description": "Get only job statuses for specific jobs, ordered by status", + "value": { + "parameters": ["JobID", "Status"], + "search": [ + {"parameter": "JobID", "operator": "in", "values": ["6", "2", "3"]} + ], + "sort": [{"parameter": "JobID", "direction": "asc"}], + }, + }, +} + + +EXAMPLE_RESPONSES: dict[int | str, dict[str, Any]] = { + 200: { + "description": "List of matching results", + "content": { + "application/json": { + "example": [ + { + "JobID": 1, + "JobGroup": "jobGroup", + "Owner": "myvo:my_nickname", + "SubmissionTime": "2023-05-25T07:03:35.602654", + "LastUpdateTime": "2023-05-25T07:03:35.602652", + "Status": "RECEIVED", + "MinorStatus": "Job accepted", + "ApplicationStatus": "Unknown", + }, + { + "JobID": 2, + "JobGroup": "my_nickname", + "Owner": "myvo:cburr", + "SubmissionTime": "2023-05-25T07:03:36.256378", + "LastUpdateTime": "2023-05-25T07:10:11.974324", + "Status": "Done", + "MinorStatus": "Application Exited Successfully", + "ApplicationStatus": "All events processed", + }, + ] + } + }, + }, + 206: { + "description": "Partial Content. Only a part of the requested range could be served.", + "headers": { + "Content-Range": { + "description": "The range of jobs returned in this response", + "schema": {"type": "string", "example": "jobs 0-1/4"}, + } + }, + "model": list[dict[str, Any]], + "content": { + "application/json": { + "example": [ + { + "JobID": 1, + "JobGroup": "jobGroup", + "Owner": "myvo:my_nickname", + "SubmissionTime": "2023-05-25T07:03:35.602654", + "LastUpdateTime": "2023-05-25T07:03:35.602652", + "Status": "RECEIVED", + "MinorStatus": "Job accepted", + "ApplicationStatus": "Unknown", + }, + { + "JobID": 2, + "JobGroup": "my_nickname", + "Owner": "myvo:cburr", + "SubmissionTime": "2023-05-25T07:03:36.256378", + "LastUpdateTime": "2023-05-25T07:10:11.974324", + "Status": "Done", + "MinorStatus": "Application Exited Successfully", + "ApplicationStatus": "All events processed", + }, + ] + } + }, + }, +} + + +@router.post("/search", responses=EXAMPLE_RESPONSES) +async def search( + config: Config, + job_db: JobDB, + job_parameters_db: JobParametersDB, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + check_permissions: CheckWMSPolicyCallable, + response: Response, + page: int = 1, + per_page: int = 100, + body: Annotated[ + JobSearchParams | None, Body(openapi_examples=EXAMPLE_SEARCHES) + ] = None, +) -> list[dict[str, Any]]: + """Retrieve information about jobs. + + **TODO: Add more docs** + """ + await check_permissions(action=ActionType.QUERY, job_db=job_db) + + # Apply a limit to per_page to prevent abuse of the API + if per_page > MAX_PER_PAGE: + per_page = MAX_PER_PAGE + + if body is None: + body = JobSearchParams() + # TODO: Apply all the job policy stuff properly using user_info + if not config.Operations["Defaults"].Services.JobMonitoring.GlobalJobsInfo: + body.search.append( + { + "parameter": "Owner", + "operator": ScalarSearchOperator.EQUAL, + "value": user_info.sub, + } + ) + + total, jobs = await job_db.search( + body.parameters, + body.search, + body.sort, + distinct=body.distinct, + page=page, + per_page=per_page, + ) + # Set the Content-Range header if needed + # https://datatracker.ietf.org/doc/html/rfc7233#section-4 + + # No jobs found but there are jobs for the requested search + # https://datatracker.ietf.org/doc/html/rfc7233#section-4.4 + if len(jobs) == 0 and total > 0: + response.headers["Content-Range"] = f"jobs */{total}" + response.status_code = HTTPStatus.REQUESTED_RANGE_NOT_SATISFIABLE + + # The total number of jobs is greater than the number of jobs returned + # https://datatracker.ietf.org/doc/html/rfc7233#section-4.2 + elif len(jobs) < total: + first_idx = per_page * (page - 1) + last_idx = min(first_idx + len(jobs), total) - 1 if total > 0 else 0 + response.headers["Content-Range"] = f"jobs {first_idx}-{last_idx}/{total}" + response.status_code = HTTPStatus.PARTIAL_CONTENT + return jobs + + +@router.post("/summary") +async def summary( + config: Config, + job_db: JobDB, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + body: JobSummaryParams, + check_permissions: CheckWMSPolicyCallable, +): + """Show information suitable for plotting.""" + await check_permissions(action=ActionType.QUERY, job_db=job_db) + # TODO: Apply all the job policy stuff properly using user_info + if not config.Operations["Defaults"].Services.JobMonitoring.GlobalJobsInfo: + body.search.append( + { + "parameter": "Owner", + "operator": ScalarSearchOperator.EQUAL, + "value": user_info.sub, + } + ) + return await job_db.summary(body.grouping, body.search) + + +@router.get("/{job_id}") +async def get_single_job( + job_id: int, + job_db: JobDB, + check_permissions: CheckWMSPolicyCallable, +): + await check_permissions(action=ActionType.READ, job_db=job_db, job_ids=[job_id]) + return f"This job {job_id}" + + +# TODO: To remove? +@router.get("/{job_id}/status") +async def get_single_job_status( + job_id: int, + job_db: JobDB, + check_permissions: CheckWMSPolicyCallable, +) -> dict[int, LimitedJobStatusReturn]: + await check_permissions(action=ActionType.READ, job_db=job_db, job_ids=[job_id]) + try: + status = await job_db.get_job_status(job_id) + except JobNotFound as e: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, detail=f"Job {job_id} not found" + ) from e + return {job_id: status} + + +@router.get("/{job_id}/status/history") +async def get_single_job_status_history( + job_id: int, + job_db: JobDB, + job_logging_db: JobLoggingDB, + check_permissions: CheckWMSPolicyCallable, +) -> dict[int, list[JobStatusReturn]]: + await check_permissions(action=ActionType.READ, job_db=job_db, job_ids=[job_id]) + try: + status = await job_logging_db.get_records(job_id) + except JobNotFound as e: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, detail="Job not found" + ) from e + return {job_id: status} + + +# TODO: To remove? +@router.get("/status/history") +async def get_job_status_history_bulk( + job_ids: Annotated[list[int], Query()], + job_logging_db: JobLoggingDB, + job_db: JobDB, + check_permissions: CheckWMSPolicyCallable, +) -> dict[int, list[JobStatusReturn]]: + await check_permissions(action=ActionType.READ, job_db=job_db, job_ids=job_ids) + result = await asyncio.gather( + *(job_logging_db.get_records(job_id) for job_id in job_ids) + ) + return {job_id: status for job_id, status in zip(job_ids, result)} + + +# TODO: To remove? +@router.get("/status") +async def get_job_status_bulk( + job_ids: Annotated[list[int], Query()], + job_db: JobDB, + check_permissions: CheckWMSPolicyCallable, +) -> dict[int, LimitedJobStatusReturn]: + await check_permissions(action=ActionType.READ, job_db=job_db, job_ids=job_ids) + try: + result = await asyncio.gather( + *(job_db.get_job_status(job_id) for job_id in job_ids) + ) + return {job_id: status for job_id, status in zip(job_ids, result)} + except JobNotFound as e: + raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(e)) from e diff --git a/diracx-routers/src/diracx/routers/jobs/status.py b/diracx-routers/src/diracx/routers/jobs/status.py new file mode 100644 index 000000000..44f6f5a26 --- /dev/null +++ b/diracx-routers/src/diracx/routers/jobs/status.py @@ -0,0 +1,266 @@ +from __future__ import annotations + +import logging +from datetime import datetime, timezone +from http import HTTPStatus +from typing import Annotated, Any + +from fastapi import BackgroundTasks, Body, HTTPException, Query +from sqlalchemy.exc import NoResultFound + +from diracx.core.exceptions import JobNotFound +from diracx.core.models import ( + JobStatusUpdate, + SetJobStatusReturn, +) +from diracx.core.properties import JOB_ADMINISTRATOR, NORMAL_USER +from diracx.db.sql.utils.job_status import ( + remove_jobs, + set_job_status, + set_job_statuses, +) + +from ..auth import has_properties +from ..dependencies import ( + Config, + JobDB, + JobLoggingDB, + SandboxMetadataDB, + TaskQueueDB, +) +from ..fastapi_classes import DiracxRouter +from .access_policies import ActionType, CheckWMSPolicyCallable + +logger = logging.getLogger(__name__) + +router = DiracxRouter(dependencies=[has_properties(NORMAL_USER | JOB_ADMINISTRATOR)]) + + +# TODO: Change to DELETE +@router.delete("/") +async def remove_bulk_jobs( + job_ids: Annotated[list[int], Query()], + config: Config, + job_db: JobDB, + job_logging_db: JobLoggingDB, + sandbox_metadata_db: SandboxMetadataDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, + check_permissions: CheckWMSPolicyCallable, +): + """Fully remove a list of jobs from the WMS databases. + + WARNING: This endpoint has been implemented for the compatibility with the legacy DIRAC WMS + and the JobCleaningAgent. However, once this agent is ported to diracx, this endpoint should + be removed, and the delete endpoint should be used instead for any other purpose. + """ + await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=job_ids) + # TODO: Remove once legacy DIRAC no longer needs this + + # TODO: implement job policy + # Some tests have already been written in the test_job_manager, + # but they need to be uncommented and are not complete + + await remove_jobs( + job_ids, + config, + job_db, + job_logging_db, + sandbox_metadata_db, + task_queue_db, + background_task, + ) + + return job_ids + + +@router.patch("/{job_id}/status") +async def set_single_job_status( + job_id: int, + status: Annotated[dict[datetime, JobStatusUpdate], Body()], + config: Config, + job_db: JobDB, + job_logging_db: JobLoggingDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, + check_permissions: CheckWMSPolicyCallable, + force: bool = False, +) -> dict[int, SetJobStatusReturn]: + await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=[job_id]) + # check that the datetime contains timezone info + for dt in status: + if dt.tzinfo is None: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=f"Timestamp {dt} is not timezone aware", + ) + + try: + latest_status = await set_job_status( + job_id, + status, + config, + job_db, + job_logging_db, + task_queue_db, + background_task, + force, + ) + except JobNotFound as e: + raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(e)) from e + return {job_id: latest_status} + + +@router.patch("/status") +async def set_job_status_bulk( + job_update: dict[int, dict[datetime, JobStatusUpdate]], + config: Config, + job_db: JobDB, + job_logging_db: JobLoggingDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, + check_permissions: CheckWMSPolicyCallable, + force: bool = False, +) -> dict[int, SetJobStatusReturn]: + await check_permissions( + action=ActionType.MANAGE, job_db=job_db, job_ids=list(job_update) + ) + # check that the datetime contains timezone info + for job_id, status in job_update.items(): + for dt in status: + if dt.tzinfo is None: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=f"Timestamp {dt} is not timezone aware for job {job_id}", + ) + + return await set_job_statuses( + job_update, + config, + job_db, + job_logging_db, + task_queue_db, + background_task, + force=force, + ) + + +# TODO: Add a parameter to replace "resetJob" +@router.post("/reschedule") +async def reschedule_bulk_jobs( + job_ids: Annotated[list[int], Query()], + job_db: JobDB, + job_logging_db: JobLoggingDB, + check_permissions: CheckWMSPolicyCallable, +): + await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=job_ids) + rescheduled_jobs = [] + # TODO: Joblist Policy: + # validJobList, invalidJobList, nonauthJobList, ownerJobList = self.jobPolicy.evaluateJobRights( + # jobList, RIGHT_RESCHEDULE + # ) + # For the moment all jobs are valid: + valid_job_list = job_ids + for job_id in valid_job_list: + # TODO: delete job in TaskQueueDB + # self.taskQueueDB.deleteJob(jobID) + result = await job_db.rescheduleJob(job_id) + try: + res_status = await job_db.get_job_status(job_id) + except NoResultFound as e: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, detail=f"Job {job_id} not found" + ) from e + + initial_status = res_status.Status + initial_minor_status = res_status.MinorStatus + + await job_logging_db.insert_record( + int(job_id), + initial_status, + initial_minor_status, + "Unknown", + datetime.now(timezone.utc), + "JobManager", + ) + if result: + rescheduled_jobs.append(job_id) + # To uncomment when jobPolicy is setup: + # if invalid_job_list or non_auth_job_list: + # logging.error("Some jobs failed to reschedule") + # if invalid_job_list: + # logging.info(f"Invalid jobs: {invalid_job_list}") + # if non_auth_job_list: + # logging.info(f"Non authorized jobs: {nonauthJobList}") + + # TODO: send jobs to OtimizationMind + # self.__sendJobsToOptimizationMind(validJobList) + return rescheduled_jobs + + +# TODO: Add a parameter to replace "resetJob" +@router.post("/{job_id}/reschedule") +async def reschedule_single_job( + job_id: int, + job_db: JobDB, + check_permissions: CheckWMSPolicyCallable, +): + await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=[job_id]) + try: + result = await job_db.rescheduleJob(job_id) + except ValueError as e: + raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(e)) from e + return result + + +@router.delete("/{job_id}") +async def remove_single_job( + job_id: int, + config: Config, + job_db: JobDB, + job_logging_db: JobLoggingDB, + sandbox_metadata_db: SandboxMetadataDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, + check_permissions: CheckWMSPolicyCallable, +): + """Fully remove a job from the WMS databases. + + WARNING: This endpoint has been implemented for the compatibility with the legacy DIRAC WMS + and the JobCleaningAgent. However, once this agent is ported to diracx, this endpoint should + be removed, and the delete endpoint should be used instead. + """ + await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=[job_id]) + # TODO: Remove once legacy DIRAC no longer needs this + + # TODO: implement job policy + + await remove_jobs( + [job_id], + config, + job_db, + job_logging_db, + sandbox_metadata_db, + task_queue_db, + background_task, + ) + + return f"Job {job_id} has been successfully removed" + + +@router.patch("/{job_id}/") +async def set_single_job_properties( + job_id: int, + job_properties: Annotated[dict[str, Any], Body()], + job_db: JobDB, + check_permissions: CheckWMSPolicyCallable, + update_timestamp: bool = False, +): + """Update the given job properties (MinorStatus, ApplicationStatus, etc).""" + await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=[job_id]) + + rowcount = await job_db.set_properties( + {job_id: job_properties}, update_timestamp=update_timestamp + ) + if not rowcount: + raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Job not found") diff --git a/diracx-routers/src/diracx/routers/jobs/submission.py b/diracx-routers/src/diracx/routers/jobs/submission.py new file mode 100644 index 000000000..c9d03c3b4 --- /dev/null +++ b/diracx-routers/src/diracx/routers/jobs/submission.py @@ -0,0 +1,204 @@ +from __future__ import annotations + +import asyncio +import logging +from datetime import datetime, timezone +from http import HTTPStatus +from typing import Annotated + +from fastapi import Body, Depends, HTTPException +from pydantic import BaseModel +from typing_extensions import TypedDict + +from diracx.core.models import ( + JobStatus, +) +from diracx.core.properties import JOB_ADMINISTRATOR, NORMAL_USER + +from ..auth import has_properties +from ..dependencies import ( + JobDB, + JobLoggingDB, +) +from ..fastapi_classes import DiracxRouter +from ..utils.users import AuthorizedUserInfo, verify_dirac_access_token +from .access_policies import ActionType, CheckWMSPolicyCallable + +logger = logging.getLogger(__name__) + +router = DiracxRouter(dependencies=[has_properties(NORMAL_USER | JOB_ADMINISTRATOR)]) + + +class InsertedJob(TypedDict): + JobID: int + Status: str + MinorStatus: str + TimeStamp: datetime + + +class JobID(BaseModel): + job_id: int + + +MAX_PARAMETRIC_JOBS = 20 + +EXAMPLE_JDLS = { + "Simple JDL": { + "value": [ + """Arguments = "jobDescription.xml -o LogLevel=INFO"; +Executable = "dirac-jobexec"; +JobGroup = jobGroup; +JobName = jobName; +JobType = User; +LogLevel = INFO; +OutputSandbox = + { + Script1_CodeOutput.log, + std.err, + std.out + }; +Priority = 1; +Site = ANY; +StdError = std.err; +StdOutput = std.out;""" + ] + }, + "Parametric JDL": { + "value": ["""Arguments = "jobDescription.xml -o LogLevel=INFO"""] + }, +} + + +@router.post("/") +async def submit_bulk_jobs( + job_definitions: Annotated[list[str], Body(openapi_examples=EXAMPLE_JDLS)], + job_db: JobDB, + job_logging_db: JobLoggingDB, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + check_permissions: CheckWMSPolicyCallable, +) -> list[InsertedJob]: + await check_permissions(action=ActionType.CREATE, job_db=job_db) + + from DIRAC.Core.Utilities.ClassAd.ClassAdLight import ClassAd + from DIRAC.Core.Utilities.ReturnValues import returnValueOrRaise + from DIRAC.WorkloadManagementSystem.Service.JobPolicy import RIGHT_SUBMIT, JobPolicy + from DIRAC.WorkloadManagementSystem.Utilities.ParametricJob import ( + generateParametricJobs, + getParameterVectorLength, + ) + + class DiracxJobPolicy(JobPolicy): + def __init__(self, user_info: AuthorizedUserInfo, allInfo: bool = True): + self.userName = user_info.preferred_username + self.userGroup = user_info.dirac_group + self.userProperties = user_info.properties + self.jobDB = None + self.allInfo = allInfo + self._permissions: dict[str, bool] = {} + self._getUserJobPolicy() + + # Check job submission permission + policyDict = returnValueOrRaise(DiracxJobPolicy(user_info).getJobPolicy()) + if not policyDict[RIGHT_SUBMIT]: + raise HTTPException(HTTPStatus.FORBIDDEN, "You are not allowed to submit jobs") + + # TODO: that needs to go in the legacy adapter (Does it ? Because bulk submission is not supported there) + for i in range(len(job_definitions)): + job_definition = job_definitions[i].strip() + if not (job_definition.startswith("[") and job_definition.endswith("]")): + job_definition = f"[{job_definition}]" + job_definitions[i] = job_definition + + if len(job_definitions) == 1: + # Check if the job is a parametric one + jobClassAd = ClassAd(job_definitions[0]) + result = getParameterVectorLength(jobClassAd) + if not result["OK"]: + print("Issue with getParameterVectorLength", result["Message"]) + return result + nJobs = result["Value"] + parametricJob = False + if nJobs is not None and nJobs > 0: + # if we are here, then jobDesc was the description of a parametric job. So we start unpacking + parametricJob = True + result = generateParametricJobs(jobClassAd) + if not result["OK"]: + return result + jobDescList = result["Value"] + else: + # if we are here, then jobDesc was the description of a single job. + jobDescList = job_definitions + else: + # if we are here, then jobDesc is a list of JDLs + # we need to check that none of them is a parametric + for job_definition in job_definitions: + res = getParameterVectorLength(ClassAd(job_definition)) + if not res["OK"]: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, detail=res["Message"] + ) + if res["Value"]: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail="You cannot submit parametric jobs in a bulk fashion", + ) + + jobDescList = job_definitions + parametricJob = True + + # TODO: make the max number of jobs configurable in the CS + if len(jobDescList) > MAX_PARAMETRIC_JOBS: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=f"Normal user cannot submit more than {MAX_PARAMETRIC_JOBS} jobs at once", + ) + + result = [] + + if parametricJob: + initialStatus = JobStatus.SUBMITTING + initialMinorStatus = "Bulk transaction confirmation" + else: + initialStatus = JobStatus.RECEIVED + initialMinorStatus = "Job accepted" + + for ( + jobDescription + ) in ( + jobDescList + ): # jobDescList because there might be a list generated by a parametric job + res = await job_db.insert( + jobDescription, + user_info.preferred_username, + user_info.dirac_group, + initialStatus, + initialMinorStatus, + user_info.vo, + ) + + job_id = res["JobID"] + logging.debug( + f'Job added to the JobDB", "{job_id} for {user_info.preferred_username}/{user_info.dirac_group}' + ) + + await job_logging_db.insert_record( + int(job_id), + initialStatus, + initialMinorStatus, + "Unknown", + datetime.now(timezone.utc), + "JobManager", + ) + + result.append(res) + + return result + + # TODO: is this needed ? + # if not parametricJob: + # self.__sendJobsToOptimizationMind(jobIDList) + # return result + + return await asyncio.gather( + *(job_db.insert(j.owner, j.group, j.vo) for j in job_definitions) + ) diff --git a/diracx-routers/tests/jobs/test_wms_access_policy.py b/diracx-routers/tests/jobs/test_wms_access_policy.py index 40e05d29b..0746317ca 100644 --- a/diracx-routers/tests/jobs/test_wms_access_policy.py +++ b/diracx-routers/tests/jobs/test_wms_access_policy.py @@ -4,7 +4,7 @@ from fastapi import HTTPException, status from diracx.core.properties import JOB_ADMINISTRATOR, NORMAL_USER -from diracx.routers.job_manager.access_policies import ( +from diracx.routers.jobs.access_policies import ( ActionType, SandboxAccessPolicy, WMSAccessPolicy, diff --git a/diracx-testing/src/diracx/testing/__init__.py b/diracx-testing/src/diracx/testing/__init__.py index 373e33104..1fa569afe 100644 --- a/diracx-testing/src/diracx/testing/__init__.py +++ b/diracx-testing/src/diracx/testing/__init__.py @@ -20,7 +20,7 @@ if TYPE_CHECKING: from diracx.core.settings import DevelopmentSettings - from diracx.routers.job_manager.sandboxes import SandboxStoreSettings + from diracx.routers.jobs.sandboxes import SandboxStoreSettings from diracx.routers.utils.users import AuthorizedUserInfo, AuthSettings @@ -123,7 +123,7 @@ def aio_moto(worker_id): @pytest.fixture(scope="session") def test_sandbox_settings(aio_moto) -> SandboxStoreSettings: - from diracx.routers.job_manager.sandboxes import SandboxStoreSettings + from diracx.routers.jobs.sandboxes import SandboxStoreSettings yield SandboxStoreSettings( bucket_name="sandboxes", diff --git a/docs/SERVICES.md b/docs/SERVICES.md index 0d6d78d4f..d6cd19b33 100644 --- a/docs/SERVICES.md +++ b/docs/SERVICES.md @@ -149,8 +149,8 @@ The various policies are defined in `diracx-routers/pyproject.toml`: ```toml [project.entry-points."diracx.access_policies"] -WMSAccessPolicy = "diracx.routers.job_manager.access_policies:WMSAccessPolicy" -SandboxAccessPolicy = "diracx.routers.job_manager.access_policies:SandboxAccessPolicy" +WMSAccessPolicy = "diracx.routers.jobs.access_policies:WMSAccessPolicy" +SandboxAccessPolicy = "diracx.routers.jobs.access_policies:SandboxAccessPolicy" ``` Each route must have a policy as an argument and call it: