importjaximportjax.numpyasjnpimporttorchfromdlpackimportasdlpack@jax.jitdef_mandelbrot_jax(c:jnp.ndarray)->jnp.ndarray:"""JAX implementation of the Mandelbrot set."""counter=jnp.full(c.shape[0],200,dtype=jnp.int32)z=jnp.zeros_like(c,dtype=c.dtype)foriinrange(200):z=z**2+ccounter=jnp.where(jnp.abs(z)>2&(counter==200),i,counter)returncounter
[docs]defmandelbrot_jax(c:torch.Tensor)->torch.Tensor:""" JAX implementation of the Mandelbrot set. Parameters ---------- c : torch.Tensor Input array of complex numbers. Returns ------- torch.Tensor Output array of integers representing the Mandelbrot set. """# https://github.com/jax-ml/jax/issues/1100c=jnp.from_dlpack(asdlpack(c))out=_mandelbrot_jax(c)returntorch.from_dlpack(asdlpack(out))