Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add bloom api support #8

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions prompt_lib/backends/bloom_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from cmath import log
import requests
from pprint import pprint
import json
import sys
import logging
from huggingface_hub import InferenceApi
import time
import os
hf_token = os.getenv('HF_API_TOKEN')
logging.basicConfig(level=logging.INFO)

# adapted from https://github.com/Sentdex/BLOOM_Examples/blob/main/BLOOM_api_example.ipynb

class BloomWrapper:
@staticmethod
def call_bloom_api(
prompt: str,
max_length: int = 248,
top_k: float = 0,
num_beams: int = 0,
no_repeat_ngram_size: int = 0,
top_p: float = 0.9,
seed: int = 42,
temperature: float = 0.7,
greedy_decoding: bool = False,
return_full_text: bool = False
) -> str:

top_k = None if top_k == 0 else top_k
do_sample = False if num_beams > 0 else not greedy_decoding
num_beams = None if (greedy_decoding or num_beams == 0) else num_beams
no_repeat_ngram_size = None if num_beams is None else no_repeat_ngram_size
top_p = None if num_beams else top_p
early_stopping = None if num_beams is None else num_beams > 0

params = {
"max_new_tokens": max_length,
"top_k": top_k,
"top_p": top_p,
"temperature": temperature,
"do_sample": do_sample,
"seed": seed,
"early_stopping":early_stopping,
"no_repeat_ngram_size":no_repeat_ngram_size,
"num_beams":num_beams,
"return_full_text":return_full_text
}

s = time.time()
inference = InferenceApi("bigscience/bloom",token=hf_token)
response = inference(prompt, params=params)
return response



if __name__ == "__main__":
response = BloomWrapper.call_bloom_api(
prompt="The thing that makes large language models interesting is",
max_length=248,
num_beams=0,
greedy_decoding=True
)

print(response)