Skip to content

Commit

Permalink
Cleanup insert_bulk - use taskgroups, and single row inserts for job …
Browse files Browse the repository at this point in the history
…ID generation.
  • Loading branch information
ryuwd committed Dec 16, 2024
1 parent 5871b2e commit ad7e5e8
Showing 1 changed file with 112 additions and 135 deletions.
247 changes: 112 additions & 135 deletions diracx-db/src/diracx/db/sql/job/db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from asyncio import TaskGroup
from copy import deepcopy
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -47,29 +48,6 @@ def _get_columns(table, parameters):
return columns


async def get_inserted_job_ids(conn, table, rows):
# TODO: We are assuming contiguous inserts for MySQL. Is that the correct thing? Should we be stricter
# about enforcing that with an explicit transaction handling?
# Retrieve the first inserted ID

if conn.engine.name == "mysql":
# Bulk insert for MySQL
await conn.execute(table.insert(), rows)
start_id = await conn.scalar(select(func.LAST_INSERT_ID()))
return list(range(start_id, start_id + len(rows)))
elif conn.engine.name == "sqlite":
# Bulk insert for SQLite
if conn.engine.dialect.server_version_info >= (3, 35, 0):
results = await conn.execute(table.insert().returning(table.c.JobID), rows)
return [row[0] for row in results]
else:
await conn.execute(table.insert(), rows)
start_id = await conn.scalar("SELECT last_insert_rowid()")
return list(range(start_id, start_id + len(rows)))
else:
raise NotImplementedError("Unsupported database backend")


class JobDB(BaseSQLDB):
metadata = JobDBBase.metadata

Expand Down Expand Up @@ -131,16 +109,6 @@ async def search(
dict(row._mapping) async for row in (await self.conn.stream(stmt))
]

async def _insertNewJDL(self, jdl) -> int:
from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL

stmt = insert(JobJDLs).values(
JDL="", JobRequirements="", OriginalJDL=compressJDL(jdl)
)
result = await self.conn.execute(stmt)
# await self.engine.commit()
return result.lastrowid

