Skip to content

Commit

Permalink
Merge pull request #153 from bertrandrigaud/wellknown
Browse files Browse the repository at this point in the history
add dirac well-known
  • Loading branch information
chrisburr authored Oct 25, 2023
2 parents 5036971 + 89dacc3 commit 4f990dc
Show file tree
Hide file tree
Showing 12 changed files with 626 additions and 8 deletions.
29 changes: 23 additions & 6 deletions src/diracx/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from __future__ import annotations

import asyncio
import json
import os
from datetime import datetime, timedelta
from typing import Optional
from typing import Annotated, Optional

from typer import Option
import typer

from diracx.client.aio import DiracClient
from diracx.client.models import DeviceFlowErrorResponse
Expand All @@ -19,11 +17,30 @@
app = AsyncTyper()


async def installation_metadata():
async with DiracClient() as api:
return await api.well_known.installation_metadata()


def vo_callback(vo: str | None) -> str:
metadata = asyncio.run(installation_metadata())
vos = list(metadata.virtual_organizations)
if not vo:
raise typer.BadParameter(
f"VO must be specified, available options are: {' '.join(vos)}"
)
if vo not in vos:
raise typer.BadParameter(
f"Unknown VO {vo}, available options are: {' '.join(vos)}"
)
return vo


@app.async_command()
async def login(
vo: str,
vo: Annotated[Optional[str], typer.Argument(callback=vo_callback)] = None,
group: Optional[str] = None,
property: Optional[list[str]] = Option(
property: Optional[list[str]] = typer.Option(
None, help="Override the default(s) with one or more properties"
),
):
Expand Down
3 changes: 3 additions & 0 deletions src/diracx/cli/internal/legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typer import Option

from diracx.core.config import Config
from diracx.core.config.schema import Field, SupportInfo

from ..utils import AsyncTyper

Expand All @@ -28,6 +29,7 @@ class VOConfig(BaseModel):
DefaultGroup: str
IdP: IdPConfig
UserSubjects: dict[str, str]
Support: SupportInfo = Field(default_factory=SupportInfo)


class ConversionConfig(BaseModel):
Expand Down Expand Up @@ -105,6 +107,7 @@ def _apply_fixes(raw, conversion_config: Path):
"DefaultGroup": vo_meta.DefaultGroup,
"Users": {},
"Groups": {},
"Support": vo_meta.Support,
}
if "DefaultStorageQuota" in original_registry:
raw["Registry"][vo]["DefaultStorageQuota"] = original_registry[
Expand Down
162 changes: 162 additions & 0 deletions src/diracx/client/aio/operations/_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,14 @@
build_jobs_get_single_job_status_request,
build_jobs_initiate_sandbox_upload_request,
build_jobs_kill_bulk_jobs_request,
build_jobs_reschedule_bulk_jobs_request,
build_jobs_reschedule_single_job_request,
build_jobs_search_request,
build_jobs_set_job_status_bulk_request,
build_jobs_set_single_job_status_request,
build_jobs_submit_bulk_jobs_request,
build_jobs_summary_request,
build_well_known_installation_metadata_request,
build_well_known_openid_configuration_request,
)
from .._vendor import raise_if_not_implemented
Expand Down Expand Up @@ -135,6 +138,57 @@ async def openid_configuration(self, **kwargs: Any) -> Any:

return deserialized

@distributed_trace_async
async def installation_metadata(self, **kwargs: Any) -> _models.Metadata:
"""Installation Metadata.
Installation Metadata.
:return: Metadata
:rtype: ~client.models.Metadata
:raises ~azure.core.exceptions.HttpResponseError:
"""
error_map = {
401: ClientAuthenticationError,
404: ResourceNotFoundError,
409: ResourceExistsError,
304: ResourceNotModifiedError,
}
error_map.update(kwargs.pop("error_map", {}) or {})

_headers = kwargs.pop("headers", {}) or {}
_params = kwargs.pop("params", {}) or {}

cls: ClsType[_models.Metadata] = kwargs.pop("cls", None)

request = build_well_known_installation_metadata_request(
headers=_headers,
params=_params,
)
request.url = self._client.format_url(request.url)

_stream = False
pipeline_response: PipelineResponse = (
await self._client._pipeline.run( # pylint: disable=protected-access
request, stream=_stream, **kwargs
)
)

response = pipeline_response.http_response

if response.status_code not in [200]:
map_error(
status_code=response.status_code, response=response, error_map=error_map
)
raise HttpResponseError(response=response)

deserialized = self._deserialize("Metadata", pipeline_response)

if cls:
return cls(pipeline_response, deserialized, {})

return deserialized


