Skip to content

Commit 61dba6d

Browse files
committed
test: 支持从测例文件加载测例
Signed-off-by: YdrMaster <[email protected]>
1 parent a645f82 commit 61dba6d

File tree

4 files changed

+239
-0
lines changed

4 files changed

+239
-0
lines changed

operators/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,9 @@ search-cuda-tools.workspace = true
4444
search-corex-tools.workspace = true
4545

4646
[dev-dependencies]
47+
ggus = "0.4"
48+
memmap2 = "0.9"
49+
patricia_tree = "0.9"
50+
4751
gemm = "0.18"
4852
rand = "0.9"

operators/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
mod common;
44
mod handle;
5+
#[cfg(test)]
6+
mod test;
57

68
pub mod add;
79
pub mod add_rows;

operators/src/test/gguf.rs

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
use super::MetaValue;
2+
use digit_layout::DigitLayout;
3+
use ggus::{
4+
GGuf, GGufError, GGufMetaDataValueType, GGufMetaError, GGufMetaMap, GGufMetaValueArray,
5+
GGufReader, GENERAL_ALIGNMENT,
6+
};
7+
use patricia_tree::StringPatriciaMap;
8+
9+
/// GGuf 文件内容,元信息和张量。
10+
pub(super) struct Content<'a> {
11+
pub meta_kvs: StringPatriciaMap<MetaValue<'a>>,
12+
pub tensors: StringPatriciaMap<GGufTensor<'a>>,
13+
}
14+
15+
impl GGufMetaMap for Content<'_> {
16+
fn get(&self, key: &str) -> Option<(GGufMetaDataValueType, &[u8])> {
17+
self.meta_kvs.get(key).map(|v| (v.ty, &*v.value))
18+
}
19+
}
20+
21+
#[derive(Clone, Debug)]
22+
pub(super) struct GGufTensor<'a> {
23+
pub ty: DigitLayout,
24+
pub shape: Vec<usize>,
25+
pub data: &'a [u8],
26+
}
27+
28+
impl<'a> Content<'a> {
29+
/// 从分片的 GGuf 文件解析内容。
30+
pub fn new(files: &[&'a [u8]]) -> Result<Self, GGufError> {
31+
std::thread::scope(|s| {
32+
let mut ans = Self {
33+
meta_kvs: Default::default(),
34+
tensors: Default::default(),
35+
};
36+
// 在多个线程中并行解析多个文件,并逐个合并到单独的结构体中
37+
for thread in files
38+
.into_iter()
39+
.map(|data| s.spawn(|| GGuf::new(data)))
40+
.collect::<Vec<_>>()
41+
.into_iter()
42+
{
43+
thread
44+
.join()
45+
.unwrap()
46+
.and_then(|gguf| ans.merge_file(gguf))?;
47+
}
48+
49+
Ok(ans)
50+
})
51+
}
52+
53+
fn merge_file(&mut self, others: GGuf<'a>) -> Result<(), GGufError> {
54+
// 合并元信息
55+
for (k, kv) in others.meta_kvs {
56+
if k == GENERAL_ALIGNMENT || k.starts_with("split.") {
57+
continue;
58+
}
59+
let value = MetaValue {
60+
ty: kv.ty(),
61+
value: kv.value_bytes(),
62+
};
63+
if self.meta_kvs.insert(k.to_string(), value).is_some() {
64+
return Err(GGufError::DuplicateMetaKey(k.into()));
65+
}
66+
}
67+
// 合并张量,并将形状转换到 usize 类型
68+
for (name, tensor) in others.tensors {
69+
let tensor = tensor.to_info();
70+
let tensor = GGufTensor {
71+
ty: tensor.ty().to_digit_layout(),
72+
shape: tensor.shape().iter().map(|&d| d as _).collect(),
73+
data: &others.data[tensor.offset()..][..tensor.nbytes()],
74+
};
75+
if self.tensors.insert(name.to_string(), tensor).is_some() {
76+
return Err(GGufError::DuplicateTensorName(name.into()));
77+
}
78+
}
79+
Ok(())
80+
}
81+
}
82+
83+
impl MetaValue<'_> {
84+
/// 从元信息读取 isize 数组,用于解析 strides
85+
pub fn to_vec_isize(&self) -> Result<Vec<isize>, GGufMetaError> {
86+
use GGufMetaDataValueType as Ty;
87+
88+
let mut reader = GGufReader::new(&self.value);
89+
let (ty, len) = match self.ty {
90+
Ty::Array => reader.read_arr_header().map_err(GGufMetaError::Read)?,
91+
ty => return Err(GGufMetaError::TypeMismatch(ty)),
92+
};
93+
94+
match ty {
95+
Ty::I32 => Ok(GGufMetaValueArray::<i32>::new(reader, len)
96+
.map(|x| x.unwrap() as _)
97+
.collect()),
98+
Ty::I64 => Ok(GGufMetaValueArray::<i64>::new(reader, len)
99+
.map(|x| x.unwrap() as _)
100+
.collect()),
101+
_ => Err(GGufMetaError::ArrTypeMismatch(ty)),
102+
}
103+
}
104+
}

