-
Notifications
You must be signed in to change notification settings - Fork 548
chore(types): Type-clean llm/ (27 errors) #1394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
66ab985
bb67063
a7b65cb
9c822ed
d929143
d064748
09198df
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,18 +13,16 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import List, Optional, Type, Union | ||
from typing import List, Optional, Type | ||
|
||
from langchain.callbacks.manager import ( | ||
AsyncCallbackManagerForLLMRun, | ||
CallbackManagerForLLMRun, | ||
) | ||
from langchain_core.language_models.llms import LLM, BaseLLM | ||
from langchain_core.language_models.llms import LLM | ||
|
||
|
||
def get_llm_instance_wrapper( | ||
llm_instance: Union[LLM, BaseLLM], llm_type: str | ||
) -> Type[LLM]: | ||
def get_llm_instance_wrapper(llm_instance: LLM, llm_type: str) -> Type[LLM]: | ||
"""Wraps an LLM instance in a class that can be registered with LLMRails. | ||
|
||
This is useful to create specific types of LLMs using a generic LLM provider | ||
|
@@ -47,7 +45,7 @@ def model_kwargs(self): | |
These are needed to allow changes to the arguments of the LLM calls. | ||
""" | ||
if hasattr(llm_instance, "model_kwargs"): | ||
return llm_instance.model_kwargs | ||
return getattr(llm_instance, "model_kwargs") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How Pyright doesn't understand this? |
||
return {} | ||
|
||
@property | ||
|
@@ -66,26 +64,29 @@ def _modify_instance_kwargs(self): | |
""" | ||
|
||
if hasattr(llm_instance, "model_kwargs"): | ||
if isinstance(llm_instance.model_kwargs, dict): | ||
llm_instance.model_kwargs["temperature"] = self.temperature | ||
llm_instance.model_kwargs["streaming"] = self.streaming | ||
model_kwargs = getattr(llm_instance, "model_kwargs") | ||
if isinstance(model_kwargs, dict): | ||
model_kwargs["temperature"] = self.temperature | ||
model_kwargs["streaming"] = self.streaming | ||
|
||
def _call( | ||
self, | ||
prompt: str, | ||
stop: Optional[List[str]] = None, | ||
run_manager: Optional[CallbackManagerForLLMRun] = None, | ||
**kwargs, | ||
) -> str: | ||
self._modify_instance_kwargs() | ||
return llm_instance._call(prompt, stop, run_manager) | ||
return llm_instance._call(prompt, stop, run_manager, **kwargs) | ||
|
||
async def _acall( | ||
self, | ||
prompt: str, | ||
stop: Optional[List[str]] = None, | ||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | ||
**kwargs, | ||
) -> str: | ||
self._modify_instance_kwargs() | ||
return await llm_instance._acall(prompt, stop, run_manager) | ||
return await llm_instance._acall(prompt, stop, run_manager, **kwargs) | ||
|
||
return WrapperLLM |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -20,12 +20,15 @@ | |||||
from langchain_core.language_models import BaseChatModel | ||||||
from langchain_core.language_models.llms import BaseLLM | ||||||
|
||||||
from .langchain_initializer import ModelInitializationError, init_langchain_model | ||||||
from nemoguardrails.llm.models.langchain_initializer import ( | ||||||
ModelInitializationError, | ||||||
init_langchain_model, | ||||||
) | ||||||
|
||||||
|
||||||
# later we can easily conver it to a class | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
def init_llm_model( | ||||||
model_name: Optional[str], | ||||||
model_name: str, | ||||||
provider_name: str, | ||||||
mode: Literal["chat", "text"], | ||||||
kwargs: Dict[str, Any], | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This module is not used anymore, I should have deprecated it as part of #1387 . Sorry for the pain |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,14 +13,33 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import asyncio | ||
from typing import Any, List, Optional | ||
|
||
from langchain.callbacks.manager import ( | ||
AsyncCallbackManagerForLLMRun, | ||
CallbackManagerForLLMRun, | ||
) | ||
from langchain.schema.output import GenerationChunk | ||
from langchain_community.llms import HuggingFacePipeline | ||
|
||
# Import HuggingFacePipeline with fallbacks for different LangChain versions | ||
HuggingFacePipeline = None # type: ignore[assignment] | ||
|
||
try: | ||
from langchain_community.llms import ( | ||
HuggingFacePipeline, # type: ignore[attr-defined,no-redef] | ||
) | ||
except ImportError: | ||
# Fallback for older versions of langchain | ||
try: | ||
from langchain.llms import ( | ||
HuggingFacePipeline, # type: ignore[attr-defined,no-redef] | ||
) | ||
except ImportError: | ||
# Create a dummy class if HuggingFacePipeline is not available | ||
class HuggingFacePipeline: # type: ignore[misc,no-redef] | ||
def __init__(self, *args, **kwargs): | ||
raise ImportError("HuggingFacePipeline is not available") | ||
|
||
|
||
class HuggingFacePipelineCompatible(HuggingFacePipeline): | ||
|
@@ -47,12 +66,13 @@ def _call( | |
) | ||
|
||
# Streaming for NeMo Guardrails is not supported in sync calls. | ||
if self.model_kwargs and self.model_kwargs.get("streaming"): | ||
raise Exception( | ||
model_kwargs = getattr(self, "model_kwargs", {}) | ||
if model_kwargs and model_kwargs.get("streaming"): | ||
raise NotImplementedError( | ||
"Streaming mode not supported for HuggingFacePipeline in NeMo Guardrails!" | ||
) | ||
|
||
llm_result = self._generate( | ||
llm_result = getattr(self, "_generate")( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think getattr does just fool Pyright to ignoring the error and it is better to directly ignore it llm_result = self._generate( #type: ignore There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same for all getattr usage below |
||
[prompt], | ||
stop=stop, | ||
run_manager=run_manager, | ||
|
@@ -78,11 +98,12 @@ async def _acall( | |
) | ||
|
||
# Handle streaming, if the flag is set | ||
if self.model_kwargs and self.model_kwargs.get("streaming"): | ||
model_kwargs = getattr(self, "model_kwargs", {}) | ||
if model_kwargs and model_kwargs.get("streaming"): | ||
# Retrieve the streamer object, needs to be set in model_kwargs | ||
streamer = self.model_kwargs.get("streamer") | ||
streamer = model_kwargs.get("streamer") | ||
if not streamer: | ||
raise Exception( | ||
raise ValueError( | ||
"Cannot stream, please add HuggingFace streamer object to model_kwargs!" | ||
) | ||
|
||
|
@@ -99,7 +120,7 @@ async def _acall( | |
run_manager=run_manager, | ||
**kwargs, | ||
) | ||
loop.create_task(self._agenerate(**generation_kwargs)) | ||
loop.create_task(getattr(self, "_agenerate")(**generation_kwargs)) | ||
|
||
# And start waiting for the chunks to come in. | ||
completion = "" | ||
|
@@ -111,7 +132,7 @@ async def _acall( | |
|
||
return completion | ||
|
||
llm_result = await self._agenerate( | ||
llm_result = await getattr(self, "_agenerate")( | ||
[prompt], | ||
stop=stop, | ||
run_manager=run_manager, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is better exclude nemoguardrails/llm/providers/_langchain_nvidia_ai_endpoints_patch.py from pyright checking. This file is a runtime patch for an optional dependency, and type-checking it provides minimal value compared to the cost of installing all extras in CI.
can we exclude it in pyproject.toml?
then we should restore the original state of workflows