Structs From Scratch: Trie - Part 1

Introduction

Today we're going to be writing a data structure called a trie. A trie is a tree-like structure useful for holding alphabet data. You can use it to write a dictionary based on a string.

So, are tree-like structure will be made up of a set of nodes. Each node will contain a piece of data that we'll be storing in the trie along with an array of pointers to leaf nodes.

Part 1 will go over making the basic, safe data structure which you can add keys to and look up values. We'll also write a spell check program using the data structure we wrote. In later parts, we'll add to it to provide a more comprehensive API. That way if you just want to know how to create the data structure you can stop after part 1.

This article expects that you have rust/cargo installed through rustup.

Implementation

For our implementation, we'll use a bool as our piece of data. As for our alphabet, we'll use lowercase a-z (where the 0th index is a and the 25th index is z). This will allow us to store a collection of words and quickly check if a word exists. The bool will be used to represent if the word exists in our collection, with false being no and true being yes. So by default, the bool will be false, signifying that the node is not a word in our collection.

Our Data Types

Let's start by create the basic data structure type we'll be using.

// src/lib.rs
pub mod trie;  
// src/trie.rs
#[derive(Debug)]
/// ADT that represents Trie data structure
pub struct Trie {
    root: TrieNode,
}

#[derive(Debug)]
/// Node structure which gets allocated to heap
struct TrieNode {
    data: bool,
    nodes: [*mut TrieNode; 26],
}

ADT stands for Abstract Data Type, since the Trie structure is the interface used by the user through public methods

You may be wondering why I'm using a second type for the root (Trie) over using a regular TrieNode, like you'd see with linked lists in C. There's a few reasons for that.

For one, most operations on our Trie should be restricted to the root node. You shouldn't be inserting enties from a node other than than the root node. Second, the user doesn't need to be aware of the underlying representation of each node. Third, we want a single "owner" of the entire data structure so we can deallocate all the nodes once it goes out of scope. If we just had TrieNode and no root type, it'd be more difficult to implement the Drop trait later.

Constructors

Now that the explanation is out of the way, let's move to our constructor functions for Trie and TrieNode.

// src/trie.rs
use std::ptr;

impl Trie {
    pub fn new() -> Self {
        Self {
            root: TrieNode::new(),
        }
    }
}

impl TrieNode {
    fn new() -> Self {
        Self {
            data: false,
            nodes: [ptr::null_mut(); 26],
        }
    }
}

To better visualize this, each node look like so:

Trie:

root
TrieNode

TrieNode:

datanodes
false[null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null]

Initial Insertion

Now that we have our data types and constructors, we need to be able to insert data into our trie, so let's write a method for that

// src/trie.rs
/* New imports */
use std::alloc::{alloc_zeroed, Layout};

impl Trie {
    /* constructor code ... */
    
    pub fn insert(&mut self, key: &str, value: bool) {
        // Keep track of current node while traversing trie
        let mut cur_node = &mut self.root;
        // Iterate through characters
        for c in key.chars() {
            assert!('a' <= c && c <= 'z', "assert char is a-z");
            let idx = (c as u8 - 'a' as u8) as usize;
            // Get next node
            cur_node = if let Some(next) = unsafe { cur_node.nodes[idx].as_mut() } {
                // If pointer is not null, get the next node
                next
            } else {
                // If pointer is null, allocate a new node and return ref to it
                unsafe {
                    cur_node.nodes[idx] = alloc_zeroed(Layout::new::<TrieNode>()).cast();
                    cur_node.nodes[idx].as_mut().expect("memory alloced")
                }
            };
        }
        // Once at final node, set value
        cur_node.data = value;
    }
}

Lookup (get)

We can insert words into our trie, but there's no way to check if the word exists. In order to make sure our insert solution works correctly, we need to write a get method to get the value stored at that node. So let's write that:

// src/trie.rs
impl Trie {
    /* constructor code ... */

    /* insert code ... */

    pub fn get(&self, key: &str) -> Option<bool> {
        let mut cur_node = &self.root;
        for c in key.chars() {
            assert!('a' <= c && c <= 'z', "assert char is a-z");
            let idx = (c as u8 - 'a' as u8) as usize;
            cur_node = unsafe { cur_node.nodes[idx].as_ref() }?;
        }
        if cur_node.data {
            Some(cur_node.data)
        }
        else {
            None            
        }
    }
}

Let's go over this method. First we create variable to track our current node, cur_node starting with the root node. Then, we iterate over all the characters in the key. For each character we assert that it's valid, then we calculate the index idx. We use the index to get the pointer cur_node.nodes[idx] and get a reference to the node with as_ref, which returns an Option. If our pointer is null, it returns None, otherwise it returns Some with a reference to the data. We use ? to return early in the case where the pointer is null, signifying that that key does not exist.

