Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ serde = {version = "1.0", features = ["derive"]}
bincode = "2.0.1"
rustc-hash = "2.1.0"
regex-automata = "0.4.9"
flate2 = "1.1.5"

# Below are fragile dependencies, even minor updates of which often break the code
[dependencies.hf-hub]
Expand Down
111 changes: 111 additions & 0 deletions INDEX_BINARY_FORMAT.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Index Binary Format Specification

This document describes the binary format used for serializing and deserializing the `Index` structure.

## Overview

The Index is saved as a compressed binary file using gzip compression. The uncompressed data follows a structured format with fixed-size fields for efficient storage and retrieval.

## Binary Format Structure

All multi-byte integers are stored in **little-endian** format.

### Header Section

| Offset | Size (bits) | Field | Description |
|--------|-------------|-------|-------------|
| 0 | 32 | vocab_size | Size of the vocabulary used to build the index |
| 4 | 32 | eos_token_id | Token ID reserved for the end-of-sequence token |
| 8 | 32 | initial_state_id | ID of the initial state in the automaton |
| 12 | 32 | num_final_states | Number of final (accepting) states |

### Final States Section

Starting at offset 16, this section contains the IDs of all final states.

| Size (bits) | Field | Description |
|-------------|-------|-------------|
| 32 × num_final_states | final_state_ids | Array of final state IDs |

### Index Type

| Size (bits) | Field | Description |
|-------------|-------|-------------|
| 8 | index_type | Type identifier for the index format (currently only type 1 is supported) |

### Transitions Section (Type 1)

The format of this section depends on the index type. For type 1:

#### States Header

| Size (bits) | Field | Description |
|-------------|-------|-------------|
| 32 | num_states | Number of states with transitions |

#### For Each State

For each of the `num_states` states:

| Size (bits) | Field | Description |
|-------------|-------|-------------|
| 32 | state_id | ID of the current state |
| 32 | num_transitions | Number of transitions from this state |

#### For Each Transition

For each of the `num_transitions` transitions in a state:

| Size (bits) | Field | Description |
|-------------|-------|-------------|
| 32 | token_id | Token ID that triggers this transition |
| 32 | next_state_id | Destination state ID for this transition |

## Compression

The entire binary structure described above is compressed using gzip compression (flate2) with default compression level before being written to disk.

## Example Layout

```
┌─────────────────────────────────────────────────────────┐
│ Compressed File (gzip) │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ Uncompressed Binary Data │ │
│ │ ┌───────────────────────────────────────────────┐ │ │
│ │ │ Header (16 bytes) │ │ │
│ │ │ - vocab_size (4 bytes) │ │ │
│ │ │ - eos_token_id (4 bytes) │ │ │
│ │ │ - initial_state_id (4 bytes) │ │ │
│ │ │ - num_final_states (4 bytes) │ │ │
│ │ └───────────────────────────────────────────────┘ │ │
│ │ ┌───────────────────────────────────────────────┐ │ │
│ │ │ Final States (4 bytes × num_final_states) │ │ │
│ │ └───────────────────────────────────────────────┘ │ │
│ │ ┌───────────────────────────────────────────────┐ │ │
│ │ │ Index Type (1 byte) │ │ │
│ │ └───────────────────────────────────────────────┘ │ │
│ │ ┌───────────────────────────────────────────────┐ │ │
│ │ │ Transitions Section │ │ │
│ │ │ - num_states (4 bytes) │ │ │
│ │ │ - For each state: │ │ │
│ │ │ - state_id (4 bytes) │ │ │
│ │ │ - num_transitions (4 bytes) │ │ │
│ │ │ - For each transition: │ │ │
│ │ │ - token_id (4 bytes) │ │ │
│ │ │ - next_state_id (4 bytes) │ │ │
│ │ └───────────────────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────┘
```

## Version History

- **Type 1**: Initial format supporting basic state transitions with token-to-state mappings.

## Future Extensions

The index type field allows for future extensions of the format. New index types can be added to support:
- Optimized storage formats for sparse or dense transition tables
- Compressed transition representations
- Alternative state machine encodings
2 changes: 2 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ pub enum Error {
error_state: u32,
missing_tokens: Vec<String>,
},
#[error("IO error: {0}")]
IOError(String),
}

