Skip to content

jax does not work with the multiprocessing "fork" strategy. #1805

Closed
@dchatterjee172

Description

@dchatterjee172
import numpy as np
from functools import partial
import jax.numpy as jnp
from jax import jit, random
from collections import namedtuple
from multiprocessing import Process, Queue
from pickle import dumps


class Brain(namedtuple("Brain", ("w1", "b1", "w2", "b2"))):
    def __sub__(self, other):
        return Brain(
            w1=self.w1 - other.w1,
            b1=self.b1 - other.b1,
            w2=self.w2 - other.w2,
            b2=self.b2 - other.b2,
        )

    def __mul__(self, scalar):
        return Brain(
            w1=self.w1 * scalar,
            b1=self.b1 * scalar,
            w2=self.w2 * scalar,
            b2=self.b2 * scalar,
        )

    __rmul__ = __mul__


def get_brain(
    input_size: int, hidden_size: int, output_size: int, max_memory: int, seed: int
):
    key = random.PRNGKey(seed)
    w1 = random.truncated_normal(
        key, lower=0, upper=0.1, shape=(input_size, hidden_size)
    )
    w2 = random.truncated_normal(
        key, lower=0, upper=0.1, shape=(hidden_size, output_size)
    )
    b1 = jnp.zeros(shape=(hidden_size,))
    b2 = jnp.zeros(shape=(output_size,))
    return Brain(w1=w1, b1=b1, w2=w2, b2=b2)


@jit
def forward(brain: Brain, data: np.ndarray):
    o1 = jnp.matmul(data, brain.w1) + brain.b1
    a1 = jnp.tanh(o1)
    o2 = jnp.matmul(a1, brain.w2) + brain.b2
    a2 = o2 - jnp.expand_dims(jnp.log(jnp.exp(o2).sum(axis=1)), 1)
    return a2


def worker(queue):
    import jax.numpy as jnp
    from jax import grad, jit

    @jit
    def forward(brain: Brain, data: np.ndarray):
        o1 = jnp.matmul(data, brain.w1) + brain.b1
        a1 = jnp.tanh(o1)
        o2 = jnp.matmul(a1, brain.w2) + brain.b2
        a2 = o2 - jnp.expand_dims(jnp.log(jnp.exp(o2).sum(axis=1)), 1)
        return a2

    @jit
    def loss(brain: Brain, data: np.ndarray, labels: np.ndarray):
        pred = forward(brain, data)
        loss = jnp.mean(-(labels * pred).sum(1))
        return loss

    @jit
    def grad_loss(brain: Brain, data: np.ndarray, labels: np.ndarray):
        return partial(grad(loss), data=data, labels=labels)(brain)

    @jit
    def sgd(brain: Brain, data: np.ndarray, labels: np.ndarray, learning_rate: float):
        g = grad_loss(brain, data, labels)
        brain = brain - g * learning_rate
        return brain

    while True:
        brain, data, label, epoch, learning_rate = queue.get()
        for i in range(epoch):
            brain = sgd(brain, data, labels, learning_rate)
        break # if multiprocess the control flow dows not even come here, nothing get's returned from forward


if __name__ == "__main__":
    brain = get_brain(100, 200, 9, 1000, 1)
    data = np.random.normal(size=(1000, 100))
    labels = np.random.uniform(0, 1, size=(1000, 9))
    queue = Queue(10)
    workers = []
    for i in range(2):
        p = Process(target=worker, args=(queue,))
        p.start()  # does not work
        workers.append(p)
    for i in range(10):
        queue.put((brain, data, labels, 1000, 1))
    worker(queue)  # works

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionQuestions for the JAX team

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions