Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/full-tests.yml
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better exclude nemoguardrails/llm/providers/_langchain_nvidia_ai_endpoints_patch.py from pyright checking. This file is a runtime patch for an optional dependency, and type-checking it provides minimal value compared to the cost of installing all extras in CI.

can we exclude it in pyproject.toml?

  exclude = [
    "nemoguardrails/llm/providers/trtllm/**",
    "nemoguardrails/llm/providers/_langchain_nvidia_ai_endpoints_patch.py"
  ]

then we should restore the original state of workflows

Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ jobs:
os: ${{ matrix.os }}
image: ${{ matrix.image }}
python-version: ${{ matrix.python-version }}
upgrade-deps: true
full-tests-summary:
name: Full Tests Summary
needs: full-tests-matrix
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/pr-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ jobs:
os: ${{ matrix.os }}
image: ${{ matrix.image }}
python-version: ${{ matrix.python-version }}
upgrade-deps: true
pr-tests-summary:
name: PR Tests Summary
needs: pr-tests-matrix
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test-coverage-report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
run: poetry config virtualenvs.in-project true

- name: Install dependencies
run: poetry install --with dev
run: poetry install --with dev --all-extras

- name: Run pre-commit hooks
run: poetry run make pre_commit
Expand Down
8 changes: 4 additions & 4 deletions nemoguardrails/llm/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def to_messages(colang_history: str) -> List[dict]:
# a message from the user, and the rest gets translated to messages from the assistant.
lines = colang_history.split("\n")

bot_lines = []
bot_lines: list[str] = []
for i, line in enumerate(lines):
if line.startswith('user "'):
# If we have bot lines in the buffer, we first add a bot message.
Expand Down Expand Up @@ -181,8 +181,8 @@ def to_messages_v2(colang_history: str) -> List[dict]:
# a message from the user, and the rest gets translated to messages from the assistant.
lines = colang_history.split("\n")

user_lines = []
bot_lines = []
user_lines: list[str] = []
bot_lines: list[str] = []
for line in lines:
if line.startswith("user action:"):
if len(bot_lines) > 0:
Expand Down Expand Up @@ -275,7 +275,7 @@ def verbose_v1(colang_history: str) -> str:
return "\n".join(lines)


def to_chat_messages(events: List[dict]) -> str:
def to_chat_messages(events: List[dict]) -> List[dict]:
"""Filter that turns an array of events into a sequence of user/assistant messages.

Properly handles multimodal content by preserving the structure when the content
Expand Down
23 changes: 12 additions & 11 deletions nemoguardrails/llm/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, Type, Union
from typing import List, Optional, Type

from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM, BaseLLM
from langchain_core.language_models.llms import LLM


def get_llm_instance_wrapper(
llm_instance: Union[LLM, BaseLLM], llm_type: str
) -> Type[LLM]:
def get_llm_instance_wrapper(llm_instance: LLM, llm_type: str) -> Type[LLM]:
"""Wraps an LLM instance in a class that can be registered with LLMRails.

This is useful to create specific types of LLMs using a generic LLM provider
Expand All @@ -47,7 +45,7 @@ def model_kwargs(self):
These are needed to allow changes to the arguments of the LLM calls.
"""
if hasattr(llm_instance, "model_kwargs"):
return llm_instance.model_kwargs
return getattr(llm_instance, "model_kwargs")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How Pyright doesn't understand this?

return {}

@property
Expand All @@ -66,26 +64,29 @@ def _modify_instance_kwargs(self):
"""

if hasattr(llm_instance, "model_kwargs"):
if isinstance(llm_instance.model_kwargs, dict):
llm_instance.model_kwargs["temperature"] = self.temperature
llm_instance.model_kwargs["streaming"] = self.streaming
model_kwargs = getattr(llm_instance, "model_kwargs")
if isinstance(model_kwargs, dict):
model_kwargs["temperature"] = self.temperature
model_kwargs["streaming"] = self.streaming

def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs,
) -> str:
self._modify_instance_kwargs()
return llm_instance._call(prompt, stop, run_manager)
return llm_instance._call(prompt, stop, run_manager, **kwargs)

async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs,
) -> str:
self._modify_instance_kwargs()
return await llm_instance._acall(prompt, stop, run_manager)
return await llm_instance._acall(prompt, stop, run_manager, **kwargs)

return WrapperLLM
7 changes: 5 additions & 2 deletions nemoguardrails/llm/models/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@
from langchain_core.language_models import BaseChatModel
from langchain_core.language_models.llms import BaseLLM

from .langchain_initializer import ModelInitializationError, init_langchain_model
from nemoguardrails.llm.models.langchain_initializer import (
ModelInitializationError,
init_langchain_model,
)


# later we can easily conver it to a class
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# later we can easily conver it to a class
# later we can easily convert it to a class

