diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index a81d34d5533..b1ae52e0301 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -35,6 +35,9 @@ from pymc.backends.arviz import dict_to_dataset, to_inference_data from pymc.backends.base import MultiTrace +from pymc.distributions.custom import CustomDistRV, CustomSymbolicDistRV +from pymc.distributions.distribution import _support_point +from pymc.logprob.abstract import _icdf, _logcdf, _logprob from pymc.model import Model, modelcontext from pymc.sampling.parallel import _cpu_count from pymc.smc.kernels import IMH @@ -346,11 +349,18 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores): # main process and our worker functions _progress = manager.dict() + # check if model contains CustomDistributions defined without dist argument + custom_methods = _find_custom_dist_dispatch_methods(params[3]) + # "manually" (de)serialize params before/after multiprocessing params = tuple(cloudpickle.dumps(p) for p in params) kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()} - with ProcessPoolExecutor(max_workers=cores) as executor: + with ProcessPoolExecutor( + max_workers=cores, + initializer=_register_custom_methods, + initargs=(custom_methods,), + ) as executor: for c in range(chains): # iterate over the jobs we need to run # set visible false so we don't have a lot of bars all at once: task_id = progress.add_task(f"Chain {c}", status="Stage: 0 Beta: 0") @@ -383,3 +393,32 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores): ) return tuple(cloudpickle.loads(r.result()) for r in done) + + +def _find_custom_dist_dispatch_methods(model): + custom_methods = {} + for rv in model.basic_RVs: + rv_type = rv.owner.op + cls = type(rv_type) + if isinstance(rv_type, CustomDistRV | CustomSymbolicDistRV): + custom_methods[cloudpickle.dumps(cls)] = ( + cloudpickle.dumps(_logprob.registry.get(cls, None)), + cloudpickle.dumps(_logcdf.registry.get(cls, None)), + cloudpickle.dumps(_icdf.registry.get(cls, None)), + cloudpickle.dumps(_support_point.registry.get(cls, None)), + ) + + return custom_methods + + +def _register_custom_methods(custom_methods): + for cls, (logprob, logcdf, icdf, support_point) in custom_methods.items(): + cls = cloudpickle.loads(cls) + if logprob is not None: + _logprob.register(cls, cloudpickle.loads(logprob)) + if logcdf is not None: + _logcdf.register(cls, cloudpickle.loads(logcdf)) + if icdf is not None: + _icdf.register(cls, cloudpickle.loads(icdf)) + if support_point is not None: + _support_point.register(cls, cloudpickle.loads(support_point)) diff --git a/tests/smc/test_smc.py b/tests/smc/test_smc.py index 3aa687459ee..09ba48dc7d2 100644 --- a/tests/smc/test_smc.py +++ b/tests/smc/test_smc.py @@ -134,6 +134,21 @@ def test_unobserved_categorical(self): assert np.all(np.median(trace["mu"], axis=0) == [1, 2]) + def test_parallel_custom(self): + def _logp(value, mu): + return -((value - mu) ** 2) + + def _random(mu, rng=None, size=None): + return rng.normal(loc=mu, scale=1, size=size) + + def _dist(mu, size=None): + return pm.Normal.dist(mu, 1, size=size) + + with pm.Model(): + mu = pm.CustomDist("mu", 0, logp=_logp, dist=_dist) + pm.CustomDist("y", mu, logp=_logp, class_name="", random=_random, observed=[1, 2]) + pm.sample_smc(draws=6, cores=2) + def test_marginal_likelihood(self): """ Verifies that the log marginal likelihood function