Skip to content

Commit

Permalink
feat: Add feedback (#38)
Browse files Browse the repository at this point in the history
* feat: Add feedback

* chore: Add multi-threading test of call vars logic

* feat: Remove langchain feedback related logic

* feat: Remove leftover langchain items
  • Loading branch information
natecanfield822 authored Nov 6, 2023
1 parent 3f854d6 commit 5f82ad5
Show file tree
Hide file tree
Showing 9 changed files with 185 additions and 5 deletions.
9 changes: 9 additions & 0 deletions src/nr_openai_observability/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
TransactionBeginEventName,
)
from nr_openai_observability.error_handling_decorator import handle_errors
from nr_openai_observability.call_vars import (
set_ai_message_ids,
create_ai_message_id,
get_conversation_id,
)


logger = logging.getLogger("nr_openai_observability")
Expand Down Expand Up @@ -352,9 +357,12 @@ def build_bedrock_events(response, event_dict, completion_id, time_delta):

if len(messages) > 0:
messages[-1]["is_final_response"] = True
ai_message_id = create_ai_message_id(messages[-1].get("id"))
set_ai_message_ids([ai_message_id])

summary = {
"id": completion_id,
"conversation_id": get_conversation_id(),
"timestamp": datetime.now(),
"response_time": int(time_delta * 1000),
"model": model,
Expand Down Expand Up @@ -426,6 +434,7 @@ def build_bedrock_result_message(
message = {
"id": message_id,
"content": content[:4095],
"conversation_id": get_conversation_id(),
"role": role,
"completion_id": completion_id,
"sequence": sequence,
Expand Down
49 changes: 45 additions & 4 deletions src/nr_openai_observability/build_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,41 @@

logger = logging.getLogger("nr_openai_observability")


def build_messages_events(messages, model, completion_id, tags={}, start_seq_num=0):
from nr_openai_observability.call_vars import (
get_conversation_id,
get_conversation_id,
)

def build_messages_events(
messages,
model,
completion_id,
message_id_override=None,
response_id=None,
tags={},
start_seq_num=0,
vendor=None
):
events = []
for index, message in enumerate(messages):
#Non-final messages (IE, user, system)
message_id = str(uuid.uuid4())
if message_id_override is not None:
#LangChain
message_id = message_id_override
elif response_id is not None:
#OpenAI
message_id = str(response_id) + "-" + str(index)
currMessage = {
"id": str(uuid.uuid4()),
"id": message_id,
"completion_id": completion_id,
"conversation_id": get_conversation_id(),
"content": (message.get("content") or "")[:4095],
"role": message.get("role"),
"sequence": index + start_seq_num,
# Grab the last populated model for langchain returned messages
**compat_fields(["model", "response.model"], model),
"vendor": "openAI",
"vendor": vendor,
"ingest_source": "PythonSDK",
**get_trace_details(),
}
Expand Down Expand Up @@ -127,6 +150,7 @@ def build_stream_completion_events(

completion = {
"id": completion_id,
"conversation_id": get_conversation_id(),
"api_key_last_four_digits": f"sk-{last_chunk.api_key[-4:]}",
"response_time": int(response_time * 1000),
"request.model": request.get("model") or request.get("engine"),
Expand Down Expand Up @@ -181,10 +205,12 @@ def build_completion_summary(
completion = {
"id": completion_id,
"request_id": response_headers.get("x-request-id", ""),
"conversation_id": get_conversation_id(),
"api_key_last_four_digits": f"sk-{response.api_key[-4:]}",
"response_time": int(response_time * 1000),
"request.model": request.get("model") or request.get("engine"),
"response.model": response.model,
"response.id": response.id,
**compat_fields(
["organization", "response.organization"], response.organization
),
Expand Down Expand Up @@ -348,3 +374,18 @@ def get_trace_details():

def compat_fields(keys, value):
return dict.fromkeys(keys, value)

def build_ai_feedback_event(category, rating, message_id, conversation_id, request_id, message):
feedback_event = {
"id": str(uuid.uuid4()),
"conversation_id": conversation_id,
"message_id": message_id,
"request_id": request_id,
"rating": rating,
"message": message,
"category": category,
"ingest_source": "PythonSDK",
"timestamp": datetime.now(),
}

return feedback_event
35 changes: 35 additions & 0 deletions src/nr_openai_observability/call_vars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import contextvars

conversation_id = contextvars.ContextVar('conversation_id')
ai_message_ids = contextvars.ContextVar('ai_message_ids')

def get_ai_message_ids(response_id=None):
if response_id is not None:
#OpenAI
return ai_message_ids.get({}).get(response_id, [])
else:
#Bedrock
return ai_message_ids.get([])

def set_ai_message_ids(message_ids, response_id=None):
if response_id is not None:
#OpenAI
current_ids = ai_message_ids.get({})
current_ids[response_id] = message_ids
ai_message_ids.set(current_ids)
else:
#Bedrock
ai_message_ids.set(message_ids)

def set_conversation_id(id):
conversation_id.set(id)

def get_conversation_id():
return conversation_id.get(None)

def create_ai_message_id(message_id, response_id=None):
return {
"conversation_id": get_conversation_id(),
"response_id": response_id,
"message_id": message_id,
}
1 change: 1 addition & 0 deletions src/nr_openai_observability/consts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
EventName = "LlmCompletion"
MessageEventName = "LlmChatCompletionMessage"
SummaryEventName = "LlmChatCompletionSummary"
FeedbackEventName = "LlmChatFeedback"
EmbeddingEventName = "LlmEmbedding"
VectorSearchEventName = "LlmVectorSearch"
VectorSearchResultsEventName = "LlmVectorSearchResult"
Expand Down
11 changes: 10 additions & 1 deletion src/nr_openai_observability/langchain_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
ToolEventName,
)
import newrelic.agent

from nr_openai_observability.build_events import build_messages_events
from nr_openai_observability.consts import MessageEventName
from nr_openai_observability.call_vars import (
set_conversation_id,
)

class NewRelicCallbackHandler(BaseCallbackHandler):
def __init__(
Expand Down Expand Up @@ -40,6 +44,7 @@ def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> Any:
"""Run when LLM starts running."""
self._save_metadata(kwargs.get("metadata", {}))
model = self._get_model(serialized, **kwargs)

tags = {
Expand All @@ -57,6 +62,7 @@ def on_chat_model_start(
**kwargs: Any,
) -> Any:
"""Run when Chat Model starts running."""
self._save_metadata(kwargs.get("metadata", {}))
invocation_params = kwargs.get("invocation_params", {})
model = self._get_model(serialized, **kwargs)

Expand Down Expand Up @@ -250,3 +256,6 @@ def _get_model(self, serialized: Dict[str, Any], **kwargs: Any) -> str:
model = invocation_params.get("_type", "")

return model

def _save_metadata(self, metadata):
set_conversation_id(metadata.get("conversation_id", None))
7 changes: 7 additions & 0 deletions src/nr_openai_observability/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from nr_openai_observability.patcher import perform_patch
from nr_openai_observability.openai_monitoring import monitor
from nr_openai_observability.build_events import build_ai_feedback_event
from nr_openai_observability.consts import FeedbackEventName

logger = logging.getLogger("nr_openai_observability")

Expand All @@ -23,3 +25,8 @@ def initialization(
)
perform_patch()
return monitor

def record_ai_feedback_event(rating, message_id, category = None, conversation_id = None, request_id = None, message = None):
event = build_ai_feedback_event(category, rating, message_id, conversation_id, request_id, message)

monitor.record_event(event, FeedbackEventName)
17 changes: 17 additions & 0 deletions src/nr_openai_observability/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
patcher_create_chat_completion_stream,
patcher_create_chat_completion_stream_async,
)
from nr_openai_observability.call_vars import (
create_ai_message_id,
get_ai_message_ids,
set_ai_message_ids
)

logger = logging.getLogger("nr_openai_observability")

Expand Down Expand Up @@ -156,6 +161,7 @@ def handle_start_completion(request, completion_id):
request.get("messages", []),
request.get("model") or request.get("engine"),
completion_id,
vendor = "openAI",
)
for event in message_events:
monitor.record_event(event, consts.MessageEventName)
Expand All @@ -180,10 +186,21 @@ def handle_finish_chat_completion(response, request, response_time, completion_i
[final_message],
response.model,
completion_id,
None,
response.id,
{"is_final_response": True},
len(initial_messages),
vendor = "openAI",
)[0]

ai_message_ids = get_ai_message_ids(response.get("id"))

ai_message_ids.append(
create_ai_message_id(response_message.get("id"), response.get("id"))
)

set_ai_message_ids(ai_message_ids, response.get("id"))

monitor.record_event(response_message, consts.MessageEventName)

monitor.record_event(completion, consts.SummaryEventName)
Expand Down
17 changes: 17 additions & 0 deletions src/nr_openai_observability/stream_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
)
from nr_openai_observability.error_handling_decorator import handle_errors
from nr_openai_observability.openai_monitoring import monitor
from nr_openai_observability.call_vars import (
create_ai_message_id,
get_ai_message_ids,
set_ai_message_ids
)


def patcher_create_chat_completion_stream(original_fn, *args, **kwargs):
Expand Down Expand Up @@ -127,6 +132,7 @@ def handle_start_completion(request, completion_id):
request.get("messages", []),
request.get("model") or request.get("engine"),
completion_id,
vendor = "openAI",
)
for event in message_events:
monitor.record_event(event, consts.MessageEventName)
Expand All @@ -152,10 +158,21 @@ def handle_finish_chat_completion(
[final_message],
last_chunk.model,
completion_id,
None,
last_chunk.id,
{"is_final_response": True},
len(initial_messages),
vendor = "openAI",
)[0]

ai_message_ids = get_ai_message_ids(response.get("id"))

ai_message_ids.append(
create_ai_message_id(response_message.get("id"), response.get("id"))
)

set_ai_message_ids(ai_message_ids, response.get("id"))

monitor.record_event(response_message, consts.MessageEventName)

monitor.record_event(completion, consts.SummaryEventName)
44 changes: 44 additions & 0 deletions tests/test_call_vars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest
import sys
import threading
import time
import asyncio
import uuid

from nr_openai_observability.call_vars import set_conversation_id, get_conversation_id

failures = 0

def test_set_conversation_id():
global failures

def set_conversation_id_thread():
global failures
test_id = str(uuid.uuid4())
#Check that each new call gets a fresh context
if get_conversation_id() is not None:
failures += 1
set_conversation_id(test_id)
time.sleep(0.001)
#check that context hasn't been polluted with a different thread
if test_id != get_conversation_id():
failures += 1

# Greatly improve the chance of an operation being interrupted
# by thread switch.
try:
sys.setswitchinterval(1e-12)
except AttributeError:
# Python 2 compatible
sys.setcheckinterval(1)

threads = []
for _ in range(1000):
t = threading.Thread(target=set_conversation_id_thread)
threads.append(t)
t.start()

for thread in threads:
thread.join()

assert failures == 0

0 comments on commit 5f82ad5

Please sign in to comment.