Skip to content

Commit 6bfe98b

Browse files
authored
Integration of Jinja2 Templating (#875)
* feat: Add support for jinja templating Signed-off-by: teleprint-me <[email protected]> * fix: Refactor chat formatter and update interface for jinja templates - Simplify the `llama2_template` in `llama_jinja_format.py` by removing unnecessary line breaks for readability without affecting functionality. - Update `ChatFormatterInterface` constructor to accept a more generic `Optional[object]` type for the template parameter, enhancing flexibility. - Introduce a `template` property to `ChatFormatterInterface` for standardized access to the template string. - Replace `MetaSingleton` metaclass with `Singleton` for the `ChatFormatterFactory` to streamline the singleton implementation. These changes enhance code readability, maintain usability, and ensure consistency in the chat formatter's design pattern usage. * Add outline for Jinja2 templating integration documentation Signed-off-by: teleprint-me <[email protected]> * Add jinja2 as a dependency with version range for Hugging Face transformers compatibility Signed-off-by: teleprint-me <[email protected]> * Update jinja2 version constraint for mkdocs-material compatibility Signed-off-by: teleprint-me <[email protected]> * Fix attribute name in AutoChatFormatter - Changed attribute name from `self._renderer` to `self._environment` --------- Signed-off-by: teleprint-me <[email protected]>
1 parent 52adc23 commit 6bfe98b

File tree

4 files changed

+243
-1
lines changed

4 files changed

+243
-1
lines changed

docs/templates.md

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Templates
2+
3+
This document provides a comprehensive guide to the integration of Jinja2 templating into the `llama-cpp-python` project, with a focus on enhancing the chat functionality of the `llama-2` model.
4+
5+
## Introduction
6+
7+
- Brief explanation of the `llama-cpp-python` project's need for a templating system.
8+
- Overview of the `llama-2` model's interaction with templating.
9+
10+
## Jinja2 Dependency Integration
11+
12+
- Rationale for choosing Jinja2 as the templating engine.
13+
- Compatibility with Hugging Face's `transformers`.
14+
- Desire for advanced templating features and simplicity.
15+
- Detailed steps for adding `jinja2` to `pyproject.toml` for dependency management.
16+
17+
## Template Management Refactor
18+
19+
- Summary of the refactor and the motivation behind it.
20+
- Description of the new chat handler selection logic:
21+
1. Preference for a user-specified `chat_handler`.
22+
2. Fallback to a user-specified `chat_format`.
23+
3. Defaulting to a chat format from a `.gguf` file if available.
24+
4. Utilizing the `llama2` default chat format as the final fallback.
25+
- Ensuring backward compatibility throughout the refactor.
26+
27+
## Implementation Details
28+
29+
- In-depth look at the new `AutoChatFormatter` class.
30+
- Example code snippets showing how to utilize the Jinja2 environment and templates.
31+
- Guidance on how to provide custom templates or use defaults.
32+
33+
## Testing and Validation
34+
35+
- Outline of the testing strategy to ensure seamless integration.
36+
- Steps for validating backward compatibility with existing implementations.
37+
38+
## Benefits and Impact
39+
40+
- Analysis of the expected benefits, including consistency, performance gains, and improved developer experience.
41+
- Discussion of the potential impact on current users and contributors.
42+
43+
## Future Work
44+
45+
- Exploration of how templating can evolve within the project.
46+
- Consideration of additional features or optimizations for the templating engine.
47+
- Mechanisms for community feedback on the templating system.
48+
49+
## Conclusion
50+
51+
- Final thoughts on the integration of Jinja2 templating.
52+
- Call to action for community involvement and feedback.

llama_cpp/llama_jinja_format.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
"""
2+
llama_cpp/llama_jinja_format.py
3+
"""
4+
import dataclasses
5+
from typing import Any, Callable, Dict, List, Optional, Protocol, Union
6+
7+
import jinja2
8+
from jinja2 import Template
9+
10+
# NOTE: We sacrifice readability for usability.
11+
# It will fail to work as expected if we attempt to format it in a readable way.
12+
llama2_template = """{% for message in messages %}{% if message['role'] == 'user' %}[INST] {{ message['content'] }} [/INST]\n{% elif message['role'] == 'assistant' %}{{ message['content'] }}\n{% elif message['role'] == 'system' %}<<SYS>> {{ message['content'] }} <</SYS>>\n{% endif %}{% endfor %}"""
13+
14+
15+
class MetaSingleton(type):
16+
"""
17+
Metaclass for implementing the Singleton pattern.
18+
"""
19+
20+
_instances = {}
21+
22+
def __call__(cls, *args, **kwargs):
23+
if cls not in cls._instances:
24+
cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs)
25+
return cls._instances[cls]
26+
27+
28+
class Singleton(object, metaclass=MetaSingleton):
29+
"""
30+
Base class for implementing the Singleton pattern.
31+
"""
32+
33+
def __init__(self):
34+
super(Singleton, self).__init__()
35+
36+
37+
@dataclasses.dataclass
38+
class ChatFormatterResponse:
39+
prompt: str
40+
stop: Optional[Union[str, List[str]]] = None
41+
42+
43+
# Base Chat Formatter Protocol
44+
class ChatFormatterInterface(Protocol):
45+
def __init__(self, template: Optional[object] = None):
46+
...
47+
48+
def __call__(
49+
self,
50+
messages: List[Dict[str, str]],
51+
**kwargs,
52+
) -> ChatFormatterResponse:
53+
...
54+
55+
@property
56+
def template(self) -> str:
57+
...
58+
59+
60+
class AutoChatFormatter(ChatFormatterInterface):
61+
def __init__(
62+
self,
63+
template: Optional[str] = None,
64+
template_class: Optional[Template] = None,
65+
):
66+
if template is not None:
67+
self._template = template
68+
else:
69+
self._template = llama2_template # default template
70+
71+
self._environment = jinja2.Environment(
72+
loader=jinja2.BaseLoader(),
73+
trim_blocks=True,
74+
lstrip_blocks=True,
75+
).from_string(
76+
self._template,
77+
template_class=template_class,
78+
)
79+
80+
def __call__(
81+
self,
82+
messages: List[Dict[str, str]],
83+
**kwargs: Any,
84+
) -> ChatFormatterResponse:
85+
formatted_sequence = self._environment.render(messages=messages, **kwargs)
86+
return ChatFormatterResponse(prompt=formatted_sequence)
87+
88+
@property
89+
def template(self) -> str:
90+
return self._template
91+
92+
93+
class FormatterNotFoundException(Exception):
94+
pass
95+
96+
97+
class ChatFormatterFactory(Singleton):
98+
_chat_formatters: Dict[str, Callable[[], ChatFormatterInterface]] = {}
99+
100+
def register_formatter(
101+
self,
102+
name: str,
103+
formatter_callable: Callable[[], ChatFormatterInterface],
104+
overwrite=False,
105+
):
106+
if not overwrite and name in self._chat_formatters:
107+
raise ValueError(
108+
f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it."
109+
)
110+
self._chat_formatters[name] = formatter_callable
111+
112+
def unregister_formatter(self, name: str):
113+
if name in self._chat_formatters:
114+
del self._chat_formatters[name]
115+
else:
116+
raise ValueError(f"No formatter registered under the name '{name}'.")
117+
118+
def get_formatter_by_name(self, name: str) -> ChatFormatterInterface:
119+
try:
120+
formatter_callable = self._chat_formatters[name]
121+
return formatter_callable()
122+
except KeyError:
123+
raise FormatterNotFoundException(
124+
f"Invalid chat format: {name} (valid formats: {list(self._chat_formatters.keys())})"
125+
)
126+
127+
128+
# Define a chat format class
129+
class Llama2Formatter(AutoChatFormatter):
130+
def __init__(self):
131+
super().__init__(llama2_template)
132+
133+
134+
# With the Singleton pattern applied, regardless of where or how many times
135+
# ChatFormatterFactory() is called, it will always return the same instance
136+
# of the factory, ensuring that the factory's state is consistent throughout
137+
# the application.
138+
ChatFormatterFactory().register_formatter("llama-2", Llama2Formatter)

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@ license = { text = "MIT" }
1111
authors = [
1212
{ name = "Andrei Betlen", email = "[email protected]" },
1313
]
14+
# mkdocs-martiral requires "jinja2~=3.0"
15+
# transformers requires "jinja2>=2.11.3"
1416
dependencies = [
1517
"typing-extensions>=4.5.0",
1618
"numpy>=1.20.0",
1719
"diskcache>=5.6.1",
20+
"jinja2>=2.11.3",
1821
]
1922
requires-python = ">=3.8"
2023
classifiers = [
@@ -72,4 +75,3 @@ Changelog = "https://llama-cpp-python.readthedocs.io/en/latest/changelog/"
7275

7376
[tool.pytest.ini_options]
7477
addopts = "--ignore=vendor"
75-

tests/test_llama_chat_format.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from typing import List
2+
3+
import pytest
4+
5+
from llama_cpp import ChatCompletionMessage
6+
from llama_cpp.llama_jinja_format import Llama2Formatter
7+
8+
9+
@pytest.fixture
10+
def sequence_of_messages() -> List[ChatCompletionMessage]:
11+
return [
12+
ChatCompletionMessage(role="system", content="Welcome to CodeHelp Bot!"),
13+
ChatCompletionMessage(
14+
role="user", content="Hi there! I need some help with Python."
15+
),
16+
ChatCompletionMessage(
17+
role="assistant", content="Of course! What do you need help with in Python?"
18+
),
19+
ChatCompletionMessage(
20+
role="user",
21+
content="I'm trying to write a function to find the factorial of a number, but I'm stuck.",
22+
),
23+
ChatCompletionMessage(
24+
role="assistant",
25+
content="I can help with that! Would you like a recursive or iterative solution?",
26+
),
27+
ChatCompletionMessage(
28+
role="user", content="Let's go with a recursive solution."
29+
),
30+
]
31+
32+
33+
def test_llama2_formatter(sequence_of_messages):
34+
expected_prompt = (
35+
"<<SYS>> Welcome to CodeHelp Bot! <</SYS>>\n"
36+
"[INST] Hi there! I need some help with Python. [/INST]\n"
37+
"Of course! What do you need help with in Python?\n"
38+
"[INST] I'm trying to write a function to find the factorial of a number, but I'm stuck. [/INST]\n"
39+
"I can help with that! Would you like a recursive or iterative solution?\n"
40+
"[INST] Let's go with a recursive solution. [/INST]\n"
41+
)
42+
43+
llama2_formatter_instance = Llama2Formatter()
44+
formatter_response = llama2_formatter_instance(sequence_of_messages)
45+
assert (
46+
expected_prompt == formatter_response.prompt
47+
), "The formatted prompt does not match the expected output."
48+
49+
50+
# Optionally, include a test for the 'stop' if it's part of the functionality.

0 commit comments

Comments
 (0)