dfdx

Blog on Deep Learning in Rust

First of all, thank you to @skinner, @iExalt, and the rest of my generous sponsors. Your support goes a long way for supporting my time on dfdx - I seriously appreciate it!

Second, here are the detailed release notes with all the changes included:

The rest of the post will be a more high level overview of the release, and some thoughts on next releases.

Table of contents

My favorite things about the release

BatchNorm2D, broadcasts, and reductions

This release added the nn::BatchNorm2D layer often used in convolutional networks. This actually required a bit infrastructure work to support broadcasts/reductions along multiple axes.

For those unfamiliar with axis reductions, it’s the same as rust’s Iterator::reduce, but can be applied to a specific axis (or axes) of a tensor. An axis of a tensor refers to one of its dimensions. For example a [[f32; 5]; 3] is a 2d array with 2 axes. If you reduce axis 0, you get a [f32; 5]. If you reduce axis 1, you get a [f32; 3].

I also added axis permutations for re-ordering the axes of a tensors, which is used in nn::MultiHeadAttention.

I went through a lot of designs, but what I ended up with feels very consistent across broadcasting/reducing/reshape/permute. You can see more detailed examples at examples/08-tensor-broadcast-reduce.rs and examples/09-tensor-permute.rs, but here’s a preview:

let a: Tensor1D<3> = tensor([1.0, 2.0, 3.0]);

// broadcast into a 2d tensor
let b: Tensor2D<5, 3> = a.broadcast();
assert_eq!(b.data(), &[[1.0, 2.0, 3.0]; 5]);

// permute (reverse) the axes
let c: Tensor2D<3, 5> = b.permute();
assert_eq!(c.data(), &[[1.0; 5], [2.0; 5], [3.0; 5]]);

// reduce to a 1d tensor by summing the 0th axis
let d: Tensor1D<5> = c.sum();
assert_eq!(d.data(), &[6.0; 5]);

Note that you don’t need to specify which axes are involved anywhere - rust determines that based on the output type. I think this is super ergonomic, but I also like this design because the code acts as documentation.

I can’t count the number of times I’ve seen comments in python code that say what shape the tensors are at specific points. You can see an example of this in this karpathy’s minGPT code:

k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

There’s nothing there that’s enforcing the comments accuracy, or that the tensor is actually that shape!

Here’s how the similar code snippet looks in dfdx:

let k: Tensor4D<B, S2, H, { K / H }, _> = k.reshape();
let k: Tensor4D<B, H, S2, { K / H }, _> = k.permute();

no_std support & Spring Fall Cleaning

Another big win was depending on no_std_compat to enable turning off std. This was first asked by @antimora, so thanks for bringing this up!

It’s always nice to clean things up, and as part of the no_std changes, I thinned out the features of all of dfdx’s dependencies. Also improved the intel-mkl features, and even added a feature flags documentation page!

Using Arc in tensors

@caelunshun opened a really nice PR about using Arc instead of Rc

I like this change a lot because:

  1. The diff is satisfyling small
  2. It’s big impact because now you can share tensors and things that contain tensors across threads!

Better cloning semantics

Had a great discussion with @M1ngXU about how cloning of tensors works, and it led to some really nice simplifications of how clone is implemented.

A unit test a day keeps the bugs away

I changed conv2d’s implementation to use matrix multiplications, and the unit tests for conv2d caught quite a few issues I had during the rewrite.

I encourage everyone to write more tests, your future self will definitely thank you!

State of dfdx

I wanted to briefly discuss where dfdx currently is, and where it still needs to go.

As of 0.10.0, you can use dfdx to represent MLPs, resnets, and simple transformers. You can train these using standard deep learning optimizers, and save them to a .npz format. Honestly there’s a ton of features behind all this!

That being said, there’s still a lot missing, and dfdx is definitely still in a pre-alpha state. The biggest things I’m thinking about and planning for are:

  1. GPU accleration via CUDA
  2. Strided arrays

The next few releases will be focused around the above, and actually will utilize GATs under the hood (which will be stabilized in rust 1.65) to keep the internals clean.

For example this is what I want the internal device trait to look like:

trait Device {
    type Storage<T>;
}

and these are a few example impls of this:

// current Cpu device that stores things on heap
impl Device for Cpu {
    type Storage<T> = Box<T>;
}

// cpu device that stores things as strided arrays instead of rust arrays
impl Device for CpuStrided {
    type Storage<T> = StridedArray<T>;
}

// cuda device
impl Device for Cuda {
    type Storage<T> = CudaBox<T>;
}

A note on breaking changes

I am planning the next release to be a breaking release as well, since it will be a refactor of the backend devices api. Notably how creating tensors and modules will require a specific device:

// or `device: Cuda`
let device: Cpu = Default::default();
let a: Tensor2D<2, 3> = TensorCreator::randn(&device);
let m: Linear<3, 5> = ModuleCreator::zeros(&device);

I went back and forth about whether to combine this release with the next to minimize breaking changes, but ultimately decided to break them into separate releases for a number of reasons:

I am definitely trying to keep breaking changes to a minimum, however I’m still prioritizing ergonomics and clean internals over api stability.

Certainly open to feedback on this, so if you have thoughts reach out on the discord!