From 5f82ad5f8aecb98c12e1de0c04b82e90cc14b1e2 Mon Sep 17 00:00:00 2001 From: Nate C <89416973+natecanfield822@users.noreply.github.com> Date: Mon, 6 Nov 2023 12:42:05 -0500 Subject: [PATCH] feat: Add feedback (#38) * feat: Add feedback * chore: Add multi-threading test of call vars logic * feat: Remove langchain feedback related logic * feat: Remove leftover langchain items --- src/nr_openai_observability/bedrock.py | 9 ++++ src/nr_openai_observability/build_events.py | 49 +++++++++++++++++-- src/nr_openai_observability/call_vars.py | 35 +++++++++++++ src/nr_openai_observability/consts.py | 1 + .../langchain_callback.py | 11 ++++- src/nr_openai_observability/monitor.py | 7 +++ src/nr_openai_observability/patcher.py | 17 +++++++ src/nr_openai_observability/stream_patcher.py | 17 +++++++ tests/test_call_vars.py | 44 +++++++++++++++++ 9 files changed, 185 insertions(+), 5 deletions(-) create mode 100644 src/nr_openai_observability/call_vars.py create mode 100644 tests/test_call_vars.py diff --git a/src/nr_openai_observability/bedrock.py b/src/nr_openai_observability/bedrock.py index 0d9ae34..b0bc5a3 100644 --- a/src/nr_openai_observability/bedrock.py +++ b/src/nr_openai_observability/bedrock.py @@ -17,6 +17,11 @@ TransactionBeginEventName, ) from nr_openai_observability.error_handling_decorator import handle_errors +from nr_openai_observability.call_vars import ( + set_ai_message_ids, + create_ai_message_id, + get_conversation_id, +) logger = logging.getLogger("nr_openai_observability") @@ -352,9 +357,12 @@ def build_bedrock_events(response, event_dict, completion_id, time_delta): if len(messages) > 0: messages[-1]["is_final_response"] = True + ai_message_id = create_ai_message_id(messages[-1].get("id")) + set_ai_message_ids([ai_message_id]) summary = { "id": completion_id, + "conversation_id": get_conversation_id(), "timestamp": datetime.now(), "response_time": int(time_delta * 1000), "model": model, @@ -426,6 +434,7 @@ def build_bedrock_result_message( message = { "id": message_id, "content": content[:4095], + "conversation_id": get_conversation_id(), "role": role, "completion_id": completion_id, "sequence": sequence, diff --git a/src/nr_openai_observability/build_events.py b/src/nr_openai_observability/build_events.py index 15e26dd..fd68163 100644 --- a/src/nr_openai_observability/build_events.py +++ b/src/nr_openai_observability/build_events.py @@ -8,18 +8,41 @@ logger = logging.getLogger("nr_openai_observability") - -def build_messages_events(messages, model, completion_id, tags={}, start_seq_num=0): +from nr_openai_observability.call_vars import ( + get_conversation_id, + get_conversation_id, +) + +def build_messages_events( + messages, + model, + completion_id, + message_id_override=None, + response_id=None, + tags={}, + start_seq_num=0, + vendor=None +): events = [] for index, message in enumerate(messages): + #Non-final messages (IE, user, system) + message_id = str(uuid.uuid4()) + if message_id_override is not None: + #LangChain + message_id = message_id_override + elif response_id is not None: + #OpenAI + message_id = str(response_id) + "-" + str(index) currMessage = { - "id": str(uuid.uuid4()), + "id": message_id, "completion_id": completion_id, + "conversation_id": get_conversation_id(), "content": (message.get("content") or "")[:4095], "role": message.get("role"), "sequence": index + start_seq_num, + # Grab the last populated model for langchain returned messages **compat_fields(["model", "response.model"], model), - "vendor": "openAI", + "vendor": vendor, "ingest_source": "PythonSDK", **get_trace_details(), } @@ -127,6 +150,7 @@ def build_stream_completion_events( completion = { "id": completion_id, + "conversation_id": get_conversation_id(), "api_key_last_four_digits": f"sk-{last_chunk.api_key[-4:]}", "response_time": int(response_time * 1000), "request.model": request.get("model") or request.get("engine"), @@ -181,10 +205,12 @@ def build_completion_summary( completion = { "id": completion_id, "request_id": response_headers.get("x-request-id", ""), + "conversation_id": get_conversation_id(), "api_key_last_four_digits": f"sk-{response.api_key[-4:]}", "response_time": int(response_time * 1000), "request.model": request.get("model") or request.get("engine"), "response.model": response.model, + "response.id": response.id, **compat_fields( ["organization", "response.organization"], response.organization ), @@ -348,3 +374,18 @@ def get_trace_details(): def compat_fields(keys, value): return dict.fromkeys(keys, value) + +def build_ai_feedback_event(category, rating, message_id, conversation_id, request_id, message): + feedback_event = { + "id": str(uuid.uuid4()), + "conversation_id": conversation_id, + "message_id": message_id, + "request_id": request_id, + "rating": rating, + "message": message, + "category": category, + "ingest_source": "PythonSDK", + "timestamp": datetime.now(), + } + + return feedback_event diff --git a/src/nr_openai_observability/call_vars.py b/src/nr_openai_observability/call_vars.py new file mode 100644 index 0000000..b49e888 --- /dev/null +++ b/src/nr_openai_observability/call_vars.py @@ -0,0 +1,35 @@ +import contextvars + +conversation_id = contextvars.ContextVar('conversation_id') +ai_message_ids = contextvars.ContextVar('ai_message_ids') + +def get_ai_message_ids(response_id=None): + if response_id is not None: + #OpenAI + return ai_message_ids.get({}).get(response_id, []) + else: + #Bedrock + return ai_message_ids.get([]) + +def set_ai_message_ids(message_ids, response_id=None): + if response_id is not None: + #OpenAI + current_ids = ai_message_ids.get({}) + current_ids[response_id] = message_ids + ai_message_ids.set(current_ids) + else: + #Bedrock + ai_message_ids.set(message_ids) + +def set_conversation_id(id): + conversation_id.set(id) + +def get_conversation_id(): + return conversation_id.get(None) + +def create_ai_message_id(message_id, response_id=None): + return { + "conversation_id": get_conversation_id(), + "response_id": response_id, + "message_id": message_id, + } \ No newline at end of file diff --git a/src/nr_openai_observability/consts.py b/src/nr_openai_observability/consts.py index c1fec81..391617b 100644 --- a/src/nr_openai_observability/consts.py +++ b/src/nr_openai_observability/consts.py @@ -1,6 +1,7 @@ EventName = "LlmCompletion" MessageEventName = "LlmChatCompletionMessage" SummaryEventName = "LlmChatCompletionSummary" +FeedbackEventName = "LlmChatFeedback" EmbeddingEventName = "LlmEmbedding" VectorSearchEventName = "LlmVectorSearch" VectorSearchResultsEventName = "LlmVectorSearchResult" diff --git a/src/nr_openai_observability/langchain_callback.py b/src/nr_openai_observability/langchain_callback.py index 2590628..3c78c02 100644 --- a/src/nr_openai_observability/langchain_callback.py +++ b/src/nr_openai_observability/langchain_callback.py @@ -11,7 +11,11 @@ ToolEventName, ) import newrelic.agent - +from nr_openai_observability.build_events import build_messages_events +from nr_openai_observability.consts import MessageEventName +from nr_openai_observability.call_vars import ( + set_conversation_id, +) class NewRelicCallbackHandler(BaseCallbackHandler): def __init__( @@ -40,6 +44,7 @@ def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> Any: """Run when LLM starts running.""" + self._save_metadata(kwargs.get("metadata", {})) model = self._get_model(serialized, **kwargs) tags = { @@ -57,6 +62,7 @@ def on_chat_model_start( **kwargs: Any, ) -> Any: """Run when Chat Model starts running.""" + self._save_metadata(kwargs.get("metadata", {})) invocation_params = kwargs.get("invocation_params", {}) model = self._get_model(serialized, **kwargs) @@ -250,3 +256,6 @@ def _get_model(self, serialized: Dict[str, Any], **kwargs: Any) -> str: model = invocation_params.get("_type", "") return model + + def _save_metadata(self, metadata): + set_conversation_id(metadata.get("conversation_id", None)) diff --git a/src/nr_openai_observability/monitor.py b/src/nr_openai_observability/monitor.py index 890efdf..0abac6d 100644 --- a/src/nr_openai_observability/monitor.py +++ b/src/nr_openai_observability/monitor.py @@ -3,6 +3,8 @@ from nr_openai_observability.patcher import perform_patch from nr_openai_observability.openai_monitoring import monitor +from nr_openai_observability.build_events import build_ai_feedback_event +from nr_openai_observability.consts import FeedbackEventName logger = logging.getLogger("nr_openai_observability") @@ -23,3 +25,8 @@ def initialization( ) perform_patch() return monitor + +def record_ai_feedback_event(rating, message_id, category = None, conversation_id = None, request_id = None, message = None): + event = build_ai_feedback_event(category, rating, message_id, conversation_id, request_id, message) + + monitor.record_event(event, FeedbackEventName) diff --git a/src/nr_openai_observability/patcher.py b/src/nr_openai_observability/patcher.py index a5ff8ec..16cb86b 100644 --- a/src/nr_openai_observability/patcher.py +++ b/src/nr_openai_observability/patcher.py @@ -23,6 +23,11 @@ patcher_create_chat_completion_stream, patcher_create_chat_completion_stream_async, ) +from nr_openai_observability.call_vars import ( + create_ai_message_id, + get_ai_message_ids, + set_ai_message_ids +) logger = logging.getLogger("nr_openai_observability") @@ -156,6 +161,7 @@ def handle_start_completion(request, completion_id): request.get("messages", []), request.get("model") or request.get("engine"), completion_id, + vendor = "openAI", ) for event in message_events: monitor.record_event(event, consts.MessageEventName) @@ -180,10 +186,21 @@ def handle_finish_chat_completion(response, request, response_time, completion_i [final_message], response.model, completion_id, + None, + response.id, {"is_final_response": True}, len(initial_messages), + vendor = "openAI", )[0] + ai_message_ids = get_ai_message_ids(response.get("id")) + + ai_message_ids.append( + create_ai_message_id(response_message.get("id"), response.get("id")) + ) + + set_ai_message_ids(ai_message_ids, response.get("id")) + monitor.record_event(response_message, consts.MessageEventName) monitor.record_event(completion, consts.SummaryEventName) diff --git a/src/nr_openai_observability/stream_patcher.py b/src/nr_openai_observability/stream_patcher.py index 96bf3c2..136f367 100644 --- a/src/nr_openai_observability/stream_patcher.py +++ b/src/nr_openai_observability/stream_patcher.py @@ -12,6 +12,11 @@ ) from nr_openai_observability.error_handling_decorator import handle_errors from nr_openai_observability.openai_monitoring import monitor +from nr_openai_observability.call_vars import ( + create_ai_message_id, + get_ai_message_ids, + set_ai_message_ids +) def patcher_create_chat_completion_stream(original_fn, *args, **kwargs): @@ -127,6 +132,7 @@ def handle_start_completion(request, completion_id): request.get("messages", []), request.get("model") or request.get("engine"), completion_id, + vendor = "openAI", ) for event in message_events: monitor.record_event(event, consts.MessageEventName) @@ -152,10 +158,21 @@ def handle_finish_chat_completion( [final_message], last_chunk.model, completion_id, + None, + last_chunk.id, {"is_final_response": True}, len(initial_messages), + vendor = "openAI", )[0] + ai_message_ids = get_ai_message_ids(response.get("id")) + + ai_message_ids.append( + create_ai_message_id(response_message.get("id"), response.get("id")) + ) + + set_ai_message_ids(ai_message_ids, response.get("id")) + monitor.record_event(response_message, consts.MessageEventName) monitor.record_event(completion, consts.SummaryEventName) diff --git a/tests/test_call_vars.py b/tests/test_call_vars.py new file mode 100644 index 0000000..748bdb4 --- /dev/null +++ b/tests/test_call_vars.py @@ -0,0 +1,44 @@ +import pytest +import sys +import threading +import time +import asyncio +import uuid + +from nr_openai_observability.call_vars import set_conversation_id, get_conversation_id + +failures = 0 + +def test_set_conversation_id(): + global failures + + def set_conversation_id_thread(): + global failures + test_id = str(uuid.uuid4()) + #Check that each new call gets a fresh context + if get_conversation_id() is not None: + failures += 1 + set_conversation_id(test_id) + time.sleep(0.001) + #check that context hasn't been polluted with a different thread + if test_id != get_conversation_id(): + failures += 1 + + # Greatly improve the chance of an operation being interrupted + # by thread switch. + try: + sys.setswitchinterval(1e-12) + except AttributeError: + # Python 2 compatible + sys.setcheckinterval(1) + + threads = [] + for _ in range(1000): + t = threading.Thread(target=set_conversation_id_thread) + threads.append(t) + t.start() + + for thread in threads: + thread.join() + + assert failures == 0 \ No newline at end of file