29
29
logger = init_logger (__name__ )
30
30
31
31
_S = TypeVar ("_S" , str , list [int ])
32
- _PromptSeq = Union [str , list [int ]]
32
+
33
+ PromptSeq = Union [str , list [int ]]
34
+ """A token sequence (list of token IDs) or text."""
33
35
34
36
35
37
@dataclass
36
38
class PromptReplacementDetails :
37
- full : _PromptSeq
39
+ """Details about the replacement token sequence or text."""
40
+
41
+ full : PromptSeq
38
42
"""The full replacement."""
39
43
40
- features : _PromptSeq
44
+ features : PromptSeq
41
45
"""
42
- The part of the replacement that corresponds to placeholder feature tokens.
46
+ The part of the replacement that corresponds to feature placeholders;
47
+ this will be replaced by the output of the vision encoder during model
48
+ inference.
43
49
"""
44
50
45
51
@staticmethod
46
- def from_seq (seq : _PromptSeq ) -> "PromptReplacementDetails" :
52
+ def from_seq (seq : PromptSeq ) -> "PromptReplacementDetails" :
47
53
return PromptReplacementDetails (full = seq , features = seq )
48
54
49
55
50
- _PromptRepl = Union [_PromptSeq , PromptReplacementDetails ]
56
+ PromptRepl = Union [PromptSeq , PromptReplacementDetails ]
57
+ """
58
+ The replacement token sequence or text.
59
+
60
+ If only part of the replacement corresponds to feature placeholders, you can
61
+ use :class:`PromptReplacementDetails` to specify which part.
62
+ """
51
63
52
64
53
65
@dataclass
54
66
class PromptReplacement :
55
67
"""
56
68
Defines how to replace portions of an input prompt with placeholder tokens.
69
+
70
+ Example:
71
+
72
+ For each image, replace one ``<image>`` input placeholder in the prompt
73
+ with a number of ``<image>`` feature placeholders
74
+ equal to the feature size of the vision encoder:
75
+
76
+ .. code-block:: python
77
+
78
+ PromptReplacement(
79
+ modality="image",
80
+ target="<image>",
81
+ replacement="<image>" * image_feature_size,
82
+ )
83
+
84
+ As above, but further pad the feature placeholders with ``<image_bos>``
85
+ and `<image_eos>``, which are not supposed to be passed to the vision
86
+ encoder:
87
+
88
+ .. code-block:: python
89
+
90
+ PromptReplacement(
91
+ modality="image",
92
+ target="<image>",
93
+ replacement=PromptReplacementDetails(
94
+ full="".join([
95
+ "<image_bos>",
96
+ "<image>" * image_feature_size,
97
+ "<image_eos>",
98
+ ]),
99
+ features="<image>" * image_feature_size,
100
+ ),
101
+ )
102
+
103
+ To avoid unnecessary tokenization during prompt replacement,
104
+ we recommended passing token sequences instead of text:
105
+
106
+ .. code-block:: python
107
+
108
+ PromptReplacement(
109
+ modality="image",
110
+ target=[image_token_id],
111
+ replacement=PromptReplacementDetails(
112
+ full=([image_bos_id] + [image_token_id] * image_feature_size
113
+ + [image_eos_id]),
114
+ features=[image_token_id] * image_feature_size,
115
+ ),
116
+ )
57
117
"""
58
118
59
119
modality : str
60
120
"""The modality for which the replacement is made."""
61
121
62
- target : _PromptSeq
122
+ target : PromptSeq
63
123
"""The token sequence (or text) to find and replace."""
64
124
65
- replacement : Union [Callable [[int ], _PromptRepl ],
66
- _PromptRepl ] = field (repr = False )
125
+ replacement : Union [Callable [[int ], PromptRepl ],
126
+ PromptRepl ] = field (repr = False )
67
127
"""
68
128
Given the index of the processed item within :attr:`modality`,
69
129
output the replacement token sequence (or text).
@@ -126,6 +186,10 @@ def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
126
186
127
187
@dataclass
128
188
class _BoundPromptSequence :
189
+ """
190
+ A :data:`_PromptSeq` bound to a tokenizer to automatically
191
+ convert between token sequence and text representations.
192
+ """
129
193
tokenizer : AnyTokenizer = field (repr = False )
130
194
131
195
_text : Optional [str ]
@@ -134,7 +198,7 @@ class _BoundPromptSequence:
134
198
@staticmethod
135
199
def from_seq (
136
200
tokenizer : AnyTokenizer ,
137
- seq : _PromptSeq ,
201
+ seq : PromptSeq ,
138
202
) -> "_BoundPromptSequence" :
139
203
return _BoundPromptSequence (
140
204
tokenizer = tokenizer ,
@@ -180,9 +244,9 @@ class BoundPromptReplacement:
180
244
tokenizer : AnyTokenizer = field (repr = False )
181
245
modality : str
182
246
183
- _target : _PromptSeq
184
- _replacement : Union [Callable [[int ], _PromptRepl ],
185
- _PromptRepl ] = field (repr = False )
247
+ _target : PromptSeq
248
+ _replacement : Union [Callable [[int ], PromptRepl ],
249
+ PromptRepl ] = field (repr = False )
186
250
187
251
def __post_init__ (self ) -> None :
188
252
self ._replacement_cache = dict [int , _BoundPromptReplacementGroup ]()
@@ -350,7 +414,7 @@ def find_text_matches(
350
414
351
415
352
416
def _resolve_matches (
353
- prompt : _PromptSeq ,
417
+ prompt : PromptSeq ,
354
418
mm_matches : Mapping [str , Sequence [_PromptReplacementMatch ]],
355
419
) -> list [_PromptReplacementMatch ]:
356
420
"""
0 commit comments