Skip to content

Commit

Permalink
Rename ResultFactory to ResultStore (#15184)
Browse files Browse the repository at this point in the history
  • Loading branch information
desertaxle authored Sep 3, 2024
1 parent 057f1d3 commit 128e12d
Show file tree
Hide file tree
Showing 23 changed files with 990 additions and 991 deletions.
6 changes: 3 additions & 3 deletions src/integrations/prefect-redis/tests/test_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from prefect_redis.records import RedisRecordStore

from prefect.filesystems import LocalFileSystem
from prefect.results import ResultFactory
from prefect.results import ResultStore
from prefect.settings import (
PREFECT_DEFAULT_RESULT_STORAGE_BLOCK,
temporary_settings,
Expand All @@ -30,10 +30,10 @@ def default_storage_setting(self, tmp_path):

@pytest.fixture
async def result(self, default_storage_setting):
factory = ResultFactory(
store = ResultStore(
persist_result=True,
)
result = await factory.create_result(obj={"test": "value"})
result = await store.create_result(obj={"test": "value"})
return result

@pytest.fixture
Expand Down
10 changes: 5 additions & 5 deletions src/prefect/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from prefect.client.schemas import FlowRun, TaskRun
from prefect.events.worker import EventsWorker
from prefect.exceptions import MissingContextError
from prefect.results import ResultFactory
from prefect.results import ResultStore
from prefect.settings import PREFECT_HOME, Profile, Settings
from prefect.states import State
from prefect.task_runners import TaskRunner
Expand Down Expand Up @@ -340,7 +340,7 @@ class EngineContext(RunContext):
detached: bool = False

# Result handling
result_factory: ResultFactory
result_store: ResultStore

# Counter for task calls allowing unique
task_run_dynamic_keys: Dict[str, int] = Field(default_factory=dict)
Expand Down Expand Up @@ -369,7 +369,7 @@ def serialize(self):
"log_prints",
"start_time",
"input_keyset",
"result_factory",
"result_store",
},
exclude_unset=True,
)
Expand All @@ -394,7 +394,7 @@ class TaskRunContext(RunContext):
parameters: Dict[str, Any]

# Result handling
result_factory: ResultFactory
result_store: ResultStore

__var__ = ContextVar("task_run")

Expand All @@ -407,7 +407,7 @@ def serialize(self):
"log_prints",
"start_time",
"input_keyset",
"result_factory",
"result_store",
},
exclude_unset=True,
)
Expand Down
19 changes: 9 additions & 10 deletions src/prefect/flow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
get_run_logger,
patch_print,
)
from prefect.results import BaseResult, ResultFactory, get_current_result_factory
from prefect.results import BaseResult, ResultStore, get_current_result_store
from prefect.settings import PREFECT_DEBUG_MODE
from prefect.states import (
Failed,
Expand Down Expand Up @@ -202,7 +202,7 @@ def begin_run(self) -> State:
self.handle_exception(
exc,
msg=message,
result_factory=get_current_result_factory().update_for_flow(
result_store=get_current_result_store().update_for_flow(
self.flow, _sync=True
),
)
Expand Down Expand Up @@ -263,14 +263,14 @@ def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
return _result

def handle_success(self, result: R) -> R:
result_factory = getattr(FlowRunContext.get(), "result_factory", None)
if result_factory is None:
raise ValueError("Result factory is not set")
result_store = getattr(FlowRunContext.get(), "result_store", None)
if result_store is None:
raise ValueError("Result store is not set")
resolved_result = resolve_futures_to_states(result)
terminal_state = run_coro_as_sync(
return_value_to_state(
resolved_result,
result_factory=result_factory,
result_store=result_store,
write_result=True,
)
)
Expand All @@ -282,15 +282,14 @@ def handle_exception(
self,
exc: Exception,
msg: Optional[str] = None,
result_factory: Optional[ResultFactory] = None,
result_store: Optional[ResultStore] = None,
) -> State:
context = FlowRunContext.get()
terminal_state = run_coro_as_sync(
exception_to_failed_state(
exc,
message=msg or "Flow run encountered an exception:",
result_factory=result_factory
or getattr(context, "result_factory", None),
result_store=result_store or getattr(context, "result_store", None),
write_result=True,
)
)
Expand Down Expand Up @@ -508,7 +507,7 @@ def setup_run_context(self, client: Optional[SyncPrefectClient] = None):
flow_run=self.flow_run,
parameters=self.parameters,
client=client,
result_factory=get_current_result_factory().update_for_flow(
result_store=get_current_result_store().update_for_flow(
self.flow, _sync=True
),
task_runner=task_runner,
Expand Down
18 changes: 12 additions & 6 deletions src/prefect/records/result_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,22 @@

import pendulum

from prefect.results import BaseResult, PersistedResult, ResultFactory
from prefect.results import BaseResult, PersistedResult, ResultStore
from prefect.transactions import IsolationLevel
from prefect.utilities.asyncutils import run_coro_as_sync

from .base import RecordStore, TransactionRecord


@dataclass
class ResultFactoryStore(RecordStore):
result_factory: ResultFactory
class ResultRecordStore(RecordStore):
"""
A record store for result records.
Collocates result metadata with result data.
"""

result_store: ResultStore
cache: Optional[PersistedResult] = None

def exists(self, key: str) -> bool:
Expand All @@ -38,8 +44,8 @@ def read(self, key: str, holder: Optional[str] = None) -> TransactionRecord:
return TransactionRecord(key=key, result=self.cache)
try:
result = PersistedResult(
serializer_type=self.result_factory.serializer.type,
storage_block_id=self.result_factory.storage_block_id,
serializer_type=self.result_store.serializer.type,
storage_block_id=self.result_store.result_storage_block_id,
storage_key=key,
)
return TransactionRecord(key=key, result=result)
Expand All @@ -52,7 +58,7 @@ def write(self, key: str, result: Any, holder: Optional[str] = None) -> None:
# if the value is already a persisted result, write it
result.write(_sync=True)
elif not isinstance(result, BaseResult):
run_coro_as_sync(self.result_factory.create_result(obj=result, key=key))
run_coro_as_sync(self.result_store.create_result(obj=result, key=key))

def supports_isolation_level(self, isolation_level: IsolationLevel) -> bool:
return isolation_level == IsolationLevel.READ_COMMITTED
88 changes: 42 additions & 46 deletions src/prefect/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,61 +169,61 @@ def _format_user_supplied_storage_key(key: str) -> str:
return key.format(**runtime_vars, parameters=prefect.runtime.task_run.parameters)


class ResultFactory(BaseModel):
class ResultStore(BaseModel):
"""
A utility to generate `Result` types.
"""

storage_block: Optional[WritableFileSystem] = Field(default=None)
result_storage: Optional[WritableFileSystem] = Field(default=None)
persist_result: bool = Field(default_factory=get_default_persist_setting)
cache_result_in_memory: bool = Field(default=True)
serializer: Serializer = Field(default_factory=get_default_result_serializer)
storage_key_fn: Callable[[], str] = Field(default=DEFAULT_STORAGE_KEY_FN)

@property
def storage_block_id(self) -> Optional[UUID]:
if self.storage_block is None:
def result_storage_block_id(self) -> Optional[UUID]:
if self.result_storage is None:
return None
return self.storage_block._block_document_id
return self.result_storage._block_document_id

@sync_compatible
async def update_for_flow(self, flow: "Flow") -> Self:
"""
Create a new result factory for a flow with updated settings.
Create a new result store for a flow with updated settings.
Args:
flow: The flow to update the result factory for.
flow: The flow to update the result store for.
Returns:
An updated result factory.
An updated result store.
"""
update = {}
if flow.result_storage is not None:
update["storage_block"] = await resolve_result_storage(flow.result_storage)
update["result_storage"] = await resolve_result_storage(flow.result_storage)
if flow.result_serializer is not None:
update["serializer"] = resolve_serializer(flow.result_serializer)
if flow.persist_result is not None:
update["persist_result"] = flow.persist_result
if flow.cache_result_in_memory is not None:
update["cache_result_in_memory"] = flow.cache_result_in_memory
if self.storage_block is None and update.get("storage_block") is None:
update["storage_block"] = await get_default_result_storage()
if self.result_storage is None and update.get("result_storage") is None:
update["result_storage"] = await get_default_result_storage()
return self.model_copy(update=update)

@sync_compatible
async def update_for_task(self: Self, task: "Task") -> Self:
"""
Create a new result factory for a task.
Create a new result store for a task.
Args:
task: The task to update the result factory for.
task: The task to update the result store for.
Returns:
An updated result factory.
An updated result store.
"""
update = {}
if task.result_storage is not None:
update["storage_block"] = await resolve_result_storage(task.result_storage)
update["result_storage"] = await resolve_result_storage(task.result_storage)
if task.result_serializer is not None:
update["serializer"] = resolve_serializer(task.result_serializer)
if task.persist_result is not None:
Expand All @@ -234,8 +234,8 @@ async def update_for_task(self: Self, task: "Task") -> Self:
update["storage_key_fn"] = partial(
_format_user_supplied_storage_key, task.result_storage_key
)
if self.storage_block is None and update.get("storage_block") is None:
update["storage_block"] = await get_default_result_storage()
if self.result_storage is None and update.get("result_storage") is None:
update["result_storage"] = await get_default_result_storage()
return self.model_copy(update=update)

@sync_compatible
Expand All @@ -252,10 +252,10 @@ async def _read(self, key: str) -> "ResultRecord":
Returns:
A result record.
"""
if self.storage_block is None:
self.storage_block = await get_default_result_storage()
if self.result_storage is None:
self.result_storage = await get_default_result_storage()

content = await self.storage_block.read_path(f"{key}")
content = await self.result_storage.read_path(f"{key}")
return ResultRecord.deserialize(content)

def read(self, key: str) -> "ResultRecord":
Expand Down Expand Up @@ -300,8 +300,8 @@ async def _write(
obj: The object to write to storage.
expiration: The expiration time for the result record.
"""
if self.storage_block is None:
self.storage_block = await get_default_result_storage()
if self.result_storage is None:
self.result_storage = await get_default_result_storage()
key = key or self.storage_key_fn()

record = ResultRecord(
Expand Down Expand Up @@ -344,10 +344,10 @@ async def _persist_result_record(self, result_record: "ResultRecord"):
Args:
result_record: The result record to persist.
"""
if self.storage_block is None:
self.storage_block = await get_default_result_storage()
if self.result_storage is None:
self.result_storage = await get_default_result_storage()

await self.storage_block.write_path(
await self.result_storage.write_path(
result_record.metadata.storage_key, content=result_record.serialize()
)

Expand Down Expand Up @@ -393,13 +393,13 @@ def key_fn():
else:
storage_key_fn = self.storage_key_fn

if self.storage_block is None:
self.storage_block = await get_default_result_storage()
if self.result_storage is None:
self.result_storage = await get_default_result_storage()

return await PersistedResult.create(
obj,
storage_block=self.storage_block,
storage_block_id=self.storage_block_id,
storage_block=self.result_storage,
storage_block_id=self.result_storage_block_id,
storage_key_fn=storage_key_fn,
serializer=self.serializer,
cache_object=should_cache_object,
Expand All @@ -417,31 +417,31 @@ async def store_parameters(self, identifier: UUID, parameters: Dict[str, Any]):
serializer=self.serializer, storage_key=str(identifier)
),
)
await self.storage_block.write_path(
await self.result_storage.write_path(
f"parameters/{identifier}", content=record.serialize()
)

@sync_compatible
async def read_parameters(self, identifier: UUID) -> Dict[str, Any]:
record = ResultRecord.deserialize(
await self.storage_block.read_path(f"parameters/{identifier}")
await self.result_storage.read_path(f"parameters/{identifier}")
)
return record.result


def get_current_result_factory() -> ResultFactory:
def get_current_result_store() -> ResultStore:
"""
Get the current result factory.
Get the current result store.
"""
from prefect.context import get_run_context

try:
run_context = get_run_context()
except MissingContextError:
result_factory = ResultFactory()
result_store = ResultStore()
else:
result_factory = run_context.result_factory
return result_factory
result_store = run_context.result_store
return result_store


class ResultRecordMetadata(BaseModel):
Expand Down Expand Up @@ -722,15 +722,13 @@ async def get(
if self.has_cached_object() and not ignore_cache:
return self._cache

result_factory_kwargs = {}
result_store_kwargs = {}
if self._serializer:
result_factory_kwargs["serializer"] = resolve_serializer(self._serializer)
result_store_kwargs["serializer"] = resolve_serializer(self._serializer)
storage_block = await self._get_storage_block(client=client)
result_factory = ResultFactory(
storage_block=storage_block, **result_factory_kwargs
)
result_store = ResultStore(result_storage=storage_block, **result_store_kwargs)

record = await result_factory.aread(self.storage_key)
record = await result_store.aread(self.storage_key)
self.expiration = record.expiration

if self._should_cache_object:
Expand Down Expand Up @@ -777,10 +775,8 @@ async def write(self, obj: R = NotSet, client: "PrefectClient" = None) -> None:
# this could error if the serializer requires kwargs
serializer = Serializer(type=self.serializer_type)

result_factory = ResultFactory(
storage_block=storage_block, serializer=serializer
)
await result_factory.awrite(
result_store = ResultStore(result_storage=storage_block, serializer=serializer)
await result_store.awrite(
obj=obj, key=self.storage_key, expiration=self.expiration
)

Expand Down
Loading

0 comments on commit 128e12d

Please sign in to comment.