diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index 882cdb8a..ca3bff5f 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -500,7 +500,10 @@ impl Backend for CandleBackend { // Used for indexing in the raw_embeddings tensor let input_lengths: Vec = (0..batch.len()) .map(|i| { - (batch.cumulative_seq_lengths[i + 1] - batch.cumulative_seq_lengths[i]) as usize + batch.cumulative_seq_lengths[i + 1] + .checked_sub(batch.cumulative_seq_lengths[i]) + .expect("Invalid cumulative sequence lengths: sequence lengths must be non-decreasing") + as usize }) .collect(); @@ -565,3 +568,45 @@ impl WrapErr for Result { self.map_err(|e| BackendError::Inference(e.to_string())) } } + +#[cfg(test)] +mod tests { + use super::*; + use text_embeddings_backend_core::Batch; + + #[test] + #[should_panic(expected = "Invalid cumulative sequence lengths")] + fn test_invalid_cumulative_seq_lengths() { + // Create a mock backend for testing + let device = Device::Cpu; + let model = Box::new(MockModel); + let backend = CandleBackend { device, model }; + + // Create a batch with invalid cumulative sequence lengths (decreasing) + let batch = Batch { + input_ids: vec![1, 2, 3, 4], + token_type_ids: vec![0, 0, 0, 0], + position_ids: vec![0, 1, 2, 3], + cumulative_seq_lengths: vec![0, 3, 2], // Invalid: 3 > 2 + max_length: 4, + pooled_indices: vec![0], + raw_indices: vec![], + }; + + // This should panic due to invalid cumulative sequence lengths + let _ = backend.embed(batch); + } + + // Mock model for testing + struct MockModel; + + impl crate::models::Model for MockModel { + fn is_padded(&self) -> bool { + false + } + + fn embed(&self, _batch: Batch) -> candle::Result<(Option, Option)> { + Ok((None, None)) + } + } +}