-
Notifications
You must be signed in to change notification settings - Fork 135
Use Numba Generators for random graphs and deprecate shared RandomState variables #316
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I'm trying to understand this issue to start a PR. What actually needs to be done? Numba functions can take numpy random generators without any hassle now. For example, this works: import numba as nb
import numpy as np
@nb.njit
def draw_nb(rng, loc, scale, size):
return rng.normal(loc=loc, scale=scale, size=size)
rng = np.random.default_rng()
draw_nb(rng, 0.0, 1.0, (10,)) So would it be enough to just chop out all of the extra machinery from Numba still doesn't support broadcasting from parameters, so all the rest of the machinery seems like it needs to say (though I personally find it quite difficult to follow, it would be nice to refactor it to be more clear). |
This is the big hurdle. Generators without broadcasting is pretty useless so we have to support it if we want to really support RVs in numba backend (and phase out the RandomState) |
This is the PR where Generator support was added to Aesara. The main complexity is writing broadcasting logic with python strings which is the usual numba backend PITA: aesara-devs/aesara#1245 |
But I'm saying I'm pretty sure you can directly plug a generator into what they already have? There's even a numba datatype for numpy random generators ( |
I think the only challenge is the broadcasting logic, which I think can't be written as a Python function without writing it for every Op? I don't remember exactly where did things break. Also the strict API for RVs requires copying the RNG if the Op is not inplace. Not sure if this is relevant. |
What did you mean, what they already have? |
Feel free to open a PR if you it seems like some minimal changes do the job (or even if they don't). Unfortunately, I've lost the context to this issue to be able to help just from thinking |
The broadcasting is done with a loop, so it's actually not too bad. Here is a basic sketch: import numba as nb
import numpy as np
@nb.njit
def draw_nb(rng, loc, scale, size):
loc_bcast = np.broadcast_to(loc, size)
scale_bcast = np.broadcast_to(scale, size)
bcast_samples = np.empty(size, dtype=np.float64)
for idx in np.ndindex(size):
bcast_samples[idx] = rng.normal(loc_bcast[idx], scale_bcast[idx])
return bcast_samples
rng = np.random.default_rng(1)
loc = np.zeros((5, 5))
scale = 1.0
size = (10, 5, 5)
samples_np = rng.normal(loc=loc, scale=scale, size=size)
rng = np.random.default_rng(1)
samples_nb = draw_nb(rng, loc, scale, size)
np.allclose(samples_np, samples_nb) #True I guess it just seems like all the special packing/unpacking of the random state that is done in the numba linker can just go. But is that the only thing causing problems? Sure I'll open a PR. I'm also curious why this all causes problems only with certain models (like scan) but not in others. |
Can you write that function in a way that works for any RV Op regardless of the number of parameters and ndim_supp/ndims_params? I imagine that's the abstraction that complicates things. |
The boxing and unboxing of RandomState is supposed to go out completely yes |
No, you have to write overloads on a case-by-case basis. But that should be a hassle, not a blocker. Also it already exists. The current code jumps through a lot of hoops to make a function that generates generalized boilerplate code. I guess it's elegant, but it makes it really hard to read. A bunch of individual functions would be easier to maintain IMO. |
It's also much more error prone though. For instance this is what the full function for NormalRV may look like (functionality wise): def normal_rv_dispatch(op, node):
def normal(rng, size, dtype_code, loc, scale):
if size is not None:
loc = np.broadcast_to(loc, size)
sigma = np.broadcast_to(sigma, size)
else:
loc, scale = np.broadcast_arrays(loc, scale)
size = loc.shape
if not op.inplace
rng = copy(rng)
out = np.empty(size, dtype=node.outputs[1].dtype
# populate out
return rng, out
return normal Can we use any helpers to avoid repeating all this boilerplate? Also it gets slightly more tricky with:
|
So I guess there are two issues being discussed here:
As I understand it, only issue (1) is causing problems. My motivation is to be able to use nutpie on any PyMC model -- currently I get an error about how random generators are not supported for some (I think scan based? But I don't have an example off the top of my head). Anyway I'll open a PR that dumps the random streams and see what it breaks when I have some time. |
I also bet numba will complain about the use of size or something and perhaps require we define two functions depending on whether size is provided. And probably size has to be converted to a fixed size tuple. Small things that add to more boilerplate. Then multiply that by all the RV ops we have between pytensor and PyMC. For me unreadable string manipulation doesn't seem too bad in comparison (but I can be persuaded). Maybe we can make it more readable? Regardless, hopefully numba will just do the work and implement overloads properly so we don't have to-.- |
Thanks! |
Description
RandomState are legacy in numpy and we can save some complexity by letting go of them in PyTensor.
We were "obliged" to keep them because that's the only kind of RNG that Numba supported until now
The text was updated successfully, but these errors were encountered: