Skip to content

Commit b59e2b5

Browse files
authored
Merge pull request #10 from VowpalWabbit/dot_prods_auto_embed
Dot prods auto embed
2 parents a9ba6a8 + ae5edef commit b59e2b5

File tree

6 files changed

+298
-93
lines changed

6 files changed

+298
-93
lines changed

libs/langchain/langchain/chains/rl_chain/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from langchain.chains.rl_chain.pick_best_chain import (
1616
PickBest,
1717
PickBestEvent,
18+
PickBestFeatureEmbedder,
1819
PickBestSelected,
1920
)
2021

@@ -37,6 +38,7 @@ def configure_logger() -> None:
3738
"PickBest",
3839
"PickBestEvent",
3940
"PickBestSelected",
41+
"PickBestFeatureEmbedder",
4042
"Embed",
4143
"BasedOn",
4244
"ToSelectFrom",

libs/langchain/langchain/chains/rl_chain/base.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,7 @@ def get_based_on_and_to_select_from(inputs: Dict[str, Any]) -> Tuple[Dict, Dict]
118118

119119
if not to_select_from:
120120
raise ValueError(
121-
"No variables using 'ToSelectFrom' found in the inputs. \
122-
Please include at least one variable containing a list to select from."
121+
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from." # noqa: E501
123122
)
124123

