Source code for mandelbrot_benchmark.backends.jax

import jax
import jax.numpy as jnp
import torch
from dlpack import asdlpack


@jax.jit
def _mandelbrot_jax(c: jnp.ndarray) -> jnp.ndarray:
    """JAX implementation of the Mandelbrot set."""
    counter = jnp.full(c.shape[0], 200, dtype=jnp.int32)
    z = jnp.zeros_like(c, dtype=c.dtype)
    for i in range(200):
        z = z**2 + c
        counter = jnp.where(jnp.abs(z) > 2 & (counter == 200), i, counter)
    return counter


[docs] def mandelbrot_jax(c: torch.Tensor) -> torch.Tensor: """ JAX implementation of the Mandelbrot set. Parameters ---------- c : torch.Tensor Input array of complex numbers. Returns ------- torch.Tensor Output array of integers representing the Mandelbrot set. """ # https://github.com/jax-ml/jax/issues/1100 c = jnp.from_dlpack(asdlpack(c)) out = _mandelbrot_jax(c) return torch.from_dlpack(asdlpack(out))