impl Error {
Expand Down
240 changes: 240 additions & 0 deletions src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,142 @@ impl Index {
})
}

pub fn save<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
use std::io::Write;
let mut buffer = Vec::new();

// Write vocab_size (32 bits)
buffer.extend_from_slice(&(self.vocab_size as u32).to_le_bytes());

// Write eos_token_id (32 bits)
buffer.extend_from_slice(&self.eos_token_id.to_le_bytes());

// Write initial_state_id (32 bits)
buffer.extend_from_slice(&self.initial_state.to_le_bytes());

// Write number of final states (32 bits)
buffer.extend_from_slice(&(self.final_states.len() as u32).to_le_bytes());

// Write final states (32 bits each)
for &final_state in &self.final_states {
buffer.extend_from_slice(&final_state.to_le_bytes());
}

// Write index type (8 bits) - Type 1 for now
buffer.push(1u8);

// Write number of states with transitions (32 bits)
buffer.extend_from_slice(&(self.transitions.len() as u32).to_le_bytes());

// Write transitions for each state
for (&state_id, transitions_map) in &self.transitions {
// Write state ID (32 bits)
buffer.extend_from_slice(&state_id.to_le_bytes());

// Write number of transitions (32 bits)
buffer.extend_from_slice(&(transitions_map.len() as u32).to_le_bytes());

// Write each transition (TokenId -> StateId)
for (&token_id, &next_state_id) in transitions_map {
buffer.extend_from_slice(&token_id.to_le_bytes());
buffer.extend_from_slice(&next_state_id.to_le_bytes());
}
}

// Write compressed data to file
let compressed = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::default());
let mut encoder = compressed;
encoder.write_all(&buffer).map_err(|e| Error::IOError(e.to_string()))?;
let compressed_data = encoder.finish().map_err(|e| Error::IOError(e.to_string()))?;

std::fs::write(path, compressed_data).map_err(|e| Error::IOError(e.to_string()))?;

Ok(())
}

pub fn load<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
use std::io::Read;

// Read and decompress file
let compressed_data = std::fs::read(path).map_err(|e| Error::IOError(e.to_string()))?;
let mut decoder = flate2::read::GzDecoder::new(&compressed_data[..]);
let mut buffer = Vec::new();
decoder.read_to_end(&mut buffer).map_err(|e| Error::IOError(e.to_string()))?;

let mut cursor = 0;

// Helper to read u32
let read_u32 = |buf: &[u8], pos: &mut usize| -> Result<u32> {
if *pos + 4 > buf.len() {
return Err(Error::IOError("Unexpected end of buffer".to_string()));
}
let value = u32::from_le_bytes([buf[*pos], buf[*pos + 1], buf[*pos + 2], buf[*pos + 3]]);
*pos += 4;
Ok(value)
};

// Read vocab_size (32 bits)
let vocab_size = read_u32(&buffer, &mut cursor)? as usize;

// Read eos_token_id (32 bits)
let eos_token_id = read_u32(&buffer, &mut cursor)?;

// Read initial_state_id (32 bits)
let initial_state = read_u32(&buffer, &mut cursor)?;

// Read number of final states (32 bits)
let num_final_states = read_u32(&buffer, &mut cursor)? as usize;

// Read final states
let mut final_states = HashSet::default();
for _ in 0..num_final_states {
let final_state = read_u32(&buffer, &mut cursor)?;
final_states.insert(final_state);
}

// Read index type (8 bits)
if cursor >= buffer.len() {
return Err(Error::IOError("Unexpected end of buffer".to_string()));
}
let index_type = buffer[cursor];
cursor += 1;

if index_type != 1 {
return Err(Error::IOError(format!("Unsupported index type: {}", index_type)));
}

// Read number of states with transitions (32 bits)
let num_states = read_u32(&buffer, &mut cursor)? as usize;

// Read transitions
let mut transitions: HashMap<StateId, HashMap<TokenId, StateId>> = HashMap::default();
for _ in 0..num_states {
// Read state ID (32 bits)
let state_id = read_u32(&buffer, &mut cursor)?;

// Read number of transitions (32 bits)
let num_transitions = read_u32(&buffer, &mut cursor)? as usize;

// Read each transition
let mut state_transitions = HashMap::default();
for _ in 0..num_transitions {
let token_id = read_u32(&buffer, &mut cursor)?;
let next_state_id = read_u32(&buffer, &mut cursor)?;
state_transitions.insert(token_id, next_state_id);
}

transitions.insert(state_id, state_transitions);
}

