Index tools
cheetah.index_tools
provides dimension algebra and scoped constraints that drive structured index generation.
Concepts
- Dimensions: named size domains; may be known upfront or inferred via equations.
- Equations: multiplicative relationships between dimensions and integer strides.
- Scopes: attach subsets of equations to named scopes; activate them when generating code.
- Indices: runtime expressions derived from solved sizes and active equations.
Basic example
import cheetah.index_tools as ixt
import cheetah.api as ch
from cheetah.api import ty
dims = ixt.Dims()
N = dims.new_dim("N", 1 << 20)
B = dims.new_dim("B")
T = dims.new_dim("T", 1024)
dims.eq(N, (B, T))
@ch.kernel(ch.Params(out=ty.ptr_mut(ty.f32), a=ty.ptr_const(ty.f32), b=ty.ptr_const(ty.f32)))
def add_vec(out, a, b):
ix = dims.init()
ix.set_index(T, ch.thread_idx_x())
ix.set_index(B, ch.block_idx_x())
out[ix[N]] = a[ix[N]] + b[ix[N]]
Strides in equations
# X * 4 = Y indicates a stride factor 4 on X relative to Y
X = dims.new_dim("X")
Y = dims.new_dim("Y")
dims.eq((X, 4), Y)
Multiple scopes
s_tile = dims.new_scope("tile")
with dims.scope(s_tile):
dims.eq(N, (B, T))
s_swap = dims.new_scope("swap")
with dims.scope(s_swap):
dims.eq(N, (T, B)) # same size fact, different index order
@ch.kernel(ch.Params())
def k():
ix = dims.init()
ix.set_index(T, ch.thread_idx_x())
ix.set_index(B, ch.block_idx_x())
with ix.scope(s_tile):
ch.raw_stmt("// tile: $0", ix[N])
with ix.scope(s_swap):
ch.raw_stmt("// swap: $0", ix[N])
Diagnostics
ix.why_partial(B) # which equation constrains B
ix.why_solved(T) # where T was solved from or set externally
Pitfalls
- Dimensions must belong to the same
Dims
instance. - Equations must imply integral sizes; fractional results raise errors.
- Attempting to read
ix[dim]
before sufficient indices are set causes a diagnostic error listing associated vars. - Scopes must share ancestry with the current scope.
Loops and vectorization
with ix.loop(T, unroll=True):
dst[ix[N]] = src[ix[N]]
with ix.vectorize(T, size=4):
dst[ix[N]] = src[ix[N]]