diff --git a/CHANGELOG.rst b/CHANGELOG.rst index efbd22d..95177a9 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,6 +2,17 @@ Changelog ========== ++++++++++ +v1.8.0 (28/08/2024) ++++++++++ + +**Added** + +- Particle count per defocus value endpoint (:code:`/dataCollections/{collectionId}/ctf`) +- Particle count per resolution bin endpoint (:code:`/dataCollections/{collectionId}/particleCountPerResolution`) +- Custom model upload endpoint +- Sample handling redirect endpoint + +++++++++ v1.7.0 (20/06/2024) +++++++++ diff --git a/Dockerfile b/Dockerfile index 0ed3750..c62f4ce 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,7 +3,7 @@ # The devcontainer should use the build target and run as root with podman # or docker with user namespaces. # -FROM docker.io/library/python:3.12.1-slim-bullseye as build +FROM docker.io/library/python:3.12.4-slim-bookworm as build # Add any system dependencies for the developer/build environment here RUN apt-get update && apt-get upgrade -y && \ diff --git a/config.json b/config.json index b4b4c8a..7716e80 100644 --- a/config.json +++ b/config.json @@ -36,7 +36,8 @@ "contact_email": "admin@facility.co.uk", "smtp_port": 8025, "smtp_server": "mail.service.com", - "active_session_cutoff": 5 + "active_session_cutoff": 5, + "sample_handling_url": "https://ebic-sample-handling.diamond.ac.uk" }, "enable_cors": false } diff --git a/pyproject.toml b/pyproject.toml index 956dd76..a46c78c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,16 +14,17 @@ classifiers = [ ] description = "PATO's backend" dependencies = [ + "python-multipart~=0.0.9", "pika~=1.3.2", - "SQLAlchemy~=2.0.16", - "fastapi~=0.109.0", - "uvicorn[standard]~=0.22.0", - "requests~=2.31.0", + "SQLAlchemy~=2.0.31", + "fastapi~=0.111.0", + "uvicorn[standard]~=0.30.1", + "requests~=2.32.3", "mysqlclient~=2.1.1", "mysql-connector-python~=8.2.0", "pydantic~=2.5.3", "types-requests", - "lims-utils~=0.1.2" + "lims-utils~=0.2.2" ] dynamic = ["version"] license.file = "LICENSE" @@ -111,7 +112,7 @@ setenv = [tool.ruff] src = ["src", "tests"] line-length = 120 -select = [ +lint.select = [ "C4", # flake8-comprehensions - https://beta.ruff.rs/docs/rules/#flake8-comprehensions-c4 "E", # pycodestyle errors - https://beta.ruff.rs/docs/rules/#error-e "F", # pyflakes rules - https://beta.ruff.rs/docs/rules/#pyflakes-f diff --git a/src/pato/crud/collections.py b/src/pato/crud/collections.py index 8a0476b..ea6f3b1 100644 --- a/src/pato/crud/collections.py +++ b/src/pato/crud/collections.py @@ -12,21 +12,28 @@ DataCollectionGroup, MotionCorrection, Movie, + ParticlePicker, ProcessingJob, ProcessingJobParameter, Proposal, TiltImageAlignment, Tomogram, ) -from sqlalchemy import Column, and_, case, extract, func, select +from sqlalchemy import Column, ColumnElement, Select, and_, case, extract, func, select from ..models.parameters import ( SPAReprocessingParameters, TomogramReprocessingParameters, ) -from ..models.response import FullMovie, ProcessingJobResponse, TomogramFullResponse +from ..models.response import ( + DataPoint, + FullMovie, + ItemList, + ProcessingJobResponse, + TomogramFullResponse, +) from ..utils.database import db, paginate -from ..utils.generic import check_session_active +from ..utils.generic import check_session_active, parse_count from ..utils.pika import pika_publisher _job_status_description = case( @@ -292,3 +299,69 @@ def get_processing_jobs( ) return paginate(query, limit, page, slow_count=False) + + +def _with_ctf_joins(query: Select, collectionId: int): + return ( + query.select_from(ProcessingJob) + .filter(ProcessingJob.dataCollectionId == collectionId) + .join(AutoProcProgram) + .join(MotionCorrection) + .join(CTF, CTF.motionCorrectionId == MotionCorrection.motionCorrectionId) + .join( + ParticlePicker, + ParticlePicker.firstMotionCorrectionId + == MotionCorrection.motionCorrectionId, + ) + ) + + +def get_ctf(collectionId: int): + data = db.session.execute( + _with_ctf_joins( + select( + CTF.estimatedDefocus.label("x"), + ParticlePicker.numberOfParticles.label("y"), + ), + collectionId, + ).group_by(MotionCorrection.imageNumber) + ).all() + + return ItemList[DataPoint](items=data) + + +def _histogram_sum_bin(condition: ColumnElement): + return func.coalesce( + func.sum( + case( + ( + condition, + ParticlePicker.numberOfParticles, + ), + ) + ), + 0, + ) + + +def get_particle_count_per_resolution(collectionId: int) -> ItemList[DataPoint]: + data = parse_count( + _with_ctf_joins( + select( + _histogram_sum_bin(CTF.estimatedResolution < 1).label("<1"), + *[ + _histogram_sum_bin( + and_( + CTF.estimatedResolution >= i, + CTF.estimatedResolution < i + 1, + ) + ).label(str(i)) + for i in range(1, 8) + ], + _histogram_sum_bin(CTF.estimatedResolution >= 9).label(">9"), + ), + collectionId, + ) + ) + + return data diff --git a/src/pato/crud/generic.py b/src/pato/crud/generic.py index 6264c45..15697ca 100644 --- a/src/pato/crud/generic.py +++ b/src/pato/crud/generic.py @@ -1,6 +1,5 @@ from typing import Literal -from fastapi import HTTPException, status from lims_utils.tables import ( CTF, AutoProcProgram, @@ -11,11 +10,11 @@ ProcessingJob, RelativeIceThickness, ) -from sqlalchemy import Column, and_, case, literal_column, select +from sqlalchemy import Column, and_, case, select from sqlalchemy import func as f from ..models.response import DataPoint, ItemList -from ..utils.database import db +from ..utils.generic import parse_count def _generate_buckets(bin: float, minimum: float, column: Column): @@ -33,22 +32,9 @@ def _generate_buckets(bin: float, minimum: float, column: Column): ) ) ).label(str(bin * i + minimum)) - for i in range(0, 10) + for i in range(0, 8) ], - f.count(case((column >= bin * 10 + minimum, 1))).label(f">{bin*10+minimum}"), - ) - - -def _parse_count(query): - data = db.session.execute(query.order_by(literal_column("1"))).mappings().one() - if not any(value != 0 for value in data.values()): - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="No items found", - ) - - return ItemList[DataPoint]( - items=[{"x": key, "y": value} for (key, value) in dict(data).items()] + f.count(case((column >= bin * 8 + minimum, 1))).label(f">{bin*8+minimum}"), ) @@ -73,7 +59,7 @@ def get_ice_histogram( .join(RelativeIceThickness) ) - return _parse_count(query) + return parse_count(query) def get_motion( @@ -94,7 +80,7 @@ def get_motion( .join(MotionCorrection) ) - return _parse_count(query) + return parse_count(query) def get_resolution( @@ -116,7 +102,7 @@ def get_resolution( .join(CTF) ) - return _parse_count(query) + return parse_count(query) def get_particle_count( @@ -137,4 +123,4 @@ def get_particle_count( .join(ParticlePicker) ) - return _parse_count(query) + return parse_count(query) diff --git a/src/pato/crud/sessions.py b/src/pato/crud/sessions.py index 6705be2..0b99576 100644 --- a/src/pato/crud/sessions.py +++ b/src/pato/crud/sessions.py @@ -1,12 +1,14 @@ import pathlib +import shutil from datetime import datetime from typing import Optional -from fastapi import HTTPException, status +from fastapi import HTTPException, UploadFile, status from lims_utils.auth import GenericUser +from lims_utils.logging import app_logger from lims_utils.models import Paged from lims_utils.tables import BLSession, DataCollection, DataCollectionGroup, Proposal -from sqlalchemy import Label, and_, extract, func, insert, or_, select +from sqlalchemy import Label, and_, func, insert, or_, select from ..models.parameters import DataCollectionCreationParameters from ..models.response import SessionAllowsReprocessing, SessionResponse @@ -15,17 +17,19 @@ from ..utils.database import db, paginate, unravel from ..utils.generic import ProposalReference, check_session_active, parse_proposal +HDF5_FILE_SIGNATURE = b"\x89\x48\x44\x46\x0d\x0a\x1a\x0a" -def _validate_session_active(proposalReference: ProposalReference): + +def _validate_session_active(proposal_reference: ProposalReference): """Check if session is active and return session ID""" session = db.session.scalar( select(BLSession) .select_from(Proposal) .join(BLSession) .filter( - BLSession.visit_number == proposalReference.visit_number, - Proposal.proposalNumber == proposalReference.number, - Proposal.proposalCode == proposalReference.code, + BLSession.visit_number == proposal_reference.visit_number, + Proposal.proposalNumber == proposal_reference.number, + Proposal.proposalCode == proposal_reference.code, ) ) @@ -35,7 +39,20 @@ def _validate_session_active(proposalReference: ProposalReference): detail="Reprocessing cannot be fired on an inactive session", ) - return session.sessionId + assert session is not None + + return session + + +def _get_folder_and_visit(prop_ref: ProposalReference): + session = _validate_session_active(prop_ref) + year = session.startDate.year + + # TODO: Make the path string pattern configurable? + return ( + f"/dls/{session.beamLineName}/data/{year}/{prop_ref.code}{prop_ref.number}-{prop_ref.visit_number}", + session, + ) def _check_raw_files_exist(file_directory: str, glob_path: str): @@ -153,26 +170,8 @@ def get_session(proposalReference: ProposalReference): def create_data_collection( proposalReference: ProposalReference, params: DataCollectionCreationParameters ): - session_id = _validate_session_active(proposalReference) - - session = db.session.execute( - select( - BLSession.beamLineName, - BLSession.endDate, - extract("year", BLSession.startDate).label("year"), - func.concat( - Proposal.proposalCode, - Proposal.proposalNumber, - "-", - BLSession.visit_number, - ).label("name"), - ) - .filter(BLSession.sessionId == session_id) - .join(Proposal, Proposal.proposalId == BLSession.proposalId) - ).one() - - # TODO: Make the path string pattern configurable? - file_directory = f"/dls/{session.beamLineName}/data/{session.year}/{session.name}/{params.fileDirectory}/" + session_folder, session = _get_folder_and_visit(proposalReference) + file_directory = f"{session_folder}/{params.fileDirectory}/" glob_path = f"GridSquare_*/Data/*{params.fileExtension}" _check_raw_files_exist(file_directory, glob_path) @@ -182,7 +181,7 @@ def create_data_collection( .filter( DataCollection.imageDirectory == file_directory, DataCollection.fileTemplate == glob_path, - DataCollectionGroup.sessionId == session_id, + DataCollectionGroup.sessionId == session.sessionId, ) .join(DataCollectionGroup) .limit(1) @@ -199,7 +198,7 @@ def create_data_collection( DataCollectionGroup.dataCollectionGroupId ), { - "sessionId": session_id, + "sessionId": session.sessionId, "comments": "Created by PATo", "experimentType": "EM", }, @@ -237,3 +236,30 @@ def check_reprocessing_enabled(proposalReference: ProposalReference): return SessionAllowsReprocessing( allowReprocessing=((bool(Config.mq.user)) and check_session_active(end_date)), ) + + +def upload_processing_model(file: UploadFile, proposal_reference: ProposalReference): + file_path = ( + f"{_get_folder_and_visit(proposal_reference)[0]}/processing/{file.filename}" + ) + file_signature = file.file.read(8) + file.file.seek(0) + + if file_signature != HDF5_FILE_SIGNATURE: + raise HTTPException( + detail="Invalid file type (must be HDF5 file)", + status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, + ) + + try: + with open(file_path, "wb") as f: + shutil.copyfileobj(file.file, f) + except OSError as e: + file.file.close() + app_logger.error(f"Failed to upload {file.filename}: {e}") + raise HTTPException( + detail="Failed to upload file", + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + file.file.close() diff --git a/src/pato/routes/collections.py b/src/pato/routes/collections.py index ef9a1a1..fb84986 100644 --- a/src/pato/routes/collections.py +++ b/src/pato/routes/collections.py @@ -9,7 +9,9 @@ TomogramReprocessingParameters, ) from ..models.response import ( + DataPoint, FullMovie, + ItemList, ProcessingJobResponse, ReprocessingResponse, TomogramFullResponse, @@ -135,3 +137,19 @@ def get_particle_count( minimum=minimum, dataBin=dataBin, ) + + +@router.get( + "/{collectionId}/ctf", + description="Get defocus/particle count information", + response_model=ItemList[DataPoint], +) +def get_ctf(collectionId: int = Depends(auth)): + return crud.get_ctf(collectionId) + + +@router.get( + "/{collectionId}/particleCountPerResolution", response_model=ItemList[DataPoint] +) +def get_particle_count_per_resolution(collectionId: int = Depends(auth)): + return crud.get_particle_count_per_resolution(collectionId) diff --git a/src/pato/routes/proposals.py b/src/pato/routes/proposals.py index 4a7e174..dd35243 100644 --- a/src/pato/routes/proposals.py +++ b/src/pato/routes/proposals.py @@ -1,4 +1,5 @@ -from fastapi import APIRouter, Body, Depends, status +from fastapi import APIRouter, Body, Depends, UploadFile, status +from fastapi.responses import RedirectResponse from lims_utils.models import Paged, pagination from ..auth import Permissions, User @@ -11,6 +12,7 @@ SessionAllowsReprocessing, SessionResponse, ) +from ..utils.config import Config router = APIRouter( tags=["Proposals"], @@ -56,3 +58,23 @@ def create_data_collection( def check_reprocessing_enabled(proposalReference=Depends(Permissions.session)): """Check if reprocessing is enabled for session""" return sessions_crud.check_reprocessing_enabled(proposalReference) + + +@router.get( + "/{proposalReference}/sessions/{visitNumber}/sampleHandling", + response_class=RedirectResponse, +) +def redirect_to_sample_handling(proposalReference: str, visitNumber: int): + """Sample handling redirect""" + suffix = f"/proposals/{proposalReference}/sessions/{visitNumber}" + return Config.facility.sample_handling_url + suffix + + +@router.post("/{proposalReference}/sessions/{visitNumber}/processingModel") +def upload_processing_model( + file: UploadFile, proposalReference=Depends(Permissions.session) +): + """Upload custom processing model""" + return sessions_crud.upload_processing_model( + file=file, proposal_reference=proposalReference + ) diff --git a/src/pato/utils/config.py b/src/pato/utils/config.py index 34d639b..4a96835 100644 --- a/src/pato/utils/config.py +++ b/src/pato/utils/config.py @@ -21,6 +21,7 @@ class Facility: smtp_server: str smtp_port: int = 587 active_session_cutoff: int = 5 + sample_handling_url: str = "https://ebic-sample-handling.diamond.ac.uk" @dataclass diff --git a/src/pato/utils/generic.py b/src/pato/utils/generic.py index 0c6be9b..2b0c811 100644 --- a/src/pato/utils/generic.py +++ b/src/pato/utils/generic.py @@ -3,12 +3,14 @@ from os.path import isfile from typing import Literal, Optional -from fastapi import HTTPException +from fastapi import HTTPException, status from lims_utils.logging import app_logger from pydantic import BaseModel +from sqlalchemy import literal_column -from ..models.response import DataPoint +from ..models.response import DataPoint, ItemList from ..utils.config import Config +from .database import db # TODO: use 'type' when supported by Mypy MovieType = Literal["denoised", "segmented"] | None @@ -108,3 +110,17 @@ def parse_proposal(proposalReference: str, visit_number: int | None = None): number=int(number), visit_number=visit_number, ) + + +def parse_count(query): + """Get mappings from query, return keys/values in graph format""" + data = db.session.execute(query.order_by(literal_column("1"))).mappings().one() + if not any(value != 0 for value in data.values()): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="No items found", + ) + + return ItemList[DataPoint]( + items=[{"x": key, "y": value} for (key, value) in dict(data).items()] + ) diff --git a/tests/collections/test_ctf_defocus.py b/tests/collections/test_ctf_defocus.py new file mode 100644 index 0000000..c885eaf --- /dev/null +++ b/tests/collections/test_ctf_defocus.py @@ -0,0 +1,9 @@ +def test_get(mock_permissions, client): + """Get defocus/particle count data for data collection""" + resp = client.get("/dataCollections/6017412/ctf") + assert resp.status_code == 200 + + items = resp.json()["items"] + assert items[0]["x"] == 5 + assert items[0]["y"] == 10 + assert len(items) == 1 diff --git a/tests/collections/test_particle_count_res.py b/tests/collections/test_particle_count_res.py new file mode 100644 index 0000000..294a4bd --- /dev/null +++ b/tests/collections/test_particle_count_res.py @@ -0,0 +1,12 @@ +def test_get(mock_permissions, client): + """Get particle count sum per resolution for data collection""" + resp = client.get("/dataCollections/6017412/particleCountPerResolution") + assert resp.status_code == 200 + assert resp.json()["items"][5]["x"] == 5 + assert resp.json()["items"][5]["y"] == 90 + +def test_get_no_items(mock_permissions, client): + """Should return 404 if no resolution/particle count data is available""" + resp = client.get("/dataCollections/6017406/particleCountPerResolution") + assert resp.status_code == 404 + diff --git a/tests/proposals/test_sh_redirect.py b/tests/proposals/test_sh_redirect.py new file mode 100644 index 0000000..65d78cc --- /dev/null +++ b/tests/proposals/test_sh_redirect.py @@ -0,0 +1,4 @@ +def test_redirect(mock_permissions, client): + """Get sample handling redirect""" + resp = client.get("/proposals/cm14451/sessions/1/sampleHandling", follow_redirects=False) + assert resp.headers["location"] == "https://ebic-sample-handling.diamond.ac.uk/proposals/cm14451/sessions/1" diff --git a/tests/sessions/test_create_dc.py b/tests/sessions/test_create_dc.py index 356a4a4..290beac 100644 --- a/tests/sessions/test_create_dc.py +++ b/tests/sessions/test_create_dc.py @@ -1,8 +1,11 @@ +from datetime import datetime from unittest.mock import patch +from lims_utils.tables import BLSession + def active_mock(_): - return 27464088 + return BLSession(startDate=datetime(year=2022, month=1, day=1), sessionId=27464088, beamLineName="m12") def raw_check_mock(_, _1): diff --git a/tests/sessions/test_upload_model.py b/tests/sessions/test_upload_model.py new file mode 100644 index 0000000..fd407da --- /dev/null +++ b/tests/sessions/test_upload_model.py @@ -0,0 +1,42 @@ +from datetime import datetime +from unittest.mock import mock_open, patch + +from lims_utils.tables import BLSession + +VALID_FILE = b"\x89\x48\x44\x46\x0d\x0a\x1a\x0a\x01\x02\x03" + +def active_mock(_): + return BLSession(startDate=datetime(year=2022, month=1, day=1), sessionId=27464088, beamLineName="m12") + +@patch("pato.crud.sessions._validate_session_active", new=active_mock) +@patch('builtins.open', new_callable=mock_open()) +def test_post(_, mock_permissions, client): + """Should write file successfully if first 8 bytes match expected signature""" + resp = client.post( + "/proposals/cm31111/sessions/5/processingModel", + files={"file": ("h5.h5", VALID_FILE, "application/octet-stream")}, + ) + + assert resp.status_code == 200 + +@patch("pato.crud.sessions._validate_session_active", new=active_mock) +@patch('builtins.open', new_callable=mock_open()) +def test_invalid_file_signature(_, mock_permissions, client): + """Should raise exception if file signature doesn't match HDF5 file signature""" + resp = client.post( + "/proposals/cm31111/sessions/5/processingModel", + files={"file": ("not-h5.h5", b"\x01\x02", "application/octet-stream")}, + ) + + assert resp.status_code == 415 + +@patch("pato.crud.sessions._validate_session_active", new=active_mock) +@patch('builtins.open', side_effect=OSError("Write Error")) +def test_write_error(_, mock_permissions, client): + """Should return 500 if there was an error writing the file""" + resp = client.post( + "/proposals/cm31111/sessions/5/processingModel", + files={"file": ("not-h5.h5", VALID_FILE, "application/octet-stream")}, + ) + + assert resp.status_code == 500