Getting started
Requirements
- Python 3.11+ recommended (project is tested on 3.11/3.12)
- CUDA toolchain only required if you compile generated code
Install
pip install -e .
Optional extras:
# tests
pip install -e .[test]
# docs
pip install -e .[docs]
Run tests
pytest
# update golden files when generator output intentionally changes
pytest --update-golden
Build & view docs
# serve locally with live-reload
mkdocs serve
# or build the static site
mkdocs build
Minimal example
import cheetah.api as ch
from cheetah.api import ty
@ch.kernel(ch.Params(n=ty.u32, x=ty.ptr_mut(ty.i32)))
def mul2(n, x):
i = ch.block_idx_x() * ch.block_dim_x() + ch.thread_idx_x()
with ch.if_(i < n):
x[i] *= 2
print(ch.render(mul2))
PyTorch binding example
from pathlib import Path
from torch.utils.cpp_extension import load
import cheetah.api as ch
from cheetah.api import ty
@ch.kernel(ch.Params(n=ty.u32, x=ty.ptr_mut(ty.i32)))
def mul2_kernel(n, x):
i = ch.block_idx_x() * ch.block_dim_x() + ch.thread_idx_x()
with ch.if_(i < n):
x[i] *= 2
@ch.fn(ch.Params(n=ty.u32, x=ty.tensor_mut(ty.i32)), host=True)
def mul2(n, x):
ptr = x.cast(ty.ptr_mut(ty.i32))
grid = (n + ch.const(255, ty.u32)) // ch.const(256, ty.u32)
block = ch.const(256, ty.u32)
ch.raw_stmt("mul2_kernel<<<$0, $1>>>($2, $3);", grid, block, n, ptr)
code = ch.render([
("mul2", mul2),
("mul2_kernel", mul2_kernel),
], headers=["cuda_runtime.h", "torch/extension.h"], bind=True, bind_name="mul2")
cu = Path("/tmp/cheetah_mul2.cu"); cu.write_text(code)
module = load(name="cheetah_mul2", sources=[str(cu)], verbose=True)
# module.mul2(n: int, x: torch.Tensor[int32, cuda]) now available
Tip: if no GPUs are visible, set TORCH_CUDA_ARCH_LIST
(e.g., 8.0
) before building. To run the repo's Torch extension test, set CHEETAH_ENABLE_TORCH_TEST=1
.