Ok(Self {
initial_state,
final_states,
transitions,
eos_token_id,
vocab_size,
})
}

/// Returns the ID of the initial state in the automaton.
pub fn initial_state(&self) -> StateId {
self.initial_state
Expand Down Expand Up @@ -391,4 +527,108 @@ mod tests {
panic!("Expected IncompatibleVocabulary error");
}
}

#[test]
fn test_save_and_load() {
let regex = "0|[1-9][0-9]*";
let eos_token_id = 4;
let mut vocabulary = Vocabulary::new(eos_token_id);
for (token, token_id) in [("blah", 0), ("1a", 1), ("2", 2), ("0", 3)] {
vocabulary
.try_insert(token, token_id as u32)
.expect("Insert failed");
}

let original_index = Index::new(regex, &vocabulary).expect("Index failed");

// Save to temporary file
let temp_path = std::env::temp_dir().join("test_index.bin");
original_index.save(&temp_path).expect("Save failed");

// Load from file
let loaded_index = Index::load(&temp_path).expect("Load failed");

// Cleanup
std::fs::remove_file(&temp_path).ok();

// Verify equality
assert_eq!(original_index, loaded_index);
assert_eq!(original_index.initial_state(), loaded_index.initial_state());
assert_eq!(original_index.final_states(), loaded_index.final_states());
assert_eq!(original_index.transitions(), loaded_index.transitions());
assert_eq!(original_index.vocab_size(), loaded_index.vocab_size());
}

#[test]
fn test_save_and_load_multibyte() {
let regex = "😇| [😈-😍][😇-😎]*";
let mut vocabulary = Vocabulary::new(8);
for (token, token_id) in [(" 😍", 5), ("blah", 0), ("😇", 2), ("😈a", 1), ("😍", 3)] {
vocabulary
.try_insert(token, token_id as u32)
.expect("Insert failed");
}
for (token, token_id) in [
(vec![32, 240, 159, 152, 136], 7),
(vec![32, 240, 159, 152, 141], 6),
(vec![240, 159, 152, 141], 4),
] {
vocabulary
.try_insert(token, token_id as u32)
.expect("Insert failed");
}

let original_index = Index::new(regex, &vocabulary).expect("Index failed");

let temp_path = std::env::temp_dir().join("test_index_multibyte.bin");
original_index.save(&temp_path).expect("Save failed");
let loaded_index = Index::load(&temp_path).expect("Load failed");
std::fs::remove_file(&temp_path).ok();

assert_eq!(original_index, loaded_index);
}

#[test]
fn test_load_nonexistent_file() {
let result = Index::load("/nonexistent/path/index.bin");
assert!(result.is_err());
assert!(matches!(result, Err(Error::IOError(_))));
}

#[test]
fn test_load_corrupted_file() {
let temp_path = std::env::temp_dir().join("test_corrupted.bin");
std::fs::write(&temp_path, b"corrupted data").expect("Write failed");

let result = Index::load(&temp_path);
std::fs::remove_file(&temp_path).ok();

assert!(result.is_err());
}

#[test]
fn test_save_preserves_file_size() {
let regex = "0|[1-9][0-9]*";
let mut vocabulary = Vocabulary::new(4);
for (token, token_id) in [("blah", 0), ("1a", 1), ("2", 2), ("0", 3)] {
vocabulary
.try_insert(token, token_id as u32)
.expect("Insert failed");
}

let index = Index::new(regex, &vocabulary).expect("Index failed");
let temp_path = std::env::temp_dir().join("test_size.bin");

index.save(&temp_path).expect("Save failed");
let metadata = std::fs::metadata(&temp_path).expect("Metadata failed");

// File should exist and be non-empty
assert!(metadata.len() > 0);

// Gzip compression should make it smaller than raw data
// Rough estimate: at least 5 * 4 bytes for basic fields + transitions
assert!(metadata.len() < 10000); // Should be much smaller for this simple case

std::fs::remove_file(&temp_path).ok();
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ pub mod prelude;
pub mod primitives;
pub mod vocabulary;


pub use error::{Error, Result};

#[cfg(feature = "python-bindings")]
Expand Down
Loading