Skip to content

Getting started

Requirements

  • Python 3.11+ recommended (project is tested on 3.11/3.12)
  • CUDA toolchain only required if you compile generated code

Install

pip install -e .

Optional extras:

# tests
pip install -e .[test]
# docs
pip install -e .[docs]

Run tests

pytest
# update golden files when generator output intentionally changes
pytest --update-golden

Build & view docs

# serve locally with live-reload
mkdocs serve
# or build the static site
mkdocs build

Minimal example

import cheetah.api as ch
from cheetah.api import ty

@ch.kernel(ch.Params(n=ty.u32, x=ty.ptr_mut(ty.i32)))
def mul2(n, x):
    i = ch.block_idx_x() * ch.block_dim_x() + ch.thread_idx_x()
    with ch.if_(i < n):
        x[i] *= 2

print(ch.render(mul2))

PyTorch binding example

from pathlib import Path
from torch.utils.cpp_extension import load
import cheetah.api as ch
from cheetah.api import ty

@ch.kernel(ch.Params(n=ty.u32, x=ty.ptr_mut(ty.i32)))
def mul2_kernel(n, x):
    i = ch.block_idx_x() * ch.block_dim_x() + ch.thread_idx_x()
    with ch.if_(i < n):
        x[i] *= 2

@ch.fn(ch.Params(n=ty.u32, x=ty.tensor_mut(ty.i32)), host=True)
def mul2(n, x):
    ptr = x.cast(ty.ptr_mut(ty.i32))
    grid = (n + ch.const(255, ty.u32)) // ch.const(256, ty.u32)
    block = ch.const(256, ty.u32)
    ch.raw_stmt("mul2_kernel<<<$0, $1>>>($2, $3);", grid, block, n, ptr)

code = ch.render([
    ("mul2", mul2),
    ("mul2_kernel", mul2_kernel),
], headers=["cuda_runtime.h", "torch/extension.h"], bind=True, bind_name="mul2")
cu = Path("/tmp/cheetah_mul2.cu"); cu.write_text(code)
module = load(name="cheetah_mul2", sources=[str(cu)], verbose=True)
# module.mul2(n: int, x: torch.Tensor[int32, cuda]) now available

Tip: if no GPUs are visible, set TORCH_CUDA_ARCH_LIST (e.g., 8.0) before building. To run the repo's Torch extension test, set CHEETAH_ENABLE_TORCH_TEST=1.