Skip to content

Commit

Permalink
feat(backend): credential swap to api keys in cred store (Significant…
Browse files Browse the repository at this point in the history
…-Gravitas#8403)

* feat(backend): credential swap

* ci: formatting

* fix: importing is hard okay

* fix: spelln' is hard

* feat: better credential provider handling

* docs: update the imports locations

* fix: test credentials + formatting

* feat: drop continuous read mode

* fix: lint

* feat: fallback credentials

* feat: charge for credential useage and have a bad backup mechnism

* fix: don't save default credentials + add d_id

* fix: formatting

* feat: basic encryption/decryption

* ref: move files around

* ref: sign all blocks out of their credentials

* ref: update target to match a new, and encrypted future

* wip: llm provider merger

* don't delete `credentials` input on nodes

* fix llm block ci issues

* updated get AICredentials

* fix fix

* insert migration to move integration credentials from `auth.user` metadata to `platform.User.metadata`

* fixed migration

* add migration for existing user integration credentials

* disabled reddit and email block

* fix credential handling in LLM blocks

* add other secret fields to credential scrubber migration

* add other secret fields to credential scrubber migration (vol. 2)

* fix: pr fixes

* fix: mock funciton

* add encrypted values

---------

Co-authored-by: Reinier van der Leer <[email protected]>
Co-authored-by: SwiftyOS <[email protected]>
Co-authored-by: Aarushi <[email protected]>
Co-authored-by: Aarushi <[email protected]>
  • Loading branch information
5 people authored Oct 31, 2024
1 parent 26caf1c commit 3aebed6
Show file tree
Hide file tree
Showing 34 changed files with 1,270 additions and 342 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING

from pydantic import SecretStr

if TYPE_CHECKING:
from redis import Redis
from backend.executor.database import DatabaseManager
Expand All @@ -10,13 +12,77 @@
from autogpt_libs.utils.synchronize import RedisKeyedMutex

from .types import (
APIKeyCredentials,
Credentials,
OAuth2Credentials,
OAuthState,
UserMetadata,
UserMetadataRaw,
UserIntegrations,
)

from backend.util.settings import Settings

settings = Settings()

revid_credentials = APIKeyCredentials(
id="fdb7f412-f519-48d1-9b5f-d2f73d0e01fe",
provider="revid",
api_key=SecretStr(settings.secrets.revid_api_key),
title="Use Credits for Revid",
expires_at=None,
)
ideogram_credentials = APIKeyCredentials(
id="760f84fc-b270-42de-91f6-08efe1b512d0",
provider="ideogram",
api_key=SecretStr(settings.secrets.ideogram_api_key),
title="Use Credits for Ideogram",
expires_at=None,
)
replicate_credentials = APIKeyCredentials(
id="6b9fc200-4726-4973-86c9-cd526f5ce5db",
provider="replicate",
api_key=SecretStr(settings.secrets.replicate_api_key),
title="Use Credits for Replicate",
expires_at=None,
)
openai_credentials = APIKeyCredentials(
id="53c25cb8-e3ee-465c-a4d1-e75a4c899c2a",
provider="llm",
api_key=SecretStr(settings.secrets.openai_api_key),
title="Use Credits for OpenAI",
expires_at=None,
)
anthropic_credentials = APIKeyCredentials(
id="24e5d942-d9e3-4798-8151-90143ee55629",
provider="llm",
api_key=SecretStr(settings.secrets.anthropic_api_key),
title="Use Credits for Anthropic",
expires_at=None,
)
groq_credentials = APIKeyCredentials(
id="4ec22295-8f97-4dd1-b42b-2c6957a02545",
provider="llm",
api_key=SecretStr(settings.secrets.groq_api_key),
title="Use Credits for Groq",
expires_at=None,
)
did_credentials = APIKeyCredentials(
id="7f7b0654-c36b-4565-8fa7-9a52575dfae2",
provider="d_id",
api_key=SecretStr(settings.secrets.did_api_key),
title="Use Credits for D-ID",
expires_at=None,
)

DEFAULT_CREDENTIALS = [
revid_credentials,
ideogram_credentials,
replicate_credentials,
openai_credentials,
anthropic_credentials,
groq_credentials,
did_credentials,
]


class SupabaseIntegrationCredentialsStore:
def __init__(self, redis: "Redis"):
Expand All @@ -27,10 +93,11 @@ def __init__(self, redis: "Redis"):
def db_manager(self) -> "DatabaseManager":
from backend.executor.database import DatabaseManager
from backend.util.service import get_service_client

return get_service_client(DatabaseManager)

def add_creds(self, user_id: str, credentials: Credentials) -> None:
with self.locked_user_metadata(user_id):
with self.locked_user_integrations(user_id):
if self.get_creds_by_id(user_id, credentials.id):
raise ValueError(
f"Can not re-create existing credentials #{credentials.id} "
Expand All @@ -41,10 +108,23 @@ def add_creds(self, user_id: str, credentials: Credentials) -> None:
)

def get_all_creds(self, user_id: str) -> list[Credentials]:
user_metadata = self._get_user_metadata(user_id)
return UserMetadata.model_validate(
user_metadata.model_dump()
).integration_credentials
users_credentials = self._get_user_integrations(user_id).credentials
all_credentials = users_credentials
if settings.secrets.revid_api_key:
all_credentials.append(revid_credentials)
if settings.secrets.ideogram_api_key:
all_credentials.append(ideogram_credentials)
if settings.secrets.groq_api_key:
all_credentials.append(groq_credentials)
if settings.secrets.replicate_api_key:
all_credentials.append(replicate_credentials)
if settings.secrets.openai_api_key:
all_credentials.append(openai_credentials)
if settings.secrets.anthropic_api_key:
all_credentials.append(anthropic_credentials)
if settings.secrets.did_api_key:
all_credentials.append(did_credentials)
return all_credentials

def get_creds_by_id(self, user_id: str, credentials_id: str) -> Credentials | None:
all_credentials = self.get_all_creds(user_id)
Expand All @@ -59,7 +139,7 @@ def get_authorized_providers(self, user_id: str) -> list[str]:
return list(set(c.provider for c in credentials))

def update_creds(self, user_id: str, updated: Credentials) -> None:
with self.locked_user_metadata(user_id):
with self.locked_user_integrations(user_id):
current = self.get_creds_by_id(user_id, updated.id)
if not current:
raise ValueError(
Expand Down Expand Up @@ -93,7 +173,7 @@ def update_creds(self, user_id: str, updated: Credentials) -> None:
self._set_user_integration_creds(user_id, updated_credentials_list)

def delete_creds_by_id(self, user_id: str, credentials_id: str) -> None:
with self.locked_user_metadata(user_id):
with self.locked_user_integrations(user_id):
filtered_credentials = [
c for c in self.get_all_creds(user_id) if c.id != credentials_id
]
Expand All @@ -110,14 +190,14 @@ def store_state_token(self, user_id: str, provider: str, scopes: list[str]) -> s
scopes=scopes,
)

with self.locked_user_metadata(user_id):
user_metadata = self._get_user_metadata(user_id)
oauth_states = user_metadata.integration_oauth_states
oauth_states.append(state.model_dump())
user_metadata.integration_oauth_states = oauth_states
with self.locked_user_integrations(user_id):
user_integrations = self._get_user_integrations(user_id)
oauth_states = user_integrations.oauth_states
oauth_states.append(state)
user_integrations.oauth_states = oauth_states

self.db_manager.update_user_metadata(
user_id=user_id, metadata=user_metadata
self.db_manager.update_user_integrations(
user_id=user_id, data=user_integrations
)

return token
Expand All @@ -132,63 +212,67 @@ def get_any_valid_scopes_from_state_token(
IS TO CHECK IF THE USER HAS GIVEN PERMISSIONS TO THE APPLICATION BEFORE EXCHANGING
THE CODE FOR TOKENS.
"""
user_metadata = self._get_user_metadata(user_id)
oauth_states = user_metadata.integration_oauth_states
user_integrations = self._get_user_integrations(user_id)
oauth_states = user_integrations.oauth_states

now = datetime.now(timezone.utc)
valid_state = next(
(
state
for state in oauth_states
if state["token"] == token
and state["provider"] == provider
and state["expires_at"] > now.timestamp()
if state.token == token
and state.provider == provider
and state.expires_at > now.timestamp()
),
None,
)

if valid_state:
return valid_state.get("scopes", [])
return valid_state.scopes

return []

def verify_state_token(self, user_id: str, token: str, provider: str) -> bool:
with self.locked_user_metadata(user_id):
user_metadata = self._get_user_metadata(user_id)
oauth_states = user_metadata.integration_oauth_states
with self.locked_user_integrations(user_id):
user_integrations = self._get_user_integrations(user_id)
oauth_states = user_integrations.oauth_states

now = datetime.now(timezone.utc)
valid_state = next(
(
state
for state in oauth_states
if state["token"] == token
and state["provider"] == provider
and state["expires_at"] > now.timestamp()
if state.token == token
and state.provider == provider
and state.expires_at > now.timestamp()
),
None,
)

if valid_state:
# Remove the used state
oauth_states.remove(valid_state)
user_metadata.integration_oauth_states = oauth_states
self.db_manager.update_user_metadata(user_id, user_metadata)
user_integrations.oauth_states = oauth_states
self.db_manager.update_user_integrations(user_id, user_integrations)
return True

return False

def _set_user_integration_creds(
self, user_id: str, credentials: list[Credentials]
) -> None:
raw_metadata = self._get_user_metadata(user_id)
raw_metadata.integration_credentials = [c.model_dump() for c in credentials]
self.db_manager.update_user_metadata(user_id, raw_metadata)
integrations = self._get_user_integrations(user_id)
# Remove default credentials from the list
credentials = [c for c in credentials if c not in DEFAULT_CREDENTIALS]
integrations.credentials = credentials
self.db_manager.update_user_integrations(user_id, integrations)

def _get_user_metadata(self, user_id: str) -> UserMetadataRaw:
metadata: UserMetadataRaw = self.db_manager.get_user_metadata(user_id=user_id)
return metadata
def _get_user_integrations(self, user_id: str) -> UserIntegrations:
integrations: UserIntegrations = self.db_manager.get_user_integrations(
user_id=user_id
)
return integrations

def locked_user_metadata(self, user_id: str):
key = (self.db_manager, f"user:{user_id}", "metadata")
def locked_user_integrations(self, user_id: str):
key = (self.db_manager, f"user:{user_id}", "integrations")
return self.locks.locked(key)
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ class UserMetadata(BaseModel):
integration_oauth_states: list[OAuthState] = Field(default_factory=list)


class UserMetadataRaw(BaseModel):
integration_credentials: list[dict] = Field(default_factory=list)
integration_oauth_states: list[dict] = Field(default_factory=list)
class UserMetadataRaw(TypedDict, total=False):
integration_credentials: list[dict]
integration_oauth_states: list[dict]


class UserIntegrations(BaseModel):
credentials: list[Credentials] = Field(default_factory=list)
oauth_states: list[OAuthState] = Field(default_factory=list)
3 changes: 3 additions & 0 deletions autogpt_platform/backend/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ PRISMA_SCHEMA="postgres/schema.prisma"

BACKEND_CORS_ALLOW_ORIGINS=["http://localhost:3000"]

# generate using `from cryptography.fernet import Fernet;Fernet.generate_key().decode()`
ENCRYPTION_KEY='dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw='

REDIS_HOST=localhost
REDIS_PORT=6379
REDIS_PASSWORD=password
Expand Down
Loading

0 comments on commit 3aebed6

Please sign in to comment.