From 120d30f0102e1d29e7609f2b20fe25bb12251222 Mon Sep 17 00:00:00 2001 From: onenewcode Date: Wed, 12 Feb 2025 19:22:30 +0800 Subject: [PATCH 1/7] =?UTF-8?q?fix:=20=E6=94=AF=E6=8C=81attention=E8=AE=A1?= =?UTF-8?q?=E7=AE=97=20q=EF=BC=8Ck=EF=BC=8Cv=E4=B8=8D=E7=9B=B8=E5=90=8C?= =?UTF-8?q?=E7=9A=84=E6=83=85=E5=86=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- operators/src/attention/args.rs | 19 ++++++++++++------- operators/src/attention/cuda.rs | 1 + operators/src/attention/operator.rs | 16 ++++++++++------ operators/src/attention_kv_cached/operator.rs | 4 ++-- 4 files changed, 25 insertions(+), 15 deletions(-) diff --git a/operators/src/attention/args.rs b/operators/src/attention/args.rs index 4c1ac43a..a617cc0f 100644 --- a/operators/src/attention/args.rs +++ b/operators/src/attention/args.rs @@ -30,6 +30,7 @@ pub(super) struct Meta { pub seq: MaybeDyn, pub att: MaybeDyn, pub dh: MaybeDyn, + pub dv: MaybeDyn, } impl Args { @@ -41,17 +42,20 @@ impl Args { seq: MaybeDyn, att: MaybeDyn, dh: MaybeDyn, + dv: MaybeDyn, ) -> Self { - let qo_layout = TensorLayout::new_dyn(dt, &[nh, seq, dh], &[dyn_(); 3]); - let kv_layout = TensorLayout::new_dyn(dt, &[nkvh, att, dh], &[dyn_(); 3]); + let q_layout = TensorLayout::new_dyn(dt, &[nh, seq, dh], &[dyn_(); 3]); + let k_layout = TensorLayout::new_dyn(dt, &[nkvh, seq, dh], &[dyn_(); 3]); + let v_layout = TensorLayout::new_dyn(dt, &[nkvh, att, dv], &[dyn_(); 3]); + let o_layout = TensorLayout::new_dyn(dt, &[nkvh, att, dh], &[dyn_(); 3]); Self { - q_layout: qo_layout.clone(), + q_layout: q_layout.clone(), q_base: null_mut(), - k_layout: kv_layout.clone(), + k_layout: k_layout.clone(), k_base: null(), - v_layout: kv_layout, + v_layout: v_layout, v_base: null(), - o_layout: qo_layout, + o_layout: o_layout, o_base: null_mut(), mask, } @@ -85,7 +89,8 @@ impl Args { nkvh: dim_distinct(&[nkvh_k, nkvh_v])?, seq: dim_distinct(&[seq_q, seq_o])?, att: dim_distinct(&[att_k, att_v])?, - dh: dim_distinct(&[dh_q, dh_k, dh_v, dh_o])?, + dh: dim_distinct(&[dh_q, dh_k])?, + dv: dim_distinct(&[dh_v, dh_o])?, }) } } diff --git a/operators/src/attention/cuda.rs b/operators/src/attention/cuda.rs index 208561c3..3a67ac90 100644 --- a/operators/src/attention/cuda.rs +++ b/operators/src/attention/cuda.rs @@ -16,6 +16,7 @@ mod test { seq.into(), att.into(), dyn_(), + dyn_(), ) } diff --git a/operators/src/attention/operator.rs b/operators/src/attention/operator.rs index c1281672..be9418ec 100644 --- a/operators/src/attention/operator.rs +++ b/operators/src/attention/operator.rs @@ -1,4 +1,4 @@ -use super::{args::Meta, Args, Attention}; +use super::{args::Meta, Args, Attention}; use crate::{ dyn_, fuesd_softmax, get_static, mat_mul, rearrange, ByteOf, Hardware, LaunchError, QueueAlloc, SchemeError, TensorLayout, Workspace, WorkspaceCollector, @@ -53,6 +53,7 @@ where seq, att, dh, + dv, .. } = args.meta()?; let Args { @@ -64,11 +65,12 @@ where } = args; // 如果不能保证 nh seq att dh 已知,用任意值初始化算子 - let (Some(&nh), Some(&seq), Some(&att), Some(&dh)) = ( + let (Some(&nh), Some(&seq), Some(&att), Some(&dh), Some(&dv)) = ( nh.get_static(), seq.get_static(), att.get_static(), dh.get_static(), + dv.get_static(), ) else { let mut wc = WorkspaceCollector::new(); @@ -149,6 +151,7 @@ where seq, att, dh, + dv, } = args.meta()?; let Args { mask, @@ -172,8 +175,8 @@ where let ele = dt.nbytes(); get_static! { nh seq dh - nh_sq seq_sq dh_sq - nkvh att + dv seq_sq dh_sq + nkvh att nh_sq nkvh_sk att_sk dh_sk }; @@ -219,6 +222,7 @@ where let k_layout = TensorLayout::new(dt, k_layout.shape(), k_layout.strides()); let att_mat_mul = TensorLayout::new_contiguous(dt, &[nkvh, head_group * seq, att]); let att_softmax = TensorLayout::new_contiguous(dt, &[nh, seq, att]); + let att_result = TensorLayout::new_contiguous(dt, &[nkvh, head_group * seq, dv]); // att = q . k^T self.mat_mul.launch( @@ -248,7 +252,7 @@ where // q = att . v self.mat_mul.launch( &mat_mul::Args { - c_layout: qx_layout.clone(), + c_layout: att_result.clone(), c_base: q_base, beta: 0., a_layout: att_mat_mul, @@ -266,7 +270,7 @@ where &rearrange::Args { dst_layout: o_layout.clone(), dst_base: *o_base, - src_layout: q_layout.clone(), + src_layout: o_layout.clone(), src_base: q_base, }, workspace, diff --git a/operators/src/attention_kv_cached/operator.rs b/operators/src/attention_kv_cached/operator.rs index 81345f4c..d0c8412e 100644 --- a/operators/src/attention_kv_cached/operator.rs +++ b/operators/src/attention_kv_cached/operator.rs @@ -1,4 +1,4 @@ -use super::{args::Meta, Args, AttnKVCached}; +use super::{args::Meta, Args, AttnKVCached}; use crate::{ attention, dyn_, get_static, rearrange, shape_mismatch, ByteOf, Hardware, LaunchError, MaybeDyn, QueueAlloc, TensorLayout, WorkspaceCollector, @@ -66,7 +66,7 @@ where }; wc.push_sub(self.attention.scheme( - &attention::Args::new_null(args.mask, dt, nh, nkvh, seq, att, dh), + &attention::Args::new_null(args.mask, dt, nh, nkvh, seq, att, dh, dh), max_workspace_size, )?); From f4a83f7205145e4ff47d8df6eeb2e15445db5044 Mon Sep 17 00:00:00 2001 From: onenewcode Date: Fri, 14 Feb 2025 13:39:33 +0800 Subject: [PATCH 2/7] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0Scale=E7=AE=97?= =?UTF-8?q?=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- operators/src/lib.rs | 1 + operators/src/scale/args.rs | 174 ++++++++++++++++++++++++++ operators/src/scale/common_cpu/mod.rs | 71 +++++++++++ operators/src/scale/cuda/mod.rs | 49 ++++++++ operators/src/scale/infini/mod.rs | 36 ++++++ operators/src/scale/mod.rs | 15 +++ operators/src/scale/opencl/mod.rs | 36 ++++++ 7 files changed, 382 insertions(+) create mode 100644 operators/src/scale/args.rs create mode 100644 operators/src/scale/common_cpu/mod.rs create mode 100644 operators/src/scale/cuda/mod.rs create mode 100644 operators/src/scale/infini/mod.rs create mode 100644 operators/src/scale/mod.rs create mode 100644 operators/src/scale/opencl/mod.rs diff --git a/operators/src/lib.rs b/operators/src/lib.rs index 02088acc..3c86f3ac 100644 --- a/operators/src/lib.rs +++ b/operators/src/lib.rs @@ -18,6 +18,7 @@ pub mod random_sample; pub mod rearrange; pub mod rms_norm; pub mod rope; +pub mod scale; pub mod swiglu; pub use common::*; diff --git a/operators/src/scale/args.rs b/operators/src/scale/args.rs new file mode 100644 index 00000000..21951926 --- /dev/null +++ b/operators/src/scale/args.rs @@ -0,0 +1,174 @@ +use crate::{ + get_static, rank_mismatch, shape_mismatch, shape_not_support, utils::type_distinct, ConstPtr, + Hardware, MutPtr, SchemeError, TensorLayout, +}; +use digit_layout::DigitLayout; +use itertools::izip; +use std::{ + cmp::Ordering, + ptr::{null, null_mut}, +}; + +#[derive(Clone)] +pub struct Args { + pub c_layout: TensorLayout, + pub c_base: MutPtr, + pub a_layout: TensorLayout, + pub a_base: ConstPtr, + pub s: f32, +} + +impl Args { + pub fn new_null( + c_layout: TensorLayout, + a_layout: TensorLayout, + b_layout: TensorLayout, + ) -> Self { + Self { + c_layout, + c_base: null_mut(), + a_layout, + a_base: null(), + s: 1.0, + } + } +} + +#[derive(Clone, Debug)] +pub(super) struct Scheme(DigitLayout, Box<[isize]>); + +impl Scheme { + pub fn new(args: &Args) -> Result { + let Args { + c_layout: c, + a_layout: a, + .. + } = args; + // # 检查基本属性 + let dt = type_distinct(&[c.dt(), a.dt()])?; + let ndim = c.ndim(); + if a.ndim() != ndim { + return Err(rank_mismatch(format!( + "c.ndim = {}, a.ndim = {}", + c.ndim(), + a.ndim(), + ))); + } + // # 输入形状 + #[derive(Clone, PartialEq, Eq, Debug)] + struct Dim { + d: usize, + c: isize, + a: isize, + } + let mut dims = Vec::with_capacity(ndim); + for (&d, &da, &sc, &sa) in izip!(c.shape(), a.shape(), c.strides(), a.strides(),) { + get_static! { + d da + sc sa + } + if da != d { + return Err(shape_mismatch(format!( + "c: {:?}, a: {:?}", + c.shape(), + a.shape(), + ))); + } + // 剔除初始的 1 长维度 + if d != 1 { + if sc == 0 { + return Err(shape_not_support("Reducing is not allowed for scale")); + } + dims.push(Dim { d, c: sc, a: sa }) + } + } + // # 排序 + dims.sort_unstable_by(|dim0, dim1| { + let &Dim { + d: d0, + c: c0, + a: a0, + } = dim0; + let &Dim { + d: d1, + c: c1, + a: a1, + } = dim1; + use Ordering::Equal as Eq; + match c0.abs().cmp(&c1.abs()) { + Eq => match a0.abs().cmp(&a1.abs()) { + ord => ord.reverse(), + }, + ord => ord.reverse(), + } + }); + // # 合并连续维度 + let mut ndim = dims.len(); + for i in (1..dims.len()).rev() { + let (head, tail) = dims.split_at_mut(i); + let f = &mut head[i - 1]; // f for front + let b = &mut tail[0]; // b for back + let d = b.d as isize; + if b.c * d == f.c && b.a * d == f.a { + *f = Dim { d: b.d * f.d, ..*b }; + *b = Dim { d: 1, c: 0, a: 0 }; + ndim -= 1 + } + } + // # 合并空间 + let mut layout = vec![0isize; 1 + ndim * 4].into_boxed_slice(); + { + let (idx, tail) = layout.split_at_mut(1 + ndim); + let (c_, tail) = tail.split_at_mut(ndim); + let (a_, b_) = tail.split_at_mut(ndim); + for (Dim { d, c, a }, idx, c_, a_) in + izip!(dims.into_iter().filter(|d| d.d != 1), &mut *idx, c_, a_) + { + *idx = d as _; + *c_ = c; + *a_ = a; + } + idx[ndim] = 1; + for i in (1..=ndim).rev() { + idx[i - 1] *= idx[i]; + } + } + Ok(Self(dt, layout)) + } + + #[inline] + pub const fn dt(&self) -> DigitLayout { + self.0 + } + + /// 执行方案维数。 + #[inline] + pub fn ndim(&self) -> usize { + (self.1.len() - 1) / 4 + } + + /// 读写单元数量。 + #[inline] + pub fn count(&self) -> usize { + self.1[0] as _ + } + + /// 索引步长。 + #[inline] + pub fn idx_strides(&self) -> &[isize] { + let ndim = self.ndim(); + &self.1[1..][..ndim] + } + + #[inline] + pub fn c_strides(&self) -> &[isize] { + let ndim = self.ndim(); + &self.1[1 + ndim..][..ndim] + } + + #[inline] + pub fn a_strides(&self) -> &[isize] { + let ndim = self.ndim(); + &self.1[1 + ndim * 2..][..ndim] + } +} diff --git a/operators/src/scale/common_cpu/mod.rs b/operators/src/scale/common_cpu/mod.rs new file mode 100644 index 00000000..b4ba50cb --- /dev/null +++ b/operators/src/scale/common_cpu/mod.rs @@ -0,0 +1,71 @@ +use super::{args::Scheme, Args, Scale}; +use crate::{common_cpu::Cpu, ByteOf, LaunchError, QueueAlloc, SchemeError}; +use digit_layout::types as ty; +use half::f16; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; + +pub struct Operator; + +impl Scale for Operator {} + +impl crate::Operator for Operator { + type Hardware = Cpu; + type TopoNode = Cpu; + type Args = Args; + + #[inline] + fn new(_node: &Self::TopoNode) -> Self { + Self + } + #[inline] + fn scheme( + &mut self, + _args: &Self::Args, + _max_workspace_size: usize, + ) -> Result { + Ok(0) + } + + fn launch( + &self, + args: &Self::Args, + _workspace: &mut [ByteOf], + _queue_alloc: &QA, + ) -> Result<(), LaunchError> + where + QA: QueueAlloc, + { + let scheme = Scheme::new(args)?; + let c = args.c_base as isize; + let a = args.a_base as isize; + let s = args.s; + let idx_strides = scheme.idx_strides(); + let c_strides = scheme.c_strides(); + let a_strides = scheme.a_strides(); + (0..scheme.count() as isize) + .into_par_iter() + .for_each(|mut rem| { + let mut c = c; + let mut a = a; + for (i, &s) in idx_strides.iter().enumerate() { + let k = rem / s; + c += k * c_strides[i]; + a += k * a_strides[i]; + rem %= s; + } + match scheme.dt() { + ty::F16 => mul::(c, a, f16::from_f32(s)), + ty::F32 => mul::(c, a, s), + ty::F64 => mul::(c, a, s as f64), + _ => todo!(), + } + }); + Ok(()) + } +} + +fn mul>(c: isize, a: isize, s: T) { + let c = c as *mut T; + let a = a as *const T; + unsafe { *c = a.read() * s } +} diff --git a/operators/src/scale/cuda/mod.rs b/operators/src/scale/cuda/mod.rs new file mode 100644 index 00000000..79eb00c9 --- /dev/null +++ b/operators/src/scale/cuda/mod.rs @@ -0,0 +1,49 @@ +use super::{args::Scheme, Args, Scale}; +use crate::{ + cuda::{dt_name, Gpu, Handle, ModuleBox}, + shape_not_support, strides_not_support, + utils::{gcd, type_distinct}, + ByteOf, LaunchError, QueueAlloc, SchemeDiversity, SchemeError, +}; +use digit_layout::DigitLayout; +use lru::LruCache; +use std::{ + ffi::{c_uint, CString}, + sync::{Arc, Mutex}, +}; + +pub struct Operator {} +impl Scale for Operator {} + +impl crate::Operator for Operator { + type Hardware = Gpu; + type TopoNode = Gpu; + type Args = Args; + + fn new(node: &Self::TopoNode) -> Self { + Self {} + } + + #[inline] + fn scheme( + &mut self, + args: &Self::Args, + _max_workspace_size: usize, + ) -> Result { + todo!(); + Ok(0) + } + + fn launch( + &self, + _args: &Self::Args, + _workspace: &mut [ByteOf], + _queue_alloc: &QA, + ) -> Result<(), LaunchError> + where + QA: QueueAlloc, + { + todo!(); + Ok(()) + } +} diff --git a/operators/src/scale/infini/mod.rs b/operators/src/scale/infini/mod.rs new file mode 100644 index 00000000..183414e1 --- /dev/null +++ b/operators/src/scale/infini/mod.rs @@ -0,0 +1,36 @@ +use super::{Args, Scale}; +use crate::{infini::Device, ByteOf, LaunchError, QueueAlloc, SchemeError}; + +pub struct Operator; + +impl Add for Operator {} + +impl crate::Operator for Operator { + type Hardware = Device; + type TopoNode = Device; + type Args = Args; + + fn new(_node: &Self::TopoNode) -> Self { + todo!() + } + + fn scheme( + &mut self, + _args: &Self::Args, + _max_workspace_size: usize, + ) -> Result { + todo!() + } + + fn launch( + &self, + _args: &Self::Args, + _workspace: &mut [ByteOf], + _queue_alloc: &QA, + ) -> Result<(), LaunchError> + where + QA: QueueAlloc, + { + todo!() + } +} diff --git a/operators/src/scale/mod.rs b/operators/src/scale/mod.rs new file mode 100644 index 00000000..e5d473b0 --- /dev/null +++ b/operators/src/scale/mod.rs @@ -0,0 +1,15 @@ +//! c =s*a + +#[cfg(any(use_cpu, test))] +pub mod common_cpu; +#[cfg(use_cuda)] +pub mod cuda; +#[cfg(use_infini)] +pub mod infini; +#[cfg(use_cl)] +pub mod opencl; + +mod args; +pub use args::Args; + +crate::op_trait!(Scale); diff --git a/operators/src/scale/opencl/mod.rs b/operators/src/scale/opencl/mod.rs new file mode 100644 index 00000000..369e8f81 --- /dev/null +++ b/operators/src/scale/opencl/mod.rs @@ -0,0 +1,36 @@ +use super::{Args, Scale}; +use crate::{opencl::ClDevice, ByteOf, LaunchError, QueueAlloc, SchemeError}; + +pub struct Operator; + +impl Scale for Operator {} + +impl crate::Operator for Operator { + type Hardware = ClDevice; + type TopoNode = ClDevice; + type Args = Args; + + fn new(_node: &Self::TopoNode) -> Self { + todo!() + } + + fn scheme( + &mut self, + _args: &Self::Args, + _max_workspace_size: usize, + ) -> Result { + todo!() + } + + fn launch( + &self, + _args: &Self::Args, + _workspace: &mut [ByteOf], + _queue_alloc: &QA, + ) -> Result<(), LaunchError> + where + QA: QueueAlloc, + { + todo!() + } +} From ab65dfa4d28ab94ac8bb39ae6af3e88a39caf23d Mon Sep 17 00:00:00 2001 From: onenewcode Date: Mon, 17 Feb 2025 14:29:02 +0800 Subject: [PATCH 3/7] =?UTF-8?q?fix:=20=E6=B7=BB=E5=8A=A0=E6=9B=B4=E5=A4=9A?= =?UTF-8?q?rope?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- operators/src/attention/args.rs | 5 +- operators/src/attention/operator.rs | 4 +- operators/src/rearrange/args.rs | 4 +- operators/src/rope/args.rs | 38 +++- operators/src/rope/common_cpu/mod.rs | 274 +++++++++++++++++++++++---- operators/src/rope/cuda/mod.rs | 4 +- operators/src/rope/mod.rs | 2 +- operators/src/scale/args.rs | 9 +- operators/src/scale/cuda/mod.rs | 21 +- 9 files changed, 296 insertions(+), 65 deletions(-) diff --git a/operators/src/attention/args.rs b/operators/src/attention/args.rs index a617cc0f..f2914d74 100644 --- a/operators/src/attention/args.rs +++ b/operators/src/attention/args.rs @@ -34,6 +34,7 @@ pub(super) struct Meta { } impl Args { + #[allow(clippy::too_many_arguments)] pub(crate) fn new_null( mask: AttnMask, dt: DigitLayout, @@ -53,9 +54,9 @@ impl Args { q_base: null_mut(), k_layout: k_layout.clone(), k_base: null(), - v_layout: v_layout, + v_layout, v_base: null(), - o_layout: o_layout, + o_layout, o_base: null_mut(), mask, } diff --git a/operators/src/attention/operator.rs b/operators/src/attention/operator.rs index be9418ec..aeb253da 100644 --- a/operators/src/attention/operator.rs +++ b/operators/src/attention/operator.rs @@ -65,7 +65,7 @@ where } = args; // 如果不能保证 nh seq att dh 已知,用任意值初始化算子 - let (Some(&nh), Some(&seq), Some(&att), Some(&dh), Some(&dv)) = ( + let (Some(&nh), Some(&seq), Some(&att), Some(&dh), Some(&_dv)) = ( nh.get_static(), seq.get_static(), att.get_static(), @@ -194,7 +194,7 @@ where let (att_buf, workspace) = workspace.split_at_mut(att_size); let head_group = nh / nkvh; - let (q_layout, qx_layout, q_base) = match qx { + let (_q_layout, qx_layout, q_base) = match qx { None => { let q_layout = TensorLayout::new_contiguous(dt, &[nh, seq, dh]); let qx_layout = TensorLayout::new_contiguous(dt, &[nkvh, head_group * seq, dh]); diff --git a/operators/src/rearrange/args.rs b/operators/src/rearrange/args.rs index 64aa3d3d..f03767ba 100644 --- a/operators/src/rearrange/args.rs +++ b/operators/src/rearrange/args.rs @@ -268,13 +268,13 @@ fn test_scheme() { dst_layout: TensorLayout::new( F16, &shape, - &[33554432 * 2, 16777216 * 2, 524288 * 2, 128 * 2, 1 * 2], + &[33554432 * 2, 16777216 * 2, 524288 * 2, 128 * 2, 2], ), dst_base: null_mut(), src_layout: TensorLayout::new( F16, &shape, - &[33554432 * 2, 16777216 * 2, 524288 * 2, 128 * 2, 1 * 2], + &[33554432 * 2, 16777216 * 2, 524288 * 2, 128 * 2, 2], ), src_base: null(), }; diff --git a/operators/src/rope/args.rs b/operators/src/rope/args.rs index 6c3dff97..373cd6df 100644 --- a/operators/src/rope/args.rs +++ b/operators/src/rope/args.rs @@ -1,10 +1,45 @@ -use crate::{ +use crate::{ type_not_support, utils::{dim_distinct, rank_error}, ConstPtr, Hardware, MaybeDyn, MutPtr, SchemeError, TensorLayout, }; use digit_layout::DigitLayout; +pub enum RopeType { + // 以下枚举通用一个 Scheme + Rope, + Pi { + s: f32, + }, + Ntk { + s: f32, + }, + Dyn { + s: f32, + a: f32, + }, + + // 以下枚举通用一个 Scheme + NtkParts { + alpha: f32, + beta: f32, + l0: f32, + s: f32, + }, + Yarn { + alpha: f32, + beta: f32, + l0: f32, + s: f32, + }, + Long { + long: ConstPtr, + short: ConstPtr, + max_pos: u32, + origin_pos: u32, + }, +} + pub struct Args { pub t_layout: TensorLayout, pub t_base: MutPtr, @@ -15,6 +50,7 @@ pub struct Args { pub cos_layout: TensorLayout, pub cos_base: ConstPtr, pub theta: f32, + pub rope_type: RopeType, } pub(super) struct Meta { diff --git a/operators/src/rope/common_cpu/mod.rs b/operators/src/rope/common_cpu/mod.rs index 0630555b..2f52e34f 100644 --- a/operators/src/rope/common_cpu/mod.rs +++ b/operators/src/rope/common_cpu/mod.rs @@ -1,11 +1,38 @@ -use super::{args::Meta, fill_pos, Args, Rope, Seq, SinCosTable}; +use std::ptr::null; + +use super::{args::Meta, args::RopeType as R, fill_pos, Args, Rope, Seq, SinCosTable}; use crate::{ common_cpu::Cpu, get_static, strides_not_support, ByteOf, LaunchError, QueueAlloc, SchemeError, Unsigned, }; use digit_layout::{types as ty, DigitLayout}; use half::f16; +#[derive(Copy, Clone)] +enum NtkPartsType { + None, + Yarn, +} +#[derive(Copy, Clone)] +enum SchemeType { + Rope { + s: f32, + }, + Long { + long: *const T, + short: *const T, + s: f32, + origin_pos: u32, + }, + #[allow(dead_code)] + NtkParts { + alpha: f32, + beta: f32, + l0: f32, + s: f32, + ntktype: NtkPartsType, + }, +} pub struct Operator; impl Rope for Operator { @@ -78,6 +105,7 @@ impl crate::Operator for Operator { p_layout, p_base, theta, + rope_type, .. } = args; let &[_, nh, dh] = t_layout.shape() else { @@ -99,33 +127,130 @@ impl crate::Operator for Operator { return Err(strides_not_support("").into()); } - macro_rules! calculate { - ($t:ty, $p:ty) => { - Scheme::<$t, $p> { - nt, - nh, - dh, - st, - sh, - sp, - theta: *theta, - t_base: t_base.cast(), - p_base: p_base.cast(), + match rope_type { + R::Rope | R::Dyn { .. } | R::Ntk { .. } | R::Pi { .. } => { + let (theta, s) = match rope_type { + R::Rope => (*theta, 1.), + R::Dyn { s, a } => (theta * (a * s - a + 1.), 1.), + R::Ntk { s } => (theta * s, 1.), + R::Pi { s } => (*theta, *s), + _ => unreachable!(), + }; + macro_rules! calculate { + ($t:ty, $p:ty) => { + Scheme::<$t, $p> { + nt, + nh, + dh, + st, + sh, + sp, + theta, + t_base: t_base.cast(), + p_base: p_base.cast(), + scheme_type: SchemeType::Rope { s }, + } + .calculate() + }; } - .calculate() - }; - } - use digit_layout::types as ty; - match (dt_t, dt_p) { - (ty::F16, ty::U32) => calculate!(f16, u32), - (ty::F16, ty::U64) => calculate!(f16, u64), - (ty::F32, ty::U32) => calculate!(f32, u32), - (ty::F32, ty::U64) => calculate!(f32, u64), - (ty::F64, ty::U32) => calculate!(f64, u32), - (ty::F64, ty::U64) => calculate!(f64, u64), - _ => todo!(), + use digit_layout::types as ty; + match (dt_t, dt_p) { + (ty::F16, ty::U32) => calculate!(f16, u32), + (ty::F16, ty::U64) => calculate!(f16, u64), + (ty::F32, ty::U32) => calculate!(f32, u32), + (ty::F32, ty::U64) => calculate!(f32, u64), + (ty::F64, ty::U32) => calculate!(f64, u32), + (ty::F64, ty::U64) => calculate!(f64, u64), + _ => todo!(), + } + } + R::Long { + long, + short, + max_pos, + origin_pos, + } => { + let s = 1.0 + + ((*max_pos as f32 / *origin_pos as f32).ln() / (*origin_pos as f32).ln()) + .sqrt(); + macro_rules! calculate { + ($t:ty, $p:ty) => { + Scheme::<$t, $p> { + nt, + nh, + dh, + st, + sh, + sp, + theta: *theta, + t_base: t_base.cast(), + p_base: p_base.cast(), + scheme_type: SchemeType::Long { + long: long.cast(), + short: short.cast(), + s, + origin_pos: *origin_pos, + }, + } + .calculate() + }; + } + + use digit_layout::types as ty; + match (dt_t, dt_p) { + (ty::F16, ty::U32) => calculate!(f16, u32), + (ty::F16, ty::U64) => calculate!(f16, u64), + (ty::F32, ty::U32) => calculate!(f32, u32), + (ty::F32, ty::U64) => calculate!(f32, u64), + (ty::F64, ty::U32) => calculate!(f64, u32), + (ty::F64, ty::U64) => calculate!(f64, u64), + _ => todo!(), + } + } + R::Yarn { alpha, beta, l0, s } | R::NtkParts { alpha, beta, l0, s } => { + let ntktype = match rope_type { + R::NtkParts { .. } => NtkPartsType::None, + R::Yarn { .. } => NtkPartsType::Yarn, + _ => unreachable!(), + }; + macro_rules! calculate { + ($t:ty, $p:ty) => { + Scheme::<$t, $p> { + nt, + nh, + dh, + st, + sh, + sp, + theta: *theta, + t_base: t_base.cast(), + p_base: p_base.cast(), + scheme_type: SchemeType::NtkParts { + alpha: *alpha, + beta: *beta, + l0: *l0, + s: *s, + ntktype, + }, + } + .calculate() + }; + } + + use digit_layout::types as ty; + match (dt_t, dt_p) { + (ty::F16, ty::U32) => calculate!(f16, u32), + (ty::F16, ty::U64) => calculate!(f16, u64), + (ty::F32, ty::U32) => calculate!(f32, u32), + (ty::F32, ty::U64) => calculate!(f32, u64), + (ty::F64, ty::U32) => calculate!(f64, u32), + (ty::F64, ty::U64) => calculate!(f64, u64), + _ => todo!(), + } + } } + Ok(()) } } @@ -142,15 +267,15 @@ struct Scheme { theta: f32, t_base: *mut A, p_base: *const P, + scheme_type: SchemeType, } unsafe impl Send for Scheme {} unsafe impl Sync for Scheme {} - /// 激活值。 trait Activation: Sized { /// 激活值类型决定计算类型。 - type Calculation; + type Calculation: Copy; /// 计算流程。 fn calculate(pair: &mut [Self; 2], sin: Self::Calculation, cos: Self::Calculation); } @@ -187,15 +312,69 @@ impl Activation for f64 { } trait Position { - fn freq_sin_cos(self, k: isize, dh: isize, theta: f32) -> (Calculation, Calculation); + fn freq_sin_cos_rope( + self, + k: isize, + dh: isize, + theta: f32, + s: f32, + ) -> (Calculation, Calculation); + fn freq_sin_cos_long( + self, + k: isize, + dh: isize, + t: f32, + f: Calculation, + s: f32, + ) -> (Calculation, Calculation); + #[allow(clippy::too_many_arguments)] + fn freq_sin_cos_ntk_part( + self, + k: isize, + dh: isize, + theta: f32, + alpha: f32, + beta: f32, + l0: f32, + s: f32, + ntktype: NtkPartsType, + ) -> (Calculation, Calculation); } macro_rules! impl_position { ($a:ty) => { impl Position<$a> for T { #[inline] - fn freq_sin_cos(self, k: isize, dh: isize, theta: f32) -> ($a, $a) { - (self.val() as $a / (theta as $a).powf(k as $a / dh as $a)).sin_cos() + fn freq_sin_cos_rope(self, k: isize, dh: isize, theta: f32, s: f32) -> ($a, $a) { + (self.val() as $a * s as $a * (theta as $a).powf(k as $a / dh as $a).recip()) + .sin_cos() + } + #[inline] + fn freq_sin_cos_long(self, k: isize, dh: isize, t: f32, f: $a, s: f32) -> ($a, $a) { + let (sin, cos) = + (self.val() as $a * (t as $a).powf(k as $a / dh as $a).recip() * f).sin_cos(); + (sin * s as $a, cos * s as $a) + } + #[inline] + fn freq_sin_cos_ntk_part( + self, + k: isize, + dh: isize, + theta: f32, + alpha: f32, + beta: f32, + l0: f32, + s: f32, + ntktype: NtkPartsType, + ) -> ($a, $a) { + use std::f32::consts::PI; + let pos = match ntktype { + NtkPartsType::None => self.val() as $a, + NtkPartsType::Yarn => self.val() as $a * (0.1 * s.ln() + 1.) as $a, + }; + let theta = theta.powf(k as f32 / dh as f32).recip(); + let r = ((l0 / (2. * PI / theta) - alpha) / (beta - alpha)).clamp(0., 1.); + (pos * ((1. - r) / s + r) as $a * theta as $a).sin_cos() } } }; @@ -206,8 +385,8 @@ impl_position!(f64); impl Scheme where - A: Activation, - P: Position + Sync + Copy, + A: Activation + Copy, + P: Position + Sync + Copy + Unsigned, { fn calculate(&self) { let &Self { @@ -220,6 +399,7 @@ where theta, t_base, p_base, + scheme_type, } = self; let nt = nt as isize; let nh = nh as isize; @@ -229,10 +409,38 @@ where for i in 0..nt { let t = unsafe { t_base.byte_offset(i * st).cast::<[A; 2]>() }; let p = unsafe { *p_base.byte_offset(i * sp) }; + let factor = match scheme_type { + SchemeType::Long { + long, + short, + origin_pos, + .. + } => unsafe { + if p.val() < origin_pos as usize { + short.byte_offset(i * st).cast() + } else { + long.byte_offset(i * st).cast() + } + }, + _ => null(), + }; for j in 0..nh { for k in 0..dh { let pair = unsafe { &mut *t.byte_offset(j * sh + k * sd) }; - let (sin, cos) = p.freq_sin_cos(k, dh, theta); + let (sin, cos) = match scheme_type { + SchemeType::Rope { s } => p.freq_sin_cos_rope(k, dh, theta, s), + SchemeType::Long { s, .. } => { + let factor = unsafe { *factor }; + p.freq_sin_cos_long(k, dh, theta, factor, s) + } + SchemeType::NtkParts { + alpha, + beta, + l0, + s, + ntktype, + } => p.freq_sin_cos_ntk_part(k, dh, theta, alpha, beta, l0, s, ntktype), + }; A::calculate(pair, sin, cos) } } diff --git a/operators/src/rope/cuda/mod.rs b/operators/src/rope/cuda/mod.rs index 9d457120..6ac010c7 100644 --- a/operators/src/rope/cuda/mod.rs +++ b/operators/src/rope/cuda/mod.rs @@ -184,7 +184,7 @@ extern "C" __global__ void {POS_U64}( #[cfg(test)] mod test { use super::{Args, Gpu, Operator, POS_U32, POS_U64}; - use crate::{Hardware, Operator as _, TensorLayout}; + use crate::{rope::args, Hardware, Operator as _, TensorLayout}; use digit_layout::{ types::{F16, F64, U32}, DigitLayout, @@ -203,6 +203,7 @@ mod test { cos_layout: TensorLayout::new_dyn(dt_t, &[dyn_(); 2], &[dyn_(); 2]), cos_base: null(), theta: 0., + rope_type: args::RopeType::Rope, } } @@ -227,6 +228,7 @@ mod test { cos_layout: TensorLayout::new_contiguous(dt_t, &[0, dh]), cos_base: null(), theta, + rope_type: args::RopeType::Rope, } } diff --git a/operators/src/rope/mod.rs b/operators/src/rope/mod.rs index 1576c069..45a5badc 100644 --- a/operators/src/rope/mod.rs +++ b/operators/src/rope/mod.rs @@ -8,7 +8,7 @@ pub mod infini; pub mod opencl; mod args; -pub use args::Args; +pub use args::{Args, RopeType}; crate::op_trait! { Rope /// 生成 sincos 表([2, n, dh])。 diff --git a/operators/src/scale/args.rs b/operators/src/scale/args.rs index 21951926..9e395148 100644 --- a/operators/src/scale/args.rs +++ b/operators/src/scale/args.rs @@ -19,11 +19,7 @@ pub struct Args { } impl Args { - pub fn new_null( - c_layout: TensorLayout, - a_layout: TensorLayout, - b_layout: TensorLayout, - ) -> Self { + pub fn new_null(c_layout: TensorLayout, a_layout: TensorLayout) -> Self { Self { c_layout, c_base: null_mut(), @@ -97,6 +93,7 @@ impl Scheme { use Ordering::Equal as Eq; match c0.abs().cmp(&c1.abs()) { Eq => match a0.abs().cmp(&a1.abs()) { + Eq => d0.cmp(&d1), ord => ord.reverse(), }, ord => ord.reverse(), @@ -120,7 +117,7 @@ impl Scheme { { let (idx, tail) = layout.split_at_mut(1 + ndim); let (c_, tail) = tail.split_at_mut(ndim); - let (a_, b_) = tail.split_at_mut(ndim); + let (a_, _b) = tail.split_at_mut(ndim); for (Dim { d, c, a }, idx, c_, a_) in izip!(dims.into_iter().filter(|d| d.d != 1), &mut *idx, c_, a_) { diff --git a/operators/src/scale/cuda/mod.rs b/operators/src/scale/cuda/mod.rs index 79eb00c9..9acec9fb 100644 --- a/operators/src/scale/cuda/mod.rs +++ b/operators/src/scale/cuda/mod.rs @@ -1,16 +1,5 @@ -use super::{args::Scheme, Args, Scale}; -use crate::{ - cuda::{dt_name, Gpu, Handle, ModuleBox}, - shape_not_support, strides_not_support, - utils::{gcd, type_distinct}, - ByteOf, LaunchError, QueueAlloc, SchemeDiversity, SchemeError, -}; -use digit_layout::DigitLayout; -use lru::LruCache; -use std::{ - ffi::{c_uint, CString}, - sync::{Arc, Mutex}, -}; +use super::{Args, Scale}; +use crate::{cuda::Gpu, ByteOf, LaunchError, QueueAlloc, SchemeError}; pub struct Operator {} impl Scale for Operator {} @@ -20,17 +9,16 @@ impl crate::Operator for Operator { type TopoNode = Gpu; type Args = Args; - fn new(node: &Self::TopoNode) -> Self { + fn new(_node: &Self::TopoNode) -> Self { Self {} } #[inline] fn scheme( &mut self, - args: &Self::Args, + _args: &Self::Args, _max_workspace_size: usize, ) -> Result { - todo!(); Ok(0) } @@ -43,7 +31,6 @@ impl crate::Operator for Operator { where QA: QueueAlloc, { - todo!(); Ok(()) } } From 6989b925e32b9768842beff653ea583f45af7b1b Mon Sep 17 00:00:00 2001 From: onenewcode Date: Mon, 24 Feb 2025 04:57:23 +0000 Subject: [PATCH 4/7] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0mmla=E7=9A=84?= =?UTF-8?q?=E5=90=B8=E6=94=B6=E7=89=88=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- operators/src/attention_mla/args.rs | 125 +++++++++++++ operators/src/attention_mla/common_cpu.rs | 1 + operators/src/attention_mla/cuda.rs | 1 + operators/src/attention_mla/infini.rs | 1 + operators/src/attention_mla/mod.rs | 26 +++ operators/src/attention_mla/opencl.rs | 1 + operators/src/attention_mla/operator.rs | 203 ++++++++++++++++++++++ operators/src/lib.rs | 1 + operators/src/rope/common_cpu/mod.rs | 3 +- 9 files changed, 360 insertions(+), 2 deletions(-) create mode 100644 operators/src/attention_mla/args.rs create mode 100644 operators/src/attention_mla/common_cpu.rs create mode 100644 operators/src/attention_mla/cuda.rs create mode 100644 operators/src/attention_mla/infini.rs create mode 100644 operators/src/attention_mla/mod.rs create mode 100644 operators/src/attention_mla/opencl.rs create mode 100644 operators/src/attention_mla/operator.rs diff --git a/operators/src/attention_mla/args.rs b/operators/src/attention_mla/args.rs new file mode 100644 index 00000000..d85bbc96 --- /dev/null +++ b/operators/src/attention_mla/args.rs @@ -0,0 +1,125 @@ +use crate::{ + dyn_, + fuesd_softmax::AttnMask, + utils::{dim_distinct, rank_error, type_distinct}, + ConstPtr, Hardware, MaybeDyn, MutPtr, SchemeError, TensorLayout, +}; +use digit_layout::DigitLayout; +use std::ptr::{null, null_mut}; + +pub struct Args { + // q传入的是是吸收后的 + pub q_layout: TensorLayout, + pub q_base: MutPtr, + + pub kv_layout: TensorLayout, + pub kv_base: ConstPtr, + + pub absorb_layout: TensorLayout, + pub absorb_base: ConstPtr, + + pub qr_layout: TensorLayout, + pub qr_base: ConstPtr, + + pub kr_layout: TensorLayout, + pub kr_base: ConstPtr, + + pub o_layout: TensorLayout, + pub o_base: MutPtr, + + pub mask: AttnMask, +} + +pub(super) struct Meta { + pub dt: DigitLayout, + pub nh: MaybeDyn, + pub seq: MaybeDyn, + pub att: MaybeDyn, + pub dkv: MaybeDyn, + pub dv: MaybeDyn, + pub dr: MaybeDyn, +} + +impl Args { + #[allow(clippy::too_many_arguments)] + pub(crate) fn new_null( + mask: AttnMask, + dt: DigitLayout, + nh: MaybeDyn, + dkv: MaybeDyn, + seq: MaybeDyn, + att: MaybeDyn, + dv: MaybeDyn, + dr: MaybeDyn, + ) -> Self { + let q_layout = TensorLayout::new_dyn(dt, &[nh, seq, dkv], &[dyn_(); 3]); + let kv_layout = TensorLayout::new_dyn(dt, &[nh, att, dkv], &[dyn_(); 3]); + let absorb_layout = TensorLayout::new_dyn(dt, &[nh, dv, dkv], &[dyn_(); 3]); + let qr_layout = TensorLayout::new_dyn(dt, &[nh, seq, dr], &[dyn_(); 3]); + let kr_layout = TensorLayout::new_dyn(dt, &[nh, att, dr], &[dyn_(); 3]); + let o_layout = TensorLayout::new_dyn(dt, &[nh, seq, dv], &[dyn_(); 3]); + Self { + q_layout, + q_base: null_mut(), + kv_layout, + kv_base: null(), + absorb_layout, + absorb_base: null(), + qr_layout, + qr_base: null(), + kr_layout, + kr_base: null(), + o_layout, + o_base: null_mut(), + mask, + } + } + + pub(super) fn meta(&self) -> Result { + let Self { + q_layout, + kv_layout, + absorb_layout, + qr_layout, + kr_layout, + o_layout, + .. + } = self; + + let &[nh_q, seq_q, dkv_q] = q_layout.shape() else { + return Err(rank_error("q", 3, q_layout.ndim())); + }; + + let &[nh_kv, attn_kv, dkv_kv] = kv_layout.shape() else { + return Err(rank_error("kv", 3, kv_layout.ndim())); + }; + let &[nh_a, dv_a, dkv_a] = absorb_layout.shape() else { + return Err(rank_error("absorb", 3, absorb_layout.ndim())); + }; + let &[nh_qr, seq_qr, dr_qr] = qr_layout.shape() else { + return Err(rank_error("qr", 3, qr_layout.ndim())); + }; + let &[nh_kr, att_kr, dr_kr] = kr_layout.shape() else { + return Err(rank_error("kr", 3, kr_layout.ndim())); + }; + let &[nh_o, seq_o, dv_o] = o_layout.shape() else { + return Err(rank_error("o", 3, o_layout.ndim())); + }; + + Ok(Meta { + dt: type_distinct(&[ + q_layout.dt(), + kv_layout.dt(), + qr_layout.dt(), + kr_layout.dt(), + o_layout.dt(), + ])?, + nh: dim_distinct(&[nh_q, nh_kv, nh_a, nh_qr, nh_kr, nh_o])?, + seq: dim_distinct(&[seq_q, seq_o, seq_qr])?, + att: dim_distinct(&[attn_kv, att_kr])?, + dkv: dim_distinct(&[dkv_a, dkv_kv, dkv_q])?, + dv: dim_distinct(&[dv_a, dv_o])?, + dr: dim_distinct(&[dr_kr, dr_qr])?, + }) + } +} diff --git a/operators/src/attention_mla/common_cpu.rs b/operators/src/attention_mla/common_cpu.rs new file mode 100644 index 00000000..cf59d751 --- /dev/null +++ b/operators/src/attention_mla/common_cpu.rs @@ -0,0 +1 @@ +impl_op!(common_cpu, Cpu); diff --git a/operators/src/attention_mla/cuda.rs b/operators/src/attention_mla/cuda.rs new file mode 100644 index 00000000..94ef4e21 --- /dev/null +++ b/operators/src/attention_mla/cuda.rs @@ -0,0 +1 @@ +impl_op!(cuda, Gpu); diff --git a/operators/src/attention_mla/infini.rs b/operators/src/attention_mla/infini.rs new file mode 100644 index 00000000..a45f3e13 --- /dev/null +++ b/operators/src/attention_mla/infini.rs @@ -0,0 +1 @@ +impl_op!(infini, Device); diff --git a/operators/src/attention_mla/mod.rs b/operators/src/attention_mla/mod.rs new file mode 100644 index 00000000..03061907 --- /dev/null +++ b/operators/src/attention_mla/mod.rs @@ -0,0 +1,26 @@ +mod args; +mod operator; + +pub use args::Args; + +crate::op_trait!(AttentionMLA); + +macro_rules! impl_op { + ($dev:ident, $proc:ident) => { + pub type Operator = super::operator::Operator< + crate::$dev::$proc, + crate::mat_mul::$dev::Operator, + crate::fuesd_softmax::$dev::Operator, + crate::rearrange::$dev::Operator, + >; + }; +} + +#[cfg(any(use_cpu, test))] +pub mod common_cpu; +#[cfg(use_cuda)] +pub mod cuda; +#[cfg(use_infini)] +pub mod infini; +#[cfg(use_cl)] +pub mod opencl; diff --git a/operators/src/attention_mla/opencl.rs b/operators/src/attention_mla/opencl.rs new file mode 100644 index 00000000..cd4b7d79 --- /dev/null +++ b/operators/src/attention_mla/opencl.rs @@ -0,0 +1 @@ +impl_op!(opencl, ClDevice); diff --git a/operators/src/attention_mla/operator.rs b/operators/src/attention_mla/operator.rs new file mode 100644 index 00000000..fefd7add --- /dev/null +++ b/operators/src/attention_mla/operator.rs @@ -0,0 +1,203 @@ +use super::{args::Meta, Args, AttentionMLA}; +use crate::{ + dyn_, fuesd_softmax, get_static, mat_mul, rearrange, ByteOf, Hardware, LaunchError, QueueAlloc, + SchemeError, TensorLayout, Workspace, WorkspaceCollector, +}; +use ndarray_layout::ArrayLayout; +use std::marker::PhantomData; + +pub struct Operator { + mat_mul: MatMul, + softmax: Softmax, + rearrange: Rearrange, + _phantom: PhantomData, +} + +impl AttentionMLA for Operator +where + H: Hardware, + M: mat_mul::MatMul, + S: fuesd_softmax::FusedSoftmax, + R: rearrange::Rearrange, +{ +} + +impl crate::Operator for Operator +where + H: Hardware, + M: mat_mul::MatMul, + S: fuesd_softmax::FusedSoftmax, + R: rearrange::Rearrange, +{ + type Hardware = H; + type TopoNode = H; + type Args = Args; + + fn new(node: &Self::TopoNode) -> Self { + Self { + mat_mul: M::new(node), + softmax: S::new(node), + rearrange: R::new(node), + _phantom: PhantomData, + } + } + + fn scheme( + &mut self, + args: &Self::Args, + max_workspace_size: usize, + ) -> Result { + // TODO + Ok(0) + } + + fn launch( + &self, + args: &Self::Args, + workspace: &mut [ByteOf], + queue_alloc: &QA, + ) -> Result<(), LaunchError> + where + QA: QueueAlloc, + { + let Meta { + dt, + nh, + seq, + att, + dkv, + dv, + dr, + } = args.meta()?; + let Args { + q_layout, + q_base, + kv_layout, + kv_base, + absorb_layout, + absorb_base, + qr_layout, + qr_base, + kr_layout, + kr_base, + o_layout, + o_base, + mask, + } = args; + + let &[nh_skv, att_skv, dkv_skv] = kv_layout.strides() else { + unreachable!() + }; + let &[nh_skr, att_skr, dr_skr] = kr_layout.strides() else { + unreachable!() + }; + let &[nh_sa, dv_sa, dkv_sa] = absorb_layout.strides() else { + unreachable!() + }; + let &[nh_so, seq_so, dv_so] = o_layout.strides() else { + unreachable!() + }; + let ele = dt.nbytes(); + get_static! { + nh seq dkv dr + nh_skv att_skv dkv_skv + nh_skr att_skr dr_skr + nh_sa dv_sa dkv_sa + nh_so seq_so dv_so + dv att + }; + + #[inline(always)] + fn layout(shape: [usize; 3], strides: [isize; 3]) -> ArrayLayout<3> { + ArrayLayout::new(&shape, &strides, 0) + } + let kv_first_layout = layout([nh, att, dkv], [nh_skv, att_skv, dkv_skv]).transpose(&[2, 1]); + let kr_layout = layout([nh, att, dr], [nh_skr, att_skr, dr_skr]).transpose(&[2, 1]); + let a_layout = layout([nh, dv, dkv], [nh_sa, dv_sa, dkv_sa]).transpose(&[2, 1]); + let att_w_layout = TensorLayout::new_contiguous(dt, &[nh, seq, att]); + let attn_t_layout = TensorLayout::new_contiguous(dt, &[nh, seq, dkv]); + let att_w_size = nh * seq * att * ele; + let att_t_size = nh * seq * dkv * ele; + let mut workspace = Workspace::new(queue_alloc, workspace, att_w_size + att_t_size); + let (att_w_buf, workspace) = workspace.split_at_mut(att_w_size); + let (attn_t_buf, workspace) = workspace.split_at_mut(att_t_size); + + let kv_first_layout = + TensorLayout::new(dt, kv_first_layout.shape(), kv_first_layout.strides()); + let kr_layout = TensorLayout::new(dt, kr_layout.shape(), kr_layout.strides()); + let a_layout = TensorLayout::new(dt, a_layout.shape(), a_layout.strides()); + // att_w = qr*kr^T + q*kv^T + self.mat_mul.launch( + &mat_mul::Args { + c_layout: att_w_layout.clone(), + c_base: att_w_buf.as_mut_ptr(), + beta: 0., + a_layout: qr_layout.clone(), + a_base: *qr_base, + b_layout: kr_layout.clone(), + b_base: *kr_base, + alpha: ((dv + dr) as f32).sqrt().recip(), + }, + workspace, + queue_alloc, + )?; + + self.mat_mul.launch( + &mat_mul::Args { + c_layout: att_w_layout.clone(), + c_base: att_w_buf.as_mut_ptr(), + beta: 1., + a_layout: q_layout.clone(), + a_base: *q_base, + b_layout: kv_first_layout.clone(), + b_base: *kv_base, + alpha: ((dv + dr) as f32).sqrt().recip(), + }, + workspace, + queue_alloc, + )?; + // att_w = softmax(att) + self.softmax.launch( + &fuesd_softmax::Args { + att_mask: *mask, + att_layout: att_w_layout.clone(), + att_base: att_w_buf.as_mut_ptr(), + }, + workspace, + queue_alloc, + )?; + // attn_t=att_o*kv + self.mat_mul.launch( + &mat_mul::Args { + c_layout: attn_t_layout.clone(), + c_base: attn_t_buf.as_mut_ptr(), + beta: 0., + a_layout: att_w_layout.clone(), + a_base: att_w_buf.as_ptr(), + b_layout: kv_layout.clone(), + b_base: *kv_base, + alpha: 1., + }, + workspace, + queue_alloc, + )?; + + // attn =attn_t*absorb^T + self.mat_mul.launch( + &mat_mul::Args { + c_layout: o_layout.clone(), + c_base: *o_base, + beta: 0., + a_layout: attn_t_layout.clone(), + a_base: attn_t_buf.as_ptr(), + b_layout: a_layout.clone(), + b_base: *absorb_base, + alpha: 1., + }, + workspace, + queue_alloc, + )?; + + Ok(()) + } +} diff --git a/operators/src/lib.rs b/operators/src/lib.rs index 3c86f3ac..2011da58 100644 --- a/operators/src/lib.rs +++ b/operators/src/lib.rs @@ -8,6 +8,7 @@ pub mod add_rows; pub mod all_reduce; pub mod attention; pub mod attention_kv_cached; +pub mod attention_mla; pub mod broadcast; pub mod conv; pub mod fuesd_softmax; diff --git a/operators/src/rope/common_cpu/mod.rs b/operators/src/rope/common_cpu/mod.rs index 2f52e34f..9e0cb690 100644 --- a/operators/src/rope/common_cpu/mod.rs +++ b/operators/src/rope/common_cpu/mod.rs @@ -1,5 +1,3 @@ -use std::ptr::null; - use super::{args::Meta, args::RopeType as R, fill_pos, Args, Rope, Seq, SinCosTable}; use crate::{ common_cpu::Cpu, get_static, strides_not_support, ByteOf, LaunchError, QueueAlloc, SchemeError, @@ -7,6 +5,7 @@ use crate::{ }; use digit_layout::{types as ty, DigitLayout}; use half::f16; +use std::ptr::null; #[derive(Copy, Clone)] enum NtkPartsType { None, From ab2c19647495e250ee340ecadc81b2793d470843 Mon Sep 17 00:00:00 2001 From: onenewcode Date: Mon, 24 Feb 2025 16:08:51 +0800 Subject: [PATCH 5/7] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0mla=5Fcache?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- operators/src/attention_mla/operator.rs | 11 +- operators/src/attention_mla_cached/args.rs | 103 ++++++++++ .../src/attention_mla_cached/common_cpu.rs | 1 + operators/src/attention_mla_cached/cuda.rs | 1 + operators/src/attention_mla_cached/infini.rs | 1 + operators/src/attention_mla_cached/mod.rs | 25 +++ operators/src/attention_mla_cached/opencl.rs | 1 + .../src/attention_mla_cached/operator.rs | 183 ++++++++++++++++++ operators/src/lib.rs | 1 + 9 files changed, 320 insertions(+), 7 deletions(-) create mode 100644 operators/src/attention_mla_cached/args.rs create mode 100644 operators/src/attention_mla_cached/common_cpu.rs create mode 100644 operators/src/attention_mla_cached/cuda.rs create mode 100644 operators/src/attention_mla_cached/infini.rs create mode 100644 operators/src/attention_mla_cached/mod.rs create mode 100644 operators/src/attention_mla_cached/opencl.rs create mode 100644 operators/src/attention_mla_cached/operator.rs diff --git a/operators/src/attention_mla/operator.rs b/operators/src/attention_mla/operator.rs index fefd7add..26064967 100644 --- a/operators/src/attention_mla/operator.rs +++ b/operators/src/attention_mla/operator.rs @@ -1,7 +1,7 @@ use super::{args::Meta, Args, AttentionMLA}; use crate::{ - dyn_, fuesd_softmax, get_static, mat_mul, rearrange, ByteOf, Hardware, LaunchError, QueueAlloc, - SchemeError, TensorLayout, Workspace, WorkspaceCollector, + fuesd_softmax, get_static, mat_mul, rearrange, ByteOf, Hardware, LaunchError, QueueAlloc, + SchemeError, TensorLayout, Workspace, }; use ndarray_layout::ArrayLayout; use std::marker::PhantomData; @@ -94,16 +94,13 @@ where let &[nh_sa, dv_sa, dkv_sa] = absorb_layout.strides() else { unreachable!() }; - let &[nh_so, seq_so, dv_so] = o_layout.strides() else { - unreachable!() - }; + let ele = dt.nbytes(); get_static! { nh seq dkv dr nh_skv att_skv dkv_skv nh_skr att_skr dr_skr nh_sa dv_sa dkv_sa - nh_so seq_so dv_so dv att }; @@ -141,7 +138,7 @@ where workspace, queue_alloc, )?; - + self.mat_mul.launch( &mat_mul::Args { c_layout: att_w_layout.clone(), diff --git a/operators/src/attention_mla_cached/args.rs b/operators/src/attention_mla_cached/args.rs new file mode 100644 index 00000000..c1209e6e --- /dev/null +++ b/operators/src/attention_mla_cached/args.rs @@ -0,0 +1,103 @@ +use crate::{ + fuesd_softmax::AttnMask, + utils::{dim_distinct, rank_error, type_distinct}, + ConstPtr, Hardware, MaybeDyn, MutPtr, SchemeError, TensorLayout, +}; +use digit_layout::DigitLayout; + +pub struct Args { + // q传入的是是吸收后的 + pub q_layout: TensorLayout, + pub q_base: MutPtr, + + pub kv_layout: TensorLayout, + pub kv_base: ConstPtr, + + pub absorb_layout: TensorLayout, + pub absorb_base: ConstPtr, + + pub qr_layout: TensorLayout, + pub qr_base: ConstPtr, + + pub kr_layout: TensorLayout, + pub kr_base: ConstPtr, + + pub o_layout: TensorLayout, + pub o_base: MutPtr, + pub kv_cache_layout: TensorLayout, + pub kv_cache_base: MutPtr, + + pub kr_cache_layout: TensorLayout, + pub kr_cache_base: MutPtr, + + pub mask: AttnMask, + pub pos: MaybeDyn, +} + +pub(super) struct Meta { + pub dt: DigitLayout, + pub nh: MaybeDyn, + pub seq: MaybeDyn, + pub att: MaybeDyn, + pub dkv: MaybeDyn, + pub dv: MaybeDyn, + pub dr: MaybeDyn, +} + +impl Args { + pub(super) fn meta(&self) -> Result { + let Self { + q_layout, + kv_layout, + absorb_layout, + qr_layout, + kr_layout, + o_layout, + kv_cache_layout, + kr_cache_layout, + .. + } = self; + + let &[nh_q, seq_q, dkv_q] = q_layout.shape() else { + return Err(rank_error("q", 3, q_layout.ndim())); + }; + + let &[nh_kv, attn_kv, dkv_kv] = kv_layout.shape() else { + return Err(rank_error("kv", 3, kv_layout.ndim())); + }; + let &[nh_a, dv_a, dkv_a] = absorb_layout.shape() else { + return Err(rank_error("absorb", 3, absorb_layout.ndim())); + }; + let &[nh_qr, seq_qr, dr_qr] = qr_layout.shape() else { + return Err(rank_error("qr", 3, qr_layout.ndim())); + }; + let &[nh_kr, att_kr, dr_kr] = kr_layout.shape() else { + return Err(rank_error("kr", 3, kr_layout.ndim())); + }; + let &[nh_o, seq_o, dv_o] = o_layout.shape() else { + return Err(rank_error("o", 3, o_layout.ndim())); + }; + let &[nh_kvc, _buf, dkv_kvc] = kv_cache_layout.shape() else { + return Err(rank_error("k_cache", 3, kv_cache_layout.ndim())); + }; + let &[nh_krc, _buf, dr_krc] = kr_cache_layout.shape() else { + return Err(rank_error("v_cache", 3, kr_cache_layout.ndim())); + }; + + Ok(Meta { + dt: type_distinct(&[ + q_layout.dt(), + kv_layout.dt(), + qr_layout.dt(), + kr_layout.dt(), + o_layout.dt(), + ])?, + nh: dim_distinct(&[nh_q, nh_kv, nh_a, nh_qr, nh_kr, nh_o, nh_krc, nh_kvc])?, + seq: dim_distinct(&[seq_q, seq_o, seq_qr])?, + att: dim_distinct(&[attn_kv, att_kr])?, + dkv: dim_distinct(&[dkv_a, dkv_kv, dkv_q, dkv_kvc])?, + dv: dim_distinct(&[dv_a, dv_o])?, + dr: dim_distinct(&[dr_kr, dr_qr, dr_krc])?, + }) + } +} diff --git a/operators/src/attention_mla_cached/common_cpu.rs b/operators/src/attention_mla_cached/common_cpu.rs new file mode 100644 index 00000000..cf59d751 --- /dev/null +++ b/operators/src/attention_mla_cached/common_cpu.rs @@ -0,0 +1 @@ +impl_op!(common_cpu, Cpu); diff --git a/operators/src/attention_mla_cached/cuda.rs b/operators/src/attention_mla_cached/cuda.rs new file mode 100644 index 00000000..94ef4e21 --- /dev/null +++ b/operators/src/attention_mla_cached/cuda.rs @@ -0,0 +1 @@ +impl_op!(cuda, Gpu); diff --git a/operators/src/attention_mla_cached/infini.rs b/operators/src/attention_mla_cached/infini.rs new file mode 100644 index 00000000..a45f3e13 --- /dev/null +++ b/operators/src/attention_mla_cached/infini.rs @@ -0,0 +1 @@ +impl_op!(infini, Device); diff --git a/operators/src/attention_mla_cached/mod.rs b/operators/src/attention_mla_cached/mod.rs new file mode 100644 index 00000000..74a54976 --- /dev/null +++ b/operators/src/attention_mla_cached/mod.rs @@ -0,0 +1,25 @@ +mod args; +mod operator; + +pub use args::Args; + +crate::op_trait!(AttMLACached); + +macro_rules! impl_op { + ($dev:ident, $proc:ident) => { + pub type Operator = super::operator::Operator< + crate::$dev::$proc, + crate::rearrange::$dev::Operator, + crate::attention_mla::$dev::Operator, + >; + }; +} + +#[cfg(any(use_cpu, test))] +pub mod common_cpu; +#[cfg(use_cuda)] +pub mod cuda; +#[cfg(use_infini)] +pub mod infini; +#[cfg(use_cl)] +pub mod opencl; diff --git a/operators/src/attention_mla_cached/opencl.rs b/operators/src/attention_mla_cached/opencl.rs new file mode 100644 index 00000000..cd4b7d79 --- /dev/null +++ b/operators/src/attention_mla_cached/opencl.rs @@ -0,0 +1 @@ +impl_op!(opencl, ClDevice); diff --git a/operators/src/attention_mla_cached/operator.rs b/operators/src/attention_mla_cached/operator.rs new file mode 100644 index 00000000..ce619228 --- /dev/null +++ b/operators/src/attention_mla_cached/operator.rs @@ -0,0 +1,183 @@ +use crate::attention_mla_cached::args::Meta; +use crate::attention_mla_cached::{Args, AttMLACached}; +use crate::{ + attention_mla, get_static, rearrange, shape_mismatch, ByteOf, Hardware, LaunchError, + QueueAlloc, SchemeError, TensorLayout, +}; +use ndarray_layout::ArrayLayout; +use std::marker::PhantomData; + +pub struct Operator { + rearrange: Rearrange, + attention: Attention, + _phantom: PhantomData, +} + +impl AttMLACached for Operator +where + H: Hardware, + R: rearrange::Rearrange, + A: attention_mla::AttentionMLA, +{ +} + +impl crate::Operator for Operator +where + H: Hardware, + R: rearrange::Rearrange, + A: attention_mla::AttentionMLA, +{ + type Hardware = H; + type TopoNode = H; + type Args = crate::attention_mla_cached::Args; + fn new(node: &Self::TopoNode) -> Self { + Self { + rearrange: R::new(node), + attention: A::new(node), + _phantom: PhantomData, + } + } + + fn scheme( + &mut self, + args: &Self::Args, + max_workspace_size: usize, + ) -> Result { + // TODO + Ok(0) + } + + fn launch( + &self, + args: &Self::Args, + workspace: &mut [ByteOf], + queue_alloc: &QA, + ) -> Result<(), LaunchError> + where + QA: QueueAlloc, + { + let Meta { + dt, + nh, + seq, + att, + dkv, + dv, + dr, + } = args.meta()?; + let Args { + q_layout, + q_base, + kv_layout, + kv_base, + absorb_layout, + absorb_base, + qr_layout, + qr_base, + kr_layout, + kr_base, + o_layout, + o_base, + kv_cache_layout, + kv_cache_base, + kr_cache_layout, + kr_cache_base, + mask, + pos, + } = args; + let &[nh_skv, att_skv, dkv_skv] = kv_layout.strides() else { + unreachable!() + }; + let &[nh_skr, att_skr, dr_skr] = kr_layout.strides() else { + unreachable!() + }; + let &[nh_sa, dv_sa, dkv_sa] = absorb_layout.strides() else { + unreachable!() + }; + + let &[_, buf_kv, _] = kv_cache_layout.shape() else { + unreachable!() + }; + let &[_, buf_kr, _] = kr_cache_layout.shape() else { + unreachable!() + }; + let &[nh_skvc, buf_skvc, dh_skvc] = kv_cache_layout.strides() else { + unreachable!() + }; + let &[nh_skrc, buf_skrc, dh_skrc] = kr_cache_layout.strides() else { + unreachable!() + }; + let ele = dt.nbytes(); + get_static! { + nh seq dkv dr + pos + buf_kv buf_kr + nh_skvc buf_skvc dh_skvc + nh_skrc buf_skrc dh_skrc + + }; + + // 检查 cache 容量 + let att = pos + seq; + if buf_kr < att || buf_kv < att { + return Err(shape_mismatch("Out of cache buffer").into()); + } + // 连接 kv cache + #[inline(always)] + fn layout(shape: [usize; 3], strides: [isize; 3]) -> ArrayLayout<3> { + ArrayLayout::new(&shape, &strides, 0) + } + + let kvc_layout = layout([nh, buf_kv, dkv], [nh_skvc, buf_skvc, dh_skvc]); + let krc_layout = layout([nh, buf_kr, dr], [nh_skrc, buf_skrc, dh_skrc]); + + let kv_cat = kvc_layout.slice(1, pos, 1, seq); + let kr_cat = krc_layout.slice(1, pos, 1, seq); + + self.rearrange.launch( + &rearrange::Args { + dst_layout: TensorLayout::new(dt, kv_cat.shape(), kv_cat.strides()), + dst_base: unsafe { kv_cache_base.byte_add(kv_cat.offset() as _) }, + src_layout: kv_layout.clone(), + src_base: *kv_base, + }, + workspace, + queue_alloc, + )?; + self.rearrange.launch( + &rearrange::Args { + dst_layout: TensorLayout::new(dt, kr_cat.shape(), kr_cat.strides()), + dst_base: unsafe { kr_cache_base.byte_add(kr_cat.offset() as _) }, + src_layout: kr_layout.clone(), + src_base: *kr_base, + }, + workspace, + queue_alloc, + )?; + // attention + let kv_layout = kvc_layout.slice(1, 0, 1, att); + let kr_layout = krc_layout.slice(1, 0, 1, att); + assert_eq!(kv_layout.offset(), 0); + assert_eq!(kr_layout.offset(), 0); + self.attention.launch( + &attention_mla::Args { + mask: *mask, + q_layout: q_layout.clone(), + q_base: *q_base, + kv_layout: TensorLayout::new(dt, kv_layout.shape(), kv_layout.strides()), + kv_base: *kv_cache_base, + kr_layout: TensorLayout::new(dt, kr_layout.shape(), kr_layout.strides()), + kr_base: *kr_cache_base, + absorb_layout: absorb_layout.clone(), + absorb_base: *absorb_base, + qr_layout: qr_layout.clone(), + qr_base: *qr_base, + o_layout: o_layout.clone(), + o_base: *o_base, + }, + workspace, + queue_alloc, + )?; + Ok(()) + } +} diff --git a/operators/src/lib.rs b/operators/src/lib.rs index 2011da58..48788e47 100644 --- a/operators/src/lib.rs +++ b/operators/src/lib.rs @@ -9,6 +9,7 @@ pub mod all_reduce; pub mod attention; pub mod attention_kv_cached; pub mod attention_mla; +pub mod attention_mla_cached; pub mod broadcast; pub mod conv; pub mod fuesd_softmax; From 1e4b2f5a0442550891454726fb27c449906df889 Mon Sep 17 00:00:00 2001 From: onenewcode Date: Tue, 25 Feb 2025 10:04:27 +0000 Subject: [PATCH 6/7] t --- operators/src/attention_mla_cached/operator.rs | 6 ++---- operators/src/rope/common_cpu/mod.rs | 7 ++++--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/operators/src/attention_mla_cached/operator.rs b/operators/src/attention_mla_cached/operator.rs index ce619228..24775a07 100644 --- a/operators/src/attention_mla_cached/operator.rs +++ b/operators/src/attention_mla_cached/operator.rs @@ -1,12 +1,10 @@ -use crate::attention_mla_cached::args::Meta; -use crate::attention_mla_cached::{Args, AttMLACached}; +use super::{args::Meta, Args, AttMLACached}; use crate::{ attention_mla, get_static, rearrange, shape_mismatch, ByteOf, Hardware, LaunchError, QueueAlloc, SchemeError, TensorLayout, }; use ndarray_layout::ArrayLayout; use std::marker::PhantomData; - pub struct Operator { rearrange: Rearrange, attention: Attention, @@ -29,7 +27,7 @@ where { type Hardware = H; type TopoNode = H; - type Args = crate::attention_mla_cached::Args; + type Args = Args; fn new(node: &Self::TopoNode) -> Self { Self { rearrange: R::new(node), diff --git a/operators/src/rope/common_cpu/mod.rs b/operators/src/rope/common_cpu/mod.rs index 9e0cb690..35aa5067 100644 --- a/operators/src/rope/common_cpu/mod.rs +++ b/operators/src/rope/common_cpu/mod.rs @@ -416,9 +416,9 @@ where .. } => unsafe { if p.val() < origin_pos as usize { - short.byte_offset(i * st).cast() + short } else { - long.byte_offset(i * st).cast() + long } }, _ => null(), @@ -429,7 +429,8 @@ where let (sin, cos) = match scheme_type { SchemeType::Rope { s } => p.freq_sin_cos_rope(k, dh, theta, s), SchemeType::Long { s, .. } => { - let factor = unsafe { *factor }; + // TODO 这里先默认为 f32 + let factor = unsafe { *factor.byte_offset(k * 4).cast() }; p.freq_sin_cos_long(k, dh, theta, factor, s) } SchemeType::NtkParts { From 20f9c5a02fe17f77c91f48ce43df2b1939438757 Mon Sep 17 00:00:00 2001 From: onenewcode Date: Wed, 26 Feb 2025 07:26:12 +0000 Subject: [PATCH 7/7] finish --- operators/src/rope/common_cpu/mod.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/operators/src/rope/common_cpu/mod.rs b/operators/src/rope/common_cpu/mod.rs index 35aa5067..dc11aa0d 100644 --- a/operators/src/rope/common_cpu/mod.rs +++ b/operators/src/rope/common_cpu/mod.rs @@ -323,7 +323,7 @@ trait Position { k: isize, dh: isize, t: f32, - f: Calculation, + f: f32, s: f32, ) -> (Calculation, Calculation); #[allow(clippy::too_many_arguments)] @@ -349,9 +349,9 @@ macro_rules! impl_position { .sin_cos() } #[inline] - fn freq_sin_cos_long(self, k: isize, dh: isize, t: f32, f: $a, s: f32) -> ($a, $a) { + fn freq_sin_cos_long(self, k: isize, dh: isize, t: f32, f: f32, s: f32) -> ($a, $a) { let (sin, cos) = - (self.val() as $a * (t as $a).powf(k as $a / dh as $a).recip() * f).sin_cos(); + (self.val() as $a * (t as $a).powf(k as $a / dh as $a).recip() * (f as $a).recip() ).sin_cos(); (sin * s as $a, cos * s as $a) } #[inline]