Skip to content

Commit 86c2642

Browse files
gbdrtmandel
andauthored
SMC inference (#1230)
* WIP SMC Signed-off-by: Guillaume Baudart <[email protected]> * Bug fix: HMM example Signed-off-by: Guillaume Baudart <[email protected]> * Patch performance SMC Signed-off-by: Guillaume Baudart <[email protected]> * Add hmm_nl_priors Signed-off-by: Guillaume Baudart <[email protected]> * feat: replay an execution (#1211) Signed-off-by: Louis Mandel <[email protected]> Signed-off-by: Guillaume Baudart <[email protected]> * refactor: performance optimizations (#1228) Signed-off-by: Louis Mandel <[email protected]> Signed-off-by: Guillaume Baudart <[email protected]> * WIP: Parallel loop SMC Signed-off-by: Guillaume Baudart <[email protected]> --------- Signed-off-by: Guillaume Baudart <[email protected]> Signed-off-by: Louis Mandel <[email protected]> Co-authored-by: Louis Mandel <[email protected]>
1 parent 8c30795 commit 86c2642

File tree

5 files changed

+234
-7
lines changed

5 files changed

+234
-7
lines changed

examples/ppdl/hmm.pdl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
defs:
2+
step:
3+
function:
4+
pre_x: number
5+
y: number
6+
return:
7+
defs:
8+
x:
9+
lang: python
10+
code: |
11+
from mu_ppl import Gaussian
12+
result = Gaussian(pre_x, 1).sample()
13+
score:
14+
lang: python
15+
code: |
16+
from mu_ppl import Gaussian
17+
result = Gaussian(x, 1).log_prob(y)
18+
lastOf:
19+
- factor: ${score}
20+
- data: ${x}
21+
obs:
22+
array: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 25, 26, 27, 28, 29, 30]
23+
24+
pre_x: 0
25+
26+
for:
27+
y: ${obs}
28+
repeat:
29+
defs:
30+
pre_x:
31+
call: ${step}
32+
args:
33+
pre_x: ${pre_x}
34+
y: ${y}
35+
data: ${pre_x}
36+
join:
37+
as: array
38+
39+
40+
41+

examples/ppdl/hmm_nl_priors.pdl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
defs:
2+
step:
3+
function:
4+
pre_x: number
5+
y: number
6+
return:
7+
defs:
8+
x:
9+
model: ollama_chat/granite3.3:2b
10+
parameters:
11+
temperature: 1
12+
input: We are modeling a random walk along a line. Generate a random number that is Gaussian distributed around ${pre_x}. We do not now the parameters of the random walk but we suspect that the new value should be greater than ${pre_x}. DO NOT GENERATE A PYTHON CODE, JUST ANSWER WITH THE NUMBER
13+
parser: json
14+
spec: number
15+
fallback:
16+
lang: python
17+
code: |
18+
from mu_ppl import Gaussian
19+
result = Gaussian(pre_x, 1).sample()
20+
score:
21+
lang: python
22+
code: |
23+
from mu_ppl import Gaussian
24+
result = Gaussian(x, 1).log_prob(y)
25+
lastOf:
26+
- factor: ${score}
27+
- data: ${x}
28+
obs:
29+
array: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
30+
31+
pre_x: 0
32+
33+
for:
34+
y: ${obs}
35+
repeat:
36+
defs:
37+
pre_x:
38+
call: ${step}
39+
args:
40+
pre_x: ${pre_x}
41+
y: ${y}
42+
data: ${pre_x}
43+
join:
44+
as: array
45+
46+
47+
48+

src/pdl/pdl_infer.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
from .pdl import InterpreterConfig, exec_program
99
from .pdl_ast import get_default_model_parameters
1010
from .pdl_parser import parse_file
11+
from .pdl_smc import infer_smc
1112
from .pdl_utils import validate_scope
13+
from matplotlib import pyplot as plt
1214

1315

1416
def main():
@@ -33,7 +35,11 @@ def main():
3335
default=5,
3436
)
3537
parser.add_argument(
36-
"-v", "--viz", help="Display the distribution of results", default=False
38+
"-v",
39+
"--viz",
40+
help="Display the distribution of results",
41+
default=False,
42+
action="store_true",
3743
)
3844
parser.add_argument(
3945
"--version",
@@ -67,12 +73,22 @@ def main():
6773
yield_result=False, yield_background=False, batch=1, cwd=Path(args.pdl).parent
6874
)
6975
program, loc = parse_file(args.pdl)
70-
with ImportanceSampling(num_particles=args.num_particles):
71-
dist = infer(
72-
lambda: exec_program(program, config, initial_scope, loc, "result")
73-
)
76+
# with ImportanceSampling(num_particles=args.num_particles):
77+
# dist = infer(
78+
# lambda: exec_program(program, config, initial_scope, loc, "result")
79+
# )
80+
81+
def model(replay):
82+
config["replay"] = replay
83+
result = exec_program(program, config, initial_scope, loc, "all")
84+
state = result["replay"]
85+
return result["result"], state
86+
87+
dist = infer_smc(args.num_particles, model)
88+
7489
if args.viz:
7590
viz(dist)
91+
plt.show()
7692
print(dist.sample())
7793
return 0
7894

