Tensor shapes with both const generic and run time dimensions

Table of contents

The goal

Most ndarray/tensor crates have shapes known either at runtime or compile time. My own dfdx crate used to be only compile time. This blog post investigates how to mix the two approaches (spoiler: I did!).

Here are some examples of what we’d want to be possible:

Tensor<(2, 3)> // 2d tensor with shape 2x3 known at compile time
Tensor<(usize, usize)> // 2d tensor with shape known at runtime
Tensor<(usize, 4, 5)> // 3d tensor with mixed dimensions
Tensor<(1, usize, 2, usize)> // 4d tensor with mixed dimensions

Notably we can have:

  1. Shapes known entirely at compile time
  2. Shapes known entirely at run time
  3. Shapes made up of both compile time and run time dimensions

The problem here is that numbers are not types, and rust will complain at you if you try the above.

Turning numbers into types

The first thing we need to figure out is how to use numbers as types.

What I mean by that is not “the type of x is int”, rather “the type of x is 5” (for example):

let a: 5 = ???;

That’s what we mean by numbers as types. So how do we do this? With const generics of course!

Here’s the magic trick:

struct Const<const N: usize>;

This is a zero sized type - it has no runtime data associated with it. However we have tagged it with a const generic value, meaning it has data at compile time!

Now numbers can be types:

// it's type is `5`! 🤯
let a: Const<5> = Const;

It’s simple yet very powerful. I’ve used this a lot in dfdx so far, let me know if you find other uses!

Anyway, now that we have our magical trick to turn numbers into types, we can start building our shapes.

Shapes that mix compile and run time dimensions

The Dim trait

All shapes are made up of some number of dimensions, but what is a dimension exactly? Well it’s just something that has a size:

trait Dim {
    fn size(&self) -> usize;
}

It’s pretty easy to see how this fits with our Const struct:

impl<const N: usize> Dim for Const<N> {
    fn size(&self) -> usize {
        N
    }
}

For runtime dimensions, we can just use usize:

impl Dim for usize {
    fn size(&self) -> usize {
        *self
    }
}

Our example from the beginning turns into the following:

Tensor<(Const<2>, Const<3>)>
Tensor<(usize, usize)>
Tensor<(usize, Const<4>, Const<5>)>
Tensor<(Const<1>, usize, Const<2>, usize)>

This isn’t quite as clean as our ideal interface, but I think this is a price well worth paying. We can also simplify this a bit when we have all Const dimensions with the following type aliases:

type Rank0 = ();
type Rank1<const M: usize> = (Const<M>, );
type Rank2<const M: usize, const N: usize> = (Const<M>, Const<N>);
// etc...

To review:

🎉

Tying it all together - the Shape trait

As far as the internals of the Shape trait, it has to do with what arrays actually are, but at a high level it’s pretty much just:

trait Shape {
    const NUM_DIMS: usize;
    fn strides(&self) -> [usize; Self::NUM_DIMS];
}

Note: astute readers may notice that I’ve used a nightly feature with [usize; Self::NUM_DIMS]. My actual implementation doesn’t use this, but it’s more clear here.

And since I really want shapes to be tuples, here’s how you can impl this trait for tuples of different sizes:

impl Shape for () {
    const NUM_DIMS: usize = 0;
    fn strides(&self) -> [usize; Self::NUM_DIMS] {
        []
    }
}

impl<D1: Dim> Shape for (D1, ) {
    const NUM_DIMS: usize = 1;
    fn strides(&self) -> [usize; Self::NUM_DIMS] {
        [1]
    }
}

impl<D1: Dim, D2: Dim> Shape for (D1, D2) {
    const NUM_DIMS: usize = 2;
    fn strides(&self) -> [usize; Self::NUM_DIMS] {
        [self.1.size(), 1]
    }
}

// etc

I like how the actual implementations are using generics for Dim, that means we don’t have to specify them differently for const vs runtime. Yay generics!

And of course, you could also impl Shape for anything else, which is the beauty of rust.

Why is this useful though?

We now have the ultimate flexibility with shapes. Here are some motivation examples from how I’ve used it.

A clean matrix multiplication API

Matrix multiplication between two tensors requires 3 different dimensions (M, K, N):

  1. The left matrix is of shape (M, K)
  2. The right matrix is of shape (K, N)
  3. The resulting matrix is of shape (M, N)

The only requirement is that the K dimension is the same between both - we don’t care what M and N are.

This is where our compile time dimensions come into play! We can represent these like so:

fn matmul<M: Dim, const K: usize, N: Dim>(
    lhs: Tensor<(M, Const<K>)>,
    rhs: Tensor<(Const<K>, N)>,
) -> Tensor<(M, N)> {
    // ...
}

Now the K dimension is enforced to be the same at compile time, but the other dimensions can be whatever you want them to be (e.g. both usizes, both Const, or some mix of the two).

Generic shapes & dimensions, regardless of compile vs runtime

This is also really nice when you don’t care about enforcing shape constraints, like an element-wise operation:

fn relu<S: Shape>(t: Tensor<S>) -> Tensor<S> {
    ...
}

Here the shape could be anything!

You can also write code generic over specific dimensions, like a batch dimension for machine learning applications:

impl MnistDataset {
    fn get<Batch: Dim>(&self, b: Batch) -> Tensor<(Batch, Const<28>, Const<28>)> {
        ...
    }
}

Actual implementation

You can find my actual implementation (of up to 6 dimensions) in my dfdx crate. This was just recently added along with a big refactor, and will be available on crates.io on the next release.

In addition to being able to create tensors with shapes like I’ve described above, there are a ton of operations you can do on tensors like matrix multiplication, elementwise operations, etc., that all utilize the shape traits.

Actually, 99% of the operations I’ve implemented are just generic over shape. Really only matrix multiplication, convolutions, and multi head attention (from transformers) require compile time checks.

Final words & thank you

The end result is very satisfying to me, and I constantly feel like I’m discovering something new rust can do. Hopefully this is useful to people, let me know if you have other ideas!

Finally, thank you to @skinner, @iExalt, @scooter-dangle, @zeroows, @quietlychris, and the rest of my generous sponsors. Your continued support is very much appreciated and helps motivate me!