Source code for mandelbrot_benchmark.backends.taichi

from typing import Any

import taichi as ti
import taichi.math as tm
from array_api_compat import array_namespace


@ti.func
def _mandelbrot_func(c: tm.vec2) -> ti.i32:
    counter = ti.i32(200)
    z = tm.vec2(0, 0)
    for i in range(200):
        z = tm.cpow(z, 2) + c
        if z.x**2 + z.y**2 >= 4.0:
            counter = i
            break
    return counter


@ti.kernel
def _mandelbrot_kernel(c: ti.types.ndarray(), out: ti.types.ndarray()):  # type: ignore
    for I in ti.grouped(out):
        out[I] = _mandelbrot_func(tm.vec2(c[I, 0], c[I, 1]))


[docs] def mandebrot_taichi(c: Any) -> Any: """ Taichi implementation of the Mandelbrot set. Since Taichi does not support complex numbers directly, internally the input is stacked as a +1D array with real and imaginary parts. Taichi's from_numpy() and to_numpy() are !!NOT!! zero-copy, so we pass non-Taichi arrays directly to the kernel. (See "Note" in https://docs.taichi-lang.org/docs/external) See Also -------- https://docs.taichi-lang.org/docs/external """ xp = array_namespace(c) out = xp.empty(c.shape, dtype=xp.int32, device=c.device) c = xp.stack([c.real, c.imag], axis=-1) _mandelbrot_kernel(c, out) return out