Testing insert & get

Now that we hopefully have a working Trie, let's write some quick unit tests to check!

// src/trie.rs
/* at the bottom of our file trie.rs */
#[cfg(tests)]
mod test {
    use super::Trie;

    #[test]
    /// Make sure we can create the Trie
    fn creation() {
        Trie::new();
    }

    #[test]
    /// Test inserting and getting value
    fn insert_get() {
        let mut trie = Trie::new();
        trie.insert("apples", true);
        assert!(matches!(trie.get("apples"), Some(true)));
    }
}

So, we've written 2 tests. One to make sure we can create Trie without panicing, and one to test our insert and get methods. Let's just run our tests with cargo test and see what we get

running 2 tests
test trie::tests::creation ... ok
test trie::tests::insert_get ... ok  

Yay, our tests pass!

Safety

Our implementations may work, but there's one problem: we allocate memory for our data structure, but we never deallocate it, so we have a memory leak. In garbage collected languages like Python or Java, memory is automatically deallocated by the garbage collector. However, Rust has manual memory management, though Rust does a good job at hiding it from the average developer since most developers won't be dealing with raw pointers.

Ok, so if we have a memory leak, why doesn't Rust tell us? Well, we actually told the Rust compiler "we know what we're doing" when we used the unsafe blocks. This can lead to undefined behavior (UB). There's actually tools that can help us determine where we have UB. Rust has tool called MIRI, which can detect UB in our code. Let's try to run it on our tests to see if it picks up on anything with cargo +nightly miri test.

You can install miri with through rustup: rustup toolchain install nightly --component miri

Our tests pass, but we get the following error 6 times:

error: memory leaked: alloc33848 (Rust heap, size: 216, align: 8), allocated here:
   --> C:\Users\hpmas\.rustup\toolchains\nightly-x86_64-pc-windows-msvc\lib\rustlib\src\rust\library\alloc\src\alloc.rs:170:14
    |
170 |     unsafe { __rust_alloc_zeroed(layout.size(), layout.align()) }
    |              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    |
    = note: inside `std::alloc::alloc_zeroed` at C:\Users\hpmas\.rustup\toolchains\nightly-x86_64-pc-windows-msvc\lib\rustlib\src\rust\library\alloc\src\alloc.rs:170:14: 170:64
note: inside `trie::Trie::insert`
   --> src\trie.rs:49:43
    |
49  |                     cur_node.nodes[idx] = alloc_zeroed(Layout::new::<TrieNode>()).cast();
    |                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
note: inside `trie::tests::insert_get`
   --> src\trie.rs:93:9
    |
93  |         trie.insert("apples", true);
    |         ^^^^^^^^^^^^^^^^^^^^^^^^^^^
note: inside closure
   --> src\trie.rs:91:20
    |
90  |     #[test]
    |     ------- in this procedural macro expansion
