Torch bindings
This guide shows how to expose generated code to Python as a PyTorch extension using pybind11 via torch.utils.cpp_extension.load
.
Pattern: host wrapper + device kernel
- Device kernel: declared with
@ch.kernel
and pointer types. - Host wrapper: declared with
@ch.fn(..., host=True)
andtorch::Tensor
types usingty.tensor_mut/const
. The wrapper casts the tensor to a pointer and launches the device kernel.
import cheetah.api as ch
from cheetah.api import ty
@ch.kernel(ch.Params(n=ty.u32, x=ty.ptr_mut(ty.i32)))
def mul2_kernel(n, x):
i = ch.block_idx_x() * ch.block_dim_x() + ch.thread_idx_x()
with ch.if_(i < n):
x[i] *= 2
@ch.fn(ch.Params(n=ty.u32, x=ty.tensor_mut(ty.i32)), host=True)
def mul2(n, x):
ptr = x.cast(ty.ptr_mut(ty.i32))
grid = (n + ch.const(255, ty.u32)) // ch.const(256, ty.u32)
block = ch.const(256, ty.u32)
ch.raw_stmt("mul2_kernel<<<$0, $1>>>($2, $3);", grid, block, n, ptr)
# Emit both the bound host wrapper and the device kernel definition
src = ch.render(
[("mul2", mul2), ("mul2_kernel", mul2_kernel)],
headers=["cuda_runtime.h", "torch/extension.h"],
bind=True,
bind_name="mul2",
)
Build and import with PyTorch
from pathlib import Path
from torch.utils.cpp_extension import load
cu = Path("/tmp/cheetah_mul2.cu"); cu.write_text(src)
module = load(name="cheetah_mul2", sources=[str(cu)], verbose=True)
# Call from Python
# module.mul2(n: int, x: torch.Tensor[int32, cuda])
Tensor types
ty.tensor_mut(elem)
/ty.tensor_const(elem)
map totorch::Tensor
with the corresponding dtype.- Use
.cast(ty.ptr_*(elem))
to retrieve the raw data pointer for device launches. - Allocate a tensor inside device code (e.g., for scratch) with
ch.alloc_tensor(elem, count)
.
Tips and troubleshooting
- Ensure the CUDA toolkit (nvcc) is installed; PyTorch finds it via
CUDA_HOME
. If no GPU is present, setTORCH_CUDA_ARCH_LIST
(e.g.,8.0
) so arch flags can be computed. - Pass required headers: always include
"torch/extension.h"
whenbind=True
, and CUDA headers like"cuda_runtime.h"
when launching kernels. - In CI or local testing, the import test is skipped unless
CHEETAH_ENABLE_TORCH_TEST=1
is set.