Skip to content

Commit 85e1ad1

Browse files
danieldkDaniël de Kok
authored andcommitted
Modernize and improve error handling
- Merge the `Error` and `ErrorKind` enums. - Move the `Error` enum to the `error` module. - Derive trait implementations using the `thiserror` crate. - Make the `Error` enum non-exhaustive - Replace the `ChunkIdentifier::try_from` method by an implementation of the `TryFrom` crate.
1 parent 886b27c commit 85e1ad1

File tree

20 files changed

+366
-420
lines changed

20 files changed

+366
-420
lines changed

.github/workflows/rust.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ jobs:
1010
matrix:
1111
rust:
1212
- stable
13-
- 1.37.0
13+
- 1.40.0
1414
steps:
1515
- uses: actions/checkout@v1
1616
- uses: actions-rs/toolchain@v1
@@ -49,7 +49,7 @@ jobs:
4949
matrix:
5050
rust:
5151
- stable
52-
- 1.37.0
52+
- 1.40.0
5353
steps:
5454
- uses: actions/checkout@v1
5555
- uses: actions-rs/toolchain@v1

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ rand = "0.7"
2727
rand_xorshift = "0.2"
2828
reductive = "0.4"
2929
serde = { version = "1", features = ["derive"] }
30+
thiserror = "1"
3031
toml = "0.5"
3132

3233
[dependencies.memmap]

src/chunks/io.rs

Lines changed: 44 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
use std::convert::TryFrom;
12
use std::fmt::{self, Display};
23
use std::fs::File;
34
use std::io::{BufReader, Read, Seek, Write};
45

56
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
67

7-
use crate::io::{Error, ErrorKind, Result};
8+
use crate::error::{Error, Result};
89

910
const MODEL_VERSION: u32 = 0;
1011

@@ -25,39 +26,20 @@ pub enum ChunkIdentifier {
2526
}
2627

