From e266c345f4fd4211982a618d95401a3442d3a10a Mon Sep 17 00:00:00 2001 From: Daniel McKnight Date: Thu, 14 Nov 2024 09:39:30 -0800 Subject: [PATCH] Add test coverage for LLM endpoints Fix errors in LLMResponse example --- neon_hana/schema/llm_requests.py | 6 +-- tests/test_app.py | 67 +++++++++++++++++++++----------- 2 files changed, 47 insertions(+), 26 deletions(-) diff --git a/neon_hana/schema/llm_requests.py b/neon_hana/schema/llm_requests.py index 1861f4d..4085a69 100644 --- a/neon_hana/schema/llm_requests.py +++ b/neon_hana/schema/llm_requests.py @@ -24,7 +24,7 @@ # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from typing import List +from typing import List, Tuple from pydantic import BaseModel @@ -42,11 +42,11 @@ class LLMRequest(BaseModel): class LLMResponse(BaseModel): response: str - history: List[tuple] + history: List[Tuple[str, str]] model_config = { "json_schema_extra": { "examples": [{ - "query": "I am well, how about you?", + "response": "As a large language model, I do not feel", "history": [("user", "hello"), ("llm", "Hi, how can I help you today?"), ("user", "I am well, how about you?"), diff --git a/tests/test_app.py b/tests/test_app.py index 1080bd5..6fb9924 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -480,34 +480,55 @@ def test_backend_coupons(self, send_request): self.assertEqual(response.status_code, 403, response.text) @patch("neon_hana.mq_service_api.send_mq_request") - def test_llm_chatgpt(self, send_request): - send_request.return_value = {} - # TODO + def test_llm(self, send_request): + send_request.return_value = {"response": "MOCK_LLM_RESPONSE"} + valid_request = {"query": "how are you?", + "history": [("user", "hello"), + ("llm", "Hi, how can I help you today?")]} + # Responses are lists instead of tuples because Pydantic will auto-cast + # for JSON encoding + valid_response = {"response": "MOCK_LLM_RESPONSE", + "history": [["user", "hello"], + ["llm", "Hi, how can I help you today?"], + ["user", "how are you?"], + ["llm", "MOCK_LLM_RESPONSE"]]} + token = self._get_tokens()["access_token"] + # ChatGPT + response = self.test_app.post("/llm/chatgpt", + json=valid_request, + headers={"Authorization": f"Bearer {token}"}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), valid_response) - @patch("neon_hana.mq_service_api.send_mq_request") - def test_llm_fastchat(self, send_request): - send_request.return_value = {} - # TODO + # Fastchat + response = self.test_app.post("/llm/fastchat", + json=valid_request, + headers={"Authorization": f"Bearer {token}"}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), valid_response) - @patch("neon_hana.mq_service_api.send_mq_request") - def test_llm_gemini(self, send_request): - send_request.return_value = {} - # TODO + # Claude + response = self.test_app.post("/llm/claude", + json=valid_request, + headers={"Authorization": f"Bearer {token}"}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), valid_response) - @patch("neon_hana.mq_service_api.send_mq_request") - def test_llm_claude(self, send_request): - send_request.return_value = {} - # TODO + # Palm + response = self.test_app.post("/llm/palm", + json=valid_request, + headers={"Authorization": f"Bearer {token}"}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), valid_response) - @patch("neon_hana.mq_service_api.send_mq_request") - def test_llm_palm(self, send_request): - send_request.return_value = {} - # TODO + # Invalid requests + response = self.test_app.post("/llm/chatgpt", + json=valid_request) + self.assertEqual(response.status_code, 403, response.text) - @patch("neon_hana.mq_service_api.send_mq_request") - def test_util_client_ip(self, send_request): - send_request.return_value = {} - # TODO + response = self.test_app.post("/llm/chatgpt", + headers={"Authorization": f"Bearer {token}"}) + self.assertEqual(response.status_code, 422, response.text) @patch("neon_hana.mq_service_api.send_mq_request") def test_util_headers(self, send_request):