-
Notifications
You must be signed in to change notification settings - Fork 16k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6352edf
commit 8d2990f
Showing
1 changed file
with
365 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,365 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import warnings\n", | ||
"\n", | ||
"warnings.filterwarnings('ignore')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"import requests\n", | ||
"from dotenv import load_dotenv\n", | ||
"from langchain_core.documents import Document" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import faiss\n", | ||
"import numpy as np\n", | ||
"import re # For text cleaning\n", | ||
"from dotenv import load_dotenv\n", | ||
"from sentence_transformers import SentenceTransformer\n", | ||
"from langchain.vectorstores import VectorStore" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 11, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"github_token = os.getenv(\"GITHUB_TOKEN\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 12, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"load_dotenv()\n", | ||
"\n", | ||
"github_token = os.getenv(\"GITHUB_TOKEN\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 13, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def fetch_github(owner, repo, endpoint):\n", | ||
" url = f\"https://api.github.com/repos/{owner}/{repo}/{endpoint}\"\n", | ||
" headers = {\"Authorization\": f\"Bearer {github_token}\"}\n", | ||
" all_data = []\n", | ||
" page = 1\n", | ||
"\n", | ||
" while True:\n", | ||
" response = requests.get(url, headers=headers, params={\"page\": page})\n", | ||
" if response.status_code == 200:\n", | ||
" data = response.json()\n", | ||
" if not data: # Break if no more data\n", | ||
" break\n", | ||
" all_data.extend(data)\n", | ||
" page += 1\n", | ||
" else:\n", | ||
" print(\"Failed with status code:\", response.status_code)\n", | ||
" return []\n", | ||
"\n", | ||
" return all_data\n", | ||
"\n", | ||
"\n", | ||
"def fetch_github_issues(owner, repo,endpoint):\n", | ||
" data = fetch_github(owner, repo, endpoint)\n", | ||
" return load_issues(data,endpoint,repo)\n", | ||
"\n", | ||
"\n", | ||
"def load_issues(data,endpoint,repo):\n", | ||
" docs = []\n", | ||
" for entry in data:\n", | ||
" str_data = entry.get(\"title\", \"\") \n", | ||
" metadata = {\n", | ||
" \"type\": endpoint,\n", | ||
" \"repo\": repo,\n", | ||
" \"author\": entry[\"user\"][\"login\"],\n", | ||
" \"comments\": entry[\"comments\"],\n", | ||
" \"body\": entry[\"body\"],\n", | ||
" \"labels\": entry[\"labels\"],\n", | ||
" \"created_at\": entry[\"created_at\"][0:10], ## slicing the extra part\n", | ||
" }\n", | ||
" if entry['body']:\n", | ||
" str_data += \" \"\n", | ||
" str_data += entry['body']\n", | ||
" doc = Document(page_content=str_data, metadata=metadata)\n", | ||
" docs.append(doc)\n", | ||
Check failure on line 109 in cookbook/Github_Agent.ipynb GitHub Actions / cd . / make lint #3.12Ruff (E741)
|
||
"\n", | ||
" return docs" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"owner = \"microsoft\"\n", | ||
"repo = \"DeepSpeed\"\n", | ||
"docs = fetch_github_issues(owner, repo, \"issues\") # Fetch issues from the specified repo\n", | ||
"\n", | ||
" # Extract and print the created date of each issue\n", | ||
"#for doc in docs:\n", | ||
" #created_at = doc.metadata.get('created_at')\n", | ||
" #print(f\"Issue created at: {created_at}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 16, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class FAISStore(VectorStore):\n", | ||
" def __init__(self):\n", | ||
" # Initialize FAISS index with a flat index type\n", | ||
" self._embeddings = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')\n", | ||
" d = 384 # Dimension of embeddings\n", | ||
" self.index = faiss.IndexFlatL2(d) # Use a flat index without clustering\n", | ||
" self.documents = []\n", | ||
"\n", | ||
" @property\n", | ||
" def embeddings(self):\n", | ||
" return self._embeddings\n", | ||
"\n", | ||
" def add_docs(self, docs):\n", | ||
" vectors_to_upsert = []\n", | ||
"\n", | ||
" for doc in docs:\n", | ||
" # Encode the cleaned document content into embeddings\n", | ||
" embed_docs = self.embeddings.encode(doc.page_content).astype('float32')\n", | ||
"\n", | ||
" # Create a unique ID for the document\n", | ||
" unique_id = doc.metadata.get(\"author\", \"unknown_author\") + \"_\" + doc.metadata.get(\"type\", \"unknown_type\")\n", | ||
"\n", | ||
" # Append vector and unique ID\n", | ||
" vectors_to_upsert.append((unique_id, embed_docs))\n", | ||
"\n", | ||
" # Store the document for future retrieval\n", | ||
" self.documents.append((unique_id, doc)) # Store Document object directly\n", | ||
"\n", | ||
" # Upsert vectors into FAISS\n", | ||
" embed_docs_array = np.array([vec for _, vec in vectors_to_upsert]).astype('float32')\n", | ||
" self.index.add(embed_docs_array) # Add vectors to the index\n", | ||
"\n", | ||
" def search(self, query, k=1):\n", | ||
" # Encode the query into an embedding\n", | ||
" query_embedding = self.embeddings.encode(query).astype('float32').reshape(1, -1)\n", | ||
"\n", | ||
" # Perform the similarity search\n", | ||
" D, I = self.index.search(query_embedding, k=k)\n", | ||
"\n", | ||
" # Retrieve metadata and content for the results\n", | ||
" results = []\n", | ||
" for idx in I[0]:\n", | ||
" if idx >= 0:\n", | ||
" unique_id, document = self.documents[idx]\n", | ||
" results.append(document)\n", | ||
"\n", | ||
" return results # Return Document objects\n", | ||
"\n", | ||
" def similarity_search(self, query, k=1):\n", | ||
" return self.search(query, k)\n", | ||
"\n", | ||
" def from_texts(self, texts, metadatas=None):\n", | ||
" \"\"\" Takes a list of texts and corresponding metadata, creates Documents, and adds them to the vector store. \"\"\"\n", | ||
" docs = [Document(page_content=self.preprocess_content(text), metadata=metadata)\n", | ||
" for text, metadata in zip(texts, metadatas or [{}]*len(texts))]\n", | ||
" self.add_docs(docs)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"store = FAISStore()\n", | ||
"owner = \"microsoft\"\n", | ||
"repo = \"DeepSpeed\"\n", | ||
"\n", | ||
" # Fetch GitHub pull requests and add them to FAISS\n", | ||
"docs = fetch_github_issues(owner, repo, \"issues\")\n", | ||
"store.add_docs(docs)\n", | ||
"\n", | ||
" # Query the FAISS index\n", | ||
"result = store.similarity_search(\"Fix bug with hybrid engine generation\")\n", | ||
"print(result)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 19, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain_groq import ChatGroq # Assuming you are using Groq for chat\n", | ||
"from langchain.chains import RetrievalQA\n", | ||
"from langchain.memory import ConversationBufferMemory\n", | ||
"from langchain import hub\n", | ||
"from langchain.tools.retriever import create_retriever_tool\n", | ||
"from langchain.agents import initialize_agent\n", | ||
"from langchain.agents import create_tool_calling_agent\n", | ||
"from langchain.agents import AgentExecutor\n", | ||
"from langchain.prompts import PromptTemplate" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 27, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"FLAG_FILE = \"data_loaded.flag\"\n", | ||
"\n", | ||
"class Agent:\n", | ||
" def __init__(self):\n", | ||
" # Initialize FAISS store separately\n", | ||
" self.vector_store = FAISStore()\n", | ||
" \n", | ||
" # Initialize memory for conversation\n", | ||
" self.conversational_memory = ConversationBufferMemory(\n", | ||
" memory_key='chat_history',\n", | ||
" return_messages=True # Store messages as a list\n", | ||
" )\n", | ||
" \n", | ||
" # Initialize the LLM\n", | ||
" self.llm = ChatGroq(\n", | ||
" temperature=0.0,\n", | ||
" model='llama-3.1-70b-versatile',\n", | ||
" api_key=os.getenv('GROQ_API_KEY'),\n", | ||
" verbose=True\n", | ||
" )\n", | ||
" \n", | ||
" def _run(self, response):\n", | ||
" template = '''This is a response from github agent. Make the Response well Structured and formatted!!\n", | ||
" Here is the response from the agent: {response}'''\n", | ||
" \n", | ||
" prompt = PromptTemplate(template=template, input_variables=['response'])\n", | ||
" formatted_prompt = prompt.format(response=response)\n", | ||
" return self.llm.invoke(formatted_prompt)\n", | ||
" \n", | ||
" \n", | ||
" def initialize(self, owner, repo, endpoint):\n", | ||
" if not os.path.exists(FLAG_FILE): # Check if the flag file exists\n", | ||
" print(\"No data found in the FAISS store. Fetching data from GitHub...\")\n", | ||
" docs = fetch_github_issues(owner, repo, endpoint) # Fetch issues/pulls\n", | ||
" if docs: # Only add if documents were fetched\n", | ||
" self.vector_store.add_docs(docs) # Add docs to the FAISS store\n", | ||
" with open(FLAG_FILE, \"w\") as f: # Create a flag file to indicate data has been loaded\n", | ||
" f.write(\"Data loaded\")\n", | ||
" print(f\"Added {len(docs)} documents to the FAISS store.\")\n", | ||
" else:\n", | ||
" print(\"No documents fetched from GitHub.\")\n", | ||
" else:\n", | ||
" user_input = input(\"Data is already loaded. Do you want to re-fetch it from GitHub? (yes/no): \").strip().lower()\n", | ||
" if user_input == 'yes':\n", | ||
" print(\"Re-fetching data from GitHub...\")\n", | ||
" docs = fetch_github_issues(owner, repo, endpoint) # Fetch issues/pulls\n", | ||
" if docs:\n", | ||
" self.vector_store.add_docs(docs) # Add docs to the FAISS store\n", | ||
" print(f\"Added {len(docs)} documents to the FAISS store.\")\n", | ||
" else:\n", | ||
" print(\"No documents fetched from GitHub.\")\n", | ||
" else:\n", | ||
" print(\"Using existing data from the FAISS store.\")\n", | ||
"\n", | ||
" def make_agent(self):\n", | ||
" # Set up the retrieval-based question answering chain\n", | ||
" retriever = self.vector_store.as_retriever() # Use `as_retriever` to make it compatible with RetrievalQA\n", | ||
"\n", | ||
" # Create the retriever tool\n", | ||
" self.retriever_tool = create_retriever_tool(\n", | ||
" retriever,\n", | ||
" \"GitHub Search\",\n", | ||
" 'The user is asking question which is related to this tool .Use this tool for any question . It will search the GitHub repository for relevant issues and pull requests.'\n", | ||
" )\n", | ||
"\n", | ||
" # Initialize the agent\n", | ||
" tools = [self.retriever_tool]\n", | ||
" #prompt = hub.pull(\"hwchase17/openai-functions-agent\")\n", | ||
" #agent = create_tool_calling_agent(self.llm, tools, prompt)\n", | ||
" #self.agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)\n", | ||
" \n", | ||
" self.agent_executor = initialize_agent(\n", | ||
" llm=self.llm,\n", | ||
" agent='conversational-react-description', \n", | ||
" tools=tools,\n", | ||
" verbose=True,\n", | ||
" max_iterations=3,\n", | ||
" memory=self.conversational_memory\n", | ||
")\n", | ||
"\n", | ||
" def run_query(self, query):\n", | ||
" \"\"\"Run a query through the agent and return the response.\"\"\"\n", | ||
" response = self.agent_executor({\"input\": query})\n", | ||
" res=self._run(response)\n", | ||
" return res" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"agent = Agent()\n", | ||
" \n", | ||
" # Initialize the agent with appropriate parameters\n", | ||
"agent.initialize(owner='microsoft', repo='DeepSpeed', endpoint='issues')\n", | ||
"agent.make_agent() # Initialize the agent tools" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.2" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |