Table of Contents
Introduction
Today we're going to be writing a data structure called a trie. A trie is a tree-like structure useful for holding alphabet data. You can use it to write a dictionary based on a string.
So, are tree-like structure will be made up of a set of nodes. Each node will contain a piece of data that we'll be storing in the trie along with an array of pointers to leaf nodes.
Part 1 will go over making the basic, safe data structure which you can add keys to and look up values. We'll also write a spell check program using the data structure we wrote. In later parts, we'll add to it to provide a more comprehensive API. That way if you just want to know how to create the data structure you can stop after part 1.
This article expects that you have
rust/cargoinstalled throughrustup.
Implementation
For our implementation, we'll use a bool as our piece of data. As for our
alphabet, we'll use lowercase a-z (where the 0th index is a and the 25th
index is z). This will allow us to store a collection of words and quickly
check if a word exists. The bool will be used to represent if the word exists
in our collection, with false being no and true being yes. So by default,
the bool will be false, signifying that the node is not a word in our
collection.
Our Data Types
Let's start by create the basic data structure type we'll be using.
// src/lib.rs
pub mod trie;
// src/trie.rs
#[derive(Debug)]
/// ADT that represents Trie data structure
pub struct Trie {
root: TrieNode,
}
#[derive(Debug)]
/// Node structure which gets allocated to heap
struct TrieNode {
data: bool,
nodes: [*mut TrieNode; 26],
}
ADT stands for Abstract Data Type, since the
Triestructure is the interface used by the user through public methods
You may be wondering why I'm using a second type for the root (Trie) over
using a regular TrieNode, like you'd see with linked lists in C. There's a
few reasons for that.
For one, most operations on our Trie should be restricted to the root node. You
shouldn't be inserting enties from a node other than than the root node.
Second, the user doesn't need to be aware of the underlying representation of
each node. Third, we want a single "owner" of the entire data structure so we
can deallocate all the nodes once it goes out of scope. If we just had
TrieNode and no root type, it'd be more difficult to implement the Drop
trait later.
Constructors
Now that the explanation is out of the way, let's move to our constructor
functions for Trie and TrieNode.
// src/trie.rs
use std::ptr;
impl Trie {
pub fn new() -> Self {
Self {
root: TrieNode::new(),
}
}
}
impl TrieNode {
fn new() -> Self {
Self {
data: false,
nodes: [ptr::null_mut(); 26],
}
}
}
To better visualize this, each node look like so:
Trie:
| root |
|---|
| TrieNode |
TrieNode:
| data | nodes |
|---|---|
| false | [null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null] |
Initial Insertion
Now that we have our data types and constructors, we need to be able to insert data into our trie, so let's write a method for that
// src/trie.rs
/* New imports */
use std::alloc::{alloc_zeroed, Layout};
impl Trie {
/* constructor code ... */
pub fn insert(&mut self, key: &str, value: bool) {
// Keep track of current node while traversing trie
let mut cur_node = &mut self.root;
// Iterate through characters
for c in key.chars() {
assert!('a' <= c && c <= 'z', "assert char is a-z");
let idx = (c as u8 - 'a' as u8) as usize;
// Get next node
cur_node = if let Some(next) = unsafe { cur_node.nodes[idx].as_mut() } {
// If pointer is not null, get the next node
next
} else {
// If pointer is null, allocate a new node and return ref to it
unsafe {
cur_node.nodes[idx] = alloc_zeroed(Layout::new::<TrieNode>()).cast();
cur_node.nodes[idx].as_mut().expect("memory alloced")
}
};
}
// Once at final node, set value
cur_node.data = value;
}
}
Lookup (get)
We can insert words into our trie, but there's no way to check if the word
exists. In order to make sure our insert solution works correctly, we need to
write a get method to get the value stored at that node. So let's write that:
// src/trie.rs
impl Trie {
/* constructor code ... */
/* insert code ... */
pub fn get(&self, key: &str) -> Option<bool> {
let mut cur_node = &self.root;
for c in key.chars() {
assert!('a' <= c && c <= 'z', "assert char is a-z");
let idx = (c as u8 - 'a' as u8) as usize;
cur_node = unsafe { cur_node.nodes[idx].as_ref() }?;
}
if cur_node.data {
Some(cur_node.data)
}
else {
None
}
}
}
Let's go over this method. First we create variable to track our current node,
cur_node starting with the root node. Then, we iterate over all the
characters in the key. For each character we assert that it's valid, then we
calculate the index idx. We use the index to get the pointer
cur_node.nodes[idx] and get a reference to the node with as_ref, which
returns an Option. If our pointer is null, it returns None, otherwise it
returns Some with a reference to the data. We use ? to return early in the
case where the pointer is null, signifying that that key does not exist.
Testing insert & get
Now that we hopefully have a working Trie, let's write some quick unit
tests to check!
// src/trie.rs
/* at the bottom of our file trie.rs */
#[cfg(tests)]
mod test {
use super::Trie;
#[test]
/// Make sure we can create the Trie
fn creation() {
Trie::new();
}
#[test]
/// Test inserting and getting value
fn insert_get() {
let mut trie = Trie::new();
trie.insert("apples", true);
assert!(matches!(trie.get("apples"), Some(true)));
}
}
So, we've written 2 tests. One to make sure we can create Trie without
panicing, and one to test our insert and get methods. Let's just run our
tests with cargo test and see what we get
running 2 tests
test trie::tests::creation ... ok
test trie::tests::insert_get ... ok
Yay, our tests pass!
Safety
Our implementations may work, but there's one problem: we allocate memory for our data structure, but we never deallocate it, so we have a memory leak. In garbage collected languages like Python or Java, memory is automatically deallocated by the garbage collector. However, Rust has manual memory management, though Rust does a good job at hiding it from the average developer since most developers won't be dealing with raw pointers.
Ok, so if we have a memory leak, why doesn't Rust tell us? Well, we actually
told the Rust compiler "we know what we're doing" when we used the unsafe
blocks. This can lead to undefined behavior (UB). There's actually tools that
can help us determine where we have UB. Rust has tool called MIRI, which can
detect UB in our code. Let's try to run it on our tests to see if it picks up
on anything with cargo +nightly miri test.
You can install miri with through rustup:
rustup toolchain install nightly --component miri
Our tests pass, but we get the following error 6 times:
error: memory leaked: alloc33848 (Rust heap, size: 216, align: 8), allocated here:
--> C:\Users\hpmas\.rustup\toolchains\nightly-x86_64-pc-windows-msvc\lib\rustlib\src\rust\library\alloc\src\alloc.rs:170:14
|
170 | unsafe { __rust_alloc_zeroed(layout.size(), layout.align()) }
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
= note: inside `std::alloc::alloc_zeroed` at C:\Users\hpmas\.rustup\toolchains\nightly-x86_64-pc-windows-msvc\lib\rustlib\src\rust\library\alloc\src\alloc.rs:170:14: 170:64
note: inside `trie::Trie::insert`
--> src\trie.rs:49:43
|
49 | cur_node.nodes[idx] = alloc_zeroed(Layout::new::<TrieNode>()).cast();
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
note: inside `trie::tests::insert_get`
--> src\trie.rs:93:9
|
93 | trie.insert("apples", true);
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^
note: inside closure
--> src\trie.rs:91:20
|
90 | #[test]
| ------- in this procedural macro expansion
91 | fn insert_get() {
| ^
= note: this error originates in the attribute macro `test` (in Nightly builds, run with -Z macro-backtrace for more info)
Looks like MIRI was able to find some memory leaks. Let's try to fix these.
Deallocating
We allocated memory for our Trie, so let's properly deallocate it. In order to
do this, we have to go to each node we've allocated to the heap and manually
deallocate it using the dealloc function. We want to make sure we don't
deallocate a node while we're still using it, since that would be a "use after
free", which is another form of UB that leads to security vulnerabilities. So,
to avoid use after free, we will traverse through the nodes in postorder
traversal.
This ensures that we deallocate every child node before deallocating the parent.
Rust allows use to write code that runs when a data structure is dropped, with
the Drop trait.
We can do this recursively to safely deallocate all of our nodes:
// src/trie.rs
impl Drop for TrieNode {
fn drop(&mut self) {
for node in self.nodes {
if !node.is_null() {
unsafe {
node.drop_in_place();
dealloc(node as *mut _, Layout::new::<Self>());
};
}
}
}
}
This Drop implementation will cause a TrieNode iterate through each of it's
child nodes and force them to be dropped before deallocating the child node
object. drop_in_place causes the Drop implementation to be called on the
object a pointer points to. We then deallocate the object so the memory is no
longer leaked.
There are 2 reasons we call drop_in_place instead of just deallocating the
object. First, we use it to recursively drop child nodes. Second, if our data
attribute is a type that needs to be properly cleaned up, its drop method will
also be called when the node's drop is called.
Since Trie owns the root node, once it goes out of scope its drop method is
called and the data structure will be properly deallocated.
Now let's try to run cargo +nightly miri test:
running 2 tests
test trie::tests::creation ... ok
test trie::tests::insert_get ... ok
Success! Miri didn't catch any memory errors. Now that we have a safe and
working data structure, let's improve it so it is more usable.
Improvements
Our current implementation only acts like a set that can hold a-z words, but we can improve it to store abitrary types within it.
Generics
Our current struct has a fixed data type, a bool, which has limited
functionality. It'd be much more useful if we could allow for arbitrary data
types to be stored as values in our Trie. Luckily Rust has generics, so our
Trie can store arbitrary types. The data attribute of the TrieNode will
use an Option<T> type to signify a value that is set, as opposed to having
false. Let's update our code to use generics
// src/trie.rs
#[derive(Debug)]
/// ADT that represents Trie data structure
pub struct Trie<T> {
root: TrieNode<T>,
}
#[derive(Debug)]
/// Node structure which gets allocated to heap
struct TrieNode<T> {
data: Option<T>,
nodes: [*mut TrieNode<T>; 26],
}
impl<T> TrieNode<T> {
/* constructor ... */
}
impl<T> Trie<T> {
/* constructor ... */
pub fn insert(&mut self, key: &str, value: T) {
/* -- snip -- */
unsafe {
cur_node.nodes[idx] = alloc_zeroed(Layout::new::<TrieNode<T>>()).cast();
cur_node.nodes[idx].as_mut().expect("memory alloced")
}
/* -- snip -- */
}
pub fn get(&self, key: &str) -> Option<&T> {
/* -- snip -- */
}
}
impl<T> Drop for TrieNode<T> {
/* fn drop() { ... }*/
}
We also need to update our tests
// src/trie.rs
#[cfg(test)]
mod tests {
use super::Trie;
#[test]
fn creation() {
Trie::<()>::new();
}
#[test]
fn insert_get() {
let mut trie = Trie::new();
trie.insert("apples", true);
assert!(matches!(trie.get("apples"), Some(true)));
}
#[test]
fn get_non_existent() {
let trie = Trie::<()>::new();
assert!(matches!(trie.get("oranges"), None));
}
#[test]
fn insert_box() {
let mut trie = Trie::<Box<String>>::new();
trie.insert("apples", Box::new(String::from("oranges")));
}
}
For some of our tests, we use () our value type. () is a zero-sized type, so our
TrieNode will only take up the space needed to hold the Option enum,
equivalent to our original implementation with the bool.
contains
There's one more function we need for our example, the contains method. This
does what you think, it checks if the Trie contains a certain key and if it
does, it returns true, otherwise it returns false.
// src/trie.rs
impl<T> Trie<T> {
/* -- snip -- */
pub fn contains(&self, key: &str) -> bool {
self.get(key).is_some()
}
}
The method is really simple, it uses get to try to find the value from its
key and uses is_some to check if it was found.
Building Our Spell Check
Ok, now that we have our basic data structure that can hold any value we need and check if the key exists, we have everything we need to write a spell check program. Our spell check program will accept 2 arguments, a dictionary file and a text file to spell check.
Let's first write a usage message to tell our user how to run the command. We'll make the first argument the dictionary file and the second argument the text file.
In this example, we're using words_alpha.txt from https://github.com/dwyl/english-words
// examples/spellcheck.rs
use std::path::PathBuf;
fn print_usage(exe_path: &PathBuf) {
let exe = exe_path.file_name().expect("can get filename");
println!("{} <dictionary> <txt_file>", exe.to_string_lossy());
}
Here we use PathBuf because the first argument of our program will always be
the path to the program being run. This way to don't hard code the program name.
Now, let's write our simple argument parsing.
// examples/spellcheck.rs
use std::{env, process};
fn main() {
let mut args = env::args();
let prog_path = PathBuf::from(args.next()
.expect("program name is always the first arg"));
let Some(dictionary_file) = args.next() else {
println!("Not enough args");
print_usage(&prog_path);
process::exit(1);
};
let Some(text_file) = args.next() else {
println!("Not enough args");
print_usage(&prog_path);
process::exit(1);
};
if args.next().is_some() {
println!("Too many args");
print_usage(&prog_path);
process::exit(1);
}
}
In Rust, Args is
an iterator that iterates over each argument passed to the program. The first
is always that path/name of the program being run, so we can expect it to always
be there. Then we try to get the next two arguments. If an argument is not
found when we call args.next(), the else block is called where we give a brief
error message and print the proper usage. The last part of our argument parsing
checks for an additional argument and if it does, it quits early with a similar
message.
Now, let's build our Trie with all the words in the dictionary.
// examples/spellcheck.rs
use std::fs::read_to_string;
use structs_from_scratch::trie::Trie;
fn main() {
/* arg parsing code ... */
let dict_raw = read_to_string(dictionary_file).unwrap();
let mut dict = Trie::new();
for word in dict_raw.split_whitespace() {
dict.insert(word, ());
}
}
The dictionary building is pretty simple, read the dictionary_file to a string
then iterator over the words based off of whitespace (this way file can be a
list of words separated by spaces, tabs, and/or newlines). Then for each word,
we insert into the dictionary with the key word and value () (again a zero
sized type, so it lets the Option act like our original bool implementation).
Once our dictionary is built we have to read through our input file.
// examples/spellcheck.rs
fn main() {
/* snip code ... */
let file = read_to_string(text_file).unwrap();
for word in file.split_whitespace() {
let trimmed = word
.trim_matches(|c| match c {
'-' | '`' | '\'' | '"' | '(' | ')' => true,
_ => false,
})
.to_lowercase();
if trimmed.is_empty() {
continue;
}
let alpha_only: Vec<&str> = trimmed.matches(|c| matches!(c, 'a'..='z')).collect();
if alpha_only.len() == trimmed.len() {
if !dict.contains(&trimmed) {
println!("{trimmed:?} misspelt");
}
}
}
}
Getting the word out of the file is a little bit interesting. We don't want to
spell check a word that isn't made up of characters other than a-z. One thing to
keep in mind, is that some words may be at the beginning of a quote, paranthesis,
or another character. So we use trim_matches to trim those characters if they
are found. So, if a word starts or ends with '-', '`', ''', '"', '(', ')' they
get trimmed out. We do a check to see if our trimmed string is empty. We could
have empty quotes, a random '-', or other weird instances, so we want to double
check those. Last thing we do before we pass the value to our dictionary is to
make sure our word is made of a-z characters. Doesn't have an is_alpha for
strings, so we filter the trimmed string to just a-z characters and compare
their length to the trimmed string.
Now we have a string that is valid for out dictionary lookup, we see if the
trimmed word is a key in our dictionary dict.contains(&trimmed). If it's not
there, we print out the word saying it's misspelt.
Running our program
I'm going to run our program on this article to see what I may have misspelt.
Command:
cargo run --example spellcheck -- .\words_alpha.txt ..\..\content\data_structs\tries-p1.md
Results:
"structs" misspelt
"trie" misspelt
"trie" misspelt
"toc" misspelt
"trie" misspelt
"trie" misspelt
"spellcheck" misspelt
"adt" misspelt
"trie" misspelt
"trie" misspelt
"trienode" misspelt
"adt" misspelt
"trie" misspelt
"trie" misspelt
"trie" misspelt
"trienode" misspelt
"trie" misspelt
"impl" misspelt
"trie" misspelt
"impl" misspelt
"trienode" misspelt
"trienode" misspelt
"impl" misspelt
"trie" misspelt
"trie" misspelt
"idx" misspelt
"alloced" misspelt
"impl" misspelt
"trie" misspelt
"idx" misspelt
"trie" misspelt
"trie" misspelt
"trie" misspelt
"ub" misspelt
"rustup" misspelt
"toolchain" misspelt
"dealloc" misspelt
"ub" misspelt
"impl" misspelt
"trienode" misspelt
"trienode" misspelt
"trie" misspelt
"abitrary" misspelt
"trie" misspelt
"trienode" misspelt
"adt" misspelt
"trie" misspelt
"alloced" misspelt
"trie" misspelt
"trie" misspelt
"trie" misspelt
"trienode" misspelt
"trie" misspelt
"spellcheck" misspelt
"exe" misspelt
"pathbuf" misspelt
"args" misspelt
"trie" misspelt
"seperated" misspelt
"dicionary" misspelt
"spellcheck" misspelt
"spellcheck" misspelt
It works!
So, there's a few words out dictionary file does have that we're using, some are short hands we use ("exe", "UB", "ADT", "idx"). It is parsing our rust sections, so that's expected. It also caught a few misspelt things I'll have to fix before I publish this!
What's next?
We have the basic data structure working, but it can be further improved. For
example, there's no way to remove keys/values so anything we add is in our Trie
until it's dropped. A pain point that can be seen in our spell check is that in
order to use our Trie as a set, we specify () as the value type.
We can create a new type and interface for the Trie so that it more closely
acts as a set. Rust also has a bunch of helpful traits that we could implement
for our type. There's traits like Iter to iterator over keys/values, Clone
to create a deep copy of our data structure, Extend to combines one Trie's
keys/values into another one.
There are many other ways to improve out implementation, but this is a brief overview. We'll save these additions for later parts of this article, so if you just wanted to understand the data struture you can stop here.
Part 2 isn't fully written yet, but you can check back later for when it's published. This link will point to the article when its done.