src/pdl/pdl_interpreter.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@
144144
)
145145
from .pdl_schema_utils import get_json_schema # noqa: E402
146146
from .pdl_schema_validator import type_check_args, type_check_spec # noqa: E402
147+
from .pdl_smc import Resample
147148
from .pdl_utils import ( # noqa: E402
148149
GeneratorWrapper,
149150
apply_defaults,
@@ -314,7 +315,7 @@ def process_prog(
314315
loc,
315316
)
316317

317-
stdlib_scope = scope | PdlDict({"stdlib": stdlib_dict})
318+
stdlib_scope = scope # | PdlDict({"stdlib": stdlib_dict})
318319

319320
result, document, final_scope, trace = process_block(
320321
state, stdlib_scope, block=prog.root, loc=loc
@@ -567,6 +568,8 @@ def process_advance_block_retry( # noqa: C901
567568
break
568569
except KeyboardInterrupt as exc:
569570
raise exc from exc
571+
except Resample as exc:
572+
raise exc from exc
570573
except Exception as exc:
571574
do_retry = block.retry and trial_idx + 1 < trial_total
572575
if block.fallback is None and not do_retry:
@@ -610,7 +613,8 @@ def process_advance_block_retry( # noqa: C901
610613
trace=trace,
611614
)
612615
result = lazy_apply(checker, result)
613-
factor(score)
616+
if score != 0:
617+
factor(score)
614618
return result, background, new_scope, trace
615619

616620

@@ -1147,6 +1151,9 @@ def loop_body(iidx, items):
11471151
factor(weight)
11481152
result = PdlConst(None)
11491153
background = DependentContext([])
1154+
assert block.pdl__id is not None
1155+
state.replay[block.pdl__id] = None
1156+
raise Resample(state.replay)
11501157
case EmptyBlock():
11511158
result = PdlConst("")
11521159
background = DependentContext([])

src/pdl/pdl_smc.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from typing import TypeVar, ParamSpec, Callable, Any
2+
from mu_ppl.distributions import Categorical
3+
from mu_ppl import ImportanceSampling
4+
from tqdm import tqdm
5+
from copy import deepcopy
6+
from concurrent.futures import ThreadPoolExecutor
7+
import asyncio
8+
9+
10+
T = TypeVar("T")
11+
P = ParamSpec("P")
12+
13+
14+
class Resample(Exception):
15+
def __init__(self, state):
16+
self.state = state
17+
18+
19+
def resample(particles: list[Any], scores: list[float]) -> list[Any]:
20+
d = Categorical(list(zip(particles, scores)))
21+
return [
22+
d.sample() for _ in range(len(particles))
23+
] # resample a new set of particles
24+
25+
26+
def _process_particle(state, model, num_particles):
27+
"""Process a single particle and return (result, state, score)"""
28+
with ImportanceSampling(0) as sampler:
29+
try:
30+
result, new_state = model(state)
31+
return result, new_state, sampler.score
32+
except Resample as exn:
33+
return None, exn.state, sampler.score
34+
35+
36+
def infer_smc(num_particles: int, model) -> Categorical[Any]:
37+
"""Sequential version"""
38+
particles = [{} for _ in range(num_particles)] # initialise the particles
39+
results: list[Any] = []
40+
scores: list[float] = []
41+
while len(results) < num_particles:
42+
states = []
43+
scores = []
44+
results = []
45+
for state in particles:
46+
result, state, score = _process_particle(state, model, num_particles)
47+
if result is not None:
48+
results.append(result)
49+
states.append(state)
50+
scores.append(score)
51+
particles = resample(states, scores)
52+
return Categorical(list(zip(results, scores)))
53+
54+
55+
# Warning: Parallel version conflict with the context managers for inference. Need fix!
56+
57+
# def infer_smc(num_particles:int, model) -> Categorical[Any]:
58+
# """Parallelized version using ThreadPoolExecutor"""
59+
# particles = [{} for _ in range(num_particles)] # initialise the particles
60+
# results: list[Any] = []
61+
# scores: list[float] = []
62+
# while len(results) < num_particles:
63+
# states = []
64+
# scores = []
65+
# results = []
66+
# with ThreadPoolExecutor() as executor:
67+
# future_to_particle = {
68+
# executor.submit(_process_particle, state, model, num_particles): state
69+
# for state in particles
70+
# }
71+
# for future in future_to_particle:
72+
# result, state, score = future.result()
73+
# if result is not None:
74+
# results.append(result) # execute all the particles
75+
# states.append(state)
76+
# scores.append(score)
77+
# particles = resample(states, scores)
78+
# return Categorical(list(zip(results, scores)))
79+
80+
81+
# async def _process_particle_async(state, model, num_particles):
82+
# with ImportanceSampling(num_particles) as sampler:
83+
# try:
84+
# loop = asyncio.get_event_loop()
85+
# result, new_state = await loop.run_in_executor(None, lambda: model(state))
86+
# return result, new_state, sampler.score
87+
# except Resample as exn:
88+
# return None, exn.state, sampler.score
89+
90+
91+
# async def infer_smc_async(num_particles: int, model) -> Categorical[Any]:
92+
# """Parallelized version using Async"""
93+
# particles = [{} for _ in range(num_particles)] # initialise the particles
94+
# results: list[Any] = []
95+
# scores: list[float] = []
96+
# while len(results) < num_particles:
97+
# states = []
98+
# scores = []
99+
# results = []
100+
# tasks = [
101+
# _process_particle_async(state, model, num_particles)
102+
# for state in particles
103+
# ]
104+
# particle_results = await asyncio.gather(*tasks)
105+
# for result, state, score in particle_results:
106+
# if result is not None:
107+
# results.append(result)
108+
# states.append(state)
109+
# scores.append(score)
110+
# particles = resample(states, scores)
111+
# return Categorical(list(zip(results, scores)))
112+
113+
114+
# def infer_smc_async(num_particles: int, model) -> Categorical[Any]:
115+
# return asyncio.run(infer_smc_async(num_particles, model))

0 commit comments

Comments
 (0)