Skip to content

feat(candle): add FlashMistral #308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 13 additions & 5 deletions backends/candle/src/flash_attn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,17 @@ pub(crate) fn flash_attn_varlen(
max_seqlen_k: usize,
softmax_scale: f32,
causal: bool,
window_size_left: Option<usize>,
) -> Result<Tensor, candle::Error> {
let runtime_compute_cap = get_runtime_compute_cap();

if runtime_compute_cap == 75 {
if alibi_slopes.is_some() {
candle::bail!("Flash attention v1 does not support alibi");
}
if window_size_left.is_some() {
candle::bail!("Flash attention v1 does not support attention windowing");
}

#[cfg(feature = "flash-attn-v1")]
{
Expand All @@ -59,10 +63,12 @@ pub(crate) fn flash_attn_varlen(
} else if (80..90).contains(&runtime_compute_cap) || runtime_compute_cap == 90 {
#[cfg(feature = "flash-attn")]
{
use candle_flash_attn::{flash_attn_varlen, flash_attn_varlen_alibi};
use candle_flash_attn::{flash_attn_varlen_alibi_windowed, flash_attn_varlen_windowed};

let window_size_right = if causal { Some(0) } else { None };

let attention = if let Some(alibi_slopes) = alibi_slopes {
flash_attn_varlen_alibi(
flash_attn_varlen_alibi_windowed(
q,
k,
v,
Expand All @@ -72,10 +78,11 @@ pub(crate) fn flash_attn_varlen(
max_seqlen_q,
max_seqlen_k,
softmax_scale,
causal,
window_size_left,
window_size_right,
)
} else {
flash_attn_varlen(
flash_attn_varlen_windowed(
q,
k,
v,
Expand All @@ -84,7 +91,8 @@ pub(crate) fn flash_attn_varlen(
max_seqlen_q,
max_seqlen_k,
softmax_scale,
causal,
window_size_left,
window_size_right,
)
};

Expand Down
1 change: 1 addition & 0 deletions backends/candle/src/layers/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use serde::Deserialize;
pub enum HiddenAct {
Gelu,
Relu,
#[serde(alias = "silu")]
Swiglu,
}

Expand Down
4 changes: 4 additions & 0 deletions backends/candle/src/layers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
mod cublaslt;
mod layer_norm;
mod linear;
#[allow(dead_code, unused)]
mod rms_norm;

pub use cublaslt::get_cublas_lt_wrapper;
pub use layer_norm::LayerNorm;
pub use linear::{HiddenAct, Linear};
#[allow(unused_imports)]
pub use rms_norm::RMSNorm;
96 changes: 96 additions & 0 deletions backends/candle/src/layers/rms_norm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
use candle::{DType, Device, Result, Tensor, D};
use candle_nn::VarBuilder;

#[derive(Debug)]
pub struct RMSNorm {
weight: Tensor,
epsilon: f32,
span: tracing::Span,
}

impl RMSNorm {
pub fn load(vb: VarBuilder, hidden_size: usize, epsilon: f32) -> Result<Self> {
Ok(Self {
weight: vb
.get(hidden_size, "weight")
.or_else(|_| vb.get(hidden_size, "gamma"))?,
epsilon,
span: tracing::span!(tracing::Level::TRACE, "rms-norm"),
})
}

pub fn forward(
&self,
hidden_states: &Tensor,
residual: Option<&Tensor>,
) -> Result<(Tensor, Tensor)> {
let _enter = self.span.enter();

match hidden_states.device() {
Device::Cpu | Device::Metal(_) => {
let mut hidden_states = hidden_states.clone();
let residual_add = if let Some(residual) = residual {
let residual_add = hidden_states.add(residual)?;
hidden_states = residual_add.clone();
residual_add
} else {
hidden_states.clone()
};

let hidden_states_dtype = hidden_states.dtype();
let internal_dtype = match hidden_states_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let hidden_size = hidden_states.dim(D::Minus1)?;
let hidden_states = hidden_states.to_dtype(internal_dtype)?;
let norm_hidden_states =
(hidden_states.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let hidden_states_normed = hidden_states
.broadcast_div(&(norm_hidden_states + self.epsilon as f64)?.sqrt()?)?;
Ok((
hidden_states_normed
.to_dtype(hidden_states_dtype)?
.broadcast_mul(&self.weight)?,
residual_add,
))
}
Device::Cuda(_) => {
#[cfg(feature = "cuda")]
{
use candle_layer_norm::{fused_add_rms_norm, rms_norm};

let original_shape = hidden_states.shape();
let hidden_states = hidden_states.flatten_to(D::Minus2)?;

if let Some(residual) = residual {
let residual = residual.flatten_to(D::Minus2)?;

let (result, residual_add) = fused_add_rms_norm(
&hidden_states,
&residual,
&self.weight,
None,
self.epsilon,
)?;
Ok((
result.reshape(original_shape)?,
residual_add.reshape(original_shape)?,
))
} else {
let residual_add = hidden_states.clone();

let result = rms_norm(&hidden_states, &self.weight, None, self.epsilon)?;

Ok((
result.reshape(original_shape)?,
residual_add.reshape(original_shape)?,
))
}
}
#[cfg(not(feature = "cuda"))]
candle::bail!("`cuda` feature is not enabled")
}
}
}
}
97 changes: 78 additions & 19 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ use crate::compute_cap::{
};
use crate::models::{
BertConfig, BertModel, DistilBertConfig, DistilBertModel, JinaBertModel, JinaCodeBertModel,
Model, NomicBertModel, NomicConfig,
MistralConfig, Model, NomicBertModel, NomicConfig,
};
#[cfg(feature = "cuda")]
use crate::models::{
FlashBertModel, FlashDistilBertModel, FlashJinaBertModel, FlashJinaCodeBertModel,
FlashNomicBertModel,
FlashMistralModel, FlashNomicBertModel,
};
use anyhow::Context;
use candle::{DType, Device};
Expand Down Expand Up @@ -56,6 +56,7 @@ enum Config {
DistilBert(DistilBertConfig),
#[serde(rename(deserialize = "nomic_bert"))]
NomicBert(NomicConfig),
Mistral(MistralConfig),
}

pub struct CandleBackend {
Expand All @@ -69,6 +70,54 @@ impl CandleBackend {
dtype: String,
model_type: ModelType,
) -> Result<Self, BackendError> {
// Default files
let default_safetensors = model_path.join("model.safetensors");
let default_pytorch = model_path.join("pytorch_model.bin");

// Single Files
let model_files = if default_safetensors.exists() {
vec![default_safetensors]
} else if default_pytorch.exists() {
vec![default_pytorch]
}
// Sharded weights
else {
// Get index file
let index_file = model_path.join("model.safetensors.index.json");

// Parse file
let index_file_string: String = std::fs::read_to_string(&index_file)
.map_err(|err| BackendError::Start(err.to_string()))?;
let json: serde_json::Value = serde_json::from_str(&index_file_string)
.map_err(|err| BackendError::Start(err.to_string()))?;

let weight_map = match json.get("weight_map") {
None => {
return Err(BackendError::Start(format!(
"no weight map in {index_file:?}"
)));
}
Some(serde_json::Value::Object(map)) => map,
Some(_) => {
return Err(BackendError::Start(format!(
"weight map in {index_file:?} is not a map"
)));
}
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
if let Some(file) = value.as_str() {
safetensors_files.insert(file.to_string());
}
}

// Collect paths
safetensors_files
.iter()
.map(|n| model_path.join(n))
.collect()
};

// Load config
let config: String = std::fs::read_to_string(model_path.join("config.json"))
.context("Unable to read config file")
Expand Down Expand Up @@ -115,17 +164,10 @@ impl CandleBackend {
)))
}?;

let safetensors_path = model_path.join("model.safetensors");
let vb = if safetensors_path.exists() {
unsafe {
VarBuilder::from_mmaped_safetensors(
&[model_path.join("model.safetensors")],
dtype,
&device,
)
}
let vb = if model_files.len() == 1 && model_files[0].extension().unwrap() == "bin" {
VarBuilder::from_pth(&model_files[0], dtype, &device)
} else {
VarBuilder::from_pth(model_path.join("pytorch_model.bin"), dtype, &device)
unsafe { VarBuilder::from_mmaped_safetensors(&model_files, dtype, &device) }
}
.s()?;

Expand All @@ -136,7 +178,7 @@ impl CandleBackend {
)),
(Config::Bert(config), Device::Cpu | Device::Metal(_)) => match config {
BertConfigWrapper::JinaBert(config) => {
tracing::info!("Starting JinaBertModel model on {:?}", device);
tracing::info!("Starting JinaBert model on {:?}", device);
Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?))
}
BertConfigWrapper::JinaCodeBert(config) => {
Expand All @@ -160,15 +202,19 @@ impl CandleBackend {
))
}
(Config::DistilBert(config), Device::Cpu | Device::Metal(_)) => {
tracing::info!("Starting DistilBertModel model on {:?}", device);
tracing::info!("Starting DistilBert model on {:?}", device);
Ok(Box::new(
DistilBertModel::load(vb, &config, model_type).s()?,
))
}
(Config::NomicBert(config), Device::Cpu | Device::Metal(_)) => {
tracing::info!("Starting NomicBertModel model on {:?}", device);
tracing::info!("Starting NomicBert model on {:?}", device);
Ok(Box::new(NomicBertModel::load(vb, &config, model_type).s()?))
}
(Config::Mistral(_), Device::Cpu | Device::Metal(_)) => Err(BackendError::Start(
"Mistral is only supported on Cuda devices in fp16 with flash attention enabled"
.to_string(),
)),
#[cfg(feature = "cuda")]
(Config::Bert(config), Device::Cuda(_)) => {
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
Expand Down Expand Up @@ -198,7 +244,7 @@ impl CandleBackend {
} else {
match config {
BertConfigWrapper::JinaBert(config) => {
tracing::info!("Starting JinaBertModel model on {:?}", device);
tracing::info!("Starting JinaBert model on {:?}", device);
Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?))
}
BertConfigWrapper::JinaCodeBert(config) => {
Expand Down Expand Up @@ -245,7 +291,7 @@ impl CandleBackend {
.to_lowercase()
== "true"
{
tracing::info!("Starting FlashDistilBertModel model on {:?}", device);
tracing::info!("Starting FlashDistilBert model on {:?}", device);
Ok(Box::new(
FlashDistilBertModel::load(vb, &config, model_type).s()?,
))
Expand All @@ -265,15 +311,28 @@ impl CandleBackend {
.to_lowercase()
== "true"
{
tracing::info!("Starting FlashNomicBertModel model on {:?}", device);
tracing::info!("Starting FlashNomicBert model on {:?}", device);
Ok(Box::new(
FlashNomicBertModel::load(vb, &config, model_type).s()?,
))
} else {
tracing::info!("Starting NomicBertModel model on {:?}", device);
tracing::info!("Starting NomicBert model on {:?}", device);
Ok(Box::new(NomicBertModel::load(vb, &config, model_type).s()?))
}
}
#[cfg(feature = "cuda")]
(Config::Mistral(config), Device::Cuda(_)) => {
if dtype != DType::F16
|| !cfg!(feature = "flash-attn")
|| get_runtime_compute_cap().unwrap() < 80
{
return Err(BackendError::Start("Mistral is only supported on Cuda devices in fp16 with flash attention v2 enabled".to_string()));
}
tracing::info!("Starting FlashMistral model on {:?}", device);
Ok(Box::new(
FlashMistralModel::load(vb, &config, model_type).s()?,
))
}
};

Ok(Self {
Expand Down
6 changes: 6 additions & 0 deletions backends/candle/src/models/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,10 @@ impl BertModel {
(pool, Some(classifier), None)
}
ModelType::Embedding(pool) => {
if pool == Pool::LastToken {
candle::bail!("`last_token` is not supported for Bert");
}

let splade = if pool == Pool::Splade {
Some(BertSpladeHead::load_roberta(vb.clone(), config)?)
} else {
Expand Down Expand Up @@ -832,6 +836,8 @@ impl BertModel {
let pooled_embeddings = match self.pool {
// CLS pooling
Pool::Cls => outputs.i((.., 0))?,
// Last token pooling is not supported for this model
Pool::LastToken => unreachable!(),
// Mean pooling
Pool::Mean => {
if let Some(ref attention_mask) = attention_mask {
Expand Down
Loading
Loading