Source code for mandelbrot_benchmark.backends.warp
from typing import Any
import torch
import warp as wp
from array_api_compat import array_namespace
@wp.func
def _mandelbrot_func(c: wp.vec2f) -> wp.int32:
counter = wp.int32(200)
z = type(c)()
for i in range(200):
z = wp.vec2f(z[0] * z[0] - z[1] * z[1], 2.0 * z[0] * z[1]) + c
if z[0] * z[0] + z[1] * z[1] >= 4.0:
counter = i
break
return counter
@wp.func # type: ignore
def _mandelbrot_func(c: wp.vec2d) -> wp.int32:
counter = wp.int32(200)
z = type(c)()
for i in range(200):
z = wp.vec2d(z[0] * z[0] - z[1] * z[1], wp.float64(2.0) * z[0] * z[1]) + c
if z[0] * z[0] + z[1] * z[1] >= wp.float64(4.0):
counter = i
break
return counter
@wp.kernel
def _mandelbrot_kernel(
c: wp.array2d(dtype=Any), # type: ignore
out: wp.array2d(dtype=wp.int32), # type: ignore
) -> None:
i, j = wp.tid()
out[i, j] = _mandelbrot_func(c[i, j])
@wp.overload # type: ignore
def _mandelbrot_kernel(
c: wp.array2d(dtype=wp.vec2f), # type: ignore
out: wp.array2d(dtype=wp.int32), # type: ignore
) -> None: ...
@wp.overload # type: ignore
def _mandelbrot_kernel(
c: wp.array2d(dtype=wp.vec2d), # type: ignore
out: wp.array2d(dtype=wp.int32), # type: ignore
) -> None: ...
[docs]
def mandelbrot_warp(c: torch.Tensor) -> torch.Tensor:
"""
Warp implementation of the Mandelbrot set.
Warp converts external arrays to its own format zero-copy,
therefore we do not need to worry about it.
See Also
--------
https://nvidia.github.io/warp/modules/interoperability.html
"""
if "cuda" in str(c.device):
device = "cuda"
else:
device = "cpu"
xp = array_namespace(c)
out = wp.empty(shape=c.shape, dtype=wp.int32, device=device)
c = xp.stack([c.real, c.imag], axis=-1)
field = wp.array(
c, dtype=wp.vec2d if c.dtype == torch.float64 else wp.vec2f, device=device
)
wp.launch(
kernel=_mandelbrot_kernel,
dim=c.shape,
inputs=[field],
outputs=[out],
device=device,
)
return wp.to_torch(out, requires_grad=False)