Skip to content

Torch bindings

This guide shows how to expose generated code to Python as a PyTorch extension using pybind11 via torch.utils.cpp_extension.load.

Pattern: host wrapper + device kernel

  • Device kernel: declared with @ch.kernel and pointer types.
  • Host wrapper: declared with @ch.fn(..., host=True) and torch::Tensor types using ty.tensor_mut/const. The wrapper casts the tensor to a pointer and launches the device kernel.
import cheetah.api as ch
from cheetah.api import ty

@ch.kernel(ch.Params(n=ty.u32, x=ty.ptr_mut(ty.i32)))
def mul2_kernel(n, x):
    i = ch.block_idx_x() * ch.block_dim_x() + ch.thread_idx_x()
    with ch.if_(i < n):
        x[i] *= 2

@ch.fn(ch.Params(n=ty.u32, x=ty.tensor_mut(ty.i32)), host=True)
def mul2(n, x):
    ptr = x.cast(ty.ptr_mut(ty.i32))
    grid = (n + ch.const(255, ty.u32)) // ch.const(256, ty.u32)
    block = ch.const(256, ty.u32)
    ch.raw_stmt("mul2_kernel<<<$0, $1>>>($2, $3);", grid, block, n, ptr)

# Emit both the bound host wrapper and the device kernel definition
src = ch.render(
    [("mul2", mul2), ("mul2_kernel", mul2_kernel)],
    headers=["cuda_runtime.h", "torch/extension.h"],
    bind=True,
    bind_name="mul2",
)

Build and import with PyTorch

from pathlib import Path
from torch.utils.cpp_extension import load

cu = Path("/tmp/cheetah_mul2.cu"); cu.write_text(src)
module = load(name="cheetah_mul2", sources=[str(cu)], verbose=True)

# Call from Python
# module.mul2(n: int, x: torch.Tensor[int32, cuda])

Tensor types

  • ty.tensor_mut(elem) / ty.tensor_const(elem) map to torch::Tensor with the corresponding dtype.
  • Use .cast(ty.ptr_*(elem)) to retrieve the raw data pointer for device launches.
  • Allocate a tensor inside device code (e.g., for scratch) with ch.alloc_tensor(elem, count).

Tips and troubleshooting

  • Ensure the CUDA toolkit (nvcc) is installed; PyTorch finds it via CUDA_HOME. If no GPU is present, set TORCH_CUDA_ARCH_LIST (e.g., 8.0) so arch flags can be computed.
  • Pass required headers: always include "torch/extension.h" when bind=True, and CUDA headers like "cuda_runtime.h" when launching kernels.
  • In CI or local testing, the import test is skipped unless CHEETAH_ENABLE_TORCH_TEST=1 is set.