Skip to content

Dims solver

cheetah.index_tools provides a small, scope-aware algebra for describing shapes and deriving runtime indices in a structured way.

Core concepts

  • Dimensions (compile-time sizes)
  • A DimName represents a named size domain within a specific Dims object (e.g., N, B, T).
  • A dimension may have a known size (e.g., N = 1<<20) or be inferred by equations.
  • Dimensions belong to the Dims instance that created them and cannot be mixed across different Dims.

  • Indices (runtime values)

  • After sizes are solved, an Indices object provides a mapping from each DimName to a runtime expression (ch.Expr) representing the current index in that dimension.
  • The user set some indices (e.g., B = blockIdx.x, T = threadIdx.x) and the solver derives the rest based on equations and active scopes.
  • Accessing ix[dim] yields the linearized runtime expression for that dimension in the current scope context.

  • Equations (shape/stride relationships)

  • Equations have multiplicative form on each side, combining dimensions and integer stride constants: lhs = rhs.
  • Example: N = B * T, or X * 4 = Y (indicating a stride of 4 on X).
  • Equations can be attached to scopes to apply only when those scopes are active.

  • Scopes (local facts)

  • Dims.new_scope(name) creates a child scope; with dims.scope(s): ... attaches equations to that scope.
  • When generating indices: with ix.scope(s): ... activates the scope’s equations for solving and index derivation.

Relationship between dimensions and indices

  • Dimensions capture static size facts and factorization (e.g., N = B * T).
  • Indices are dynamic values chosen at runtime that must respect those static facts. For a factorization N = B * T, any linearized index for N must decompose uniquely into (B_index, T_index) using the sizes |B| and |T|.
  • The solver ensures size consistency first (all dimensions referenced in equations have integral sizes). Only then are index values propagated between dimensions according to the active equations.

Solving sizes (compile-time)

  1. Collect equations from all scopes for Dims.init() size solving. (Actual index generation will use scope subsets.)
  2. Convert each equation to a normalized form _SolverEquation, tracking unknowns (dimensions with integer exponents) and a known rational product of constants.
  3. Repeatedly solve equations with a single unknown by taking exact integer roots; propagate solved values to other equations.
  4. If any equation yields a fractional size or remains underdetermined, raise a DimsError with a reduced form for debugging.

Notes: - Sizes must be integers. Fractional sizes or roots are errors. - Dimensions with size 1 are automatically considered solved and their index defaults to 0.

Preparing Indices (runtime view)

ix = dims.init() returns an Indices object bound to the solved sizes and scope tree. Internally, equations in the current scope are grouped to allow incremental propagation when indices are set.

  • Setting an index: ix.set_index(dim, expr) binds the runtime index for dim to expr and propagates information through equations in the current scope.
  • Reading an index: ix[dim] returns the derived runtime expression for dim if it is solved in the current scope; otherwise, a diagnostic DimsError explains why it is not yet resolved.

Information flow and consistency

  • Each equation induces a direction of information flow depending on which side has known/solved indices.
  • The solver tracks, per equation, whether information flows left→right or right→left. If both sides would need to flow simultaneously (due to multiple constraints), a precise DimsError is raised explaining the conflict and its origin.

Linearization and unlinearization

  • Linearization: a tuple of dimensions (d0, d1, ..., dk) forms a linearized index by the usual row-major rule using strides computed from sizes. The solver represents the linearized side as a sum of idx(di) * stride(di) expressions.
  • Unlinearization: given the linearized expression from the solved side, the solver derives the indices for each unknown dimension on the other side by dividing/modding by the appropriate strides and sizes.
  • Integer terms (stride constants) constrain the arithmetic but are never solved as indices themselves.

Scopes during runtime

  • Enter a scope to activate its equations for index solving: python with ix.scope(some_scope): ... # equations in some_scope guide propagation
  • Scopes must obey ancestry: entering a scope is allowed only if that scope’s parent has been entered. The scope's parent does not need to be the most recently entered scope.

Loops and vectorization helpers

  • with ix.loop(dim, unroll=...): introduces a for loop bound to the size of dim and sets the dimension’s index to the loop variable within the scope: python with ix.loop(T, unroll=True): dst[ix[N]] = src[ix[N]]
  • with ix.vectorize(dim, size): requires |dim| == size and fixes that index to 0 for the scope, indicating a vectorized region.

Diagnostics and debugging

  • ix.why_partial(dim) explains which equation constrains but doesn’t fully solve dim in the current scope.
  • ix.why_solved(dim) indicates where the value came from (set externally or solved by a specific equation).
  • When an index is unavailable, errors enumerate associated variables and their solved state.

Common pitfalls

  • Mixing DimNames from different Dims objects is illegal and will raise a DimsError.
  • Equations must imply integral sizes; fractional solutions (including non-integer roots) are rejected.
  • Indices may be unsolved if you read them before setting any necessary upstream indices in the active scope.
  • Scopes must share ancestry; attempting to use a scope from a different branch raises an error.

End-to-end pattern

  1. Define dimensions and equations (optionally organized in scopes).
  2. Call dims.init() to validate and freeze sizes; keep the returned Indices for runtime.
  3. In kernels, set some indices from CUDA builtins (e.g., threadIdx.x, blockIdx.x), or runtime values. Let the solver propogate the rest.
  4. Use ix[dim] to access linearized indices.
  5. Use ix.loop(...)/ix.vectorize(...) as appropriate.