Skip to content

Commit

Permalink
Add test coverage for LLM endpoints
Browse files Browse the repository at this point in the history
Fix errors in LLMResponse example
  • Loading branch information
NeonDaniel committed Nov 14, 2024
1 parent 95a27cb commit e266c34
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 26 deletions.
6 changes: 3 additions & 3 deletions neon_hana/schema/llm_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from typing import List
from typing import List, Tuple

from pydantic import BaseModel

Expand All @@ -42,11 +42,11 @@ class LLMRequest(BaseModel):

class LLMResponse(BaseModel):
response: str
history: List[tuple]
history: List[Tuple[str, str]]
model_config = {
"json_schema_extra": {
"examples": [{
"query": "I am well, how about you?",
"response": "As a large language model, I do not feel",
"history": [("user", "hello"),
("llm", "Hi, how can I help you today?"),
("user", "I am well, how about you?"),
Expand Down
67 changes: 44 additions & 23 deletions tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,34 +480,55 @@ def test_backend_coupons(self, send_request):
self.assertEqual(response.status_code, 403, response.text)

@patch("neon_hana.mq_service_api.send_mq_request")
def test_llm_chatgpt(self, send_request):
send_request.return_value = {}
# TODO
def test_llm(self, send_request):
send_request.return_value = {"response": "MOCK_LLM_RESPONSE"}
valid_request = {"query": "how are you?",
"history": [("user", "hello"),
("llm", "Hi, how can I help you today?")]}
# Responses are lists instead of tuples because Pydantic will auto-cast
# for JSON encoding
valid_response = {"response": "MOCK_LLM_RESPONSE",
"history": [["user", "hello"],
["llm", "Hi, how can I help you today?"],
["user", "how are you?"],
["llm", "MOCK_LLM_RESPONSE"]]}
token = self._get_tokens()["access_token"]
# ChatGPT
response = self.test_app.post("/llm/chatgpt",
json=valid_request,
headers={"Authorization": f"Bearer {token}"})
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), valid_response)

@patch("neon_hana.mq_service_api.send_mq_request")
def test_llm_fastchat(self, send_request):
send_request.return_value = {}
# TODO
# Fastchat
response = self.test_app.post("/llm/fastchat",
json=valid_request,
headers={"Authorization": f"Bearer {token}"})
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), valid_response)

@patch("neon_hana.mq_service_api.send_mq_request")
def test_llm_gemini(self, send_request):
send_request.return_value = {}
# TODO
# Claude
response = self.test_app.post("/llm/claude",
json=valid_request,
headers={"Authorization": f"Bearer {token}"})
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), valid_response)

@patch("neon_hana.mq_service_api.send_mq_request")
def test_llm_claude(self, send_request):
send_request.return_value = {}
# TODO
# Palm
response = self.test_app.post("/llm/palm",
json=valid_request,
headers={"Authorization": f"Bearer {token}"})
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), valid_response)

@patch("neon_hana.mq_service_api.send_mq_request")
def test_llm_palm(self, send_request):
send_request.return_value = {}
# TODO
# Invalid requests
response = self.test_app.post("/llm/chatgpt",
json=valid_request)
self.assertEqual(response.status_code, 403, response.text)

@patch("neon_hana.mq_service_api.send_mq_request")
def test_util_client_ip(self, send_request):
send_request.return_value = {}
# TODO
response = self.test_app.post("/llm/chatgpt",
headers={"Authorization": f"Bearer {token}"})
self.assertEqual(response.status_code, 422, response.text)

@patch("neon_hana.mq_service_api.send_mq_request")
def test_util_headers(self, send_request):
Expand Down

0 comments on commit e266c34

Please sign in to comment.