Skip to content

API Cookbook

Practical instructions and examples for the main public APIs.

cheetah.api

Types (ty)

  • Primitives: ty.bool, ty.i8, ty.i16, ty.i32, ty.i64, ty.u8, ty.u16, ty.u32, ty.u64, ty.bf16, ty.f16, ty.f32, ty.f64 (plus vector pairs like ty.f32x2 if supported)
  • Pointers: ty.ptr_const(elem), ty.ptr_mut(elem), ty.ptr_volatile(elem); elem=None for void*
  • Tensors (PyTorch): ty.tensor_const(elem), ty.tensor_mut(elem)

Parameters and function declaration

import cheetah.api as ch
from cheetah.api import ty

params = ch.Params(n=ty.u32, x=ty.ptr_mut(ty.i32))

@ch.kernel(params)  # __global__
def k(n, x):
    ...

@ch.fn(ch.Params(a=ty.i32, b=ty.i32), device=True, ret=ty.i32)
def add(a, b):
    return a + b

Constants and naming

c = ch.const(42, ty.i32)
ch.set_name("counter", c)  # propagates into codegen names when possible

Expressions and pointer operations

p = ch.alloc(ty.i32)
p.val = 7             # store
v = p.val             # load
p2 = p.offset(4)      # pointer arithmetic
p2[0] = v + 1         # index sugar for load/store

f = ch.const(1.5, ty.f32)
i = f.cast(ty.i32)    # primitive cast

Control flow

  • If:
with ch.if_(cond):
    ...
  • If/Else returning a value:
x = ch.cond(cond, then=lambda: a, else_=lambda: b)
  • Structured if/else blocks:
with ch.if_else(cond) as h:
    with h.then():
        ...
    with h.else_():
        ...
  • For loops:
with ch.for_(0, n, 1, unroll=True) as i:
    ...
  • Return:
ch.ret()      # for void
ch.ret(expr)  # for value-returning functions

Logical helpers

x = ch.and_(a > 0, b < 10, c == 3)
y = ch.or_(flag1, flag2)
z = ch.not_(ok)

Memory utilities

ptr = ch.alloc(ty.i32)
arr = ch.alloc_array(ty.f32, 128)
sh = ch.alloc_shared(ty.u32)
sh_arr = ch.alloc_shared_array(ty.i32, 256)
extern_sh = ch.alloc_extern_shared(ty.i32)

t = ch.alloc_tensor(ty.f32, 16)  # torch::Tensor

Raw code and inline PTX

ch.raw_stmt("// marker: $0", expr)
r = ch.raw_expr("($a + $b)", ty.i32, a=a, b=b)

(out.val,) = ch.asm("add.u32 $0, $a, $b;", ty.u32, a=a, b=b)
# volatile and clobber
v = ch.asm("mov.u32 $0, %clock;", ch.AsmVolatile(), ty.u32)
ch.asm("st.u32 [$0], $1;", ch.AsmMemoryClobber(), out, a)

CUDA builtins

tx = ch.thread_idx_x(); ty_ = ch.thread_idx_y(); tz = ch.thread_idx_z()
bx = ch.block_idx_x();  by  = ch.block_idx_y();  bz = ch.block_idx_z()
bdx = ch.block_dim_x(); bdy = ch.block_dim_y();  bdz = ch.block_dim_z()

Rendering

src = ch.render(k)  # or ("export_name", k)
src = ch.render([k1, ("name2", k2)], headers=["cuda_runtime.h"])  # optional headers

# Torch extension (pybind11): host wrapper + device kernel
@ch.kernel(ch.Params(n=ty.u32, x=ty.ptr_mut(ty.i32)))
def mul2_kernel(n, x):
    ...

@ch.fn(ch.Params(n=ty.u32, x=ty.tensor_mut(ty.i32)), host=True)
def mul2(n, x):
    ptr = x.cast(ty.ptr_mut(ty.i32))
    ch.raw_stmt("mul2_kernel<<<$0, $1>>>($2, $3);", grid, block, n, ptr)

src = ch.render([
    ("mul2", mul2),
    ("mul2_kernel", mul2_kernel),
], headers=["cuda_runtime.h", "torch/extension.h"], bind=True, bind_name="mul2")

cheetah.index_tools

Creating dimensions and equations

import cheetah.index_tools as ixt

dims = ixt.Dims()
N = dims.new_dim("N", 1 << 20)  # known size
B = dims.new_dim("B")            # unknown size (inferred)
T = dims.new_dim("T", 1024)      # known size

# multiplicative equation: N = B * T
dims.eq(N, (B, T))

# include strides: X * 4 = Y
X = dims.new_dim("X"); Y = dims.new_dim("Y")
dims.eq((X, 4), Y)

Scopes and attaching equations

s0 = dims.new_scope("tile")
with dims.scope(s0):
    dims.eq(N, (B, T))

s1 = dims.new_scope("swap")
with dims.scope(s1):
    dims.eq(N, (T, B))

Solving sizes and generating indices

ix = dims.init()  # solves sizes and prepares runtime index state

# set some indices from CUDA builtins
ix.set_index(T, ch.thread_idx_x())
ix.set_index(B, ch.block_idx_x())

# access linearized index for a dimension
addr = ix[N]

Using scopes at runtime

with ix.scope(s0):
    ch.raw_stmt("// tile: $0", ix[N])
with ix.scope(s1):
    ch.raw_stmt("// swap: $0", ix[N])

Looping and vectorizing

with ix.loop(T, unroll=True):
    dst[ix[N]] = src[ix[N]]

with ix.vectorize(T, size=4):
    dst[ix[N]] = src[ix[N]]

Diagnostics

print(ix.why_partial(B))  # explains constraining equation
print(ix.why_solved(T))   # explains where T came from

Notes

  • All DimNames must come from the same Dims instance.
  • Equations must lead to integral sizes.
  • Indices may be unresolved until sufficient upstream indices are set and the right scope is active.