Skip to content

Index tools

cheetah.index_tools provides dimension algebra and scoped constraints that drive structured index generation.

Concepts

  • Dimensions: named size domains; may be known upfront or inferred via equations.
  • Equations: multiplicative relationships between dimensions and integer strides.
  • Scopes: attach subsets of equations to named scopes; activate them when generating code.
  • Indices: runtime expressions derived from solved sizes and active equations.

Basic example

import cheetah.index_tools as ixt
import cheetah.api as ch
from cheetah.api import ty

dims = ixt.Dims()
N = dims.new_dim("N", 1 << 20)
B = dims.new_dim("B")
T = dims.new_dim("T", 1024)
dims.eq(N, (B, T))

@ch.kernel(ch.Params(out=ty.ptr_mut(ty.f32), a=ty.ptr_const(ty.f32), b=ty.ptr_const(ty.f32)))
def add_vec(out, a, b):
    ix = dims.init()
    ix.set_index(T, ch.thread_idx_x())
    ix.set_index(B, ch.block_idx_x())
    out[ix[N]] = a[ix[N]] + b[ix[N]]

Strides in equations

# X * 4 = Y indicates a stride factor 4 on X relative to Y
X = dims.new_dim("X")
Y = dims.new_dim("Y")
dims.eq((X, 4), Y)

Multiple scopes

s_tile = dims.new_scope("tile")
with dims.scope(s_tile):
    dims.eq(N, (B, T))

s_swap = dims.new_scope("swap")
with dims.scope(s_swap):
    dims.eq(N, (T, B))  # same size fact, different index order

@ch.kernel(ch.Params())
def k():
    ix = dims.init()
    ix.set_index(T, ch.thread_idx_x())
    ix.set_index(B, ch.block_idx_x())
    with ix.scope(s_tile):
        ch.raw_stmt("// tile: $0", ix[N])
    with ix.scope(s_swap):
        ch.raw_stmt("// swap: $0", ix[N])

Diagnostics

ix.why_partial(B)  # which equation constrains B
ix.why_solved(T)   # where T was solved from or set externally

Pitfalls

  • Dimensions must belong to the same Dims instance.
  • Equations must imply integral sizes; fractional results raise errors.
  • Attempting to read ix[dim] before sufficient indices are set causes a diagnostic error listing associated vars.
  • Scopes must share ancestry with the current scope.

Loops and vectorization

with ix.loop(T, unroll=True):
    dst[ix[N]] = src[ix[N]]

with ix.vectorize(T, size=4):
    dst[ix[N]] = src[ix[N]]