-
Notifications
You must be signed in to change notification settings - Fork 151
Closed
Description
Description
JAX inner-most dispatch for RandomVariables: jax_sample_fn, look like
pytensor/pytensor/link/jax/dispatch/random.py
Lines 146 to 172 in 964cccb
| @jax_sample_fn.register(ptr.CauchyRV) | |
| @jax_sample_fn.register(ptr.GumbelRV) | |
| @jax_sample_fn.register(ptr.LaplaceRV) | |
| @jax_sample_fn.register(ptr.LogisticRV) | |
| @jax_sample_fn.register(ptr.NormalRV) | |
| def jax_sample_fn_loc_scale(op, node): | |
| """JAX implementation of random variables in the loc-scale families. | |
| JAX only implements the standard version of random variables in the | |
| loc-scale family. We thus need to translate and rescale the results | |
| manually. | |
| """ | |
| name = op.name | |
| jax_op = getattr(jax.random, name) | |
| def sample_fn(rng, size, dtype, *parameters): | |
| rng_key = rng["jax_state"] | |
| rng_key, sampling_key = jax.random.split(rng_key, 2) | |
| loc, scale = parameters | |
| if size is None: | |
| size = jax.numpy.broadcast_arrays(loc, scale)[0].shape | |
| sample = loc + jax_op(sampling_key, size, dtype) * scale | |
| rng["jax_state"] = rng_key | |
| return (rng, sample) | |
| return sample_fn |
The whole rng logic could be handled on the outermost dispatch jax_funcify_RandomVariable instead:
pytensor/pytensor/link/jax/dispatch/random.py
Lines 104 to 117 in 964cccb
| if None in static_size: | |
| assert_size_argument_jax_compatible(node) | |
| def sample_fn(rng, size, *parameters): | |
| return jax_sample_fn(op, node=node)(rng, size, out_dtype, *parameters) | |
| else: | |
| def sample_fn(rng, size, *parameters): | |
| return jax_sample_fn(op, node=node)( | |
| rng, static_size, out_dtype, *parameters | |
| ) | |
| return sample_fn |
If an implementation needs a split other than 2, they can split the provided rng again anyway.