91  |     fn insert_get() {
    |                    ^
    = note: this error originates in the attribute macro `test` (in Nightly builds, run with -Z macro-backtrace for more info)

Looks like MIRI was able to find some memory leaks. Let's try to fix these.

Deallocating

We allocated memory for our Trie, so let's properly deallocate it. In order to do this, we have to go to each node we've allocated to the heap and manually deallocate it using the dealloc function. We want to make sure we don't deallocate a node while we're still using it, since that would be a "use after free", which is another form of UB that leads to security vulnerabilities. So, to avoid use after free, we will traverse through the nodes in postorder traversal. This ensures that we deallocate every child node before deallocating the parent.

Rust allows use to write code that runs when a data structure is dropped, with the Drop trait. We can do this recursively to safely deallocate all of our nodes:

// src/trie.rs
impl Drop for TrieNode {
    fn drop(&mut self) {
        for node in self.nodes {
            if !node.is_null() {
                unsafe { 
                    node.drop_in_place();
                    dealloc(node as *mut _, Layout::new::<Self>());
                };
            }
        }
    }
}

This Drop implementation will cause a TrieNode iterate through each of it's child nodes and force them to be dropped before deallocating the child node object. drop_in_place causes the Drop implementation to be called on the object a pointer points to. We then deallocate the object so the memory is no longer leaked.

There are 2 reasons we call drop_in_place instead of just deallocating the object. First, we use it to recursively drop child nodes. Second, if our data attribute is a type that needs to be properly cleaned up, its drop method will also be called when the node's drop is called.

Since Trie owns the root node, once it goes out of scope its drop method is called and the data structure will be properly deallocated.

Now let's try to run cargo +nightly miri test:

running 2 tests
test trie::tests::creation ... ok
test trie::tests::insert_get ... ok

Success! Miri didn't catch any memory errors. Now that we have a safe and working data structure, let's improve it so it is more usable.

Improvements

Our current implementation only acts like a set that can hold a-z words, but we can improve it to store abitrary types within it.

Generics

Our current struct has a fixed data type, a bool, which has limited functionality. It'd be much more useful if we could allow for arbitrary data types to be stored as values in our Trie. Luckily Rust has generics, so our Trie can store arbitrary types. The data attribute of the TrieNode will use an Option<T> type to signify a value that is set, as opposed to having false. Let's update our code to use generics

// src/trie.rs
#[derive(Debug)]
/// ADT that represents Trie data structure
pub struct Trie<T> {
    root: TrieNode<T>,
}

#[derive(Debug)]
/// Node structure which gets allocated to heap
struct TrieNode<T> {
    data: Option<T>,
    nodes: [*mut TrieNode<T>; 26],
}

impl<T> TrieNode<T> {
    /* constructor ... */
}

impl<T> Trie<T> {
    /* constructor ... */

    pub fn insert(&mut self, key: &str, value: T) {
        /* -- snip -- */
                unsafe {
                    cur_node.nodes[idx] = alloc_zeroed(Layout::new::<TrieNode<T>>()).cast();
                    cur_node.nodes[idx].as_mut().expect("memory alloced")
                }
        /* -- snip -- */
    }

    pub fn get(&self, key: &str) -> Option<&T> {
        /* -- snip -- */
    }
}
impl<T> Drop for TrieNode<T> {
    /* fn drop() { ... }*/
}

We also need to update our tests

// src/trie.rs
#[cfg(test)]
mod tests {
    use super::Trie;

    #[test]
    fn creation() {
        Trie::<()>::new();
    }

    #[test]
    fn insert_get() {
        let mut trie = Trie::new();
        trie.insert("apples", true);
        assert!(matches!(trie.get("apples"), Some(true)));
    }

    #[test]
    fn get_non_existent() {
        let trie = Trie::<()>::new();
        assert!(matches!(trie.get("oranges"), None));
    }

    #[test]
    fn insert_box() {
        let mut trie = Trie::<Box<String>>::new();
        trie.insert("apples", Box::new(String::from("oranges")));
    }
}

For some of our tests, we use () our value type. () is a zero-sized type, so our TrieNode will only take up the space needed to hold the Option enum, equivalent to our original implementation with the bool.

contains

There's one more function we need for our example, the contains method. This does what you think, it checks if the Trie contains a certain key and if it does, it returns true, otherwise it returns false.

// src/trie.rs
impl<T> Trie<T> {
    /* -- snip -- */
    pub fn contains(&self, key: &str) -> bool {
        self.get(key).is_some()
    }
}

The method is really simple, it uses get to try to find the value from its key and uses is_some to check if it was found.

Building Our Spell Check

Ok, now that we have our basic data structure that can hold any value we need and check if the key exists, we have everything we need to write a spell check program. Our spell check program will accept 2 arguments, a dictionary file and a text file to spell check.

Let's first write a usage message to tell our user how to run the command. We'll make the first argument the dictionary file and the second argument the text file.

In this example, we're using words_alpha.txt from https://github.com/dwyl/english-words

// examples/spellcheck.rs
use std::path::PathBuf;

fn print_usage(exe_path: &PathBuf) {
    let exe = exe_path.file_name().expect("can get filename");
    println!("{} <dictionary> <txt_file>", exe.to_string_lossy());
}

Here we use PathBuf because the first argument of our program will always be the path to the program being run. This way to don't hard code the program name.

Now, let's write our simple argument parsing.

// examples/spellcheck.rs
use std::{env, process};

fn main() {
    let mut args = env::args();
    let prog_path = PathBuf::from(args.next()
        .expect("program name is always the first arg"));
    let Some(dictionary_file) = args.next() else {
        println!("Not enough args");
        print_usage(&prog_path);
        process::exit(1);
    };
    let Some(text_file) = args.next() else {
        println!("Not enough args");
        print_usage(&prog_path);
        process::exit(1);
    };
    if args.next().is_some() {
        println!("Too many args");
        print_usage(&prog_path);
        process::exit(1);
    }
}

In Rust, Args is an iterator that iterates over each argument passed to the program. The first is always that path/name of the program being run, so we can expect it to always be there. Then we try to get the next two arguments. If an argument is not found when we call args.next(), the else block is called where we give a brief error message and print the proper usage. The last part of our argument parsing checks for an additional argument and if it does, it quits early with a similar message.

Now, let's build our Trie with all the words in the dictionary.

// examples/spellcheck.rs
use std::fs::read_to_string;

use structs_from_scratch::trie::Trie;

fn main() {
    /* arg parsing code ... */
    let dict_raw = read_to_string(dictionary_file).unwrap();

    let mut dict = Trie::new();
    for word in dict_raw.split_whitespace() {
        dict.insert(word, ());
    }
}    

The dictionary building is pretty simple, read the dictionary_file to a string then iterator over the words based off of whitespace (this way file can be a list of words separated by spaces, tabs, and/or newlines). Then for each word, we insert into the dictionary with the key word and value () (again a zero sized type, so it lets the Option act like our original bool implementation).

Once our dictionary is built we have to read through our input file.

// examples/spellcheck.rs
fn main() {
    /* snip code ... */
    let file = read_to_string(text_file).unwrap();
    for word in file.split_whitespace() {
        let trimmed = word
            .trim_matches(|c| match c {
                '-' | '`' | '\'' | '"' | '(' | ')' => true,
                _ => false,
            })
            .to_lowercase();
        if trimmed.is_empty() {
            continue;
        }
        let alpha_only: Vec<&str> = trimmed.matches(|c| matches!(c, 'a'..='z')).collect();

        if alpha_only.len() == trimmed.len() {
            if !dict.contains(&trimmed) {
                println!("{trimmed:?} misspelt");
            }
        }
    }
}

Getting the word out of the file is a little bit interesting. We don't want to spell check a word that isn't made up of characters other than a-z. One thing to keep in mind, is that some words may be at the beginning of a quote, paranthesis, or another character. So we use trim_matches to trim those characters if they are found. So, if a word starts or ends with '-', '`', ''', '"', '(', ')' they get trimmed out. We do a check to see if our trimmed string is empty. We could have empty quotes, a random '-', or other weird instances, so we want to double check those. Last thing we do before we pass the value to our dictionary is to make sure our word is made of a-z characters. Doesn't have an is_alpha for strings, so we filter the trimmed string to just a-z characters and compare their length to the trimmed string.

Now we have a string that is valid for out dictionary lookup, we see if the trimmed word is a key in our dictionary dict.contains(&trimmed). If it's not there, we print out the word saying it's misspelt.

Running our program

I'm going to run our program on this article to see what I may have misspelt.

Command:

cargo run --example spellcheck -- .\words_alpha.txt ..\..\content\data_structs\tries-p1.md

Results:

"structs" misspelt
"trie" misspelt
"trie" misspelt
"toc" misspelt
"trie" misspelt
"trie" misspelt
"spellcheck" misspelt
"adt" misspelt
"trie" misspelt
"trie" misspelt
"trienode" misspelt
"adt" misspelt
"trie" misspelt
"trie" misspelt
"trie" misspelt
"trienode" misspelt
"trie" misspelt
"impl" misspelt
"trie" misspelt
"impl" misspelt
"trienode" misspelt
"trienode" misspelt
"impl" misspelt
"trie" misspelt
"trie" misspelt
"idx" misspelt
"alloced" misspelt
"impl" misspelt
"trie" misspelt
"idx" misspelt
"trie" misspelt
"trie" misspelt
"trie" misspelt
"ub" misspelt
"rustup" misspelt
"toolchain" misspelt
"dealloc" misspelt
"ub" misspelt
"impl" misspelt
"trienode" misspelt
"trienode" misspelt
"trie" misspelt
"abitrary" misspelt
"trie" misspelt
"trienode" misspelt
"adt" misspelt
"trie" misspelt
"alloced" misspelt
"trie" misspelt
"trie" misspelt
"trie" misspelt
"trienode" misspelt
"trie" misspelt
"spellcheck" misspelt
"exe" misspelt
"pathbuf" misspelt
"args" misspelt
"trie" misspelt
"seperated" misspelt
"dicionary" misspelt
"spellcheck" misspelt
"spellcheck" misspelt

It works!

So, there's a few words out dictionary file does have that we're using, some are short hands we use ("exe", "UB", "ADT", "idx"). It is parsing our rust sections, so that's expected. It also caught a few misspelt things I'll have to fix before I publish this!

What's next?

We have the basic data structure working, but it can be further improved. For example, there's no way to remove keys/values so anything we add is in our Trie until it's dropped. A pain point that can be seen in our spell check is that in order to use our Trie as a set, we specify () as the value type.

We can create a new type and interface for the Trie so that it more closely acts as a set. Rust also has a bunch of helpful traits that we could implement for our type. There's traits like Iter to iterator over keys/values, Clone to create a deep copy of our data structure, Extend to combines one Trie's keys/values into another one.

There are many other ways to improve out implementation, but this is a brief overview. We'll save these additions for later parts of this article, so if you just wanted to understand the data struture you can stop here.

Part 2 isn't fully written yet, but you can check back later for when it's published. This link will point to the article when its done.

Full source

src/trie.rs

examples/spellcheck.rs