Source code for mandelbrot_benchmark.backends.torch

from typing import Any

from array_api_compat import array_namespace


# @torch.compile
[docs] def mandelbrot_torch(c: Any) -> Any: """Pure Python implementation of the Mandelbrot set.""" xp = array_namespace(c) counter = xp.full(c.shape, 200, dtype=xp.int32, device=c.device) z = xp.zeros_like(c) for i in range(200): z = z * z + c idx = (xp.abs(z) > 2) & (counter == 200) counter[idx] = i return counter