2728
impl ChunkIdentifier {
28-
pub fn try_from(identifier: u32) -> Option<Self> {
29-
use self::ChunkIdentifier::*;
30-
31-
match identifier {
32-
1 => Some(SimpleVocab),
33-
2 => Some(NdArray),
34-
3 => Some(BucketSubwordVocab),
35-
4 => Some(QuantizedArray),
36-
5 => Some(Metadata),
37-
6 => Some(NdNorms),
38-
7 => Some(FastTextSubwordVocab),
39-
8 => Some(ExplicitSubwordVocab),
40-
_ => None,
41-
}
42-
}
43-
4429
/// Read and ensure that the chunk has the given identifier.
4530
pub fn ensure_chunk_type<R>(read: &mut R, identifier: ChunkIdentifier) -> Result<()>
4631
where
4732
R: Read,
4833
{
4934
let chunk_id = read
5035
.read_u32::<LittleEndian>()
51-
.map_err(|e| ErrorKind::io_error("Cannot read chunk identifier", e))?;
52-
let chunk_id = ChunkIdentifier::try_from(chunk_id)
53-
.ok_or_else(|| ErrorKind::Format(format!("Unknown chunk identifier: {}", chunk_id)))
54-
.map_err(Error::from)?;
36+
.map_err(|e| Error::io_error("Cannot read chunk identifier", e))?;
37+
let chunk_id = ChunkIdentifier::try_from(chunk_id)?;
5538
if chunk_id != identifier {
56-
return Err(ErrorKind::Format(format!(
39+
return Err(Error::Format(format!(
5740
"Invalid chunk identifier, expected: {}, got: {}",
5841
identifier, chunk_id
59-
))
60-
.into());
42+
)));
6143
}
6244

6345
Ok(())
@@ -82,6 +64,26 @@ impl Display for ChunkIdentifier {
8264
}
8365
}
8466

67+
impl TryFrom<u32> for ChunkIdentifier {
68+
type Error = Error;
69+
70+
fn try_from(identifier: u32) -> Result<Self> {
71+
use self::ChunkIdentifier::*;
72+
73+
match identifier {
74+
1 => Ok(SimpleVocab),
75+
2 => Ok(NdArray),
76+
3 => Ok(BucketSubwordVocab),
77+
4 => Ok(QuantizedArray),
78+
5 => Ok(Metadata),
79+
6 => Ok(NdNorms),
80+
7 => Ok(FastTextSubwordVocab),
81+
8 => Ok(ExplicitSubwordVocab),
82+
unknown => Err(Error::UnknownChunkIdentifier(unknown)),
83+
}
84+
}
85+
}
86+
8587
/// Trait defining identifiers for data types.
8688
pub trait TypeId {
8789
/// Read and ensure that the data type is equal to `Self`.
@@ -102,14 +104,13 @@ macro_rules! typeid_impl {
102104
{
103105
let type_id = read
104106
.read_u32::<LittleEndian>()
105-
.map_err(|e| ErrorKind::io_error("Cannot read type identifier", e))?;
107+
.map_err(|e| Error::io_error("Cannot read type identifier", e))?;
106108
if type_id != Self::type_id() {
107-
return Err(ErrorKind::Format(format!(
109+
return Err(Error::Format(format!(
108110
"Invalid type, expected: {}, got: {}",
109111
Self::type_id(),
110112
type_id
111-
))
112-
.into());
113+
)));
113114
}
114115

115116
Ok(())
@@ -183,18 +184,18 @@ impl WriteChunk for Header {
183184
{
184185
write
185186
.write_all(&MAGIC)
186-
.map_err(|e| ErrorKind::io_error("Cannot write magic", e))?;
187+
.map_err(|e| Error::io_error("Cannot write magic", e))?;
187188
write
188189
.write_u32::<LittleEndian>(MODEL_VERSION)
189-
.map_err(|e| ErrorKind::io_error("Cannot write model version", e))?;
190+
.map_err(|e| Error::io_error("Cannot write model version", e))?;
190191
write
191192
.write_u32::<LittleEndian>(self.chunk_identifiers.len() as u32)
192-
.map_err(|e| ErrorKind::io_error("Cannot write chunk identifiers length", e))?;
193+
.map_err(|e| Error::io_error("Cannot write chunk identifiers length", e))?;
193194

194195
for &identifier in &self.chunk_identifiers {
195196
write
196197
.write_u32::<LittleEndian>(identifier as u32)
197-
.map_err(|e| ErrorKind::io_error("Cannot write chunk identifier", e))?;
198+
.map_err(|e| Error::io_error("Cannot write chunk identifier", e))?;
198199
}
199200

200201
Ok(())
@@ -209,40 +210,36 @@ impl ReadChunk for Header {
209210
// Magic and version ceremony.
210211
let mut magic = [0u8; 4];
211212
read.read_exact(&mut magic)
212-
.map_err(|e| ErrorKind::io_error("Cannot read magic", e))?;
213+
.map_err(|e| Error::io_error("Cannot read magic", e))?;
213214

214215
if magic != MAGIC {
215-
return Err(ErrorKind::Format(format!(
216+
return Err(Error::Format(format!(
216217
"Expected 'FiFu' as magic, got: {}",
217218
String::from_utf8_lossy(&magic).into_owned()
218-
))
219-
.into());
219+
)));
220220
}
221221

222222
let version = read
223223
.read_u32::<LittleEndian>()
224-
.map_err(|e| ErrorKind::io_error("Cannot read model version", e))?;
224+
.map_err(|e| Error::io_error("Cannot read model version", e))?;
225225
if version != MODEL_VERSION {
226-
return Err(
227-
ErrorKind::Format(format!("Unknown finalfusion version: {}", version)).into(),
228-
);
226+
return Err(Error::Format(format!(
227+
"Unknown finalfusion version: {}",
228+
version
229+
)));
229230
}
230231

231232
// Read chunk identifiers.
232233
let chunk_identifiers_len = read
233234
.read_u32::<LittleEndian>()
234-
.map_err(|e| ErrorKind::io_error("Cannot read chunk identifiers length", e))?
235+
.map_err(|e| Error::io_error("Cannot read chunk identifiers length", e))?
235236
as usize;
236237
let mut chunk_identifiers = Vec::with_capacity(chunk_identifiers_len);
237238
for _ in 0..chunk_identifiers_len {
238239
let identifier = read
239240
.read_u32::<LittleEndian>()
240-
.map_err(|e| ErrorKind::io_error("Cannot read chunk identifier", e))?;
241-
let chunk_identifier = ChunkIdentifier::try_from(identifier)
242-
.ok_or_else(|| {
243-
ErrorKind::Format(format!("Unknown chunk identifier: {}", identifier))
244-
})
245-
.map_err(Error::from)?;
241+
.map_err(|e| Error::io_error("Cannot read chunk identifier", e))?;
242+
let chunk_identifier = ChunkIdentifier::try_from(identifier)?;
246243
chunk_identifiers.push(chunk_identifier);
247244
}
248245

src/chunks/metadata.rs

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@ use std::ops::{Deref, DerefMut};
66
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
77
use toml::Value;
88

9-
use super::io::{ChunkIdentifier, Header, ReadChunk, WriteChunk};
10-
use crate::io::{Error, ErrorKind, ReadMetadata, Result};
9+
use crate::chunks::io::{ChunkIdentifier, Header, ReadChunk, WriteChunk};
10+
use crate::error::{Error, Result};
11+
use crate::io::ReadMetadata;
1112

1213
/// Embeddings metadata.
1314
///
@@ -52,23 +53,22 @@ impl ReadChunk for Metadata {
5253
ChunkIdentifier::ensure_chunk_type(read, ChunkIdentifier::Metadata)?;
5354

5455
// Read chunk length.
55-
let chunk_len = read
56-
.read_u64::<LittleEndian>()
57-
.map_err(|e| ErrorKind::io_error("Cannot read chunk length", e))?
58-
as usize;
56+
let chunk_len =
57+
read.read_u64::<LittleEndian>()
58+
.map_err(|e| Error::io_error("Cannot read chunk length", e))? as usize;
5959

6060
// Read TOML data.
6161
let mut buf = vec![0; chunk_len];
6262
read.read_exact(&mut buf)
63-
.map_err(|e| ErrorKind::io_error("Cannot read TOML metadata", e))?;
63+
.map_err(|e| Error::io_error("Cannot read TOML metadata", e))?;
6464
let buf_str = String::from_utf8(buf)
65-
.map_err(|e| ErrorKind::Format(format!("TOML metadata contains invalid UTF-8: {}", e)))
65+
.map_err(|e| Error::Format(format!("TOML metadata contains invalid UTF-8: {}", e)))
6666
.map_err(Error::from)?;
6767

6868
Ok(Metadata::new(
6969
buf_str
7070
.parse::<Value>()
71-
.map_err(|e| ErrorKind::Format(format!("Cannot deserialize TOML metadata: {}", e)))
71+
.map_err(|e| Error::Format(format!("Cannot deserialize TOML metadata: {}", e)))
7272
.map_err(Error::from)?,
7373
))
7474
}
@@ -87,13 +87,13 @@ impl WriteChunk for Metadata {
8787

8888
write
8989
.write_u32::<LittleEndian>(self.chunk_identifier() as u32)
90-
.map_err(|e| ErrorKind::io_error("Cannot write metadata chunk identifier", e))?;
90+
.map_err(|e| Error::io_error("Cannot write metadata chunk identifier", e))?;
9191
write
9292
.write_u64::<LittleEndian>(metadata_str.len() as u64)
93-
.map_err(|e| ErrorKind::io_error("Cannot write metadata length", e))?;
93+
.map_err(|e| Error::io_error("Cannot write metadata length", e))?;
9494
write
9595
.write_all(metadata_str.as_bytes())
96-
.map_err(|e| ErrorKind::io_error("Cannot write metadata", e))?;
96+
.map_err(|e| Error::io_error("Cannot write metadata", e))?;
9797

9898
Ok(())
9999
}
@@ -108,9 +108,9 @@ impl ReadMetadata for Option<Metadata> {
108108
let chunks = header.chunk_identifiers();
109109

110110
if chunks.is_empty() {
111-
return Err(
112-
ErrorKind::Format(String::from("Embedding file does not contain chunks")).into(),
113-
);
111+
return Err(Error::Format(String::from(
112+
"Embedding file does not contain chunks",
113+
)));
114114
}
115115

116116
if header.chunk_identifiers()[0] == ChunkIdentifier::Metadata {

src/chunks/norms.rs

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ use std::ops::Deref;
77
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
88
use ndarray::Array1;
99

10-
use super::io::{ChunkIdentifier, ReadChunk, TypeId, WriteChunk};
11-
use crate::io::{ErrorKind, Result};
10+
use crate::chunks::io::{ChunkIdentifier, ReadChunk, TypeId, WriteChunk};
11+
use crate::error::{Error, Result};
1212
use crate::util::padding;
1313

1414
/// Chunk for storing embedding l2 norms.
@@ -58,24 +58,25 @@ impl ReadChunk for NdNorms {
5858

5959
// Read and discard chunk length.
6060
read.read_u64::<LittleEndian>()
61-
.map_err(|e| ErrorKind::io_error("Cannot read norms chunk length", e))?;
61+
.map_err(|e| Error::io_error("Cannot read norms chunk length", e))?;
6262

6363
let len = read
6464
.read_u64::<LittleEndian>()
65-
.map_err(|e| ErrorKind::io_error("Cannot read norms vector length", e))?
65+
.map_err(|e| Error::io_error("Cannot read norms vector length", e))?
6666
as usize;
6767

6868
f32::ensure_data_type(read)?;
6969

70-
let n_padding = padding::<f32>(read.seek(SeekFrom::Current(0)).map_err(|e| {
71-
ErrorKind::io_error("Cannot get file position for computing padding", e)
72-
})?);
70+
let n_padding =
71+
padding::<f32>(read.seek(SeekFrom::Current(0)).map_err(|e| {
72+
Error::io_error("Cannot get file position for computing padding", e)
73+
})?);
7374
read.seek(SeekFrom::Current(n_padding as i64))
74-
.map_err(|e| ErrorKind::io_error("Cannot skip padding", e))?;
75+
.map_err(|e| Error::io_error("Cannot skip padding", e))?;
7576

7677
let mut data = vec![0f32; len];
7778
read.read_f32_into::<LittleEndian>(&mut data)
78-
.map_err(|e| ErrorKind::io_error("Cannot read norms", e))?;
79+
.map_err(|e| Error::io_error("Cannot read norms", e))?;
7980

8081
Ok(NdNorms::new(data))
8182
}
@@ -92,10 +93,11 @@ impl WriteChunk for NdNorms {
9293
{
9394
write
9495
.write_u32::<LittleEndian>(ChunkIdentifier::NdNorms as u32)
95-
.map_err(|e| ErrorKind::io_error("Cannot write norms chunk identifier", e))?;
96-
let n_padding = padding::<f32>(write.seek(SeekFrom::Current(0)).map_err(|e| {
97-
ErrorKind::io_error("Cannot get file position for computing padding", e)
98-
})?);
96+
.map_err(|e| Error::io_error("Cannot write norms chunk identifier", e))?;
97+
let n_padding =
98+
padding::<f32>(write.seek(SeekFrom::Current(0)).map_err(|e| {
99+
Error::io_error("Cannot get file position for computing padding", e)
100+
})?);
99101

100102
// Chunk size: len (u64), type id (u32), padding ([0,4) bytes), vector.
101103
let chunk_len = size_of::<u64>()
@@ -104,23 +106,23 @@ impl WriteChunk for NdNorms {
104106
+ (self.len() * size_of::<f32>());
105107
write
106108
.write_u64::<LittleEndian>(chunk_len as u64)
107-
.map_err(|e| ErrorKind::io_error("Cannot write norms chunk length", e))?;
109+
.map_err(|e| Error::io_error("Cannot write norms chunk length", e))?;
108110
write
109111
.write_u64::<LittleEndian>(self.len() as u64)
110-
.map_err(|e| ErrorKind::io_error("Cannot write norms vector length", e))?;
112+
.map_err(|e| Error::io_error("Cannot write norms vector length", e))?;
111113
write
112114
.write_u32::<LittleEndian>(f32::type_id())
113-
.map_err(|e| ErrorKind::io_error("Cannot write norms vector type identifier", e))?;
115+
.map_err(|e| Error::io_error("Cannot write norms vector type identifier", e))?;
114116

115117
let padding = vec![0; n_padding as usize];
116118
write
117119
.write_all(&padding)
118-
.map_err(|e| ErrorKind::io_error("Cannot write padding", e))?;
120+
.map_err(|e| Error::io_error("Cannot write padding", e))?;
119121

120122
for &val in self.iter() {
121123
write
122124
.write_f32::<LittleEndian>(val)
123-
.map_err(|e| ErrorKind::io_error("Cannot write norm", e))?;
125+
.map_err(|e| Error::io_error("Cannot write norm", e))?;
124126
}
125127

126128
Ok(())

0 commit comments

Comments
 (0)