operators/src/test/mod.rs

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
mod gguf;
2+
3+
use digit_layout::DigitLayout;
4+
use gguf::{Content, GGufTensor};
5+
use ggus::{GGufMetaDataValueType, GGufMetaMap, GGufMetaMapExt};
6+
use memmap2::Mmap;
7+
use ndarray_layout::{ArrayLayout, Endian::BigEndian};
8+
use std::{collections::HashMap, env::var_os, fs::File};
9+
10+
/// 测例数据。
11+
pub struct TestCase<'a> {
12+
/// 测例在文件中的序号。
13+
pub index: usize,
14+
/// 测例元信息,即传递给算子非张量参数。
15+
pub attributes: HashMap<String, MetaValue<'a>>,
16+
/// 测例张量,包括算子的输入和正确答案。
17+
pub tensors: HashMap<String, Tensor<'a>>,
18+
}
19+
20+
/// 元信息键值对。
21+
pub struct MetaValue<'a> {
22+
pub ty: GGufMetaDataValueType,
23+
pub value: &'a [u8],
24+
}
25+
26+
/// 测例张量。
27+
pub struct Tensor<'a> {
28+
pub ty: DigitLayout,
29+
pub layout: ArrayLayout<4>,
30+
pub data: &'a [u8],
31+
}
32+
33+
impl GGufMetaMap for TestCase<'_> {
34+
fn get(&self, key: &str) -> Option<(GGufMetaDataValueType, &[u8])> {
35+
self.attributes.get(key).map(|v| (v.ty, &*v.value))
36+
}
37+
}
38+
39+
impl<'a> Content<'a> {
40+
pub fn into_cases(self) -> HashMap<String, Vec<TestCase<'a>>> {
41+
assert_eq!(self.general_architecture().unwrap(), "infiniop-test");
42+
let mut ans = HashMap::new();
43+
44+
let ntest = self.get_usize("test_count").unwrap();
45+
let Self {
46+
mut meta_kvs,
47+
mut tensors,
48+
} = self;
49+
for i in 0..ntest {
50+
let prefix = format!("test.{i}.");
51+
52+
let mut meta_kvs = meta_kvs
53+
.split_by_prefix(&prefix)
54+
.into_iter()
55+
.map(|(k, v)| (k[prefix.len()..].to_string(), v))
56+
.collect::<HashMap<_, _>>();
57+
let tensors = tensors
58+
.split_by_prefix(&prefix)
59+
.into_iter()
60+
.map(|(k, v)| {
61+
let GGufTensor {
62+
ty,
63+
mut shape,
64+
data,
65+
} = v;
66+
shape.reverse();
67+
68+
let k = k[prefix.len()..].to_string();
69+
let element_size = ty.nbytes();
70+
let layout = if let Some(strides) = meta_kvs.remove(&format!("{k}.strides")) {
71+
let mut strides = strides.to_vec_isize().unwrap();
72+
for x in &mut strides {
73+
*x *= element_size as isize
74+
}
75+
strides.reverse();
76+
77+
ArrayLayout::<4>::new(&shape, &strides, 0)
78+
} else {
79+
ArrayLayout::<4>::new_contiguous(&shape, BigEndian, element_size)
80+
};
81+
(k, Tensor { ty, layout, data })
82+
})
83+
.collect();
84+
85+
let case = TestCase {
86+
index: i,
87+
attributes: meta_kvs,
88+
tensors,
89+
};
90+
91+
let op_name = case.get_str("op_name").unwrap().to_string();
92+
ans.entry(op_name).or_insert_with(Vec::new).push(case);
93+
}
94+
95+
ans
96+
}
97+
}
98+
99+
#[test]
100+
fn test() {
101+
let Some(name) = var_os("TEST_CASES") else {
102+
eprintln!("TEST_CASES not set");
103+
return;
104+
};
105+
let Ok(file) = File::open(&name) else {
106+
eprintln!("Failed to open {}", name.to_string_lossy());
107+
return;
108+
};
109+
let mmap = unsafe { Mmap::map(&file).unwrap() };
110+
let cases = Content::new(&[&*mmap]).unwrap().into_cases();
111+
112+
for (op_name, cases) in cases {
113+
for case in cases {
114+
println!("Test case {}: {op_name}", case.index);
115+
for (k, v) in &case.attributes {
116+
println!(" {k}: {:?}", v.ty)
117+
}
118+
for (k, v) in &case.tensors {
119+
println!(
120+
" {k}: {} {:?} / {:?} {}",
121+
v.ty,
122+
v.layout.shape(),
123+
v.layout.strides(),
124+
v.data.len()
125+
)
126+
}
127+
}
128+
}
129+
}

0 commit comments

Comments
 (0)