def init_llm_model(
model_name: Optional[str],
model_name: str,
provider_name: str,
mode: Literal["chat", "text"],
kwargs: Dict[str, Any],
Expand Down
15 changes: 8 additions & 7 deletions nemoguardrails/llm/params.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This module is not used anymore, I should have deprecated it as part of #1387 . Sorry for the pain

Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"""

import logging
from typing import Dict, Type
from typing import Any, Dict, Type

from langchain.base_language import BaseLanguageModel

Expand All @@ -33,18 +33,18 @@ class LLMParams:
def __init__(self, llm: BaseLanguageModel, **kwargs):
self.llm = llm
self.altered_params = kwargs
self.original_params = {}
self.original_params: dict[str, Any] = {}

def __enter__(self):
# Here we can access and modify the global language model parameters.
self.original_params = {}
for param, value in self.altered_params.items():
if hasattr(self.llm, param):
self.original_params[param] = getattr(self.llm, param)
setattr(self.llm, param, value)

elif hasattr(self.llm, "model_kwargs"):
if param not in self.llm.model_kwargs:
model_kwargs = getattr(self.llm, "model_kwargs", {})
if param not in model_kwargs:
log.warning(
"Parameter %s does not exist for %s. Passing to model_kwargs",
param,
Expand All @@ -53,9 +53,10 @@ def __enter__(self):

self.original_params[param] = None
else:
self.original_params[param] = self.llm.model_kwargs[param]
self.original_params[param] = model_kwargs[param]

self.llm.model_kwargs[param] = value
model_kwargs[param] = value
setattr(self.llm, "model_kwargs", model_kwargs)

else:
log.warning(
Expand All @@ -64,7 +65,7 @@ def __enter__(self):
self.llm.__class__.__name__,
)

def __exit__(self, type, value, traceback):
def __exit__(self, exc_type, value, traceback):
# Restore original parameters when exiting the context
for param, value in self.original_params.items():
if hasattr(self.llm, param):
Expand Down
39 changes: 30 additions & 9 deletions nemoguardrails/llm/providers/huggingface/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
from typing import Any, List, Optional

from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.schema.output import GenerationChunk
from langchain_community.llms import HuggingFacePipeline

# Import HuggingFacePipeline with fallbacks for different LangChain versions
HuggingFacePipeline = None # type: ignore[assignment]

try:
from langchain_community.llms import (
HuggingFacePipeline, # type: ignore[attr-defined,no-redef]
)
except ImportError:
# Fallback for older versions of langchain
try:
from langchain.llms import (
HuggingFacePipeline, # type: ignore[attr-defined,no-redef]
)
except ImportError:
# Create a dummy class if HuggingFacePipeline is not available
class HuggingFacePipeline: # type: ignore[misc,no-redef]
def __init__(self, *args, **kwargs):
raise ImportError("HuggingFacePipeline is not available")


class HuggingFacePipelineCompatible(HuggingFacePipeline):
Expand All @@ -47,12 +66,13 @@ def _call(
)

# Streaming for NeMo Guardrails is not supported in sync calls.
if self.model_kwargs and self.model_kwargs.get("streaming"):
raise Exception(
model_kwargs = getattr(self, "model_kwargs", {})
if model_kwargs and model_kwargs.get("streaming"):
raise NotImplementedError(
"Streaming mode not supported for HuggingFacePipeline in NeMo Guardrails!"
)

llm_result = self._generate(
llm_result = getattr(self, "_generate")(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think getattr does just fool Pyright to ignoring the error and it is better to directly ignore it

        llm_result = self._generate( #type: ignore

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for all getattr usage below

[prompt],
stop=stop,
run_manager=run_manager,
Expand All @@ -78,11 +98,12 @@ async def _acall(
)

# Handle streaming, if the flag is set
if self.model_kwargs and self.model_kwargs.get("streaming"):
model_kwargs = getattr(self, "model_kwargs", {})
if model_kwargs and model_kwargs.get("streaming"):
# Retrieve the streamer object, needs to be set in model_kwargs
streamer = self.model_kwargs.get("streamer")
streamer = model_kwargs.get("streamer")
if not streamer:
raise Exception(
raise ValueError(
"Cannot stream, please add HuggingFace streamer object to model_kwargs!"
)

Expand All @@ -99,7 +120,7 @@ async def _acall(
run_manager=run_manager,
**kwargs,
)
loop.create_task(self._agenerate(**generation_kwargs))
loop.create_task(getattr(self, "_agenerate")(**generation_kwargs))

# And start waiting for the chunks to come in.
completion = ""
Expand All @@ -111,7 +132,7 @@ async def _acall(

return completion

llm_result = await self._agenerate(
llm_result = await getattr(self, "_agenerate")(
[prompt],
stop=stop,
run_manager=run_manager,
Expand Down
26 changes: 22 additions & 4 deletions nemoguardrails/llm/providers/huggingface/streamers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,27 @@
# limitations under the License.

import asyncio
from typing import TYPE_CHECKING, Optional

from transformers.generation.streamers import TextStreamer
TRANSFORMERS_AVAILABLE = True
try:
from transformers.generation.streamers import ( # type: ignore[import-untyped]
TextStreamer,
)
except ImportError:
# Fallback if transformers is not available
TRANSFORMERS_AVAILABLE = False

class TextStreamer: # type: ignore[no-redef]
def __init__(self, *args, **kwargs):
pass

class AsyncTextIteratorStreamer(TextStreamer):

if TYPE_CHECKING:
from transformers import AutoTokenizer # type: ignore[import-untyped]


class AsyncTextIteratorStreamer(TextStreamer): # type: ignore[misc]
"""
Simple async implementation for HuggingFace Transformers streamers.

Expand All @@ -30,12 +46,14 @@ def __init__(
self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs
):
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
self.text_queue = asyncio.Queue()
self.text_queue: asyncio.Queue[str] = asyncio.Queue()
self.stop_signal = None
self.loop = None
self.loop: Optional[asyncio.AbstractEventLoop] = None

def on_finalized_text(self, text: str, stream_end: bool = False):
"""Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
if self.loop is None:
return
if len(text) > 0:
asyncio.run_coroutine_threadsafe(self.text_queue.put(text), self.loop)

Expand Down
Loading