Jupyter notebooks and multiprocessing - How to spawn and fork(server)

Today I came up with a clever solution for a dumb problem. When using multiprocessing or concurrent.futures in a Jupyter notebook one is generally limited to the fork start method. This is because the spawn and the forkserver methods require that the target function is defined in an importable module. If you are working with Jupyter notebooks the target function very likely resides in the notebook itself, which is not importable. So, how do we make it importable?

Easy, we just get its source code, write it into a temporary file, add its directory to sys.path, and import it dynamically! The code below does all of that inside of a contextmanager.

import inspect
import tempfile
import sys

from contextlib import contextmanager
from importlib import import_module, invalidate_caches
from pathlib import Path


@contextmanager
def enable_spawn(func):
    invalidate_caches()
    source = inspect.getsource(func)
    with tempfile.NamedTemporaryFile(suffix=".py", mode="w") as f:
        f.write(source)
        f.flush()
        path = Path(f.name)
        directory = str(path.parent)
        sys.path.append(directory)
        module = import_module(str(path.stem))
        yield getattr(module, func.__name__)
        sys.path.remove(directory)

Now for a demo of how it actually works. The main limitation is that only the body of the target function is written to the file, meaning that all of its imports have to be contained in it.

from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing

def target_func(arg):
    from math import sqrt
    return arg, sqrt(arg)

context = multiprocessing.get_context("spawn")
with ProcessPoolExecutor(max_workers=10, mp_context=context) as pool, enable_spawn(target_func) as target:
    futures = [pool.submit(target, i) for i in range(20)]
    for future in as_completed(futures):
        print(future.result())
(0, 0.0)
(1, 1.0)
(2, 1.4142135623730951)
(3, 1.7320508075688772)
(4, 2.0)
(5, 2.23606797749979)
(6, 2.449489742783178)
(7, 2.6457513110645907)
(8, 2.8284271247461903)
(9, 3.0)
(10, 3.1622776601683795)
(11, 3.3166247903554)
(12, 3.4641016151377544)
(14, 3.7416573867739413)
(15, 3.872983346207417)
(16, 4.0)
(17, 4.123105625617661)
(18, 4.242640687119285)
(19, 4.358898943540674)

Doing this might not (always) be a good idea, as the other start methods have their own properties, so ymmv.