diff --git a/backends/ort/src/lib.rs b/backends/ort/src/lib.rs index 2d6384e3..1f794e62 100644 --- a/backends/ort/src/lib.rs +++ b/backends/ort/src/lib.rs @@ -1,7 +1,8 @@ use ndarray::{s, Axis}; use nohash_hasher::BuildNoHashHasher; -use ort::session::{builder::GraphOptimizationLevel, Session}; +use ort::session::{builder::GraphOptimizationLevel, Session, SessionInputValue}; use serde::Deserialize; +use std::borrow::Cow; use std::collections::HashMap; use std::ops::{Div, Mul}; use std::path::Path; @@ -11,21 +12,49 @@ use text_embeddings_backend_core::{ }; #[derive(Debug, Clone, Deserialize)] -pub struct PastKeyValuesConfig { +pub struct Config { + pub pad_token_id: Option, + pub eos_token_id: Option, + // NOTE: the fields below are only required when the ONNX model expects the `past_key_values` + // as input i.e., whenever the ONNX model has been exported with optimized MHA nodes pub hidden_size: usize, pub num_hidden_layers: usize, - pub num_key_value_heads: usize, + pub num_key_value_heads: Option, +} + +#[derive(Debug, Clone, Default, Deserialize)] +#[serde(rename_all = "lowercase")] +enum PaddingSide { + Left, + #[default] + Right, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct TokenizerConfig { + #[serde(default)] + padding_side: PaddingSide, +} + +struct ModelInputs { + pub input_ids: ndarray::Array2, + pub attention_mask: ndarray::Array2, + pub token_type_ids: Option>, + pub position_ids: Option>, + pub input_lengths: ndarray::Array1, + pub past_key_values: Option, ndarray::Array4)>>, } pub struct OrtBackend { session: Mutex, + config: Config, + tokenizer_config: TokenizerConfig, token_type_ids: bool, // NOTE: required since the key can either be `token_type_ids` or `input_type` token_type_ids_key: String, position_ids: bool, past_key_values: bool, - past_key_values_config: Option, pool: Pool, } @@ -62,6 +91,27 @@ impl OrtBackend { } }; + let config: Config = { + let content = std::fs::read_to_string(model_path.join("config.json")) + .map_err(|e| BackendError::Start(format!("Failed to read `config.json`: {}", e)))?; + serde_json::from_str(&content) + .map_err(|e| BackendError::Start(format!("Failed to parse `config.json`: {}", e)))? + }; + + let tokenizer_config_path = model_path.join("tokenizer_config.json"); + let tokenizer_config: TokenizerConfig = if tokenizer_config_path.exists() { + let content = std::fs::read_to_string(&tokenizer_config_path).map_err(|e| { + BackendError::Start(format!("Failed to read `tokenizer_config.json`: {}", e)) + })?; + serde_json::from_str(&content).map_err(|e| { + BackendError::Start(format!("Failed to parse `tokenizer_config.json`: {}", e)) + })? + } else { + TokenizerConfig { + padding_side: PaddingSide::default(), + } + }; + let session = Session::builder() .s()? .with_intra_threads(num_cpus::get()) @@ -94,175 +144,262 @@ impl OrtBackend { } } - let past_key_values_config = match past_key_values { - true => { - let path = model_path.join("config.json"); - if !path.exists() { - return Err(BackendError::Start(format!( - "config.json not found at {:?}", - path - ))); - } - let content = std::fs::read_to_string(path).map_err(|e| { - BackendError::Start(format!("Failed to read config.json: {}", e)) - })?; - Some( - serde_json::from_str::(&content).map_err(|e| { - BackendError::Start(format!("Failed to parse config.json: {}", e)) - })?, - ) - } - false => None, - }; - Ok(Self { session: Mutex::new(session), + config, + tokenizer_config, token_type_ids, token_type_ids_key, position_ids, past_key_values, - past_key_values_config, pool, }) } } -impl Backend for OrtBackend { - fn max_batch_size(&self) -> Option { - Some(8) - } - - fn health(&self) -> Result<(), BackendError> { - Ok(()) - } - - fn is_padded(&self) -> bool { - true - } - - fn embed(&self, batch: Batch) -> Result { +impl OrtBackend { + fn prepare_inputs( + &self, + batch: &Batch, + padding_side: &PaddingSide, + ) -> Result<(ModelInputs, bool), BackendError> { let batch_size = batch.len(); let max_length = batch.max_length as usize; + let elems = batch_size * max_length; - // Whether a least one of the request in the batch is padded - let mut masking = false; - - let (input_ids, token_type_ids, input_lengths, attention_mask, position_ids) = { - let elems = batch_size * max_length; + let pad_token_id = self + .config + .pad_token_id + .unwrap_or(self.config.eos_token_id.unwrap_or(0)) as i64; + let (input_ids, attention_mask, token_type_ids, position_ids, input_lengths, masking) = if batch_size > 1 { - // Prepare padded batch let mut input_ids = Vec::with_capacity(elems); - let mut token_type_ids = Vec::with_capacity(elems); let mut attention_mask = Vec::with_capacity(elems); + let mut token_type_ids = Vec::with_capacity(elems); let mut position_ids = Vec::with_capacity(elems); let mut input_lengths = Vec::with_capacity(batch_size); - for i in 0..batch_size { - let start = batch.cumulative_seq_lengths[i] as usize; - let end = batch.cumulative_seq_lengths[i + 1] as usize; - let seq_length = (end - start) as u32; - input_lengths.push(seq_length as f32); - - // Copy values - for (pos, j) in (start..end).enumerate() { - input_ids.push(batch.input_ids[j] as i64); - token_type_ids.push(batch.token_type_ids[j] as i64); - attention_mask.push(1_i64); - position_ids.push(pos as i64); - } + // Whether at least one of the request in the batch is padded + let mut masking = false; + + match padding_side { + PaddingSide::Right => { + for i in 0..batch_size { + let start = batch.cumulative_seq_lengths[i] as usize; + let end = batch.cumulative_seq_lengths[i + 1] as usize; + let seq_length = (end - start) as u32; + input_lengths.push(seq_length as f32); + + for (pos, j) in (start..end).enumerate() { + input_ids.push(batch.input_ids[j] as i64); + attention_mask.push(1_i64); + token_type_ids.push(batch.token_type_ids[j] as i64); + position_ids.push(pos as i64); + } + + let padding = batch.max_length - seq_length; + if padding > 0 { + // NOTE: Set `masking=true` to use attention mask as not all the + // sequences in the batch have the same length + masking = true; + for pad_pos in 0..padding { + input_ids.push(pad_token_id); + attention_mask.push(0_i64); + token_type_ids.push(0); + position_ids.push((seq_length + pad_pos) as i64); + } + } + } - // Add padding if needed - let padding = batch.max_length - seq_length; - if padding > 0 { - // Set bool to use attention mask - masking = true; - for pad_pos in 0..padding { - input_ids.push(0); - token_type_ids.push(0); - attention_mask.push(0_i64); - position_ids.push((seq_length + pad_pos) as i64); + ( + input_ids, + attention_mask, + token_type_ids, + position_ids, + input_lengths, + masking, + ) + } + PaddingSide::Left => { + for i in 0..batch_size { + let start = batch.cumulative_seq_lengths[i] as usize; + let end = batch.cumulative_seq_lengths[i + 1] as usize; + let seq_length = (end - start) as u32; + input_lengths.push(seq_length as f32); + + let padding = batch.max_length - seq_length; + if padding > 0 { + // NOTE: Set `masking=true` to use attention mask as not all the + // sequences in the batch have the same length + masking = true; + for _ in 0..padding { + input_ids.push(pad_token_id); + attention_mask.push(0_i64); + token_type_ids.push(0); + position_ids.push(0); + } + } + + for (pos, j) in (start..end).enumerate() { + input_ids.push(batch.input_ids[j] as i64); + attention_mask.push(1_i64); + token_type_ids.push(batch.token_type_ids[j] as i64); + position_ids.push((padding + pos as u32) as i64); + } } + + ( + input_ids, + attention_mask, + token_type_ids, + position_ids, + input_lengths, + masking, + ) } } - ( - input_ids, - token_type_ids, - input_lengths, - attention_mask, - position_ids, - ) } else { let attention_mask = vec![1_i64; elems]; let position_ids: Vec = (0..max_length as i64).collect(); ( - batch.input_ids.into_iter().map(|v| v as i64).collect(), - batch.token_type_ids.into_iter().map(|v| v as i64).collect(), - vec![batch.max_length as f32], + batch.input_ids.iter().map(|v| *v as i64).collect(), attention_mask, + batch.token_type_ids.iter().map(|v| *v as i64).collect(), position_ids, + vec![batch.max_length as f32], + // NOTE: no need to mask the inputs when the batch only contains one element + false, ) - } - }; + }; - // Create ndarrays let input_ids = ndarray::Array2::from_shape_vec((batch_size, max_length), input_ids).e()?; let attention_mask = ndarray::Array2::from_shape_vec((batch_size, max_length), attention_mask).e()?; - let position_ids = - ndarray::Array2::from_shape_vec((batch_size, max_length), position_ids).e()?; + + let token_type_ids = if self.token_type_ids { + Some(ndarray::Array2::from_shape_vec((batch_size, max_length), token_type_ids).e()?) + } else { + None + }; + + let position_ids = if self.position_ids { + Some(ndarray::Array2::from_shape_vec((batch_size, max_length), position_ids).e()?) + } else { + None + }; + let input_lengths = ndarray::Array1::from_vec(input_lengths); - let inputs = { - let mut inputs = ort::inputs![ - "input_ids" => ort::value::Tensor::from_array(input_ids).e()?, - "attention_mask" => ort::value::Tensor::from_array(attention_mask.clone()).e()?, - ]; - - if self.token_type_ids { - let token_type_ids_tensor = - ndarray::Array2::from_shape_vec((batch_size, max_length), token_type_ids) - .e()?; - let token_type_ids_value = - ort::value::Tensor::from_array(token_type_ids_tensor).e()?; + let past_key_values = if self.past_key_values { + let hidden_size = self.config.hidden_size; + let num_hidden_layers = self.config.num_hidden_layers; + let num_key_value_heads = self + .config + .num_key_value_heads + .unwrap_or(self.config.num_hidden_layers); + let head_size = hidden_size / num_key_value_heads; + let mut arrays = Vec::new(); + + for _ in 0..num_hidden_layers { + let shape = (batch_size, num_key_value_heads, 0, head_size); + let key_array = ndarray::Array4::::zeros(shape); + let value_array = ndarray::Array4::::zeros(shape); + arrays.push((key_array, value_array)); + } + + Some(arrays) + } else { + None + }; + + Ok(( + ModelInputs { + input_ids, + attention_mask, + token_type_ids, + position_ids, + input_lengths, + past_key_values, + }, + masking, + )) + } + + fn prepare_ort_inputs( + &self, + input_ids: ndarray::Array2, + attention_mask: ndarray::Array2, + token_type_ids: Option>, + position_ids: Option>, + past_key_values: Option, ndarray::Array4)>>, + ) -> Result, SessionInputValue<'_>)>, BackendError> { + let mut inputs = ort::inputs![ + "input_ids" => ort::value::Tensor::from_array(input_ids).e()?, + "attention_mask" => ort::value::Tensor::from_array(attention_mask).e()?, + ]; + + if let Some(token_type_ids) = token_type_ids { + let token_type_ids = ort::value::Tensor::from_array(token_type_ids).e()?; + inputs.push(( + self.token_type_ids_key.clone().into(), + token_type_ids.into(), + )); + } + + if let Some(position_ids) = position_ids { + let position_ids = ort::value::Tensor::from_array(position_ids).e()?; + inputs.push(("position_ids".into(), position_ids.into())); + } + + if let Some(past_key_values) = past_key_values { + for (layer_idx, (key, value)) in past_key_values.into_iter().enumerate() { + let key = ort::value::Tensor::from_array(key).e()?; + let value = ort::value::Tensor::from_array(value).e()?; + inputs.push(( - self.token_type_ids_key.clone().into(), - token_type_ids_value.into(), + format!("past_key_values.{}.key", layer_idx).into(), + key.into(), + )); + inputs.push(( + format!("past_key_values.{}.value", layer_idx).into(), + value.into(), )); } + } - if self.position_ids { - let position_ids_value = ort::value::Tensor::from_array(position_ids).e()?; - inputs.push(("position_ids".into(), position_ids_value.into())); - } + Ok(inputs) + } +} - if self.past_key_values { - let config = self.past_key_values_config.as_ref().unwrap(); - let head_size = config.hidden_size / config.num_key_value_heads; - - for i in 0..config.num_hidden_layers { - let key_shape = (batch_size, config.num_key_value_heads, 0, head_size); - let value_shape = (batch_size, config.num_key_value_heads, 0, head_size); - - let empty_key = ndarray::Array4::::zeros(key_shape); - let empty_value = ndarray::Array4::::zeros(value_shape); - - let key_value = ort::value::Tensor::from_array(empty_key).e()?; - let value_value = ort::value::Tensor::from_array(empty_value).e()?; - inputs.push(( - format!("past_key_values.{}.key", i).into(), - key_value.into(), - )); - inputs.push(( - format!("past_key_values.{}.value", i).into(), - value_value.into(), - )); - } - } +impl Backend for OrtBackend { + fn max_batch_size(&self) -> Option { + Some(8) + } - inputs - }; + fn health(&self) -> Result<(), BackendError> { + Ok(()) + } + + fn is_padded(&self) -> bool { + true + } + + fn embed(&self, batch: Batch) -> Result { + let batch_size = batch.len(); + let max_length = batch.max_length as usize; + + let (model_inputs, masking) = + self.prepare_inputs(&batch, &self.tokenizer_config.padding_side)?; + + let inputs = self.prepare_ort_inputs( + model_inputs.input_ids, + model_inputs.attention_mask.clone(), + model_inputs.token_type_ids, + model_inputs.position_ids, + model_inputs.past_key_values, + )?; // Run model let mut session = self.session.lock().unwrap(); @@ -303,18 +440,73 @@ impl Backend for OrtBackend { }; let pooled_embeddings = match self.pool { - Pool::Cls => outputs.slice(s![.., 0, ..]).into_owned().into_dyn(), - Pool::LastToken => { - let axis_len = outputs.len_of(Axis(1)); - outputs - .slice(s![.., axis_len - 1, ..]) - .into_owned() - .into_dyn() - } + Pool::Cls => match self.tokenizer_config.padding_side { + PaddingSide::Left => { + if masking { + let mut cls_embeddings = Vec::new(); + for (batch_idx, &seq_length) in + model_inputs.input_lengths.iter().enumerate() + { + let padding = max_length as f32 - seq_length; + let cls_pos = padding as usize; + cls_embeddings + .push(outputs.slice(s![batch_idx, cls_pos, ..]).to_owned()); + } + ndarray::stack( + Axis(0), + &cls_embeddings.iter().map(|x| x.view()).collect::>(), + ) + .unwrap() + .into_dyn() + } else { + outputs.slice(s![.., 0, ..]).into_owned().into_dyn() + } + } + PaddingSide::Right => outputs.slice(s![.., 0, ..]).into_owned().into_dyn(), + }, + Pool::LastToken => match self.tokenizer_config.padding_side { + // NOTE: when using left-padding, the last-token is always in the last position + // as the padding tokens are on the left (note that given that the last token + // in the sequence is the EOS token we need to use the last - 1. + PaddingSide::Left => { + let axis_len = outputs.len_of(Axis(1)); + outputs + .slice(s![.., axis_len - 1, ..]) + .into_owned() + .into_dyn() + } + PaddingSide::Right => { + if masking { + let mut last_token_embeddings = Vec::new(); + for (batch_idx, &seq_length) in + model_inputs.input_lengths.iter().enumerate() + { + let last_pos = seq_length as usize - 1; + last_token_embeddings + .push(outputs.slice(s![batch_idx, last_pos, ..]).to_owned()); + } + ndarray::stack( + Axis(0), + &last_token_embeddings + .iter() + .map(|x| x.view()) + .collect::>(), + ) + .unwrap() + .into_dyn() + } else { + let axis_len = outputs.len_of(Axis(1)); + outputs + .slice(s![.., axis_len - 1, ..]) + .into_owned() + .into_dyn() + } + } + }, Pool::Mean => { if masking { - let mut attention_mask = attention_mask; - let mut input_lengths = input_lengths; + let mut attention_mask = model_inputs.attention_mask; + let mut input_lengths = model_inputs.input_lengths; if let Some(indices) = indices { // Select values in the batch @@ -322,14 +514,33 @@ impl Backend for OrtBackend { input_lengths = input_lengths.select(Axis(0), &indices); }; - // Cast and reshape - let attention_mask = attention_mask.mapv(|x| x as f32).insert_axis(Axis(2)); - - // Mask padded values - outputs = outputs.mul(attention_mask); - outputs - .sum_axis(Axis(1)) - .div(input_lengths.insert_axis(Axis(1))) + match self.tokenizer_config.padding_side { + PaddingSide::Left => { + let mut mean_embeddings = Vec::new(); + for (batch_idx, &seq_length) in input_lengths.iter().enumerate() { + let padding = max_length as f32 - seq_length; + let start_pos = padding as usize; + let valid_embeddings = + outputs.slice(s![batch_idx, start_pos.., ..]); + mean_embeddings + .push(valid_embeddings.mean_axis(Axis(0)).unwrap()); + } + ndarray::stack( + Axis(0), + &mean_embeddings.iter().map(|x| x.view()).collect::>(), + ) + .unwrap() + .into_dyn() + } + PaddingSide::Right => { + let attention_mask = + attention_mask.mapv(|x| x as f32).insert_axis(Axis(2)); + outputs = outputs.mul(attention_mask); + outputs + .sum_axis(Axis(1)) + .div(input_lengths.insert_axis(Axis(1))) + } + } } else { outputs.mean_axis(Axis(1)).unwrap() } @@ -356,22 +567,49 @@ impl Backend for OrtBackend { // member of the batch that require pooling // or if batch_size > 1 and the members of the batch have different lengths let raw_embeddings = if (masking || has_pooling_requests) && batch_size > 1 { - let mut final_indices: Vec = Vec::with_capacity(batch_size * max_length); + match self.tokenizer_config.padding_side { + PaddingSide::Left => { + let mut final_indices: Vec = + Vec::with_capacity(batch_size * max_length); + + for i in batch.raw_indices.iter() { + let i = *i as usize; + let length = batch.cumulative_seq_lengths[i + 1] + - batch.cumulative_seq_lengths[i]; + let padding = batch.max_length - length; + + // For left padding, actual tokens start after the padding + let start = i * batch.max_length as usize + padding as usize; + let end = start + length as usize; + + for j in start..end { + final_indices.push(j); + } + } - for i in batch.raw_indices.iter() { - let start = i * batch.max_length; - let i = *i as usize; - let length = - batch.cumulative_seq_lengths[i + 1] - batch.cumulative_seq_lengths[i]; + // Select the tokens with final indices + outputs.select(Axis(0), &final_indices) + } + PaddingSide::Right => { + let mut final_indices: Vec = + Vec::with_capacity(batch_size * max_length); + + for i in batch.raw_indices.iter() { + let start = i * batch.max_length; + let i = *i as usize; + let length = batch.cumulative_seq_lengths[i + 1] + - batch.cumulative_seq_lengths[i]; + + for j in start..start + length { + // Add indices for the tokens of this specific member of the batch + final_indices.push(j as usize); + } + } - for j in start..start + length { - // Add indices for the tokens of this specific member of the batch - final_indices.push(j as usize); + // Select the tokens with final indices + outputs.select(Axis(0), &final_indices) } } - - // Select the tokens with final indices - outputs.select(Axis(0), &final_indices) } else { outputs }; @@ -399,112 +637,16 @@ impl Backend for OrtBackend { fn predict(&self, batch: Batch) -> Result { let batch_size = batch.len(); - let max_length = batch.max_length as usize; - let (input_ids, token_type_ids, attention_mask, position_ids) = { - let elems = batch_size * max_length; + let (model_inputs, _) = self.prepare_inputs(&batch, &self.tokenizer_config.padding_side)?; - if batch_size > 1 { - // Prepare padded batch - let mut input_ids = Vec::with_capacity(elems); - let mut token_type_ids = Vec::with_capacity(elems); - let mut attention_mask = Vec::with_capacity(elems); - let mut position_ids = Vec::with_capacity(elems); - - for i in 0..batch_size { - let start = batch.cumulative_seq_lengths[i] as usize; - let end = batch.cumulative_seq_lengths[i + 1] as usize; - let seq_length = (end - start) as u32; - - // Copy values - for (pos, j) in (start..end).enumerate() { - input_ids.push(batch.input_ids[j] as i64); - token_type_ids.push(batch.token_type_ids[j] as i64); - attention_mask.push(1_i64); - position_ids.push(pos as i64); - } - - // Add padding if needed - let padding = batch.max_length - seq_length; - if padding > 0 { - for pad_pos in 0..padding { - input_ids.push(0); - token_type_ids.push(0); - attention_mask.push(0_i64); - position_ids.push((seq_length + pad_pos) as i64); - } - } - } - (input_ids, token_type_ids, attention_mask, position_ids) - } else { - let attention_mask = vec![1_i64; elems]; - let position_ids: Vec = (0..max_length as i64).collect(); - - ( - batch.input_ids.into_iter().map(|v| v as i64).collect(), - batch.token_type_ids.into_iter().map(|v| v as i64).collect(), - attention_mask, - position_ids, - ) - } - }; - - // Create ndarrays - let input_ids = ndarray::Array2::from_shape_vec((batch_size, max_length), input_ids).e()?; - let attention_mask = - ndarray::Array2::from_shape_vec((batch_size, max_length), attention_mask).e()?; - let position_ids = - ndarray::Array2::from_shape_vec((batch_size, max_length), position_ids).e()?; - - let inputs = { - let mut inputs = ort::inputs![ - "input_ids" => ort::value::Tensor::from_array(input_ids).e()?, - "attention_mask" => ort::value::Tensor::from_array(attention_mask.clone()).e()?, - ]; - - if self.token_type_ids { - let token_type_ids_tensor = - ndarray::Array2::from_shape_vec((batch_size, max_length), token_type_ids) - .e()?; - let token_type_ids_value = - ort::value::Tensor::from_array(token_type_ids_tensor).e()?; - inputs.push(( - self.token_type_ids_key.clone().into(), - token_type_ids_value.into(), - )); - } - - if self.position_ids { - let position_ids_value = ort::value::Tensor::from_array(position_ids).e()?; - inputs.push(("position_ids".into(), position_ids_value.into())); - } - - if self.past_key_values { - let config = self.past_key_values_config.as_ref().unwrap(); - let head_size = config.hidden_size / config.num_key_value_heads; - - for i in 0..config.num_hidden_layers { - let key_shape = (batch_size, config.num_key_value_heads, 0, head_size); - let value_shape = (batch_size, config.num_key_value_heads, 0, head_size); - - let empty_key = ndarray::Array4::::zeros(key_shape); - let empty_value = ndarray::Array4::::zeros(value_shape); - - let key_value = ort::value::Tensor::from_array(empty_key).e()?; - let value_value = ort::value::Tensor::from_array(empty_value).e()?; - inputs.push(( - format!("past_key_values.{}.key", i).into(), - key_value.into(), - )); - inputs.push(( - format!("past_key_values.{}.value", i).into(), - value_value.into(), - )); - } - } - - inputs - }; + let inputs = self.prepare_ort_inputs( + model_inputs.input_ids, + model_inputs.attention_mask.clone(), + model_inputs.token_type_ids, + model_inputs.position_ids, + model_inputs.past_key_values, + )?; // Run model let mut session = self.session.lock().unwrap(); diff --git a/backends/src/lib.rs b/backends/src/lib.rs index 85f26e6c..49a31c0a 100644 --- a/backends/src/lib.rs +++ b/backends/src/lib.rs @@ -380,6 +380,18 @@ async fn init_backend( } } + // NOTE: for ONNX we need to retrieve the `tokenizer_config.json` to identify which + // `padding_side` needs to be applied for the input processing and the pooling + if let Some(api_repo) = api_repo.as_ref() { + tracing::info!("Downloading `tokenizer_config.json`"); + match api_repo.get("tokenizer_config.json").await { + Ok(_) => (), + Err(err) => { + tracing::warn!("Could not download `tokenizer_config.json`: {}", err) + } + } + } + let backend = OrtBackend::new(&model_path, dtype.to_string(), model_type.clone()); match backend { Ok(b) => return Ok(Box::new(b)),