Closed
Description
import numpy as np
from functools import partial
import jax.numpy as jnp
from jax import jit, random
from collections import namedtuple
from multiprocessing import Process, Queue
from pickle import dumps
class Brain(namedtuple("Brain", ("w1", "b1", "w2", "b2"))):
def __sub__(self, other):
return Brain(
w1=self.w1 - other.w1,
b1=self.b1 - other.b1,
w2=self.w2 - other.w2,
b2=self.b2 - other.b2,
)
def __mul__(self, scalar):
return Brain(
w1=self.w1 * scalar,
b1=self.b1 * scalar,
w2=self.w2 * scalar,
b2=self.b2 * scalar,
)
__rmul__ = __mul__
def get_brain(
input_size: int, hidden_size: int, output_size: int, max_memory: int, seed: int
):
key = random.PRNGKey(seed)
w1 = random.truncated_normal(
key, lower=0, upper=0.1, shape=(input_size, hidden_size)
)
w2 = random.truncated_normal(
key, lower=0, upper=0.1, shape=(hidden_size, output_size)
)
b1 = jnp.zeros(shape=(hidden_size,))
b2 = jnp.zeros(shape=(output_size,))
return Brain(w1=w1, b1=b1, w2=w2, b2=b2)
@jit
def forward(brain: Brain, data: np.ndarray):
o1 = jnp.matmul(data, brain.w1) + brain.b1
a1 = jnp.tanh(o1)
o2 = jnp.matmul(a1, brain.w2) + brain.b2
a2 = o2 - jnp.expand_dims(jnp.log(jnp.exp(o2).sum(axis=1)), 1)
return a2
def worker(queue):
import jax.numpy as jnp
from jax import grad, jit
@jit
def forward(brain: Brain, data: np.ndarray):
o1 = jnp.matmul(data, brain.w1) + brain.b1
a1 = jnp.tanh(o1)
o2 = jnp.matmul(a1, brain.w2) + brain.b2
a2 = o2 - jnp.expand_dims(jnp.log(jnp.exp(o2).sum(axis=1)), 1)
return a2
@jit
def loss(brain: Brain, data: np.ndarray, labels: np.ndarray):
pred = forward(brain, data)
loss = jnp.mean(-(labels * pred).sum(1))
return loss
@jit
def grad_loss(brain: Brain, data: np.ndarray, labels: np.ndarray):
return partial(grad(loss), data=data, labels=labels)(brain)
@jit
def sgd(brain: Brain, data: np.ndarray, labels: np.ndarray, learning_rate: float):
g = grad_loss(brain, data, labels)
brain = brain - g * learning_rate
return brain
while True:
brain, data, label, epoch, learning_rate = queue.get()
for i in range(epoch):
brain = sgd(brain, data, labels, learning_rate)
break # if multiprocess the control flow dows not even come here, nothing get's returned from forward
if __name__ == "__main__":
brain = get_brain(100, 200, 9, 1000, 1)
data = np.random.normal(size=(1000, 100))
labels = np.random.uniform(0, 1, size=(1000, 9))
queue = Queue(10)
workers = []
for i in range(2):
p = Process(target=worker, args=(queue,))
p.start() # does not work
workers.append(p)
for i in range(10):
queue.put((brain, data, labels, 1000, 1))
worker(queue) # works