Skip to content

Commit

Permalink
Modified another same issue; modified with lint check
Browse files Browse the repository at this point in the history
1. Found and modified another 2 same issues in openai.py.
2. The previous changes did not pass the lint check. Now modified with lint check.
  • Loading branch information
MartinChen1973 committed Jan 15, 2025
1 parent 73ad987 commit 2f6dabc
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
10 changes: 8 additions & 2 deletions libs/community/langchain_community/embeddings/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,10 @@ def _get_len_safe_embeddings(
)
if not isinstance(average_embedded, dict):
average_embedded = average_embedded.dict()
average = average_embedded["data"][0]["embedding"]
if len(average_embedded["data"]) > 0:
average = average_embedded["data"][0]["embedding"]
else:
raise ValueError(average_embedded["message"])
else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist()
Expand Down Expand Up @@ -645,7 +648,10 @@ async def _aget_len_safe_embeddings(
)
if not isinstance(average_embedded, dict):
average_embedded = average_embedded.dict()
average = average_embedded["data"][0]["embedding"]
if len(average_embedded["data"]) > 0:
average = average_embedded["data"][0]["embedding"]
else:
raise ValueError(average_embedded["message"])
else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist()
Expand Down
7 changes: 5 additions & 2 deletions libs/partners/openai/langchain_openai/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,10 @@ def empty_embedding() -> List[float]:
)
if not isinstance(average_embedded, dict):
average_embedded = average_embedded.model_dump()
_cached_empty_embedding = average_embedded["data"][0]["embedding"]
if len(average_embedded["data"]) > 0:
_cached_empty_embedding = average_embedded["data"][0]["embedding"]
else:
raise ValueError(average_embedded["message"])
return _cached_empty_embedding

return [e if e is not None else empty_embedding() for e in embeddings]
Expand Down Expand Up @@ -552,7 +555,7 @@ async def empty_embedding() -> List[float]:
)
if not isinstance(average_embedded, dict):
average_embedded = average_embedded.model_dump()
if (len(average_embedded["data"]) > 0):
if len(average_embedded["data"]) > 0:
_cached_empty_embedding = average_embedded["data"][0]["embedding"]
else:
raise ValueError(average_embedded["message"])
Expand Down

0 comments on commit 2f6dabc

Please sign in to comment.