Source code for mandelbrot_benchmark.backends.numba
from typing import Any
import numba
from numba.cuda import as_cuda_array
from numba.cuda.cudadrv.error import CudaSupportError
def _mandelbrot_32(c: Any) -> Any:
"""Pure Python implementation of the Mandelbrot set."""
counter = numba.int32(200)
z = numba.complex64(0)
for i in range(200):
z = z * z + c
if z.real**2 + z.imag**2 >= 4:
counter = i
break
return counter
_mandelbrot_numba = numba.vectorize(
[numba.int32(numba.complex64)], target="parallel", nopython=True, fastmath=True
)(_mandelbrot_32)
try:
_mandelbrot_numba_cuda = numba.vectorize(
[numba.int32(numba.complex64)],
target="cuda",
)(_mandelbrot_32)
except CudaSupportError:
_mandelbrot_numba_cuda = None
[docs]
def mandelbrot_numba(c: Any) -> Any:
"""
Numba implementation of the Mandelbrot set.
Parameters
----------
c : Any
Input array of complex numbers.
Returns
-------
Any
Output array of integers representing the Mandelbrot set.
"""
if "cuda" in str(c.device):
c = as_cuda_array(c)
return _mandelbrot_numba_cuda(c)
else:
return _mandelbrot_numba(c)