Dims solver
cheetah.index_tools
provides a small, scope-aware algebra for describing shapes and deriving runtime indices in a structured way.
Core concepts
- Dimensions (compile-time sizes)
- A
DimName
represents a named size domain within a specificDims
object (e.g.,N
,B
,T
). - A dimension may have a known size (e.g.,
N = 1<<20
) or be inferred by equations. -
Dimensions belong to the
Dims
instance that created them and cannot be mixed across differentDims
. -
Indices (runtime values)
- After sizes are solved, an
Indices
object provides a mapping from eachDimName
to a runtime expression (ch.Expr
) representing the current index in that dimension. - The user set some indices (e.g.,
B = blockIdx.x
,T = threadIdx.x
) and the solver derives the rest based on equations and active scopes. -
Accessing
ix[dim]
yields the linearized runtime expression for that dimension in the current scope context. -
Equations (shape/stride relationships)
- Equations have multiplicative form on each side, combining dimensions and integer stride constants:
lhs = rhs
. - Example:
N = B * T
, orX * 4 = Y
(indicating a stride of 4 onX
). -
Equations can be attached to scopes to apply only when those scopes are active.
-
Scopes (local facts)
Dims.new_scope(name)
creates a child scope;with dims.scope(s): ...
attaches equations to that scope.- When generating indices:
with ix.scope(s): ...
activates the scope’s equations for solving and index derivation.
Relationship between dimensions and indices
- Dimensions capture static size facts and factorization (e.g.,
N = B * T
). - Indices are dynamic values chosen at runtime that must respect those static facts. For a factorization
N = B * T
, any linearized index forN
must decompose uniquely into(B_index, T_index)
using the sizes|B|
and|T|
. - The solver ensures size consistency first (all dimensions referenced in equations have integral sizes). Only then are index values propagated between dimensions according to the active equations.
Solving sizes (compile-time)
- Collect equations from all scopes for
Dims.init()
size solving. (Actual index generation will use scope subsets.) - Convert each equation to a normalized form
_SolverEquation
, tracking unknowns (dimensions with integer exponents) and a known rational product of constants. - Repeatedly solve equations with a single unknown by taking exact integer roots; propagate solved values to other equations.
- If any equation yields a fractional size or remains underdetermined, raise a
DimsError
with a reduced form for debugging.
Notes: - Sizes must be integers. Fractional sizes or roots are errors. - Dimensions with size 1 are automatically considered solved and their index defaults to 0.
Preparing Indices
(runtime view)
ix = dims.init()
returns an Indices
object bound to the solved sizes and scope tree. Internally, equations in the current scope are grouped to allow incremental propagation when indices are set.
- Setting an index:
ix.set_index(dim, expr)
binds the runtime index fordim
toexpr
and propagates information through equations in the current scope. - Reading an index:
ix[dim]
returns the derived runtime expression fordim
if it is solved in the current scope; otherwise, a diagnosticDimsError
explains why it is not yet resolved.
Information flow and consistency
- Each equation induces a direction of information flow depending on which side has known/solved indices.
- The solver tracks, per equation, whether information flows left→right or right→left. If both sides would need to flow simultaneously (due to multiple constraints), a precise
DimsError
is raised explaining the conflict and its origin.
Linearization and unlinearization
- Linearization: a tuple of dimensions
(d0, d1, ..., dk)
forms a linearized index by the usual row-major rule using strides computed from sizes. The solver represents the linearized side as a sum ofidx(di) * stride(di)
expressions. - Unlinearization: given the linearized expression from the solved side, the solver derives the indices for each unknown dimension on the other side by dividing/modding by the appropriate strides and sizes.
- Integer terms (stride constants) constrain the arithmetic but are never solved as indices themselves.
Scopes during runtime
- Enter a scope to activate its equations for index solving:
python with ix.scope(some_scope): ... # equations in some_scope guide propagation
- Scopes must obey ancestry: entering a scope is allowed only if that scope’s parent has been entered. The scope's parent does not need to be the most recently entered scope.
Loops and vectorization helpers
with ix.loop(dim, unroll=...)
: introduces afor
loop bound to the size ofdim
and sets the dimension’s index to the loop variable within the scope:python with ix.loop(T, unroll=True): dst[ix[N]] = src[ix[N]]
with ix.vectorize(dim, size)
: requires|dim| == size
and fixes that index to 0 for the scope, indicating a vectorized region.
Diagnostics and debugging
ix.why_partial(dim)
explains which equation constrains but doesn’t fully solvedim
in the current scope.ix.why_solved(dim)
indicates where the value came from (set externally or solved by a specific equation).- When an index is unavailable, errors enumerate associated variables and their solved state.
Common pitfalls
- Mixing
DimName
s from differentDims
objects is illegal and will raise aDimsError
. - Equations must imply integral sizes; fractional solutions (including non-integer roots) are rejected.
- Indices may be unsolved if you read them before setting any necessary upstream indices in the active scope.
- Scopes must share ancestry; attempting to use a scope from a different branch raises an error.
End-to-end pattern
- Define dimensions and equations (optionally organized in scopes).
- Call
dims.init()
to validate and freeze sizes; keep the returnedIndices
for runtime. - In kernels, set some indices from CUDA builtins (e.g.,
threadIdx.x
,blockIdx.x
), or runtime values. Let the solver propogate the rest. - Use
ix[dim]
to access linearized indices. - Use
ix.loop(...)
/ix.vectorize(...)
as appropriate.