async def _insertJob(self, jobData: dict[str, Any]):
stmt = insert(Jobs).values(jobData)
await self.conn.execute(stmt)
Expand Down Expand Up @@ -272,118 +240,127 @@ async def insert_bulk(
original_jdls = []

# generate the jobIDs first
for job in jobs:
original_jdl = deepcopy(job.jdl)
jobManifest = returnValueOrRaise(
checkAndAddOwner(original_jdl, job.owner, job.owner_group)
)
async with TaskGroup() as tg:
for job in jobs:
original_jdl = deepcopy(job.jdl)
jobManifest = returnValueOrRaise(
checkAndAddOwner(original_jdl, job.owner, job.owner_group)
)

# Fix possible lack of brackets
if original_jdl.strip()[0] != "[":
original_jdl = f"[{original_jdl}]"
# Fix possible lack of brackets
if original_jdl.strip()[0] != "[":
original_jdl = f"[{original_jdl}]"

original_jdls.append(
(
original_jdl,
jobManifest,
tg.create_task(
self.conn.execute(
JobJDLs.__table__.insert().values(
JDL="",
JobRequirements="",
OriginalJDL=compressJDL(original_jdl),
)
)
),
)
)

original_jdls.append((original_jdl, jobManifest))
job_ids = []

job_ids = await get_inserted_job_ids(
self.conn,
JobJDLs.__table__,
[
{
"JDL": "",
"JobRequirements": "",
"OriginalJDL": compressJDL(original_jdl),
async with TaskGroup() as tg:
for job, (original_jdl, jobManifest_, job_id_task) in zip(
jobs, original_jdls
):
job_id = job_id_task.result().lastrowid
job_attrs = {
"JobID": job_id,
"LastUpdateTime": datetime.now(tz=timezone.utc),
"SubmissionTime": datetime.now(tz=timezone.utc),
"Owner": job.owner,
"OwnerGroup": job.owner_group,
"VO": job.vo,
}
for original_jdl, _ in original_jdls
],
)

for job_id, job, (original_jdl, jobManifest_) in zip(
job_ids, jobs, original_jdls
):
job_attrs = {
"LastUpdateTime": datetime.now(tz=timezone.utc),
"SubmissionTime": datetime.now(tz=timezone.utc),
"Owner": job.owner,
"OwnerGroup": job.owner_group,
"VO": job.vo,
"JobID": job_id,
}

jobManifest_.setOption("JobID", job_id)

# 2.- Check JDL and Prepare DIRAC JDL
jobJDL = jobManifest_.dumpAsJDL()

# Replace the JobID placeholder if any
if jobJDL.find("%j") != -1:
jobJDL = jobJDL.replace("%j", str(job_id))

class_ad_job = ClassAd(jobJDL)

class_ad_req = ClassAd("[]")
if not class_ad_job.isOK():
# Rollback the entire transaction
raise ValueError(f"Error in JDL syntax for job JDL: {original_jdl}")
# TODO: check if that is actually true
if class_ad_job.lookupAttribute("Parameters"):
raise NotImplementedError("Parameters in the JDL are not supported")

# TODO is this even needed?
class_ad_job.insertAttributeInt("JobID", job_id)

await self.checkAndPrepareJob(
job_id,
class_ad_job,
class_ad_req,
job.owner,
job.owner_group,
job_attrs,
job.vo,
)
jobJDL = createJDLWithInitialStatus(
class_ad_job,
class_ad_req,
self.jdl2DBParameters,
job_attrs,
job.initial_status,
job.initial_minor_status,
modern=True,
)
# assert "JobType" in job_attrs, job_attrs
jobs_to_insert.append(job_attrs)
jdls_to_update.append(
{
"b_JobID": job_id,
"JDL": compressJDL(jobJDL),
}
)
jobManifest_.setOption("JobID", job_id)

if class_ad_job.lookupAttribute("InputData"):
inputData = class_ad_job.getListFromExpression("InputData")
inputdata_to_insert += [
{"JobID": job_id, "LFN": lfn} for lfn in inputData if lfn
]
await self.conn.execute(
JobJDLs.__table__.update().where(
JobJDLs.__table__.c.JobID == bindparam("b_JobID")
),
jdls_to_update,
)
# 2.- Check JDL and Prepare DIRAC JDL
jobJDL = jobManifest_.dumpAsJDL()

plen = len(jobs_to_insert[0].keys())
for item in jobs_to_insert:
assert plen == len(item.keys()), f"{plen} is not == {len(item.keys())}"
# Replace the JobID placeholder if any
if jobJDL.find("%j") != -1:
jobJDL = jobJDL.replace("%j", str(job_id))

await self.conn.execute(
Jobs.__table__.insert(),
jobs_to_insert,
)
class_ad_job = ClassAd(jobJDL)

class_ad_req = ClassAd("[]")
if not class_ad_job.isOK():
# Rollback the entire transaction
raise ValueError(f"Error in JDL syntax for job JDL: {original_jdl}")
# TODO: check if that is actually true
if class_ad_job.lookupAttribute("Parameters"):
raise NotImplementedError("Parameters in the JDL are not supported")

# TODO is this even needed?
class_ad_job.insertAttributeInt("JobID", job_id)

if inputdata_to_insert:
await self.conn.execute(
InputData.__table__.insert(),
inputdata_to_insert,
await self.checkAndPrepareJob(
job_id,
class_ad_job,
class_ad_req,
job.owner,
job.owner_group,
job_attrs,
job.vo,
)
jobJDL = createJDLWithInitialStatus(
class_ad_job,
class_ad_req,
self.jdl2DBParameters,
job_attrs,
job.initial_status,
job.initial_minor_status,
modern=True,
)
# assert "JobType" in job_attrs, job_attrs
job_ids.append(job_id)
jobs_to_insert.append(job_attrs)
jdls_to_update.append(
{
"b_JobID": job_id,
"JDL": compressJDL(jobJDL),
}
)

if class_ad_job.lookupAttribute("InputData"):
inputData = class_ad_job.getListFromExpression("InputData")
inputdata_to_insert += [
{"JobID": job_id, "LFN": lfn} for lfn in inputData if lfn
]

tg.create_task(
self.conn.execute(
JobJDLs.__table__.update().where(
JobJDLs.__table__.c.JobID == bindparam("b_JobID")
),
jdls_to_update,
)
)
tg.create_task(
self.conn.execute(
Jobs.__table__.insert(),
jobs_to_insert,
)
)

if inputdata_to_insert:
tg.create_task(
self.conn.execute(
InputData.__table__.insert(),
inputdata_to_insert,
)
)

return job_ids

Expand Down

0 comments on commit ad7e5e8

Please sign in to comment.