Skip to content

Commit

Permalink
Add Preprocess.model_endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
allegroai committed Apr 18, 2022
1 parent 49e5acb commit 409fc15
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 2 deletions.
3 changes: 2 additions & 1 deletion clearml_serving/preprocess/preprocess_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ class Preprocess(object):

def __init__(self):
# set internal state, this will be called only once. (i.e. not per request)
pass
# it will also set the internal model_endpoint to reference the specific model endpoint object being served
self.model_endpoint = None # type: clearml_serving.serving.endpoints.ModelEndpoint

def load(self, local_file_name: str) -> Optional[Any]: # noqa
"""
Expand Down
3 changes: 2 additions & 1 deletion clearml_serving/serving/model_request_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,8 @@ def add_endpoint(
if len(models) > 1:
print("Warning: Found multiple Models for \'{}\', selecting id={}".format(model_query, models[0].id))
endpoint.model_id = models[0].id
elif not endpoint.model_id:
elif not endpoint.model_id and endpoint.engine_type != "custom":
# if the "engine_type" is "custom" it might be there is no model_id attached
print("Warning: No Model provided for \'{}\'".format(url))

# upload as new artifact
Expand Down
4 changes: 4 additions & 0 deletions clearml_serving/serving/preprocess_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def _instantiate_custom_preprocess_cls(self, task: Task) -> None:
Preprocess.send_request = BasePreprocessRequest._preprocess_send_request
# create preprocess class
self._preprocess = Preprocess()
# update the model endpoint on the instance we created
self._preprocess.model_endpoint = self.model_endpoint
# custom model load callback function
if callable(getattr(self._preprocess, 'load', None)):
self._model = self._preprocess.load(self._get_local_model_file())
Expand Down Expand Up @@ -129,6 +131,8 @@ def process(self, data: Any, collect_custom_statistics_fn: Callable[[dict], None
pass

def _get_local_model_file(self):
if not self.model_endpoint.model_id:
return None
model_repo_object = Model(model_id=self.model_endpoint.model_id)
return model_repo_object.get_local_copy()

Expand Down

0 comments on commit 409fc15

Please sign in to comment.