class AuthOperations:
"""
Expand Down Expand Up @@ -1472,6 +1526,114 @@ async def get_job_status_history_bulk(

return deserialized

@distributed_trace_async
async def reschedule_bulk_jobs(self, *, job_ids: List[int], **kwargs: Any) -> Any:
"""Reschedule Bulk Jobs.
Reschedule Bulk Jobs.
:keyword job_ids: Required.
:paramtype job_ids: list[int]
:return: any
:rtype: any
:raises ~azure.core.exceptions.HttpResponseError:
"""
error_map = {
401: ClientAuthenticationError,
404: ResourceNotFoundError,
409: ResourceExistsError,
304: ResourceNotModifiedError,
}
error_map.update(kwargs.pop("error_map", {}) or {})

_headers = kwargs.pop("headers", {}) or {}
_params = kwargs.pop("params", {}) or {}

cls: ClsType[Any] = kwargs.pop("cls", None)

request = build_jobs_reschedule_bulk_jobs_request(
job_ids=job_ids,
headers=_headers,
params=_params,
)
request.url = self._client.format_url(request.url)

_stream = False
pipeline_response: PipelineResponse = (
await self._client._pipeline.run( # pylint: disable=protected-access
request, stream=_stream, **kwargs
)
)

response = pipeline_response.http_response

if response.status_code not in [200]:
map_error(
status_code=response.status_code, response=response, error_map=error_map
)
raise HttpResponseError(response=response)

deserialized = self._deserialize("object", pipeline_response)

if cls:
return cls(pipeline_response, deserialized, {})

return deserialized

@distributed_trace_async
async def reschedule_single_job(self, job_id: int, **kwargs: Any) -> Any:
"""Reschedule Single Job.
Reschedule Single Job.
:param job_id: Required.
:type job_id: int
:return: any
:rtype: any
:raises ~azure.core.exceptions.HttpResponseError:
"""
error_map = {
401: ClientAuthenticationError,
404: ResourceNotFoundError,
409: ResourceExistsError,
304: ResourceNotModifiedError,
}
error_map.update(kwargs.pop("error_map", {}) or {})

_headers = kwargs.pop("headers", {}) or {}
_params = kwargs.pop("params", {}) or {}

cls: ClsType[Any] = kwargs.pop("cls", None)

request = build_jobs_reschedule_single_job_request(
job_id=job_id,
headers=_headers,
params=_params,
)
request.url = self._client.format_url(request.url)

_stream = False
pipeline_response: PipelineResponse = (
await self._client._pipeline.run( # pylint: disable=protected-access
request, stream=_stream, **kwargs
)
)

response = pipeline_response.http_response

if response.status_code not in [200]:
map_error(
status_code=response.status_code, response=response, error_map=error_map
)
raise HttpResponseError(response=response)

deserialized = self._deserialize("object", pipeline_response)

if cls:
return cls(pipeline_response, deserialized, {})

return deserialized

@overload
async def search(
self,
Expand Down
8 changes: 8 additions & 0 deletions src/diracx/client/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from ._models import BodyAuthToken
from ._models import BodyAuthTokenGrantType
from ._models import GroupInfo
from ._models import HTTPValidationError
from ._models import InitiateDeviceFlowResponse
from ._models import InsertedJob
Expand All @@ -16,15 +17,18 @@
from ._models import JobSummaryParams
from ._models import JobSummaryParamsSearchItem
from ._models import LimitedJobStatusReturn
from ._models import Metadata
from ._models import SandboxDownloadResponse
from ._models import SandboxInfo
from ._models import SandboxUploadResponse
from ._models import ScalarSearchSpec
from ._models import SetJobStatusReturn
from ._models import SortSpec
from ._models import SortSpecDirection
from ._models import SupportInfo
from ._models import TokenResponse
from ._models import UserInfoResponse
from ._models import VOInfo
from ._models import ValidationError
from ._models import ValidationErrorLocItem
from ._models import VectorSearchSpec
Expand All @@ -48,6 +52,7 @@
__all__ = [
"BodyAuthToken",
"BodyAuthTokenGrantType",
"GroupInfo",
"HTTPValidationError",
"InitiateDeviceFlowResponse",
"InsertedJob",
Expand All @@ -58,15 +63,18 @@
"JobSummaryParams",
"JobSummaryParamsSearchItem",
"LimitedJobStatusReturn",
"Metadata",
"SandboxDownloadResponse",
"SandboxInfo",
"SandboxUploadResponse",
"ScalarSearchSpec",
"SetJobStatusReturn",
"SortSpec",
"SortSpecDirection",
"SupportInfo",
"TokenResponse",
"UserInfoResponse",
"VOInfo",
"ValidationError",
"ValidationErrorLocItem",
"VectorSearchSpec",
Expand Down
Loading

0 comments on commit 4f990dc

Please sign in to comment.