1
+ use std:: convert:: TryFrom ;
1
2
use std:: fmt:: { self , Display } ;
2
3
use std:: fs:: File ;
3
4
use std:: io:: { BufReader , Read , Seek , Write } ;
4
5
5
6
use byteorder:: { LittleEndian , ReadBytesExt , WriteBytesExt } ;
6
7
7
- use crate :: io :: { Error , ErrorKind , Result } ;
8
+ use crate :: error :: { Error , Result } ;
8
9
9
10
const MODEL_VERSION : u32 = 0 ;
10
11
@@ -25,39 +26,20 @@ pub enum ChunkIdentifier {
25
26
}
26
27
27
28
impl ChunkIdentifier {
28
- pub fn try_from ( identifier : u32 ) -> Option < Self > {
29
- use self :: ChunkIdentifier :: * ;
30
-
31
- match identifier {
32
- 1 => Some ( SimpleVocab ) ,
33
- 2 => Some ( NdArray ) ,
34
- 3 => Some ( BucketSubwordVocab ) ,
35
- 4 => Some ( QuantizedArray ) ,
36
- 5 => Some ( Metadata ) ,
37
- 6 => Some ( NdNorms ) ,
38
- 7 => Some ( FastTextSubwordVocab ) ,
39
- 8 => Some ( ExplicitSubwordVocab ) ,
40
- _ => None ,
41
- }
42
- }
43
-
44
29
/// Read and ensure that the chunk has the given identifier.
45
30
pub fn ensure_chunk_type < R > ( read : & mut R , identifier : ChunkIdentifier ) -> Result < ( ) >
46
31
where
47
32
R : Read ,
48
33
{
49
34
let chunk_id = read
50
35
. read_u32 :: < LittleEndian > ( )
51
- . map_err ( |e| ErrorKind :: io_error ( "Cannot read chunk identifier" , e) ) ?;
52
- let chunk_id = ChunkIdentifier :: try_from ( chunk_id)
53
- . ok_or_else ( || ErrorKind :: Format ( format ! ( "Unknown chunk identifier: {}" , chunk_id) ) )
54
- . map_err ( Error :: from) ?;
36
+ . map_err ( |e| Error :: io_error ( "Cannot read chunk identifier" , e) ) ?;
37
+ let chunk_id = ChunkIdentifier :: try_from ( chunk_id) ?;
55
38
if chunk_id != identifier {
56
- return Err ( ErrorKind :: Format ( format ! (
39
+ return Err ( Error :: Format ( format ! (
57
40
"Invalid chunk identifier, expected: {}, got: {}" ,
58
41
identifier, chunk_id
59
- ) )
60
- . into ( ) ) ;
42
+ ) ) ) ;
61
43
}
62
44
63
45
Ok ( ( ) )
@@ -82,6 +64,26 @@ impl Display for ChunkIdentifier {
82
64
}
83
65
}
84
66
67
+ impl TryFrom < u32 > for ChunkIdentifier {
68
+ type Error = Error ;
69
+
70
+ fn try_from ( identifier : u32 ) -> Result < Self > {
71
+ use self :: ChunkIdentifier :: * ;
72
+
73
+ match identifier {
74
+ 1 => Ok ( SimpleVocab ) ,
75
+ 2 => Ok ( NdArray ) ,
76
+ 3 => Ok ( BucketSubwordVocab ) ,
77
+ 4 => Ok ( QuantizedArray ) ,
78
+ 5 => Ok ( Metadata ) ,
79
+ 6 => Ok ( NdNorms ) ,
80
+ 7 => Ok ( FastTextSubwordVocab ) ,
81
+ 8 => Ok ( ExplicitSubwordVocab ) ,
82
+ unknown => Err ( Error :: UnknownChunkIdentifier ( unknown) ) ,
83
+ }
84
+ }
85
+ }
86
+
85
87
/// Trait defining identifiers for data types.
86
88
pub trait TypeId {
87
89
/// Read and ensure that the data type is equal to `Self`.
@@ -102,14 +104,13 @@ macro_rules! typeid_impl {
102
104
{
103
105
let type_id = read
104
106
. read_u32:: <LittleEndian >( )
105
- . map_err( |e| ErrorKind :: io_error( "Cannot read type identifier" , e) ) ?;
107
+ . map_err( |e| Error :: io_error( "Cannot read type identifier" , e) ) ?;
106
108
if type_id != Self :: type_id( ) {
107
- return Err ( ErrorKind :: Format ( format!(
109
+ return Err ( Error :: Format ( format!(
108
110
"Invalid type, expected: {}, got: {}" ,
109
111
Self :: type_id( ) ,
110
112
type_id
111
- ) )
112
- . into( ) ) ;
113
+ ) ) ) ;
113
114
}
114
115
115
116
Ok ( ( ) )
@@ -183,18 +184,18 @@ impl WriteChunk for Header {
183
184
{
184
185
write
185
186
. write_all ( & MAGIC )
186
- . map_err ( |e| ErrorKind :: io_error ( "Cannot write magic" , e) ) ?;
187
+ . map_err ( |e| Error :: io_error ( "Cannot write magic" , e) ) ?;
187
188
write
188
189
. write_u32 :: < LittleEndian > ( MODEL_VERSION )
189
- . map_err ( |e| ErrorKind :: io_error ( "Cannot write model version" , e) ) ?;
190
+ . map_err ( |e| Error :: io_error ( "Cannot write model version" , e) ) ?;
190
191
write
191
192
. write_u32 :: < LittleEndian > ( self . chunk_identifiers . len ( ) as u32 )
192
- . map_err ( |e| ErrorKind :: io_error ( "Cannot write chunk identifiers length" , e) ) ?;
193
+ . map_err ( |e| Error :: io_error ( "Cannot write chunk identifiers length" , e) ) ?;
193
194
194
195
for & identifier in & self . chunk_identifiers {
195
196
write
196
197
. write_u32 :: < LittleEndian > ( identifier as u32 )
197
- . map_err ( |e| ErrorKind :: io_error ( "Cannot write chunk identifier" , e) ) ?;
198
+ . map_err ( |e| Error :: io_error ( "Cannot write chunk identifier" , e) ) ?;
198
199
}
199
200
200
201
Ok ( ( ) )
@@ -209,40 +210,36 @@ impl ReadChunk for Header {
209
210
// Magic and version ceremony.
210
211
let mut magic = [ 0u8 ; 4 ] ;
211
212
read. read_exact ( & mut magic)
212
- . map_err ( |e| ErrorKind :: io_error ( "Cannot read magic" , e) ) ?;
213
+ . map_err ( |e| Error :: io_error ( "Cannot read magic" , e) ) ?;
213
214
214
215
if magic != MAGIC {
215
- return Err ( ErrorKind :: Format ( format ! (
216
+ return Err ( Error :: Format ( format ! (
216
217
"Expected 'FiFu' as magic, got: {}" ,
217
218
String :: from_utf8_lossy( & magic) . into_owned( )
218
- ) )
219
- . into ( ) ) ;
219
+ ) ) ) ;
220
220
}
221
221
222
222
let version = read
223
223
. read_u32 :: < LittleEndian > ( )
224
- . map_err ( |e| ErrorKind :: io_error ( "Cannot read model version" , e) ) ?;
224
+ . map_err ( |e| Error :: io_error ( "Cannot read model version" , e) ) ?;
225
225
if version != MODEL_VERSION {
226
- return Err (
227
- ErrorKind :: Format ( format ! ( "Unknown finalfusion version: {}" , version) ) . into ( ) ,
228
- ) ;
226
+ return Err ( Error :: Format ( format ! (
227
+ "Unknown finalfusion version: {}" ,
228
+ version
229
+ ) ) ) ;
229
230
}
230
231
231
232
// Read chunk identifiers.
232
233
let chunk_identifiers_len = read
233
234
. read_u32 :: < LittleEndian > ( )
234
- . map_err ( |e| ErrorKind :: io_error ( "Cannot read chunk identifiers length" , e) ) ?
235
+ . map_err ( |e| Error :: io_error ( "Cannot read chunk identifiers length" , e) ) ?
235
236
as usize ;
236
237
let mut chunk_identifiers = Vec :: with_capacity ( chunk_identifiers_len) ;
237
238
for _ in 0 ..chunk_identifiers_len {
238
239
let identifier = read
239
240
. read_u32 :: < LittleEndian > ( )
240
- . map_err ( |e| ErrorKind :: io_error ( "Cannot read chunk identifier" , e) ) ?;
241
- let chunk_identifier = ChunkIdentifier :: try_from ( identifier)
242
- . ok_or_else ( || {
243
- ErrorKind :: Format ( format ! ( "Unknown chunk identifier: {}" , identifier) )
244
- } )
245
- . map_err ( Error :: from) ?;
241
+ . map_err ( |e| Error :: io_error ( "Cannot read chunk identifier" , e) ) ?;
242
+ let chunk_identifier = ChunkIdentifier :: try_from ( identifier) ?;
246
243
chunk_identifiers. push ( chunk_identifier) ;
247
244
}
248
245
0 commit comments