diff --git a/libs/community/langchain_community/tools/wikipedia/tool.py b/libs/community/langchain_community/tools/wikipedia/tool.py index a74d437538d42..66af94e41a055 100644 --- a/libs/community/langchain_community/tools/wikipedia/tool.py +++ b/libs/community/langchain_community/tools/wikipedia/tool.py @@ -1,13 +1,20 @@ """Tool for the Wikipedia API.""" -from typing import Optional +from typing import Optional, Type from langchain_core.callbacks import CallbackManagerForToolRun +from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.tools import BaseTool from langchain_community.utilities.wikipedia import WikipediaAPIWrapper +class WikipediaQueryInput(BaseModel): + """Input for the WikipediaQuery tool.""" + + query: str = Field(description="query to look up on wikipedia") + + class WikipediaQueryRun(BaseTool): """Tool that searches the Wikipedia API.""" @@ -20,6 +27,8 @@ class WikipediaQueryRun(BaseTool): ) api_wrapper: WikipediaAPIWrapper + args_schema: Type[BaseModel] = WikipediaQueryInput + def _run( self, query: str,