125124
based_on = {
@@ -229,6 +228,9 @@ def save(self) -> None:
229228

230229

231230
class Embedder(Generic[TEvent], ABC):
231+
def __init__(self, *args: Any, **kwargs: Any):
232+
pass
233+
232234
@abstractmethod
233235
def format(self, event: TEvent) -> str:
234236
...
@@ -300,9 +302,7 @@ def score_response(
300302
return resp
301303
except Exception as e:
302304
raise RuntimeError(
303-
f"The auto selection scorer did not manage to score the response, \
304-
there is always the option to try again or tweak the reward prompt.\
305-
Error: {e}"
305+
f"The auto selection scorer did not manage to score the response, there is always the option to try again or tweak the reward prompt. Error: {e}" # noqa: E501
306306
)
307307

308308

@@ -316,7 +316,7 @@ class RLChain(Chain, Generic[TEvent]):
316316
- selection_scorer (Union[SelectionScorer, None]): Scorer for the selection. Can be set to None.
317317
- policy (Optional[Policy]): The policy used by the chain to learn to populate a dynamic prompt.
318318
- auto_embed (bool): Determines if embedding should be automatic. Default is False.
319-
- metrics (Optional[MetricsTracker]): Tracker for metrics, can be set to None.
319+
- metrics (Optional[Union[MetricsTrackerRollingWindow, MetricsTrackerAverage]]): Tracker for metrics, can be set to None.
320320
321321
Initialization Attributes:
322322
- feature_embedder (Embedder): Embedder used for the `BasedOn` and `ToSelectFrom` inputs.
@@ -325,7 +325,8 @@ class RLChain(Chain, Generic[TEvent]):
325325
- vw_cmd (List[str], optional): Command line arguments for the VW model.
326326
- policy (Type[VwPolicy]): Policy used by the chain.
327327
- vw_logs (Optional[Union[str, os.PathLike]]): Path for the VW logs.
328-
- metrics_step (int): Step for the metrics tracker. Default is -1.
328+
- metrics_step (int): Step for the metrics tracker. Default is -1. If set without metrics_window_size, average metrics will be tracked, otherwise rolling window metrics will be tracked.
329+
- metrics_window_size (int): Window size for the metrics tracker. Default is -1. If set, rolling window metrics will be tracked.
329330
330331
Notes:
331332
The class initializes the VW model using the provided arguments. If `selection_scorer` is not provided, a warning is logged, indicating that no reinforcement learning will occur unless the `update_with_delayed_score` method is called.
@@ -423,8 +424,7 @@ def update_with_delayed_score(
423424
""" # noqa: E501
424425
if self._can_use_selection_scorer() and not force_score:
425426
raise RuntimeError(
426-
"The selection scorer is set, and force_score was not set to True. \
427-
Please set force_score=True to use this function."
427+
"The selection scorer is set, and force_score was not set to True. Please set force_score=True to use this function." # noqa: E501
428428
)
429429
if self.metrics:
430430
self.metrics.on_feedback(score)
@@ -458,9 +458,7 @@ def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
458458
or self.selected_based_on_input_key in inputs.keys()
459459
):
460460
raise ValueError(
461-
f"The rl chain does not accept '{self.selected_input_key}' \
462-
or '{self.selected_based_on_input_key}' as input keys, \
463-
they are reserved for internal use during auto reward."
461+
f"The rl chain does not accept '{self.selected_input_key}' or '{self.selected_based_on_input_key}' as input keys, they are reserved for internal use during auto reward." # noqa: E501
464462
)
465463

466464
def _can_use_selection_scorer(self) -> bool:
@@ -498,9 +496,6 @@ def _call(
498496
) -> Dict[str, Any]:
499497
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
500498

501-
if self.auto_embed:
502-
inputs = prepare_inputs_for_autoembed(inputs=inputs)
503-
504499
event: TEvent = self._call_before_predict(inputs=inputs)
505500
prediction = self.active_policy.predict(event=event)
506501
if self.metrics:
@@ -573,8 +568,7 @@ def embed_string_type(
573568

574569
if namespace is None:
575570
raise ValueError(
576-
"The default namespace must be \
577-
provided when embedding a string or _Embed object."
571+
"The default namespace must be provided when embedding a string or _Embed object." # noqa: E501
578572
)
579573

580574
return {namespace: keep_str + encoded}

libs/langchain/langchain/chains/rl_chain/pick_best_chain.py

Lines changed: 141 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -53,21 +53,24 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
5353
model name (Any, optional): The type of embeddings to be used for feature representation. Defaults to BERT SentenceTransformer.
5454
""" # noqa E501
5555

56-
def __init__(self, model: Optional[Any] = None, *args: Any, **kwargs: Any):
56+
def __init__(
57+
self, auto_embed: bool, model: Optional[Any] = None, *args: Any, **kwargs: Any
58+
):
5759
super().__init__(*args, **kwargs)
5860

5961
if model is None:
6062
from sentence_transformers import SentenceTransformer
6163

62-
model = SentenceTransformer("bert-base-nli-mean-tokens")
64+
model = SentenceTransformer("all-mpnet-base-v2")
6365

6466
self.model = model
67+
self.auto_embed = auto_embed
6568

66-
def format(self, event: PickBestEvent) -> str:
67-
"""
68-
Converts the `BasedOn` and `ToSelectFrom` into a format that can be used by VW
69-
"""
69+
@staticmethod
70+
def _str(embedding: List[float]) -> str:
71+
return " ".join([f"{i}:{e}" for i, e in enumerate(embedding)])
7072

73+
def get_label(self, event: PickBestEvent) -> tuple:
7174
cost = None
7275
if event.selected:
7376
chosen_action = event.selected.index
@@ -77,7 +80,11 @@ def format(self, event: PickBestEvent) -> str:
7780
else None
7881
)
7982
prob = event.selected.probability
83+
return chosen_action, cost, prob
84+
else:
85+
return None, None, None
8086

87+
def get_context_and_action_embeddings(self, event: PickBestEvent) -> tuple:
8188
context_emb = base.embed(event.based_on, self.model) if event.based_on else None
8289
to_select_from_var_name, to_select_from = next(
8390
iter(event.to_select_from.items()), (None, None)
@@ -97,6 +104,95 @@ def format(self, event: PickBestEvent) -> str:
97104
raise ValueError(
98105
"Context and to_select_from must be provided in the inputs dictionary"
99106
)
107+
return context_emb, action_embs
108+
109+
def get_indexed_dot_product(self, context_emb: List, action_embs: List) -> Dict:
110+
import numpy as np
111+
112+
unique_contexts = set()
113+
for context_item in context_emb:
114+
for ns, ee in context_item.items():
115+
if isinstance(ee, list):
116+
for ea in ee:
117+
unique_contexts.add(f"{ns}={ea}")
118+
else:
119+
unique_contexts.add(f"{ns}={ee}")
120+
121+
encoded_contexts = self.model.encode(list(unique_contexts))
122+
context_embeddings = dict(zip(unique_contexts, encoded_contexts))
123+
124+
unique_actions = set()
125+
for action in action_embs:
126+
for ns, e in action.items():
127+
if isinstance(e, list):
128+
for ea in e:
129+
unique_actions.add(f"{ns}={ea}")
130+
else:
131+
unique_actions.add(f"{ns}={e}")
132+
133+
encoded_actions = self.model.encode(list(unique_actions))
134+
action_embeddings = dict(zip(unique_actions, encoded_actions))
135+
136+
action_matrix = np.stack([v for k, v in action_embeddings.items()])
137+
context_matrix = np.stack([v for k, v in context_embeddings.items()])
138+
dot_product_matrix = np.dot(context_matrix, action_matrix.T)
139+
140+
indexed_dot_product: Dict = {}
141+
142+
for i, context_key in enumerate(context_embeddings.keys()):
143+
indexed_dot_product[context_key] = {}
144+
for j, action_key in enumerate(action_embeddings.keys()):
145+
indexed_dot_product[context_key][action_key] = dot_product_matrix[i, j]
146+
147+
return indexed_dot_product
148+
149+
def format_auto_embed_on(self, event: PickBestEvent) -> str:
150+
chosen_action, cost, prob = self.get_label(event)
151+
context_emb, action_embs = self.get_context_and_action_embeddings(event)
152+
indexed_dot_product = self.get_indexed_dot_product(context_emb, action_embs)
153+
154+
action_lines = []
155+
for i, action in enumerate(action_embs):
156+
line_parts = []
157+
dot_prods = []
158+
if cost is not None and chosen_action == i:
159+
line_parts.append(f"{chosen_action}:{cost}:{prob}")
160+
for ns, action in action.items():
161+
line_parts.append(f"|{ns}")
162+
elements = action if isinstance(action, list) else [action]
163+
nsa = []
164+
for elem in elements:
165+
line_parts.append(f"{elem}")
166+
ns_a = f"{ns}={elem}"
167+
nsa.append(ns_a)
168+
for k, v in indexed_dot_product.items():
169+
dot_prods.append(v[ns_a])
170+
nsa_str = " ".join(nsa)
171+
line_parts.append(f"|# {nsa_str}")
172+
173+
line_parts.append(f"|dotprod {self._str(dot_prods)}")
174+
action_lines.append(" ".join(line_parts))
175+
176+
shared = []
177+
for item in context_emb:
178+
for ns, context in item.items():
179+
shared.append(f"|{ns}")
180+
elements = context if isinstance(context, list) else [context]
181+
nsc = []
182+
for elem in elements:
183+
shared.append(f"{elem}")
184+
nsc.append(f"{ns}={elem}")
185+
nsc_str = " ".join(nsc)
186+
shared.append(f"|@ {nsc_str}")
187+
188+
return "shared " + " ".join(shared) + "\n" + "\n".join(action_lines)
189+
190+
def format_auto_embed_off(self, event: PickBestEvent) -> str:
191+
"""
192+
Converts the `BasedOn` and `ToSelectFrom` into a format that can be used by VW
193+
"""
194+
chosen_action, cost, prob = self.get_label(event)
195+
context_emb, action_embs = self.get_context_and_action_embeddings(event)
100196

101197
example_string = ""
102198
example_string += "shared "
@@ -120,6 +216,12 @@ def format(self, event: PickBestEvent) -> str:
120216
# Strip the last newline
121217
return example_string[:-1]
122218

219+
def format(self, event: PickBestEvent) -> str:
220+
if self.auto_embed:
221+
return self.format_auto_embed_on(event)
222+
else:
223+
return self.format_auto_embed_off(event)
224+
123225

124226
class PickBest(base.RLChain[PickBestEvent]):
125227
"""
@@ -154,50 +256,60 @@ def __init__(
154256
*args: Any,
155257
**kwargs: Any,
156258
):
157-
vw_cmd = kwargs.get("vw_cmd", [])
158-
if not vw_cmd:
159-
vw_cmd = [
160-
"--cb_explore_adf",
161-
"--quiet",
162-
"--interactions=::",
163-
"--coin",
164-
"--squarecb",
165-
]
259+
auto_embed = kwargs.get("auto_embed", False)
260+
261+
feature_embedder = kwargs.get("feature_embedder", None)
262+
if feature_embedder:
263+
if "auto_embed" in kwargs:
264+
logger.warning(
265+
"auto_embed will take no effect when explicit feature_embedder is provided" # noqa E501
266+
)
267+
# turning auto_embed off for cli setting below
268+
auto_embed = False
166269
else:
270+
feature_embedder = PickBestFeatureEmbedder(auto_embed=auto_embed)
271+
kwargs["feature_embedder"] = feature_embedder
272+
273+
vw_cmd = kwargs.get("vw_cmd", [])
274+
if vw_cmd:
167275
if "--cb_explore_adf" not in vw_cmd:
168276
raise ValueError(
169277
"If vw_cmd is specified, it must include --cb_explore_adf"
170278
)
171-
kwargs["vw_cmd"] = vw_cmd
279+
else:
280+
interactions = ["--interactions=::"]
281+
if auto_embed:
282+
interactions = [
283+
"--interactions=@#",
284+
"--ignore_linear=@",
285+
"--ignore_linear=#",
286+
]
287+
vw_cmd = interactions + [
288+
"--cb_explore_adf",
289+
"--coin",
290+
"--squarecb",
291+
"--quiet",
292+
]
172293

173-
feature_embedder = kwargs.get("feature_embedder", None)
174-
if not feature_embedder:
175-
feature_embedder = PickBestFeatureEmbedder()
176-
kwargs["feature_embedder"] = feature_embedder
294+
kwargs["vw_cmd"] = vw_cmd
177295

178296
super().__init__(*args, **kwargs)
179297

180298
def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBestEvent:
181299
context, actions = base.get_based_on_and_to_select_from(inputs=inputs)
182300
if not actions:
183301
raise ValueError(
184-
"No variables using 'ToSelectFrom' found in the inputs. \
185-
Please include at least one variable containing \
186-
a list to select from."
302+
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from." # noqa E501
187303
)
188304

189305
if len(list(actions.values())) > 1:
190306
raise ValueError(
191-
"Only one variable using 'ToSelectFrom' can be provided in the inputs \
192-
for the PickBest chain. Please provide only one variable \
193-
containing a list to select from."
307+
"Only one variable using 'ToSelectFrom' can be provided in the inputs for the PickBest chain. Please provide only one variable containing a list to select from." # noqa E501
194308
)
195309

196310
if not context:
197311
raise ValueError(
198-
"No variables using 'BasedOn' found in the inputs. \
199-
Please include at least one variable containing information \
200-
to base the selected of ToSelectFrom on."
312+
"No variables using 'BasedOn' found in the inputs. Please include at least one variable containing information to base the selected of ToSelectFrom on." # noqa E501
201313
)
202314

203315
event = PickBestEvent(inputs=inputs, to_select_from=actions, based_on=context)

0 commit comments

Comments
 (0)