API Cookbook
Practical instructions and examples for the main public APIs.
cheetah.api
Types (ty
)
- Primitives:
ty.bool
,ty.i8
,ty.i16
,ty.i32
,ty.i64
,ty.u8
,ty.u16
,ty.u32
,ty.u64
,ty.bf16
,ty.f16
,ty.f32
,ty.f64
(plus vector pairs likety.f32x2
if supported) - Pointers:
ty.ptr_const(elem)
,ty.ptr_mut(elem)
,ty.ptr_volatile(elem)
;elem=None
forvoid*
- Tensors (PyTorch):
ty.tensor_const(elem)
,ty.tensor_mut(elem)
Parameters and function declaration
import cheetah.api as ch
from cheetah.api import ty
params = ch.Params(n=ty.u32, x=ty.ptr_mut(ty.i32))
@ch.kernel(params) # __global__
def k(n, x):
...
@ch.fn(ch.Params(a=ty.i32, b=ty.i32), device=True, ret=ty.i32)
def add(a, b):
return a + b
Constants and naming
c = ch.const(42, ty.i32)
ch.set_name("counter", c) # propagates into codegen names when possible
Expressions and pointer operations
p = ch.alloc(ty.i32)
p.val = 7 # store
v = p.val # load
p2 = p.offset(4) # pointer arithmetic
p2[0] = v + 1 # index sugar for load/store
f = ch.const(1.5, ty.f32)
i = f.cast(ty.i32) # primitive cast
Control flow
- If:
with ch.if_(cond):
...
- If/Else returning a value:
x = ch.cond(cond, then=lambda: a, else_=lambda: b)
- Structured if/else blocks:
with ch.if_else(cond) as h:
with h.then():
...
with h.else_():
...
- For loops:
with ch.for_(0, n, 1, unroll=True) as i:
...
- Return:
ch.ret() # for void
ch.ret(expr) # for value-returning functions
Logical helpers
x = ch.and_(a > 0, b < 10, c == 3)
y = ch.or_(flag1, flag2)
z = ch.not_(ok)
Memory utilities
ptr = ch.alloc(ty.i32)
arr = ch.alloc_array(ty.f32, 128)
sh = ch.alloc_shared(ty.u32)
sh_arr = ch.alloc_shared_array(ty.i32, 256)
extern_sh = ch.alloc_extern_shared(ty.i32)
t = ch.alloc_tensor(ty.f32, 16) # torch::Tensor
Raw code and inline PTX
ch.raw_stmt("// marker: $0", expr)
r = ch.raw_expr("($a + $b)", ty.i32, a=a, b=b)
(out.val,) = ch.asm("add.u32 $0, $a, $b;", ty.u32, a=a, b=b)
# volatile and clobber
v = ch.asm("mov.u32 $0, %clock;", ch.AsmVolatile(), ty.u32)
ch.asm("st.u32 [$0], $1;", ch.AsmMemoryClobber(), out, a)
CUDA builtins
tx = ch.thread_idx_x(); ty_ = ch.thread_idx_y(); tz = ch.thread_idx_z()
bx = ch.block_idx_x(); by = ch.block_idx_y(); bz = ch.block_idx_z()
bdx = ch.block_dim_x(); bdy = ch.block_dim_y(); bdz = ch.block_dim_z()
Rendering
src = ch.render(k) # or ("export_name", k)
src = ch.render([k1, ("name2", k2)], headers=["cuda_runtime.h"]) # optional headers
# Torch extension (pybind11): host wrapper + device kernel
@ch.kernel(ch.Params(n=ty.u32, x=ty.ptr_mut(ty.i32)))
def mul2_kernel(n, x):
...
@ch.fn(ch.Params(n=ty.u32, x=ty.tensor_mut(ty.i32)), host=True)
def mul2(n, x):
ptr = x.cast(ty.ptr_mut(ty.i32))
ch.raw_stmt("mul2_kernel<<<$0, $1>>>($2, $3);", grid, block, n, ptr)
src = ch.render([
("mul2", mul2),
("mul2_kernel", mul2_kernel),
], headers=["cuda_runtime.h", "torch/extension.h"], bind=True, bind_name="mul2")
cheetah.index_tools
Creating dimensions and equations
import cheetah.index_tools as ixt
dims = ixt.Dims()
N = dims.new_dim("N", 1 << 20) # known size
B = dims.new_dim("B") # unknown size (inferred)
T = dims.new_dim("T", 1024) # known size
# multiplicative equation: N = B * T
dims.eq(N, (B, T))
# include strides: X * 4 = Y
X = dims.new_dim("X"); Y = dims.new_dim("Y")
dims.eq((X, 4), Y)
Scopes and attaching equations
s0 = dims.new_scope("tile")
with dims.scope(s0):
dims.eq(N, (B, T))
s1 = dims.new_scope("swap")
with dims.scope(s1):
dims.eq(N, (T, B))
Solving sizes and generating indices
ix = dims.init() # solves sizes and prepares runtime index state
# set some indices from CUDA builtins
ix.set_index(T, ch.thread_idx_x())
ix.set_index(B, ch.block_idx_x())
# access linearized index for a dimension
addr = ix[N]
Using scopes at runtime
with ix.scope(s0):
ch.raw_stmt("// tile: $0", ix[N])
with ix.scope(s1):
ch.raw_stmt("// swap: $0", ix[N])
Looping and vectorizing
with ix.loop(T, unroll=True):
dst[ix[N]] = src[ix[N]]
with ix.vectorize(T, size=4):
dst[ix[N]] = src[ix[N]]
Diagnostics
print(ix.why_partial(B)) # explains constraining equation
print(ix.why_solved(T)) # explains where T came from
Notes
- All
DimName
s must come from the sameDims
instance. - Equations must lead to integral sizes.
- Indices may be unresolved until sufficient upstream indices are set and the right scope is active.