jasku.xyz

Consolidating Rust generics with associated types

Published on October 30, 2024

TLDR; Use degeneric-macros to consolidate generics in your APIs

Introduction to generics

In Rust, generics help us avoid code duplication. A classical example of a generic structure is a list - Vec in Rust. Vec can hold (almost) any type of element. With generics, the programmer only needs to write code for Vec once.

In Rust, the following language constructs can use generics:

  1. types
  2. functions
  3. impl blocks

Generic types

This is what generics look like in practice:

pub struct Vec<T> {
    // TODO
}

In the example above, the T is a generic parameter. The value of a generic parameter is a type. A Vec of strings can be created using the turbofish syntax:

let my_strings = Vec::<String>::new();

In the example above, the value of the T parameter is String.

Generic functions

In order to process generic types, functions support generics too. Here's an example of a function that prints all the elements of a Vec:

/// Prints all elements of the supplied Vec, one per line
fn print_all<T>(v: &Vec<T>)
    // Ensure that T implements Display
    where T: Display,
{
    for element in &v {
        println!("{element}");
    }
}

In this example, print_all uses a generic parameter T. The where clause ensures that all the elements of v implement the Display trait.

Functions can specify multiple generic parameters if needed. For example, here's a snippet that prints three Vec-s in a table-like layout:

/// Treat the supplied vectors as columns of a table and print the full table
fn print_table<T, U, V>(vec1: &Vec<T>, vec2: &Vec<U>, vec3: &Vec<V>)
    where T: Display, U: Display, V: Display
{
    let max_len = vec1.len().max(vec2.len()).max(vec3.len());

    for i in 0..max_len {
        let col1 = vec1.get(i).map_or("".to_string(), |v| v.to_string());
        let col2 = vec2.get(i).map_or("".to_string(), |v| v.to_string());
        let col3 = vec3.get(i).map_or("".to_string(), |v| v.to_string());

        println!("{:<10} {:<10} {:<10}", col1, col2, col3);
    }
}

Snippet created with the help of Copilot, my AI coding companion.

Observe the differences between signatures of print_all vs print_table:

  1. print_table has more parameters
  2. print_table has more generic parameters
  3. print_table has more bounds on generic parameters

The signature of print_table is arguably more complex.

Creating an API

Let's use those generic functions to build an API:

pub struct Api;

impl Api {
    fn print_all<T>(&self, v: &Vec<T>)
        // Ensure that T implements Display
        where T: Display,
    {
        for element in &v {
            println!("{element}");
        }
    }

    fn print_table<T, U, V>(&self, vec1: &Vec<T>, vec2: &Vec<U>, vec3: &Vec<V>)
        where T: Display, U: Display, V: Display
    {
        let max_len = vec1.len().max(vec2.len()).max(vec3.len());

        for i in 0..max_len {
            let col1 = vec1.get(i).map_or("".to_string(), |v| v.to_string());
            let col2 = vec2.get(i).map_or("".to_string(), |v| v.to_string());
            let col3 = vec3.get(i).map_or("".to_string(), |v| v.to_string());

            println!("{:<10} {:<10} {:<10}", col1, col2, col3);
        }
    }
}

Observe that the API currently has two functions. Imagine the programmer wants to add another function that prints the elements in ascending order. This adds the requirement that the elements need to be sortable (Ord). If the programmer then decides to start using async code, the elements will possibly have to implement Send + Sync + 'static as well.

As a result, the programmer must update all the code that interacts with the API to respect the new trait bounds. That includes builders, factories but also impl blocks for the Api struct and its consumers.

It is easy to forget a place or two while performing the updates.

What if there was a way to consolidate?

There is. It is associated types.

The following code:

fn something<T,U,V,W>(_: T, _: U, _: V, _: W)
    where
        T: Send + Sync + 'static,
        U: Display + PartialEq + Send + Sync + 'static,
        V: Ord + Eq + Send + Sync + 'static,
        W: Clone + Display + Eq + Send + Sync + 'static,
{
    todo!()
}

Can be replaced by:

fn something<C>(_: C)
    where
        C: MyValuesTrait,
{
    todo!()
}

But what is MyValuesTrait? Isn't this cheating? Not really. The trait bounds have to exist somewhere. They don't necessarily have to exist near the function header:

trait MyValuesTrait {
    type T: Send + Sync + 'static;
    type U: Display + PartialEq + Send + Sync + 'static;
    type V: Ord + Eq + Send + Sync + 'static;
    type W: Clone + Display + Eq + Send + Sync + 'static;

    fn get_t(&self) -> &Self::T;
    fn get_u(&self) -> &Self::U;
    fn get_v(&self) -> &Self::V;
    fn get_w(&self) -> &Self::W;
}

struct MyValues<T, U, V, W>
    where
        T: Send + Sync + 'static,
        U: Display + PartialEq + Send + Sync + 'static,
        V: Ord + Eq + Send + Sync + 'static,
        W: Clone + Display + Eq + Send + Sync + 'static,
{
    t: T,
    u: U,
    v: V,
    w: W,
}

// TODO impl MyValuesTrait for MyValues

You just added a lot of extra boilerplate code! How is this any better??

This is better for number of reasons:

  1. the bounds can be reused!!!
  2. the implementation logic is uncluttered
  3. struct has only type bounds, no extra stuff
  4. the bounds can (and should) be extracted to their own module

Look at reason 1 again. That is the whole point: consolidation.

Let's talk about these again because this is quite a big deal.

Whenever you need to use those four values with those exact constraints, you can use MyValuesTrait.

Whenever those bounds need to be updated, they can be updated on the trait which automatically propagates the changes across the codebase.

The struct itself might seem a little overwhelming. However, its only responsibility is to hold all of those values. It serves no other purpose. It is critical not to clutter the struct with other things. The struct now becomes a container.

Okay, but what about the boilerplate code?

  1. it is manageable
  2. it can be generated via a macro

This is what it looks like in its full glory:

trait MyValuesTrait {
    type T: Send + Sync + 'static;
    type U: Display + PartialEq + Send + Sync + 'static;
    type V: Ord + Eq + Send + Sync + 'static;
    type W: Clone + Display + Eq + Send + Sync + 'static;
}

struct MyValues<T, U, V, W>
    where
        T: Send + Sync + 'static,
        U: Display + PartialEq + Send + Sync + 'static,
        V: Ord + Eq + Send + Sync + 'static,
        W: Clone + Display + Eq + Send + Sync + 'static,
{
    t: T,
    u: U,
    v: V,
    w: W,
}

impl<_T,_U,_V,_W> MyValuesTrait for MyValues<_T, _U, _V, _W>
    where
        _T: Send + Sync + 'static,
        _U: Display + PartialEq + Send + Sync + 'static,
        _V: Ord + Eq + Send + Sync + 'static,
        _W: Clone + Display + Eq + Send + Sync + 'static,
{
    type T = _T;
    type U = _U;
    type V = _V;
    type W = _W;

    fn get_t(&self) -> &Self::T { &self.t }
    fn get_u(&self) -> &Self::U { &self.u }
    fn get_v(&self) -> &Self::V { &self.v }
    fn get_w(&self) -> &Self::W { &self.w }
}

In the listing above, the only thing that isn't boilerplate is the struct itself. Everything else can be generated based on the struct.

This is the core idea of degeneric-macros.