Skip to content

Kernels

This guide covers kernel declaration, expressions, control flow, memory, builtins, and rendering/binding.

Declaring kernels and functions

import cheetah.api as ch
from cheetah.api import ty

@ch.kernel(ch.Params(n=ty.u32, x=ty.ptr_mut(ty.i32)))
def mul2(n, x):
    ...

@ch.fn(
    ch.Params(a=ty.i32, b=ty.i32),
    device=True,  # __device__
    host=False,
    force_inline=True,
    ret=ty.i32,
)
def add(a, b):
    return a + b

@ch.kernel(ch.Params(x=ty.tensor_mut(ty.f32), y=ty.tensor_const(ty.f16)))
def tensor_kernel(x, y):
    ...

Expressions and pointer ops

ptr = ch.alloc(ty.i32)
ptr.val = 42                 # store
val = ptr.val                # load
ptr2 = ptr.offset(3)         # pointer arithmetic
ptr2[0] = 7                  # index sugar for load/store

Casts

f = ch.const(1.5, ty.f32)
i = f.cast(ty.i32)           # primitive cast
vp = ch.const(0, ty.ptr_const(None))  # void*

Control flow

with ch.if_(cond):
    ...

res = ch.cond(cond, then=lambda: a + 1, else_=lambda: b - 1)

with ch.if_else(cond) as h:
    with h.then():
        ...
    with h.else_():
        ...

with ch.for_(0, n, 1, unroll=True) as i:
    ...

ch.ret()

Raw code and inline PTX

ch.raw_stmt("// marker: $0", expr)
r = ch.raw_expr("($a + $b)", ty.i32, a=a, b=b)

(out.val,) = ch.asm("add.u32 $0, $a, $b;", ty.u32, a=a, b=b)

Builtins

tx = ch.thread_idx_x()
ty = ch.thread_idx_y()
bx = ch.block_idx_x()
bdx = ch.block_dim_x()

Memory

buf = ch.alloc(ty.i32)
buf_arr = ch.alloc_array(ty.f32, 128)
sh = ch.alloc_shared(ty.u32)
extern_sh = ch.alloc_extern_shared(ty.i32)

t = ch.alloc_tensor(ty.f32, 16)

Rendering and binding

# Prototypes + definitions (no binding)
src = ch.render(mul2, headers=["cuda_runtime.h"]) 

# Torch binding via host wrapper + device kernel
@ch.fn(ch.Params(n=ty.u32, x=ty.tensor_mut(ty.i32)), host=True)
def mul2(n, x):
    ptr = x.cast(ty.ptr_mut(ty.i32))
    grid = (n + ch.const(255, ty.u32)) // ch.const(256, ty.u32)
    block = ch.const(256, ty.u32)
    ch.raw_stmt("mul2_kernel<<<$0, $1>>>($2, $3);", grid, block, n, ptr)

src2 = ch.render([
    ("mul2", mul2),            # bound host wrapper (torch::Tensor param)
    ("mul2_kernel", mul2),     # device kernel definition
], headers=["cuda_runtime.h", "torch/extension.h"], bind=True, bind_name="mul2")