Kernels
This guide covers kernel declaration, expressions, control flow, memory, builtins, and rendering/binding.
Declaring kernels and functions
import cheetah.api as ch
from cheetah.api import ty
@ch.kernel(ch.Params(n=ty.u32, x=ty.ptr_mut(ty.i32)))
def mul2(n, x):
...
@ch.fn(
ch.Params(a=ty.i32, b=ty.i32),
device=True, # __device__
host=False,
force_inline=True,
ret=ty.i32,
)
def add(a, b):
return a + b
@ch.kernel(ch.Params(x=ty.tensor_mut(ty.f32), y=ty.tensor_const(ty.f16)))
def tensor_kernel(x, y):
...
Expressions and pointer ops
ptr = ch.alloc(ty.i32)
ptr.val = 42 # store
val = ptr.val # load
ptr2 = ptr.offset(3) # pointer arithmetic
ptr2[0] = 7 # index sugar for load/store
Casts
f = ch.const(1.5, ty.f32)
i = f.cast(ty.i32) # primitive cast
vp = ch.const(0, ty.ptr_const(None)) # void*
Control flow
with ch.if_(cond):
...
res = ch.cond(cond, then=lambda: a + 1, else_=lambda: b - 1)
with ch.if_else(cond) as h:
with h.then():
...
with h.else_():
...
with ch.for_(0, n, 1, unroll=True) as i:
...
ch.ret()
Raw code and inline PTX
ch.raw_stmt("// marker: $0", expr)
r = ch.raw_expr("($a + $b)", ty.i32, a=a, b=b)
(out.val,) = ch.asm("add.u32 $0, $a, $b;", ty.u32, a=a, b=b)
Builtins
tx = ch.thread_idx_x()
ty = ch.thread_idx_y()
bx = ch.block_idx_x()
bdx = ch.block_dim_x()
Memory
buf = ch.alloc(ty.i32)
buf_arr = ch.alloc_array(ty.f32, 128)
sh = ch.alloc_shared(ty.u32)
extern_sh = ch.alloc_extern_shared(ty.i32)
t = ch.alloc_tensor(ty.f32, 16)
Rendering and binding
# Prototypes + definitions (no binding)
src = ch.render(mul2, headers=["cuda_runtime.h"])
# Torch binding via host wrapper + device kernel
@ch.fn(ch.Params(n=ty.u32, x=ty.tensor_mut(ty.i32)), host=True)
def mul2(n, x):
ptr = x.cast(ty.ptr_mut(ty.i32))
grid = (n + ch.const(255, ty.u32)) // ch.const(256, ty.u32)
block = ch.const(256, ty.u32)
ch.raw_stmt("mul2_kernel<<<$0, $1>>>($2, $3);", grid, block, n, ptr)
src2 = ch.render([
("mul2", mul2), # bound host wrapper (torch::Tensor param)
("mul2_kernel", mul2), # device kernel definition
], headers=["cuda_runtime.h", "torch/extension.h"], bind=True, bind_name="mul2")