diff --git a/operators/src/attention/args.rs b/operators/src/attention/args.rs index 4c1ac43a..f2914d74 100644 --- a/operators/src/attention/args.rs +++ b/operators/src/attention/args.rs @@ -30,9 +30,11 @@ pub(super) struct Meta { pub seq: MaybeDyn, pub att: MaybeDyn, pub dh: MaybeDyn, + pub dv: MaybeDyn, } impl Args { + #[allow(clippy::too_many_arguments)] pub(crate) fn new_null( mask: AttnMask, dt: DigitLayout, @@ -41,17 +43,20 @@ impl Args { seq: MaybeDyn, att: MaybeDyn, dh: MaybeDyn, + dv: MaybeDyn, ) -> Self { - let qo_layout = TensorLayout::new_dyn(dt, &[nh, seq, dh], &[dyn_(); 3]); - let kv_layout = TensorLayout::new_dyn(dt, &[nkvh, att, dh], &[dyn_(); 3]); + let q_layout = TensorLayout::new_dyn(dt, &[nh, seq, dh], &[dyn_(); 3]); + let k_layout = TensorLayout::new_dyn(dt, &[nkvh, seq, dh], &[dyn_(); 3]); + let v_layout = TensorLayout::new_dyn(dt, &[nkvh, att, dv], &[dyn_(); 3]); + let o_layout = TensorLayout::new_dyn(dt, &[nkvh, att, dh], &[dyn_(); 3]); Self { - q_layout: qo_layout.clone(), + q_layout: q_layout.clone(), q_base: null_mut(), - k_layout: kv_layout.clone(), + k_layout: k_layout.clone(), k_base: null(), - v_layout: kv_layout, + v_layout, v_base: null(), - o_layout: qo_layout, + o_layout, o_base: null_mut(), mask, } @@ -85,7 +90,8 @@ impl Args { nkvh: dim_distinct(&[nkvh_k, nkvh_v])?, seq: dim_distinct(&[seq_q, seq_o])?, att: dim_distinct(&[att_k, att_v])?, - dh: dim_distinct(&[dh_q, dh_k, dh_v, dh_o])?, + dh: dim_distinct(&[dh_q, dh_k])?, + dv: dim_distinct(&[dh_v, dh_o])?, }) } } diff --git a/operators/src/attention/cuda.rs b/operators/src/attention/cuda.rs index 208561c3..3a67ac90 100644 --- a/operators/src/attention/cuda.rs +++ b/operators/src/attention/cuda.rs @@ -16,6 +16,7 @@ mod test { seq.into(), att.into(), dyn_(), + dyn_(), ) } diff --git a/operators/src/attention/operator.rs b/operators/src/attention/operator.rs index c1281672..aeb253da 100644 --- a/operators/src/attention/operator.rs +++ b/operators/src/attention/operator.rs @@ -1,4 +1,4 @@ -use super::{args::Meta, Args, Attention}; +use super::{args::Meta, Args, Attention}; use crate::{ dyn_, fuesd_softmax, get_static, mat_mul, rearrange, ByteOf, Hardware, LaunchError, QueueAlloc, SchemeError, TensorLayout, Workspace, WorkspaceCollector, @@ -53,6 +53,7 @@ where seq, att, dh, + dv, .. } = args.meta()?; let Args { @@ -64,11 +65,12 @@ where } = args; // 如果不能保证 nh seq att dh 已知,用任意值初始化算子 - let (Some(&nh), Some(&seq), Some(&att), Some(&dh)) = ( + let (Some(&nh), Some(&seq), Some(&att), Some(&dh), Some(&_dv)) = ( nh.get_static(), seq.get_static(), att.get_static(), dh.get_static(), + dv.get_static(), ) else { let mut wc = WorkspaceCollector::new(); @@ -149,6 +151,7 @@ where seq, att, dh, + dv, } = args.meta()?; let Args { mask, @@ -172,8 +175,8 @@ where let ele = dt.nbytes(); get_static! { nh seq dh - nh_sq seq_sq dh_sq - nkvh att + dv seq_sq dh_sq + nkvh att nh_sq nkvh_sk att_sk dh_sk }; @@ -191,7 +194,7 @@ where let (att_buf, workspace) = workspace.split_at_mut(att_size); let head_group = nh / nkvh; - let (q_layout, qx_layout, q_base) = match qx { + let (_q_layout, qx_layout, q_base) = match qx { None => { let q_layout = TensorLayout::new_contiguous(dt, &[nh, seq, dh]); let qx_layout = TensorLayout::new_contiguous(dt, &[nkvh, head_group * seq, dh]); @@ -219,6 +222,7 @@ where let k_layout = TensorLayout::new(dt, k_layout.shape(), k_layout.strides()); let att_mat_mul = TensorLayout::new_contiguous(dt, &[nkvh, head_group * seq, att]); let att_softmax = TensorLayout::new_contiguous(dt, &[nh, seq, att]); + let att_result = TensorLayout::new_contiguous(dt, &[nkvh, head_group * seq, dv]); // att = q . k^T self.mat_mul.launch( @@ -248,7 +252,7 @@ where // q = att . v self.mat_mul.launch( &mat_mul::Args { - c_layout: qx_layout.clone(), + c_layout: att_result.clone(), c_base: q_base, beta: 0., a_layout: att_mat_mul, @@ -266,7 +270,7 @@ where &rearrange::Args { dst_layout: o_layout.clone(), dst_base: *o_base, - src_layout: q_layout.clone(), + src_layout: o_layout.clone(), src_base: q_base, }, workspace, diff --git a/operators/src/attention_kv_cached/operator.rs b/operators/src/attention_kv_cached/operator.rs index 81345f4c..d0c8412e 100644 --- a/operators/src/attention_kv_cached/operator.rs +++ b/operators/src/attention_kv_cached/operator.rs @@ -1,4 +1,4 @@ -use super::{args::Meta, Args, AttnKVCached}; +use super::{args::Meta, Args, AttnKVCached}; use crate::{ attention, dyn_, get_static, rearrange, shape_mismatch, ByteOf, Hardware, LaunchError, MaybeDyn, QueueAlloc, TensorLayout, WorkspaceCollector, @@ -66,7 +66,7 @@ where }; wc.push_sub(self.attention.scheme( - &attention::Args::new_null(args.mask, dt, nh, nkvh, seq, att, dh), + &attention::Args::new_null(args.mask, dt, nh, nkvh, seq, att, dh, dh), max_workspace_size, )?); diff --git a/operators/src/attention_mla/args.rs b/operators/src/attention_mla/args.rs new file mode 100644 index 00000000..d85bbc96 --- /dev/null +++ b/operators/src/attention_mla/args.rs @@ -0,0 +1,125 @@ +use crate::{ + dyn_, + fuesd_softmax::AttnMask, + utils::{dim_distinct, rank_error, type_distinct}, + ConstPtr, Hardware, MaybeDyn, MutPtr, SchemeError, TensorLayout, +}; +use digit_layout::DigitLayout; +use std::ptr::{null, null_mut}; + +pub struct Args { + // q传入的是是吸收后的 + pub q_layout: TensorLayout, + pub q_base: MutPtr, + + pub kv_layout: TensorLayout, + pub kv_base: ConstPtr, + + pub absorb_layout: TensorLayout, + pub absorb_base: ConstPtr, + + pub qr_layout: TensorLayout, + pub qr_base: ConstPtr, + + pub kr_layout: TensorLayout, + pub kr_base: ConstPtr, + + pub o_layout: TensorLayout, + pub o_base: MutPtr, + + pub mask: AttnMask, +} + +pub(super) struct Meta { + pub dt: DigitLayout, + pub nh: MaybeDyn, + pub seq: MaybeDyn, + pub att: MaybeDyn, + pub dkv: MaybeDyn, + pub dv: MaybeDyn, + pub dr: MaybeDyn, +} + +impl Args { + #[allow(clippy::too_many_arguments)] + pub(crate) fn new_null( + mask: AttnMask, + dt: DigitLayout, + nh: MaybeDyn, + dkv: MaybeDyn, + seq: MaybeDyn, + att: MaybeDyn, + dv: MaybeDyn, + dr: MaybeDyn, + ) -> Self { + let q_layout = TensorLayout::new_dyn(dt, &[nh, seq, dkv], &[dyn_(); 3]); + let kv_layout = TensorLayout::new_dyn(dt, &[nh, att, dkv], &[dyn_(); 3]); + let absorb_layout = TensorLayout::new_dyn(dt, &[nh, dv, dkv], &[dyn_(); 3]); + let qr_layout = TensorLayout::new_dyn(dt, &[nh, seq, dr], &[dyn_(); 3]); + let kr_layout = TensorLayout::new_dyn(dt, &[nh, att, dr], &[dyn_(); 3]); + let o_layout = TensorLayout::new_dyn(dt, &[nh, seq, dv], &[dyn_(); 3]); + Self { + q_layout, + q_base: null_mut(), + kv_layout, + kv_base: null(), + absorb_layout, + absorb_base: null(), + qr_layout, + qr_base: null(), + kr_layout, + kr_base: null(), + o_layout, + o_base: null_mut(), + mask, + } + } + + pub(super) fn meta(&self) -> Result { + let Self { + q_layout, + kv_layout, + absorb_layout, + qr_layout, + kr_layout, + o_layout, + .. + } = self; + + let &[nh_q, seq_q, dkv_q] = q_layout.shape() else { + return Err(rank_error("q", 3, q_layout.ndim())); + }; + + let &[nh_kv, attn_kv, dkv_kv] = kv_layout.shape() else { + return Err(rank_error("kv", 3, kv_layout.ndim())); + }; + let &[nh_a, dv_a, dkv_a] = absorb_layout.shape() else { + return Err(rank_error("absorb", 3, absorb_layout.ndim())); + }; + let &[nh_qr, seq_qr, dr_qr] = qr_layout.shape() else { + return Err(rank_error("qr", 3, qr_layout.ndim())); + }; + let &[nh_kr, att_kr, dr_kr] = kr_layout.shape() else { + return Err(rank_error("kr", 3, kr_layout.ndim())); + }; + let &[nh_o, seq_o, dv_o] = o_layout.shape() else { + return Err(rank_error("o", 3, o_layout.ndim())); + }; + + Ok(Meta { + dt: type_distinct(&[ + q_layout.dt(), + kv_layout.dt(), + qr_layout.dt(), + kr_layout.dt(), + o_layout.dt(), + ])?, + nh: dim_distinct(&[nh_q, nh_kv, nh_a, nh_qr, nh_kr, nh_o])?, + seq: dim_distinct(&[seq_q, seq_o, seq_qr])?, + att: dim_distinct(&[attn_kv, att_kr])?, + dkv: dim_distinct(&[dkv_a, dkv_kv, dkv_q])?, + dv: dim_distinct(&[dv_a, dv_o])?, + dr: dim_distinct(&[dr_kr, dr_qr])?, + }) + } +} diff --git a/operators/src/attention_mla/common_cpu.rs b/operators/src/attention_mla/common_cpu.rs new file mode 100644 index 00000000..cf59d751 --- /dev/null +++ b/operators/src/attention_mla/common_cpu.rs @@ -0,0 +1 @@ +impl_op!(common_cpu, Cpu); diff --git a/operators/src/attention_mla/cuda.rs b/operators/src/attention_mla/cuda.rs new file mode 100644 index 00000000..94ef4e21 --- /dev/null +++ b/operators/src/attention_mla/cuda.rs @@ -0,0 +1 @@ +impl_op!(cuda, Gpu); diff --git a/operators/src/attention_mla/infini.rs b/operators/src/attention_mla/infini.rs new file mode 100644 index 00000000..a45f3e13 --- /dev/null +++ b/operators/src/attention_mla/infini.rs @@ -0,0 +1 @@ +impl_op!(infini, Device); diff --git a/operators/src/attention_mla/mod.rs b/operators/src/attention_mla/mod.rs new file mode 100644 index 00000000..03061907 --- /dev/null +++ b/operators/src/attention_mla/mod.rs @@ -0,0 +1,26 @@ +mod args; +mod operator; + +pub use args::Args; + +crate::op_trait!(AttentionMLA); + +macro_rules! impl_op { + ($dev:ident, $proc:ident) => { + pub type Operator = super::operator::Operator< + crate::$dev::$proc, + crate::mat_mul::$dev::Operator, + crate::fuesd_softmax::$dev::Operator, + crate::rearrange::$dev::Operator, + >; + }; +} + +#[cfg(any(use_cpu, test))] +pub mod common_cpu; +#[cfg(use_cuda)] +pub mod cuda; +#[cfg(use_infini)] +pub mod infini; +#[cfg(use_cl)] +pub mod opencl; diff --git a/operators/src/attention_mla/opencl.rs b/operators/src/attention_mla/opencl.rs new file mode 100644 index 00000000..cd4b7d79 --- /dev/null +++ b/operators/src/attention_mla/opencl.rs @@ -0,0 +1 @@ +impl_op!(opencl, ClDevice); diff --git a/operators/src/attention_mla/operator.rs b/operators/src/attention_mla/operator.rs new file mode 100644 index 00000000..26064967 --- /dev/null +++ b/operators/src/attention_mla/operator.rs @@ -0,0 +1,200 @@ +use super::{args::Meta, Args, AttentionMLA}; +use crate::{ + fuesd_softmax, get_static, mat_mul, rearrange, ByteOf, Hardware, LaunchError, QueueAlloc, + SchemeError, TensorLayout, Workspace, +}; +use ndarray_layout::ArrayLayout; +use std::marker::PhantomData; + +pub struct Operator { + mat_mul: MatMul, + softmax: Softmax, + rearrange: Rearrange, + _phantom: PhantomData, +} + +impl AttentionMLA for Operator +where + H: Hardware, + M: mat_mul::MatMul, + S: fuesd_softmax::FusedSoftmax, + R: rearrange::Rearrange, +{ +} + +impl crate::Operator for Operator +where + H: Hardware, + M: mat_mul::MatMul, + S: fuesd_softmax::FusedSoftmax, + R: rearrange::Rearrange, +{ + type Hardware = H; + type TopoNode = H; + type Args = Args; + + fn new(node: &Self::TopoNode) -> Self { + Self { + mat_mul: M::new(node), + softmax: S::new(node), + rearrange: R::new(node), + _phantom: PhantomData, + } + } + + fn scheme( + &mut self, + args: &Self::Args, + max_workspace_size: usize, + ) -> Result { + // TODO + Ok(0) + } + + fn launch( + &self, + args: &Self::Args, + workspace: &mut [ByteOf], + queue_alloc: &QA, + ) -> Result<(), LaunchError> + where + QA: QueueAlloc, + { + let Meta { + dt, + nh, + seq, + att, + dkv, + dv, + dr, + } = args.meta()?; + let Args { + q_layout, + q_base, + kv_layout, + kv_base, + absorb_layout, + absorb_base, + qr_layout, + qr_base, + kr_layout, + kr_base, + o_layout, + o_base, + mask, + } = args; + + let &[nh_skv, att_skv, dkv_skv] = kv_layout.strides() else { + unreachable!() + }; + let &[nh_skr, att_skr, dr_skr] = kr_layout.strides() else { + unreachable!() + }; + let &[nh_sa, dv_sa, dkv_sa] = absorb_layout.strides() else { + unreachable!() + }; + + let ele = dt.nbytes(); + get_static! { + nh seq dkv dr + nh_skv att_skv dkv_skv + nh_skr att_skr dr_skr + nh_sa dv_sa dkv_sa + dv att + }; + + #[inline(always)] + fn layout(shape: [usize; 3], strides: [isize; 3]) -> ArrayLayout<3> { + ArrayLayout::new(&shape, &strides, 0) + } + let kv_first_layout = layout([nh, att, dkv], [nh_skv, att_skv, dkv_skv]).transpose(&[2, 1]); + let kr_layout = layout([nh, att, dr], [nh_skr, att_skr, dr_skr]).transpose(&[2, 1]); + let a_layout = layout([nh, dv, dkv], [nh_sa, dv_sa, dkv_sa]).transpose(&[2, 1]); + let att_w_layout = TensorLayout::new_contiguous(dt, &[nh, seq, att]); + let attn_t_layout = TensorLayout::new_contiguous(dt, &[nh, seq, dkv]); + let att_w_size = nh * seq * att * ele; + let att_t_size = nh * seq * dkv * ele; + let mut workspace = Workspace::new(queue_alloc, workspace, att_w_size + att_t_size); + let (att_w_buf, workspace) = workspace.split_at_mut(att_w_size); + let (attn_t_buf, workspace) = workspace.split_at_mut(att_t_size); + + let kv_first_layout = + TensorLayout::new(dt, kv_first_layout.shape(), kv_first_layout.strides()); + let kr_layout = TensorLayout::new(dt, kr_layout.shape(), kr_layout.strides()); + let a_layout = TensorLayout::new(dt, a_layout.shape(), a_layout.strides()); + // att_w = qr*kr^T + q*kv^T + self.mat_mul.launch( + &mat_mul::Args { + c_layout: att_w_layout.clone(), + c_base: att_w_buf.as_mut_ptr(), + beta: 0., + a_layout: qr_layout.clone(), + a_base: *qr_base, + b_layout: kr_layout.clone(), + b_base: *kr_base, + alpha: ((dv + dr) as f32).sqrt().recip(), + }, + workspace, + queue_alloc, + )?; + + self.mat_mul.launch( + &mat_mul::Args { + c_layout: att_w_layout.clone(), + c_base: att_w_buf.as_mut_ptr(), + beta: 1., + a_layout: q_layout.clone(), + a_base: *q_base, + b_layout: kv_first_layout.clone(), + b_base: *kv_base, + alpha: ((dv + dr) as f32).sqrt().recip(), + }, + workspace, + queue_alloc, + )?; + // att_w = softmax(att) + self.softmax.launch( + &fuesd_softmax::Args { + att_mask: *mask, + att_layout: att_w_layout.clone(), + att_base: att_w_buf.as_mut_ptr(), + }, + workspace, + queue_alloc, + )?; + // attn_t=att_o*kv + self.mat_mul.launch( + &mat_mul::Args { + c_layout: attn_t_layout.clone(), + c_base: attn_t_buf.as_mut_ptr(), + beta: 0., + a_layout: att_w_layout.clone(), + a_base: att_w_buf.as_ptr(), + b_layout: kv_layout.clone(), + b_base: *kv_base, + alpha: 1., + }, + workspace, + queue_alloc, + )?; + + // attn =attn_t*absorb^T + self.mat_mul.launch( + &mat_mul::Args { + c_layout: o_layout.clone(), + c_base: *o_base, + beta: 0., + a_layout: attn_t_layout.clone(), + a_base: attn_t_buf.as_ptr(), + b_layout: a_layout.clone(), + b_base: *absorb_base, + alpha: 1., + }, + workspace, + queue_alloc, + )?; + + Ok(()) + } +} diff --git a/operators/src/attention_mla_cached/args.rs b/operators/src/attention_mla_cached/args.rs new file mode 100644 index 00000000..c1209e6e --- /dev/null +++ b/operators/src/attention_mla_cached/args.rs @@ -0,0 +1,103 @@ +use crate::{ + fuesd_softmax::AttnMask, + utils::{dim_distinct, rank_error, type_distinct}, + ConstPtr, Hardware, MaybeDyn, MutPtr, SchemeError, TensorLayout, +}; +use digit_layout::DigitLayout; + +pub struct Args { + // q传入的是是吸收后的 + pub q_layout: TensorLayout, + pub q_base: MutPtr, + + pub kv_layout: TensorLayout, + pub kv_base: ConstPtr, + + pub absorb_layout: TensorLayout, + pub absorb_base: ConstPtr, + + pub qr_layout: TensorLayout, + pub qr_base: ConstPtr, + + pub kr_layout: TensorLayout, + pub kr_base: ConstPtr, + + pub o_layout: TensorLayout, + pub o_base: MutPtr, + pub kv_cache_layout: TensorLayout, + pub kv_cache_base: MutPtr, + + pub kr_cache_layout: TensorLayout, + pub kr_cache_base: MutPtr, + + pub mask: AttnMask, + pub pos: MaybeDyn, +} + +pub(super) struct Meta { + pub dt: DigitLayout, + pub nh: MaybeDyn, + pub seq: MaybeDyn, + pub att: MaybeDyn, + pub dkv: MaybeDyn, + pub dv: MaybeDyn, + pub dr: MaybeDyn, +} + +impl Args { + pub(super) fn meta(&self) -> Result { + let Self { + q_layout, + kv_layout, + absorb_layout, + qr_layout, + kr_layout, + o_layout, + kv_cache_layout, + kr_cache_layout, + .. + } = self; + + let &[nh_q, seq_q, dkv_q] = q_layout.shape() else { + return Err(rank_error("q", 3, q_layout.ndim())); + }; + + let &[nh_kv, attn_kv, dkv_kv] = kv_layout.shape() else { + return Err(rank_error("kv", 3, kv_layout.ndim())); + }; + let &[nh_a, dv_a, dkv_a] = absorb_layout.shape() else { + return Err(rank_error("absorb", 3, absorb_layout.ndim())); + }; + let &[nh_qr, seq_qr, dr_qr] = qr_layout.shape() else { + return Err(rank_error("qr", 3, qr_layout.ndim())); + }; + let &[nh_kr, att_kr, dr_kr] = kr_layout.shape() else { + return Err(rank_error("kr", 3, kr_layout.ndim())); + }; + let &[nh_o, seq_o, dv_o] = o_layout.shape() else { + return Err(rank_error("o", 3, o_layout.ndim())); + }; + let &[nh_kvc, _buf, dkv_kvc] = kv_cache_layout.shape() else { + return Err(rank_error("k_cache", 3, kv_cache_layout.ndim())); + }; + let &[nh_krc, _buf, dr_krc] = kr_cache_layout.shape() else { + return Err(rank_error("v_cache", 3, kr_cache_layout.ndim())); + }; + + Ok(Meta { + dt: type_distinct(&[ + q_layout.dt(), + kv_layout.dt(), + qr_layout.dt(), + kr_layout.dt(), + o_layout.dt(), + ])?, + nh: dim_distinct(&[nh_q, nh_kv, nh_a, nh_qr, nh_kr, nh_o, nh_krc, nh_kvc])?, + seq: dim_distinct(&[seq_q, seq_o, seq_qr])?, + att: dim_distinct(&[attn_kv, att_kr])?, + dkv: dim_distinct(&[dkv_a, dkv_kv, dkv_q, dkv_kvc])?, + dv: dim_distinct(&[dv_a, dv_o])?, + dr: dim_distinct(&[dr_kr, dr_qr, dr_krc])?, + }) + } +} diff --git a/operators/src/attention_mla_cached/common_cpu.rs b/operators/src/attention_mla_cached/common_cpu.rs new file mode 100644 index 00000000..cf59d751 --- /dev/null +++ b/operators/src/attention_mla_cached/common_cpu.rs @@ -0,0 +1 @@ +impl_op!(common_cpu, Cpu); diff --git a/operators/src/attention_mla_cached/cuda.rs b/operators/src/attention_mla_cached/cuda.rs new file mode 100644 index 00000000..94ef4e21 --- /dev/null +++ b/operators/src/attention_mla_cached/cuda.rs @@ -0,0 +1 @@ +impl_op!(cuda, Gpu); diff --git a/operators/src/attention_mla_cached/infini.rs b/operators/src/attention_mla_cached/infini.rs new file mode 100644 index 00000000..a45f3e13 --- /dev/null +++ b/operators/src/attention_mla_cached/infini.rs @@ -0,0 +1 @@ +impl_op!(infini, Device); diff --git a/operators/src/attention_mla_cached/mod.rs b/operators/src/attention_mla_cached/mod.rs new file mode 100644 index 00000000..74a54976 --- /dev/null +++ b/operators/src/attention_mla_cached/mod.rs @@ -0,0 +1,25 @@ +mod args; +mod operator; + +pub use args::Args; + +crate::op_trait!(AttMLACached); + +macro_rules! impl_op { + ($dev:ident, $proc:ident) => { + pub type Operator = super::operator::Operator< + crate::$dev::$proc, + crate::rearrange::$dev::Operator, + crate::attention_mla::$dev::Operator, + >; + }; +} + +#[cfg(any(use_cpu, test))] +pub mod common_cpu; +#[cfg(use_cuda)] +pub mod cuda; +#[cfg(use_infini)] +pub mod infini; +#[cfg(use_cl)] +pub mod opencl; diff --git a/operators/src/attention_mla_cached/opencl.rs b/operators/src/attention_mla_cached/opencl.rs new file mode 100644 index 00000000..cd4b7d79 --- /dev/null +++ b/operators/src/attention_mla_cached/opencl.rs @@ -0,0 +1 @@ +impl_op!(opencl, ClDevice); diff --git a/operators/src/attention_mla_cached/operator.rs b/operators/src/attention_mla_cached/operator.rs new file mode 100644 index 00000000..24775a07 --- /dev/null +++ b/operators/src/attention_mla_cached/operator.rs @@ -0,0 +1,181 @@ +use super::{args::Meta, Args, AttMLACached}; +use crate::{ + attention_mla, get_static, rearrange, shape_mismatch, ByteOf, Hardware, LaunchError, + QueueAlloc, SchemeError, TensorLayout, +}; +use ndarray_layout::ArrayLayout; +use std::marker::PhantomData; +pub struct Operator { + rearrange: Rearrange, + attention: Attention, + _phantom: PhantomData, +} + +impl AttMLACached for Operator +where + H: Hardware, + R: rearrange::Rearrange, + A: attention_mla::AttentionMLA, +{ +} + +impl crate::Operator for Operator +where + H: Hardware, + R: rearrange::Rearrange, + A: attention_mla::AttentionMLA, +{ + type Hardware = H; + type TopoNode = H; + type Args = Args; + fn new(node: &Self::TopoNode) -> Self { + Self { + rearrange: R::new(node), + attention: A::new(node), + _phantom: PhantomData, + } + } + + fn scheme( + &mut self, + args: &Self::Args, + max_workspace_size: usize, + ) -> Result { + // TODO + Ok(0) + } + + fn launch( + &self, + args: &Self::Args, + workspace: &mut [ByteOf], + queue_alloc: &QA, + ) -> Result<(), LaunchError> + where + QA: QueueAlloc, + { + let Meta { + dt, + nh, + seq, + att, + dkv, + dv, + dr, + } = args.meta()?; + let Args { + q_layout, + q_base, + kv_layout, + kv_base, + absorb_layout, + absorb_base, + qr_layout, + qr_base, + kr_layout, + kr_base, + o_layout, + o_base, + kv_cache_layout, + kv_cache_base, + kr_cache_layout, + kr_cache_base, + mask, + pos, + } = args; + let &[nh_skv, att_skv, dkv_skv] = kv_layout.strides() else { + unreachable!() + }; + let &[nh_skr, att_skr, dr_skr] = kr_layout.strides() else { + unreachable!() + }; + let &[nh_sa, dv_sa, dkv_sa] = absorb_layout.strides() else { + unreachable!() + }; + + let &[_, buf_kv, _] = kv_cache_layout.shape() else { + unreachable!() + }; + let &[_, buf_kr, _] = kr_cache_layout.shape() else { + unreachable!() + }; + let &[nh_skvc, buf_skvc, dh_skvc] = kv_cache_layout.strides() else { + unreachable!() + }; + let &[nh_skrc, buf_skrc, dh_skrc] = kr_cache_layout.strides() else { + unreachable!() + }; + let ele = dt.nbytes(); + get_static! { + nh seq dkv dr + pos + buf_kv buf_kr + nh_skvc buf_skvc dh_skvc + nh_skrc buf_skrc dh_skrc + + }; + + // 检查 cache 容量 + let att = pos + seq; + if buf_kr < att || buf_kv < att { + return Err(shape_mismatch("Out of cache buffer").into()); + } + // 连接 kv cache + #[inline(always)] + fn layout(shape: [usize; 3], strides: [isize; 3]) -> ArrayLayout<3> { + ArrayLayout::new(&shape, &strides, 0) + } + + let kvc_layout = layout([nh, buf_kv, dkv], [nh_skvc, buf_skvc, dh_skvc]); + let krc_layout = layout([nh, buf_kr, dr], [nh_skrc, buf_skrc, dh_skrc]); + + let kv_cat = kvc_layout.slice(1, pos, 1, seq); + let kr_cat = krc_layout.slice(1, pos, 1, seq); + + self.rearrange.launch( + &rearrange::Args { + dst_layout: TensorLayout::new(dt, kv_cat.shape(), kv_cat.strides()), + dst_base: unsafe { kv_cache_base.byte_add(kv_cat.offset() as _) }, + src_layout: kv_layout.clone(), + src_base: *kv_base, + }, + workspace, + queue_alloc, + )?; + self.rearrange.launch( + &rearrange::Args { + dst_layout: TensorLayout::new(dt, kr_cat.shape(), kr_cat.strides()), + dst_base: unsafe { kr_cache_base.byte_add(kr_cat.offset() as _) }, + src_layout: kr_layout.clone(), + src_base: *kr_base, + }, + workspace, + queue_alloc, + )?; + // attention + let kv_layout = kvc_layout.slice(1, 0, 1, att); + let kr_layout = krc_layout.slice(1, 0, 1, att); + assert_eq!(kv_layout.offset(), 0); + assert_eq!(kr_layout.offset(), 0); + self.attention.launch( + &attention_mla::Args { + mask: *mask, + q_layout: q_layout.clone(), + q_base: *q_base, + kv_layout: TensorLayout::new(dt, kv_layout.shape(), kv_layout.strides()), + kv_base: *kv_cache_base, + kr_layout: TensorLayout::new(dt, kr_layout.shape(), kr_layout.strides()), + kr_base: *kr_cache_base, + absorb_layout: absorb_layout.clone(), + absorb_base: *absorb_base, + qr_layout: qr_layout.clone(), + qr_base: *qr_base, + o_layout: o_layout.clone(), + o_base: *o_base, + }, + workspace, + queue_alloc, + )?; + Ok(()) + } +} diff --git a/operators/src/lib.rs b/operators/src/lib.rs index 02088acc..48788e47 100644 --- a/operators/src/lib.rs +++ b/operators/src/lib.rs @@ -8,6 +8,8 @@ pub mod add_rows; pub mod all_reduce; pub mod attention; pub mod attention_kv_cached; +pub mod attention_mla; +pub mod attention_mla_cached; pub mod broadcast; pub mod conv; pub mod fuesd_softmax; @@ -18,6 +20,7 @@ pub mod random_sample; pub mod rearrange; pub mod rms_norm; pub mod rope; +pub mod scale; pub mod swiglu; pub use common::*; diff --git a/operators/src/rearrange/args.rs b/operators/src/rearrange/args.rs index 64aa3d3d..f03767ba 100644 --- a/operators/src/rearrange/args.rs +++ b/operators/src/rearrange/args.rs @@ -268,13 +268,13 @@ fn test_scheme() { dst_layout: TensorLayout::new( F16, &shape, - &[33554432 * 2, 16777216 * 2, 524288 * 2, 128 * 2, 1 * 2], + &[33554432 * 2, 16777216 * 2, 524288 * 2, 128 * 2, 2], ), dst_base: null_mut(), src_layout: TensorLayout::new( F16, &shape, - &[33554432 * 2, 16777216 * 2, 524288 * 2, 128 * 2, 1 * 2], + &[33554432 * 2, 16777216 * 2, 524288 * 2, 128 * 2, 2], ), src_base: null(), }; diff --git a/operators/src/rope/args.rs b/operators/src/rope/args.rs index 6c3dff97..373cd6df 100644 --- a/operators/src/rope/args.rs +++ b/operators/src/rope/args.rs @@ -1,10 +1,45 @@ -use crate::{ +use crate::{ type_not_support, utils::{dim_distinct, rank_error}, ConstPtr, Hardware, MaybeDyn, MutPtr, SchemeError, TensorLayout, }; use digit_layout::DigitLayout; +pub enum RopeType { + // 以下枚举通用一个 Scheme + Rope, + Pi { + s: f32, + }, + Ntk { + s: f32, + }, + Dyn { + s: f32, + a: f32, + }, + + // 以下枚举通用一个 Scheme + NtkParts { + alpha: f32, + beta: f32, + l0: f32, + s: f32, + }, + Yarn { + alpha: f32, + beta: f32, + l0: f32, + s: f32, + }, + Long { + long: ConstPtr, + short: ConstPtr, + max_pos: u32, + origin_pos: u32, + }, +} + pub struct Args { pub t_layout: TensorLayout, pub t_base: MutPtr, @@ -15,6 +50,7 @@ pub struct Args { pub cos_layout: TensorLayout, pub cos_base: ConstPtr, pub theta: f32, + pub rope_type: RopeType, } pub(super) struct Meta { diff --git a/operators/src/rope/common_cpu/mod.rs b/operators/src/rope/common_cpu/mod.rs index 0630555b..dc11aa0d 100644 --- a/operators/src/rope/common_cpu/mod.rs +++ b/operators/src/rope/common_cpu/mod.rs @@ -1,11 +1,37 @@ -use super::{args::Meta, fill_pos, Args, Rope, Seq, SinCosTable}; +use super::{args::Meta, args::RopeType as R, fill_pos, Args, Rope, Seq, SinCosTable}; use crate::{ common_cpu::Cpu, get_static, strides_not_support, ByteOf, LaunchError, QueueAlloc, SchemeError, Unsigned, }; use digit_layout::{types as ty, DigitLayout}; use half::f16; +use std::ptr::null; +#[derive(Copy, Clone)] +enum NtkPartsType { + None, + Yarn, +} +#[derive(Copy, Clone)] +enum SchemeType { + Rope { + s: f32, + }, + Long { + long: *const T, + short: *const T, + s: f32, + origin_pos: u32, + }, + #[allow(dead_code)] + NtkParts { + alpha: f32, + beta: f32, + l0: f32, + s: f32, + ntktype: NtkPartsType, + }, +} pub struct Operator; impl Rope for Operator { @@ -78,6 +104,7 @@ impl crate::Operator for Operator { p_layout, p_base, theta, + rope_type, .. } = args; let &[_, nh, dh] = t_layout.shape() else { @@ -99,33 +126,130 @@ impl crate::Operator for Operator { return Err(strides_not_support("").into()); } - macro_rules! calculate { - ($t:ty, $p:ty) => { - Scheme::<$t, $p> { - nt, - nh, - dh, - st, - sh, - sp, - theta: *theta, - t_base: t_base.cast(), - p_base: p_base.cast(), + match rope_type { + R::Rope | R::Dyn { .. } | R::Ntk { .. } | R::Pi { .. } => { + let (theta, s) = match rope_type { + R::Rope => (*theta, 1.), + R::Dyn { s, a } => (theta * (a * s - a + 1.), 1.), + R::Ntk { s } => (theta * s, 1.), + R::Pi { s } => (*theta, *s), + _ => unreachable!(), + }; + macro_rules! calculate { + ($t:ty, $p:ty) => { + Scheme::<$t, $p> { + nt, + nh, + dh, + st, + sh, + sp, + theta, + t_base: t_base.cast(), + p_base: p_base.cast(), + scheme_type: SchemeType::Rope { s }, + } + .calculate() + }; } - .calculate() - }; - } - use digit_layout::types as ty; - match (dt_t, dt_p) { - (ty::F16, ty::U32) => calculate!(f16, u32), - (ty::F16, ty::U64) => calculate!(f16, u64), - (ty::F32, ty::U32) => calculate!(f32, u32), - (ty::F32, ty::U64) => calculate!(f32, u64), - (ty::F64, ty::U32) => calculate!(f64, u32), - (ty::F64, ty::U64) => calculate!(f64, u64), - _ => todo!(), + use digit_layout::types as ty; + match (dt_t, dt_p) { + (ty::F16, ty::U32) => calculate!(f16, u32), + (ty::F16, ty::U64) => calculate!(f16, u64), + (ty::F32, ty::U32) => calculate!(f32, u32), + (ty::F32, ty::U64) => calculate!(f32, u64), + (ty::F64, ty::U32) => calculate!(f64, u32), + (ty::F64, ty::U64) => calculate!(f64, u64), + _ => todo!(), + } + } + R::Long { + long, + short, + max_pos, + origin_pos, + } => { + let s = 1.0 + + ((*max_pos as f32 / *origin_pos as f32).ln() / (*origin_pos as f32).ln()) + .sqrt(); + macro_rules! calculate { + ($t:ty, $p:ty) => { + Scheme::<$t, $p> { + nt, + nh, + dh, + st, + sh, + sp, + theta: *theta, + t_base: t_base.cast(), + p_base: p_base.cast(), + scheme_type: SchemeType::Long { + long: long.cast(), + short: short.cast(), + s, + origin_pos: *origin_pos, + }, + } + .calculate() + }; + } + + use digit_layout::types as ty; + match (dt_t, dt_p) { + (ty::F16, ty::U32) => calculate!(f16, u32), + (ty::F16, ty::U64) => calculate!(f16, u64), + (ty::F32, ty::U32) => calculate!(f32, u32), + (ty::F32, ty::U64) => calculate!(f32, u64), + (ty::F64, ty::U32) => calculate!(f64, u32), + (ty::F64, ty::U64) => calculate!(f64, u64), + _ => todo!(), + } + } + R::Yarn { alpha, beta, l0, s } | R::NtkParts { alpha, beta, l0, s } => { + let ntktype = match rope_type { + R::NtkParts { .. } => NtkPartsType::None, + R::Yarn { .. } => NtkPartsType::Yarn, + _ => unreachable!(), + }; + macro_rules! calculate { + ($t:ty, $p:ty) => { + Scheme::<$t, $p> { + nt, + nh, + dh, + st, + sh, + sp, + theta: *theta, + t_base: t_base.cast(), + p_base: p_base.cast(), + scheme_type: SchemeType::NtkParts { + alpha: *alpha, + beta: *beta, + l0: *l0, + s: *s, + ntktype, + }, + } + .calculate() + }; + } + + use digit_layout::types as ty; + match (dt_t, dt_p) { + (ty::F16, ty::U32) => calculate!(f16, u32), + (ty::F16, ty::U64) => calculate!(f16, u64), + (ty::F32, ty::U32) => calculate!(f32, u32), + (ty::F32, ty::U64) => calculate!(f32, u64), + (ty::F64, ty::U32) => calculate!(f64, u32), + (ty::F64, ty::U64) => calculate!(f64, u64), + _ => todo!(), + } + } } + Ok(()) } } @@ -142,15 +266,15 @@ struct Scheme { theta: f32, t_base: *mut A, p_base: *const P, + scheme_type: SchemeType, } unsafe impl Send for Scheme {} unsafe impl Sync for Scheme {} - /// 激活值。 trait Activation: Sized { /// 激活值类型决定计算类型。 - type Calculation; + type Calculation: Copy; /// 计算流程。 fn calculate(pair: &mut [Self; 2], sin: Self::Calculation, cos: Self::Calculation); } @@ -187,15 +311,69 @@ impl Activation for f64 { } trait Position { - fn freq_sin_cos(self, k: isize, dh: isize, theta: f32) -> (Calculation, Calculation); + fn freq_sin_cos_rope( + self, + k: isize, + dh: isize, + theta: f32, + s: f32, + ) -> (Calculation, Calculation); + fn freq_sin_cos_long( + self, + k: isize, + dh: isize, + t: f32, + f: f32, + s: f32, + ) -> (Calculation, Calculation); + #[allow(clippy::too_many_arguments)] + fn freq_sin_cos_ntk_part( + self, + k: isize, + dh: isize, + theta: f32, + alpha: f32, + beta: f32, + l0: f32, + s: f32, + ntktype: NtkPartsType, + ) -> (Calculation, Calculation); } macro_rules! impl_position { ($a:ty) => { impl Position<$a> for T { #[inline] - fn freq_sin_cos(self, k: isize, dh: isize, theta: f32) -> ($a, $a) { - (self.val() as $a / (theta as $a).powf(k as $a / dh as $a)).sin_cos() + fn freq_sin_cos_rope(self, k: isize, dh: isize, theta: f32, s: f32) -> ($a, $a) { + (self.val() as $a * s as $a * (theta as $a).powf(k as $a / dh as $a).recip()) + .sin_cos() + } + #[inline] + fn freq_sin_cos_long(self, k: isize, dh: isize, t: f32, f: f32, s: f32) -> ($a, $a) { + let (sin, cos) = + (self.val() as $a * (t as $a).powf(k as $a / dh as $a).recip() * (f as $a).recip() ).sin_cos(); + (sin * s as $a, cos * s as $a) + } + #[inline] + fn freq_sin_cos_ntk_part( + self, + k: isize, + dh: isize, + theta: f32, + alpha: f32, + beta: f32, + l0: f32, + s: f32, + ntktype: NtkPartsType, + ) -> ($a, $a) { + use std::f32::consts::PI; + let pos = match ntktype { + NtkPartsType::None => self.val() as $a, + NtkPartsType::Yarn => self.val() as $a * (0.1 * s.ln() + 1.) as $a, + }; + let theta = theta.powf(k as f32 / dh as f32).recip(); + let r = ((l0 / (2. * PI / theta) - alpha) / (beta - alpha)).clamp(0., 1.); + (pos * ((1. - r) / s + r) as $a * theta as $a).sin_cos() } } }; @@ -206,8 +384,8 @@ impl_position!(f64); impl Scheme where - A: Activation, - P: Position + Sync + Copy, + A: Activation + Copy, + P: Position + Sync + Copy + Unsigned, { fn calculate(&self) { let &Self { @@ -220,6 +398,7 @@ where theta, t_base, p_base, + scheme_type, } = self; let nt = nt as isize; let nh = nh as isize; @@ -229,10 +408,39 @@ where for i in 0..nt { let t = unsafe { t_base.byte_offset(i * st).cast::<[A; 2]>() }; let p = unsafe { *p_base.byte_offset(i * sp) }; + let factor = match scheme_type { + SchemeType::Long { + long, + short, + origin_pos, + .. + } => unsafe { + if p.val() < origin_pos as usize { + short + } else { + long + } + }, + _ => null(), + }; for j in 0..nh { for k in 0..dh { let pair = unsafe { &mut *t.byte_offset(j * sh + k * sd) }; - let (sin, cos) = p.freq_sin_cos(k, dh, theta); + let (sin, cos) = match scheme_type { + SchemeType::Rope { s } => p.freq_sin_cos_rope(k, dh, theta, s), + SchemeType::Long { s, .. } => { + // TODO 这里先默认为 f32 + let factor = unsafe { *factor.byte_offset(k * 4).cast() }; + p.freq_sin_cos_long(k, dh, theta, factor, s) + } + SchemeType::NtkParts { + alpha, + beta, + l0, + s, + ntktype, + } => p.freq_sin_cos_ntk_part(k, dh, theta, alpha, beta, l0, s, ntktype), + }; A::calculate(pair, sin, cos) } } diff --git a/operators/src/rope/cuda/mod.rs b/operators/src/rope/cuda/mod.rs index 9d457120..6ac010c7 100644 --- a/operators/src/rope/cuda/mod.rs +++ b/operators/src/rope/cuda/mod.rs @@ -184,7 +184,7 @@ extern "C" __global__ void {POS_U64}( #[cfg(test)] mod test { use super::{Args, Gpu, Operator, POS_U32, POS_U64}; - use crate::{Hardware, Operator as _, TensorLayout}; + use crate::{rope::args, Hardware, Operator as _, TensorLayout}; use digit_layout::{ types::{F16, F64, U32}, DigitLayout, @@ -203,6 +203,7 @@ mod test { cos_layout: TensorLayout::new_dyn(dt_t, &[dyn_(); 2], &[dyn_(); 2]), cos_base: null(), theta: 0., + rope_type: args::RopeType::Rope, } } @@ -227,6 +228,7 @@ mod test { cos_layout: TensorLayout::new_contiguous(dt_t, &[0, dh]), cos_base: null(), theta, + rope_type: args::RopeType::Rope, } } diff --git a/operators/src/rope/mod.rs b/operators/src/rope/mod.rs index 1576c069..45a5badc 100644 --- a/operators/src/rope/mod.rs +++ b/operators/src/rope/mod.rs @@ -8,7 +8,7 @@ pub mod infini; pub mod opencl; mod args; -pub use args::Args; +pub use args::{Args, RopeType}; crate::op_trait! { Rope /// 生成 sincos 表([2, n, dh])。 diff --git a/operators/src/scale/args.rs b/operators/src/scale/args.rs new file mode 100644 index 00000000..9e395148 --- /dev/null +++ b/operators/src/scale/args.rs @@ -0,0 +1,171 @@ +use crate::{ + get_static, rank_mismatch, shape_mismatch, shape_not_support, utils::type_distinct, ConstPtr, + Hardware, MutPtr, SchemeError, TensorLayout, +}; +use digit_layout::DigitLayout; +use itertools::izip; +use std::{ + cmp::Ordering, + ptr::{null, null_mut}, +}; + +#[derive(Clone)] +pub struct Args { + pub c_layout: TensorLayout, + pub c_base: MutPtr, + pub a_layout: TensorLayout, + pub a_base: ConstPtr, + pub s: f32, +} + +impl Args { + pub fn new_null(c_layout: TensorLayout, a_layout: TensorLayout) -> Self { + Self { + c_layout, + c_base: null_mut(), + a_layout, + a_base: null(), + s: 1.0, + } + } +} + +#[derive(Clone, Debug)] +pub(super) struct Scheme(DigitLayout, Box<[isize]>); + +impl Scheme { + pub fn new(args: &Args) -> Result { + let Args { + c_layout: c, + a_layout: a, + .. + } = args; + // # 检查基本属性 + let dt = type_distinct(&[c.dt(), a.dt()])?; + let ndim = c.ndim(); + if a.ndim() != ndim { + return Err(rank_mismatch(format!( + "c.ndim = {}, a.ndim = {}", + c.ndim(), + a.ndim(), + ))); + } + // # 输入形状 + #[derive(Clone, PartialEq, Eq, Debug)] + struct Dim { + d: usize, + c: isize, + a: isize, + } + let mut dims = Vec::with_capacity(ndim); + for (&d, &da, &sc, &sa) in izip!(c.shape(), a.shape(), c.strides(), a.strides(),) { + get_static! { + d da + sc sa + } + if da != d { + return Err(shape_mismatch(format!( + "c: {:?}, a: {:?}", + c.shape(), + a.shape(), + ))); + } + // 剔除初始的 1 长维度 + if d != 1 { + if sc == 0 { + return Err(shape_not_support("Reducing is not allowed for scale")); + } + dims.push(Dim { d, c: sc, a: sa }) + } + } + // # 排序 + dims.sort_unstable_by(|dim0, dim1| { + let &Dim { + d: d0, + c: c0, + a: a0, + } = dim0; + let &Dim { + d: d1, + c: c1, + a: a1, + } = dim1; + use Ordering::Equal as Eq; + match c0.abs().cmp(&c1.abs()) { + Eq => match a0.abs().cmp(&a1.abs()) { + Eq => d0.cmp(&d1), + ord => ord.reverse(), + }, + ord => ord.reverse(), + } + }); + // # 合并连续维度 + let mut ndim = dims.len(); + for i in (1..dims.len()).rev() { + let (head, tail) = dims.split_at_mut(i); + let f = &mut head[i - 1]; // f for front + let b = &mut tail[0]; // b for back + let d = b.d as isize; + if b.c * d == f.c && b.a * d == f.a { + *f = Dim { d: b.d * f.d, ..*b }; + *b = Dim { d: 1, c: 0, a: 0 }; + ndim -= 1 + } + } + // # 合并空间 + let mut layout = vec![0isize; 1 + ndim * 4].into_boxed_slice(); + { + let (idx, tail) = layout.split_at_mut(1 + ndim); + let (c_, tail) = tail.split_at_mut(ndim); + let (a_, _b) = tail.split_at_mut(ndim); + for (Dim { d, c, a }, idx, c_, a_) in + izip!(dims.into_iter().filter(|d| d.d != 1), &mut *idx, c_, a_) + { + *idx = d as _; + *c_ = c; + *a_ = a; + } + idx[ndim] = 1; + for i in (1..=ndim).rev() { + idx[i - 1] *= idx[i]; + } + } + Ok(Self(dt, layout)) + } + + #[inline] + pub const fn dt(&self) -> DigitLayout { + self.0 + } + + /// 执行方案维数。 + #[inline] + pub fn ndim(&self) -> usize { + (self.1.len() - 1) / 4 + } + + /// 读写单元数量。 + #[inline] + pub fn count(&self) -> usize { + self.1[0] as _ + } + + /// 索引步长。 + #[inline] + pub fn idx_strides(&self) -> &[isize] { + let ndim = self.ndim(); + &self.1[1..][..ndim] + } + + #[inline] + pub fn c_strides(&self) -> &[isize] { + let ndim = self.ndim(); + &self.1[1 + ndim..][..ndim] + } + + #[inline] + pub fn a_strides(&self) -> &[isize] { + let ndim = self.ndim(); + &self.1[1 + ndim * 2..][..ndim] + } +} diff --git a/operators/src/scale/common_cpu/mod.rs b/operators/src/scale/common_cpu/mod.rs new file mode 100644 index 00000000..b4ba50cb --- /dev/null +++ b/operators/src/scale/common_cpu/mod.rs @@ -0,0 +1,71 @@ +use super::{args::Scheme, Args, Scale}; +use crate::{common_cpu::Cpu, ByteOf, LaunchError, QueueAlloc, SchemeError}; +use digit_layout::types as ty; +use half::f16; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; + +pub struct Operator; + +impl Scale for Operator {} + +impl crate::Operator for Operator { + type Hardware = Cpu; + type TopoNode = Cpu; + type Args = Args; + + #[inline] + fn new(_node: &Self::TopoNode) -> Self { + Self + } + #[inline] + fn scheme( + &mut self, + _args: &Self::Args, + _max_workspace_size: usize, + ) -> Result { + Ok(0) + } + + fn launch( + &self, + args: &Self::Args, + _workspace: &mut [ByteOf], + _queue_alloc: &QA, + ) -> Result<(), LaunchError> + where + QA: QueueAlloc, + { + let scheme = Scheme::new(args)?; + let c = args.c_base as isize; + let a = args.a_base as isize; + let s = args.s; + let idx_strides = scheme.idx_strides(); + let c_strides = scheme.c_strides(); + let a_strides = scheme.a_strides(); + (0..scheme.count() as isize) + .into_par_iter() + .for_each(|mut rem| { + let mut c = c; + let mut a = a; + for (i, &s) in idx_strides.iter().enumerate() { + let k = rem / s; + c += k * c_strides[i]; + a += k * a_strides[i]; + rem %= s; + } + match scheme.dt() { + ty::F16 => mul::(c, a, f16::from_f32(s)), + ty::F32 => mul::(c, a, s), + ty::F64 => mul::(c, a, s as f64), + _ => todo!(), + } + }); + Ok(()) + } +} + +fn mul>(c: isize, a: isize, s: T) { + let c = c as *mut T; + let a = a as *const T; + unsafe { *c = a.read() * s } +} diff --git a/operators/src/scale/cuda/mod.rs b/operators/src/scale/cuda/mod.rs new file mode 100644 index 00000000..9acec9fb --- /dev/null +++ b/operators/src/scale/cuda/mod.rs @@ -0,0 +1,36 @@ +use super::{Args, Scale}; +use crate::{cuda::Gpu, ByteOf, LaunchError, QueueAlloc, SchemeError}; + +pub struct Operator {} +impl Scale for Operator {} + +impl crate::Operator for Operator { + type Hardware = Gpu; + type TopoNode = Gpu; + type Args = Args; + + fn new(_node: &Self::TopoNode) -> Self { + Self {} + } + + #[inline] + fn scheme( + &mut self, + _args: &Self::Args, + _max_workspace_size: usize, + ) -> Result { + Ok(0) + } + + fn launch( + &self, + _args: &Self::Args, + _workspace: &mut [ByteOf], + _queue_alloc: &QA, + ) -> Result<(), LaunchError> + where + QA: QueueAlloc, + { + Ok(()) + } +} diff --git a/operators/src/scale/infini/mod.rs b/operators/src/scale/infini/mod.rs new file mode 100644 index 00000000..183414e1 --- /dev/null +++ b/operators/src/scale/infini/mod.rs @@ -0,0 +1,36 @@ +use super::{Args, Scale}; +use crate::{infini::Device, ByteOf, LaunchError, QueueAlloc, SchemeError}; + +pub struct Operator; + +impl Add for Operator {} + +impl crate::Operator for Operator { + type Hardware = Device; + type TopoNode = Device; + type Args = Args; + + fn new(_node: &Self::TopoNode) -> Self { + todo!() + } + + fn scheme( + &mut self, + _args: &Self::Args, + _max_workspace_size: usize, + ) -> Result { + todo!() + } + + fn launch( + &self, + _args: &Self::Args, + _workspace: &mut [ByteOf], + _queue_alloc: &QA, + ) -> Result<(), LaunchError> + where + QA: QueueAlloc, + { + todo!() + } +} diff --git a/operators/src/scale/mod.rs b/operators/src/scale/mod.rs new file mode 100644 index 00000000..e5d473b0 --- /dev/null +++ b/operators/src/scale/mod.rs @@ -0,0 +1,15 @@ +//! c =s*a + +#[cfg(any(use_cpu, test))] +pub mod common_cpu; +#[cfg(use_cuda)] +pub mod cuda; +#[cfg(use_infini)] +pub mod infini; +#[cfg(use_cl)] +pub mod opencl; + +mod args; +pub use args::Args; + +crate::op_trait!(Scale); diff --git a/operators/src/scale/opencl/mod.rs b/operators/src/scale/opencl/mod.rs new file mode 100644 index 00000000..369e8f81 --- /dev/null +++ b/operators/src/scale/opencl/mod.rs @@ -0,0 +1,36 @@ +use super::{Args, Scale}; +use crate::{opencl::ClDevice, ByteOf, LaunchError, QueueAlloc, SchemeError}; + +pub struct Operator; + +impl Scale for Operator {} + +impl crate::Operator for Operator { + type Hardware = ClDevice; + type TopoNode = ClDevice; + type Args = Args; + + fn new(_node: &Self::TopoNode) -> Self { + todo!() + } + + fn scheme( + &mut self, + _args: &Self::Args, + _max_workspace_size: usize, + ) -> Result { + todo!() + } + + fn launch( + &self, + _args: &Self::Args, + _workspace: &mut [ByteOf], + _queue_alloc: &QA, + ) -> Result<(), LaunchError> + where + QA: QueueAlloc, + { + todo!() + } +}