Skip to content

refactor: 简化算子设计 #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ jobs:
run:
cargo clippy
--all-features
--all-targets
--message-format=json | clippy-sarif | tee rust-clippy-results.sarif | sarif-fmt
continue-on-error: true

Expand Down
13 changes: 7 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[workspace]
members = ["operators"]
resolver = "2"
resolver = "3"
package.edition = "2024"

[workspace.dependencies]
clrt = { git = "https://github.com/InfiniTensor/clrt", rev = "984ac7a" }
Expand All @@ -11,8 +12,8 @@ infini-op = { git = "https://github.com/InfiniTensor/infini-toolkit", rev = "e83
infini-ccl = { git = "https://github.com/InfiniTensor/infini-toolkit", rev = "e8362c3" }
search-infini-tools = { git = "https://github.com/InfiniTensor/infini-toolkit", rev = "e8362c3" }

cuda = { git = "https://github.com/YdrMaster/cuda-driver", rev = "f3ffbcc" }
cublas = { git = "https://github.com/YdrMaster/cuda-driver", rev = "f3ffbcc" }
nccl = { git = "https://github.com/YdrMaster/cuda-driver", rev = "f3ffbcc" }
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "f3ffbcc" }
search-corex-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "f3ffbcc" }
cuda = { git = "https://github.com/YdrMaster/cuda-driver", rev = "c2b12d3" }
cublas = { git = "https://github.com/YdrMaster/cuda-driver", rev = "c2b12d3" }
nccl = { git = "https://github.com/YdrMaster/cuda-driver", rev = "c2b12d3" }
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "c2b12d3" }
search-corex-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "c2b12d3" }
10 changes: 5 additions & 5 deletions operators/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "operators"
version = "0.0.0"
edition = "2021"
edition.workspace = true
authors = ["YdrMaster <[email protected]>"]

[features]
Expand All @@ -13,13 +13,13 @@ nvidia-gpu = ["cuda", "cublas", "nccl", "fslock", "libloading"]
iluvatar-gpu = ["cuda", "cublas", "fslock", "libloading"]

[dependencies]
digit-layout = "0.2"
ndarray-layout = "0.1"
digit-layout = "0.3"
ndarray-layout = "0.2"
rayon = "1.10"
lru = "0.12"
lru = "0.14"
num-traits = "0.2"
itertools = "0.14"
half = "2.4"
half = "2.6"
log = "0.4"

gemm = { version = "0.18", optional = true }
Expand Down
17 changes: 10 additions & 7 deletions operators/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,21 @@ fn main() {
{
infini.define()
}
let use_nvidia = cfg!(feature = "nvidia-gpu") && find_cuda_root().is_some();

// iluvatar
let use_iluvatar = cfg!(feature = "iluvatar-gpu") && find_corex().is_some();
if use_iluvatar {
iluvatar.define();
cuda.define();
return;
}

let use_nvidia = cfg!(feature = "nvidia-gpu") && find_cuda_root().is_some();
if use_nvidia {
nvidia.define();
if find_nccl_root().is_some() {
nccl.define()
}
}
if use_iluvatar {
iluvatar.define()
}
if use_nvidia || use_iluvatar {
cuda.define()
cuda.define();
}
}
69 changes: 16 additions & 53 deletions operators/src/.clang-format
Original file line number Diff line number Diff line change
@@ -1,66 +1,29 @@
# Generated from CLion C/C++ Code Style settings
---
BasedOnStyle: LLVM
AccessModifierOffset: -4
AlignAfterOpenBracket: Align
# AlignConsecutiveAssignments: None
AlignOperands: Align
AllowAllArgumentsOnNextLine: false
AllowAllConstructorInitializersOnNextLine: false
AllowAllParametersOfDeclarationOnNextLine: false
AllowShortBlocksOnASingleLine: Always
AllowShortCaseLabelsOnASingleLine: false
AllowShortFunctionsOnASingleLine: All
AllowShortIfStatementsOnASingleLine: Always
AllowShortLambdasOnASingleLine: All
AllowShortLoopsOnASingleLine: true
AlwaysBreakAfterReturnType: None
AlwaysBreakTemplateDeclarations: No
BreakBeforeBraces: Custom
IndentWidth: 4 # 缩进宽度,LLVM 默认值为 2,改为 4
AccessModifierOffset: -4 # public/protected/private 访问控制符相对成员的偏移,与 IndentWidth 配合,LLVM 默认值为 -2
AlignOperands: AlignAfterOperator # 双目运算符的行间对齐,LLVM 默认值为 Align,改为带符号一起换行
ColumnLimit: 0 # 列宽限制,LLVM 默认值为 80,改为不限制
AllowShortBlocksOnASingleLine: Always # 是否允许短块(单个语句的块)不换行,LLVM 默认值为 Never,改为允许
AllowShortLoopsOnASingleLine: true # 是否允许短循环不换行,LLVM 默认值为 false,改为允许
InsertBraces: true # 是否在 if/for/while/switch 等语句后插入大括号,LLVM 默认值为 false,改为允许
BreakBeforeBraces: Custom # 大括号换行配置,LLVM 默认值为 LLVM,改为自定义以使 BraceWrapping 生效
BraceWrapping:
AfterCaseLabel: false
AfterClass: false
AfterControlStatement: Never
AfterEnum: false
AfterFunction: false
AfterNamespace: false
AfterObjCDeclaration: false
AfterStruct: false
AfterUnion: false
AfterExternBlock: false
BeforeCatch: false
BeforeElse: false
BeforeLambdaBody: false
BeforeWhile: false
IndentBraces: false
SplitEmptyFunction: false
SplitEmptyFunction: true
SplitEmptyRecord: true
BreakBeforeBinaryOperators: None
BreakBeforeTernaryOperators: true
BreakConstructorInitializers: BeforeColon
BreakInheritanceList: BeforeColon
ColumnLimit: 0
CompactNamespaces: true
ContinuationIndentWidth: 4
IndentCaseLabels: true
IndentPPDirectives: None
IndentWidth: 4
KeepEmptyLinesAtTheStartOfBlocks: true
MaxEmptyLinesToKeep: 2
NamespaceIndentation: All
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: true
PointerAlignment: Right
ReflowComments: false
SpaceAfterCStyleCast: true
SpaceAfterLogicalNot: false
SpaceAfterTemplateKeyword: false
SpaceBeforeAssignmentOperators: true
SpaceBeforeCpp11BracedList: false
SpaceBeforeCtorInitializerColon: true
SpaceBeforeInheritanceColon: true
SpaceBeforeParens: ControlStatements
SpaceBeforeRangeBasedForLoopColon: true
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 0
SpacesInAngles: false
SpacesInCStyleCastParentheses: false
SpacesInContainerLiterals: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
TabWidth: 4
UseTab: Never
SplitEmptyNamespace: true
20 changes: 8 additions & 12 deletions operators/src/add/args.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
get_static, rank_mismatch, shape_mismatch, shape_not_support, utils::type_distinct, ConstPtr,
Hardware, MutPtr, SchemeError, TensorLayout,
ConstPtr, Hardware, LaunchError, MutPtr, TensorLayout, rank_mismatch, shape_mismatch,
shape_not_support, utils::type_distinct,
};
use digit_layout::DigitLayout;
use itertools::izip;
Expand Down Expand Up @@ -40,15 +40,15 @@ impl<H: Hardware> Args<H> {
pub(super) struct Scheme(DigitLayout, Box<[isize]>);

impl Scheme {
pub fn new<H: Hardware>(args: &Args<H>) -> Result<Self, SchemeError> {
pub fn new<H: Hardware>(args: &Args<H>) -> Result<Self, LaunchError> {
let Args {
c_layout: c,
a_layout: a,
b_layout: b,
..
} = args;
// # 检查基本属性
let dt = type_distinct(&[c.dt(), a.dt(), b.dt()])?;
let dt = type_distinct(&[c.dt, a.dt, b.dt])?;
let ndim = c.ndim();
if a.ndim() != ndim || b.ndim() != ndim {
return Err(rank_mismatch(format!(
Expand All @@ -68,17 +68,13 @@ impl Scheme {
}
let mut dims = Vec::with_capacity(ndim);
for (&d, &da, &db, &sc, &sa, &sb) in izip!(
c.shape(),
a.shape(),
b.shape(),
c.shape_group(),
a.shape_group(),
b.shape_group(),
c.strides(),
a.strides(),
b.strides()
b.strides(),
) {
get_static! {
d da db
sc sa sb
}
if da != d || db != d {
return Err(shape_mismatch(format!(
"c: {:?}, a: {:?}, b: {:?}",
Expand Down
12 changes: 2 additions & 10 deletions operators/src/add/common_cpu/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{args::Scheme, Add, Args};
use crate::{common_cpu::Cpu, ByteOf, LaunchError, QueueAlloc, SchemeError};
use super::{Add, Args, args::Scheme};
use crate::{ByteOf, LaunchError, QueueAlloc, common_cpu::Cpu};
use digit_layout::types as ty;
use half::f16;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
Expand All @@ -17,14 +17,6 @@ impl crate::Operator for Operator {
fn new(_node: &Self::TopoNode) -> Self {
Self
}
#[inline]
fn scheme(
&mut self,
_args: &Self::Args,
_max_workspace_size: usize,
) -> Result<usize, SchemeError> {
Ok(0)
}

fn launch<QA>(
&self,
Expand Down
2 changes: 1 addition & 1 deletion operators/src/add/cuda/add.cuh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
template<class Tdata>
template <class Tdata>
static __device__ void _add(
Tdata *__restrict__ c,
Tdata const *__restrict__ a,
Expand Down
67 changes: 18 additions & 49 deletions operators/src/add/cuda/mod.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
use super::{args::Scheme, Add, Args};
use super::{Add, Args, args::Scheme};
use crate::{
cuda::{dt_name, Gpu, Handle, ModuleBox},
ByteOf, LaunchError, QueueAlloc, SchemeDiversity,
cuda::{Gpu, Handle, ModuleBox, dt_name},
shape_not_support, strides_not_support,
utils::{gcd, type_distinct},
ByteOf, LaunchError, QueueAlloc, SchemeDiversity, SchemeError,
utils::gcd,
};
use cuda::params;
use digit_layout::DigitLayout;
use lru::LruCache;
use std::{
ffi::{c_uint, CString},
ffi::c_uint,
sync::{Arc, Mutex},
};

Expand All @@ -32,20 +33,6 @@ impl crate::Operator for Operator {
}
}

#[inline]
fn scheme(
&mut self,
args: &Self::Args,
_max_workspace_size: usize,
) -> Result<usize, SchemeError> {
let dt = type_distinct(&[args.c_layout.dt(), args.a_layout.dt(), args.b_layout.dt()])?;
self.schemes
.lock()
.unwrap()
.get_or_insert(dt, || compile(&self.handle, dt));
Ok(0)
}

fn launch<QA>(
&self,
args: &Self::Args,
Expand All @@ -60,20 +47,20 @@ impl crate::Operator for Operator {
let count = scheme.count();

let &[1] = scheme.idx_strides() else {
return Err(shape_not_support("").into());
return Err(shape_not_support(""));
};
let &[sc] = scheme.c_strides() else {
return Err(shape_not_support("").into());
return Err(shape_not_support(""));
};
let &[sa] = scheme.a_strides() else {
return Err(shape_not_support("").into());
return Err(shape_not_support(""));
};
let &[sb] = scheme.b_strides() else {
return Err(shape_not_support("").into());
return Err(shape_not_support(""));
};
let unit = dt.nbytes() as isize;
if sc != unit || sa != unit || sb != unit {
return Err(strides_not_support("").into());
return Err(strides_not_support(""));
}

let block_dims = gcd(count, self.max_threads_block);
Expand All @@ -84,18 +71,15 @@ impl crate::Operator for Operator {
b_base,
..
} = args;
let params = cuda::params![c_base, a_base, b_base];

self.schemes
.lock()
.unwrap()
.get_or_insert(dt, || compile(&self.handle, dt))
.launch(
CString::new("add").unwrap(),
grid_dims as c_uint,
block_dims as c_uint,
params.as_ptr(),
0,
c"add",
(grid_dims as c_uint, block_dims as c_uint, 0),
&params![*c_base, *a_base, *b_base].to_ptrs(),
queue_alloc.queue(),
);
Ok(())
Expand Down Expand Up @@ -124,25 +108,12 @@ extern "C" __global__ void add(
#[cfg(test)]
mod test {
use super::{Args, Gpu, Operator};
use crate::{dyn_, Hardware, Operator as _, TensorLayout};
use crate::{Hardware, Operator as _, TensorLayout};
use digit_layout::{
types::{F16, F64},
DigitLayout,
types::{F16, F64},
};
use std::ptr::null;

fn dyn_args<H: Hardware>(dt: DigitLayout) -> Args<H> {
use std::ptr::null_mut;
let layout = TensorLayout::new_dyn(dt, &[dyn_(); 2], &[dyn_(); 2]);
Args {
c_layout: layout.clone(),
c_base: null_mut(),
a_layout: layout.clone(),
a_base: null(),
b_layout: layout.clone(),
b_base: null(),
}
}
fn args<H: Hardware>(
dt: DigitLayout,
n: usize,
Expand Down Expand Up @@ -178,10 +149,8 @@ mod test {
return;
};

let mut cpu_op = RefOp::new(&Cpu);
let mut gpu_op = Operator::new(&gpu);
cpu_op.scheme(&dyn_args(F64), 0).unwrap();
gpu_op.scheme(&dyn_args(F16), 0).unwrap();
let cpu_op = RefOp::new(&Cpu);
let gpu_op = Operator::new(&gpu);

let n = 1;
let d = 768;
Expand Down
Loading