Skip to content

Commit 4b784ee

Browse files
authored
add model struct in PS Go (#1764)
* add go model definition * add model struct in Go * fix unit test * fix comment * rename IndexedSlices in elastic.proto to IndexedSlicesProto
1 parent 565486d commit 4b784ee

File tree

10 files changed

+207
-51
lines changed

10 files changed

+207
-51
lines changed

elasticdl/pkg/commonnew/embedding_table.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ func (e *EmbeddingTable) GetEmbeddingVectors(indices []int64) *Tensor {
4343
}
4444

4545
// SetEmbeddingVectors sets (indices, value) pair to embedding vector
46-
func (e *EmbeddingTable) SetEmbeddingVectors(idxslice *IndexedTensor) error {
47-
for i, index := range idxslice.Indices {
46+
func (e *EmbeddingTable) SetEmbeddingVectors(idxslice *IndexedSlices) error {
47+
for i, index := range idxslice.Ids {
4848
value := e.GetEmbeddingVector(index)
49-
copy(value.Buffer, idxslice.GetRow(int64(i)).Buffer)
49+
copy(value.Buffer, idxslice.ConcatTensors.GetRow(int64(i)).Buffer)
5050
}
5151
return nil
5252
}

elasticdl/pkg/commonnew/embedding_table_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ func TestEmbeddingTableGet(t *testing.T) {
1616
e1 := NewEmbeddingTable(2, "zero", Float32)
1717
v1 := e1.GetEmbeddingVector(1) // Note: this is a reference type, future changes have effect on it
1818
t1 := NewTensor([]float32{1, 2}, []int64{1, 2})
19-
it := NewIndexedTensor(t1, []int64{1})
19+
it := NewIndexedSlices(t1, []int64{1})
2020
e1.SetEmbeddingVectors(it)
2121
assert.Equal(t, Slice(v1).([]float32), []float32{1, 2}, "GetEmbeddingVector FAIL")
2222

@@ -30,7 +30,7 @@ func TestEmbeddingTableSet(t *testing.T) {
3030
i := []int64{1, 3, 5}
3131
v := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}
3232
tensor := NewTensor(v, []int64{3, 2})
33-
it := NewIndexedTensor(tensor, i)
33+
it := NewIndexedSlices(tensor, i)
3434

3535
err := e.SetEmbeddingVectors(it)
3636
assert.Nil(t, err)

elasticdl/pkg/commonnew/tensor.go

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,8 @@ func (t *Tensor) IsValid() bool {
142142
return true
143143
}
144144

145-
// DeserializeFromTensorPB transforms pb to tensor
146-
func DeserializeFromTensorPB(pb *tensor_go_proto.TensorProto) *Tensor {
145+
// DeserializeFromTensorProto transforms pb to tensor
146+
func DeserializeFromTensorProto(pb *tensor_go_proto.TensorProto) *Tensor {
147147
pbDim := pb.GetTensorShape().GetDim()
148148
dims := make([]int64, len(pbDim), len(pbDim))
149149
for i, iDim := range pbDim {
@@ -160,8 +160,8 @@ func DeserializeFromTensorPB(pb *tensor_go_proto.TensorProto) *Tensor {
160160
}
161161
}
162162

163-
// SerializeToTensor transforms tensor to pb
164-
func (t *Tensor) SerializeToTensor() *tensor_go_proto.TensorProto {
163+
// SerializeToTensorProto transforms tensor to pb
164+
func (t *Tensor) SerializeToTensorProto() *tensor_go_proto.TensorProto {
165165
shapeDim := make([]*tensor_shape_go_proto.TensorShapeProto_Dim, len(t.Dims), len(t.Dims))
166166
for i, dim := range t.Dims {
167167
shapeDim[i] = &tensor_shape_go_proto.TensorShapeProto_Dim{
@@ -178,35 +178,35 @@ func (t *Tensor) SerializeToTensor() *tensor_go_proto.TensorProto {
178178
}
179179
}
180180

181-
// IndexedTensor : IndexedSlice in memory representation
182-
type IndexedTensor struct {
183-
Tensor
184-
Indices []int64
181+
// IndexedSlices : IndexedSlice in memory representation
182+
type IndexedSlices struct {
183+
ConcatTensors *Tensor
184+
Ids []int64
185185
}
186186

187-
// NewIndexedTensor return a IndexedTensor instance
188-
func NewIndexedTensor(t *Tensor, indices []int64) *IndexedTensor {
189-
return &IndexedTensor{
190-
Tensor: *t,
191-
Indices: indices,
187+
// NewIndexedSlices return a IndexedTensor instance
188+
func NewIndexedSlices(t *Tensor, ids []int64) *IndexedSlices {
189+
return &IndexedSlices{
190+
ConcatTensors: t,
191+
Ids: ids,
192192
}
193193
}
194194

195-
// SerializeToIndexedSlices return proto.IndexedSlices
196-
func (t *IndexedTensor) SerializeToIndexedSlices() *proto.IndexedSlices {
197-
if t.Dims[0] != int64(len(t.Indices)) || len(t.Dims) != 2 {
195+
// SerializeToIndexedSlicesProto return proto.IndexedSlices
196+
func (t *IndexedSlices) SerializeToIndexedSlicesProto() *proto.IndexedSlicesProto {
197+
if t.ConcatTensors.Dims[0] != int64(len(t.Ids)) || len(t.ConcatTensors.Dims) != 2 {
198198
return nil
199199
}
200-
return &proto.IndexedSlices{
201-
ConcatTensors: t.SerializeToTensor(),
202-
Ids: t.Indices,
200+
return &proto.IndexedSlicesProto{
201+
ConcatTensors: t.ConcatTensors.SerializeToTensorProto(),
202+
Ids: t.Ids,
203203
}
204204
}
205205

206-
// DeserializeFromIndexedSlicePB return common.IndexedTensor
207-
func DeserializeFromIndexedSlicePB(pb *proto.IndexedSlices) *IndexedTensor {
208-
return &IndexedTensor{
209-
Tensor: *DeserializeFromTensorPB(pb.ConcatTensors),
210-
Indices: pb.Ids,
206+
// DeserializeFromIndexedSliceProto return common.IndexedTensor
207+
func DeserializeFromIndexedSliceProto(pb *proto.IndexedSlicesProto) *IndexedSlices {
208+
return &IndexedSlices{
209+
ConcatTensors: DeserializeFromTensorProto(pb.ConcatTensors),
210+
Ids: pb.Ids,
211211
}
212212
}

elasticdl/pkg/commonnew/tensor_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ func TestPbTransform(t *testing.T) {
6363
Dtype: Float32,
6464
}
6565

66-
t1 := DeserializeFromTensorPB(&pb)
66+
t1 := DeserializeFromTensorProto(&pb)
6767
assert.Equal(t, Slice(t1).([]float32), val, "Deserialize FAIL")
6868

69-
pb2 := t1.SerializeToTensor()
69+
pb2 := t1.SerializeToTensorProto()
7070
assert.Equal(t, pb2.GetTensorContent(), bval, "Serialize FAIL")
7171
}

elasticdl/pkg/kernelnew/kernel.go

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,23 @@ func SGD(grad *commonnew.Tensor, param *commonnew.Tensor, lr float32) error {
1919
}
2020

2121
// SparseSGD kernel
22-
func SparseSGD(grad *commonnew.IndexedTensor, param *commonnew.EmbeddingTable, lr float32) error {
23-
if grad.Dims[1] != param.Dim {
22+
func SparseSGD(grad *commonnew.IndexedSlices, param *commonnew.EmbeddingTable, lr float32) error {
23+
if grad.ConcatTensors.Dims[1] != param.Dim {
2424
return fmt.Errorf("grad width is not equal to embedding dim")
2525
}
26-
for i, index := range grad.Indices {
26+
for i, index := range grad.Ids {
2727
vector := param.GetEmbeddingVector(index)
28-
subGrad := grad.GetRow(int64(i))
28+
subGrad := grad.ConcatTensors.GetRow(int64(i))
2929
SGD(subGrad, vector, lr)
3030
}
3131
return nil
3232
}
3333

3434
// IndexedSGD kernel
35-
func IndexedSGD(grad *commonnew.IndexedTensor, param *commonnew.Tensor, lr float32) error {
36-
for i, index := range grad.Indices {
35+
func IndexedSGD(grad *commonnew.IndexedSlices, param *commonnew.Tensor, lr float32) error {
36+
for i, index := range grad.Ids {
3737
vector := param.GetRow(index)
38-
subGrad := grad.GetRow(int64(i))
38+
subGrad := grad.ConcatTensors.GetRow(int64(i))
3939
SGD(subGrad, vector, lr)
4040
}
4141
return nil
@@ -62,12 +62,15 @@ func Adam(grad *commonnew.Tensor, param *commonnew.Tensor, m *commonnew.Tensor,
6262
}
6363

6464
// SparseAdam kernel
65-
func SparseAdam(grad *commonnew.IndexedTensor, param *commonnew.EmbeddingTable, m *commonnew.EmbeddingTable, v *commonnew.EmbeddingTable, lr float32, step int64, beta1 float32, beta2 float32, epsilon float32, amsgrad bool, maxSquare *commonnew.EmbeddingTable) error {
66-
if grad.Dims[1] != param.Dim {
65+
func SparseAdam(grad *commonnew.IndexedSlices, param *commonnew.EmbeddingTable,
66+
m *commonnew.EmbeddingTable, v *commonnew.EmbeddingTable, lr float32,
67+
step int64, beta1 float32, beta2 float32, epsilon float32, amsgrad bool,
68+
maxSquare *commonnew.EmbeddingTable) error {
69+
if grad.ConcatTensors.Dims[1] != param.Dim {
6770
return fmt.Errorf("grad width is not equal to embedding dim")
6871
}
69-
for i, index := range grad.Indices {
70-
subgrad := grad.GetRow(int64(i))
72+
for i, index := range grad.Ids {
73+
subgrad := grad.ConcatTensors.GetRow(int64(i))
7174
subparam := param.GetEmbeddingVector(index)
7275
subm := m.GetEmbeddingVector(index)
7376
subv := v.GetEmbeddingVector(index)
@@ -81,12 +84,15 @@ func SparseAdam(grad *commonnew.IndexedTensor, param *commonnew.EmbeddingTable,
8184
}
8285

8386
// IndexedAdam kernel
84-
func IndexedAdam(grad *commonnew.IndexedTensor, param *commonnew.Tensor, m *commonnew.Tensor, v *commonnew.Tensor, lr float32, step int64, beta1 float32, beta2 float32, epsilon float32, amsgrad bool, maxSquare *commonnew.Tensor) error {
85-
if grad.Dims[1] != param.Dims[1] {
87+
func IndexedAdam(grad *commonnew.IndexedSlices, param *commonnew.Tensor,
88+
m *commonnew.Tensor, v *commonnew.Tensor, lr float32, step int64,
89+
beta1 float32, beta2 float32, epsilon float32, amsgrad bool,
90+
maxSquare *commonnew.Tensor) error {
91+
if grad.ConcatTensors.Dims[1] != param.Dims[1] {
8692
return fmt.Errorf("grad width is not equal to embedding dim")
8793
}
88-
for i, index := range grad.Indices {
89-
subgrad := grad.GetRow(int64(i))
94+
for i, index := range grad.Ids {
95+
subgrad := grad.ConcatTensors.GetRow(int64(i))
9096
subparam := param.GetRow(index)
9197
subm := m.GetRow(index)
9298
subv := v.GetRow(index)

elasticdl/pkg/kernelnew/kernel_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ func TestSparseSGD(t *testing.T) {
3838
d := []int64{3, 2}
3939
indices := []int64{1, 3, 3}
4040
grad := commonnew.NewTensor(a, d)
41-
isgrad := commonnew.NewIndexedTensor(grad, indices)
41+
isgrad := commonnew.NewIndexedSlices(grad, indices)
4242

4343
table := commonnew.NewEmbeddingTable(2, "zero", commonnew.Float32)
4444

@@ -194,7 +194,7 @@ func TestSparseAdam(t *testing.T) {
194194
m := commonnew.NewTensor(rawM, dim)
195195
v := commonnew.NewTensor(rawV, dim)
196196
maxSquare := commonnew.NewTensor(rawMaxSquare, dim)
197-
isgrad := commonnew.NewIndexedTensor(grad, []int64{1})
197+
isgrad := commonnew.NewIndexedSlices(grad, []int64{1})
198198

199199
ptable.EmbeddingVectors[1] = param
200200
mtable.EmbeddingVectors[1] = m

elasticdl/pkg/ps/model.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package ps
2+
3+
import (
4+
"elasticdl.org/elasticdl/pkg/commonnew"
5+
"elasticdl.org/elasticdl/pkg/proto"
6+
"fmt"
7+
)
8+
9+
// Model contains dense parameters and embedding tables
10+
type Model struct {
11+
DenseParameters map[string]*commonnew.Tensor
12+
EmbeddingTables map[string]*commonnew.EmbeddingTable
13+
Version int32
14+
Initialized bool
15+
}
16+
17+
// NewModel creates a model instance
18+
func NewModel() *Model {
19+
return &Model{
20+
DenseParameters: make(map[string]*commonnew.Tensor),
21+
EmbeddingTables: make(map[string]*commonnew.EmbeddingTable),
22+
}
23+
}
24+
25+
// GetDenseParameter returns dense parameter pointer
26+
func (model *Model) GetDenseParameter(name string) *commonnew.Tensor {
27+
if value, ok := model.DenseParameters[name]; ok {
28+
return value
29+
}
30+
return nil
31+
}
32+
33+
// GetEmbeddingTable returns embedding table pointer
34+
func (model *Model) GetEmbeddingTable(name string) *commonnew.EmbeddingTable {
35+
if value, ok := model.EmbeddingTables[name]; ok {
36+
return value
37+
}
38+
return nil
39+
}
40+
41+
// SetEmbeddingTableInfo sets embedding table info of an embedding param
42+
func (model *Model) SetEmbeddingTableInfo(info *proto.EmbeddingTableInfo) {
43+
if _, ok := model.EmbeddingTables[info.Name]; ok {
44+
return
45+
}
46+
t := commonnew.NewEmbeddingTable(info.Dim, info.Initializer, info.Dtype)
47+
model.EmbeddingTables[info.Name] = t
48+
}
49+
50+
// InitFromModelPB inits the model from model PB
51+
func (model *Model) InitFromModelPB(pb *proto.Model) error {
52+
for _, v := range pb.EmbeddingTableInfo {
53+
model.SetEmbeddingTableInfo(v)
54+
}
55+
for name, v := range pb.DenseParameters {
56+
model.DenseParameters[name] = commonnew.DeserializeFromTensorProto(v)
57+
}
58+
for name, v := range pb.EmbeddingTables {
59+
table := model.GetEmbeddingTable(name)
60+
if table == nil {
61+
return fmt.Errorf("Embedding table %s is not created", name)
62+
}
63+
iv := commonnew.DeserializeFromIndexedSliceProto(v)
64+
err := model.EmbeddingTables[name].SetEmbeddingVectors(iv)
65+
if err != nil {
66+
return err
67+
}
68+
}
69+
if pb.Version >= 0 {
70+
model.Version = pb.Version
71+
}
72+
return nil
73+
}

elasticdl/pkg/ps/model_test.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package ps
2+
3+
import (
4+
"elasticdl.org/elasticdl/pkg/commonnew"
5+
"elasticdl.org/elasticdl/pkg/proto"
6+
"github.com/stretchr/testify/assert"
7+
"testing"
8+
)
9+
10+
func TestModelInit(t *testing.T) {
11+
d1 := []int64{2, 3}
12+
v1 := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}
13+
t1 := commonnew.NewTensor(v1, d1)
14+
15+
d2 := []int64{2, 2}
16+
v2 := []float32{1.0, 2.0, 1.1, 2.2}
17+
t2 := commonnew.NewTensor(v2, d2)
18+
19+
model := NewModel()
20+
model.DenseParameters["t1"] = t1
21+
model.DenseParameters["t2"] = t2
22+
23+
assert.Len(t, model.DenseParameters, 2)
24+
assert.Contains(t, model.DenseParameters, "t1")
25+
assert.Contains(t, model.DenseParameters, "t2")
26+
27+
assert.Equal(t, model.GetDenseParameter("t1").Dims, d1)
28+
assert.Equal(t, model.GetDenseParameter("t2").Dims, d2)
29+
assert.Nil(t, model.GetDenseParameter("t3"))
30+
}
31+
32+
func TestModelInitFrom(t *testing.T) {
33+
var modelPB = proto.Model{
34+
Version: int32(1),
35+
EmbeddingTables: make(map[string]*proto.IndexedSlicesProto),
36+
EmbeddingTableInfo: []*proto.EmbeddingTableInfo{},
37+
}
38+
d1 := []int64{3, 2}
39+
v1 := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}
40+
t1 := commonnew.NewTensor(v1, d1)
41+
42+
i1 := []int64{1, 3, 5}
43+
var is = commonnew.IndexedSlices{
44+
ConcatTensors: t1,
45+
Ids: i1,
46+
}
47+
isPB := is.SerializeToIndexedSlicesProto()
48+
modelPB.EmbeddingTables["e1"] = isPB
49+
50+
var epb = proto.EmbeddingTableInfo{
51+
Name: "e1",
52+
Dim: 2,
53+
Initializer: "zero",
54+
Dtype: commonnew.Float32,
55+
}
56+
modelPB.EmbeddingTableInfo = append(modelPB.EmbeddingTableInfo, &epb)
57+
58+
model := NewModel()
59+
assert.NotNil(t, model)
60+
err := model.InitFromModelPB(&modelPB)
61+
62+
assert.Nil(t, err)
63+
assert.Contains(t, model.EmbeddingTables, "e1")
64+
65+
e1 := model.GetEmbeddingTable("e1")
66+
assert.Equal(t, int64(2), e1.Dim)
67+
assert.Equal(t, 3, len(e1.EmbeddingVectors))
68+
69+
ev1 := e1.GetEmbeddingVector(1)
70+
assert.True(t, commonnew.CompareFloatArray([]float32{1.0, 2.0}, commonnew.Slice(ev1).([]float32), 0.0001))
71+
72+
ev3 := e1.GetEmbeddingVector(3)
73+
assert.True(t, commonnew.CompareFloatArray([]float32{3.0, 4.0}, commonnew.Slice(ev3).([]float32), 0.0001))
74+
75+
ev5 := e1.GetEmbeddingVector(5)
76+
assert.True(t, commonnew.CompareFloatArray([]float32{5.0, 6.0}, commonnew.Slice(ev5).([]float32), 0.0001))
77+
}

elasticdl/proto/elasticdl.proto

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ message Tensor {
7070
tensorflow.DataType dtype = 5;
7171
}
7272

73-
message IndexedSlices {
73+
message IndexedSlicesProto {
7474
tensorflow.TensorProto concat_tensors = 1;
7575
repeated int64 ids = 2;
7676
}
@@ -87,7 +87,7 @@ message Model {
8787
repeated Tensor param = 2;
8888
repeated EmbeddingTableInfo embedding_table_info = 3;
8989
map<string, tensorflow.TensorProto> dense_parameters = 4;
90-
map<string, IndexedSlices> embedding_tables = 5;
90+
map<string, IndexedSlicesProto> embedding_tables = 5;
9191
}
9292

9393
message GetTaskRequest {

elasticdl/python/common/tensor_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,6 @@ def serialize_indexed_slices(slices, pb):
9696

9797

9898
def indexed_slices_to_pb(slices):
99-
pb = elasticdl_pb2.IndexedSlices()
99+
pb = elasticdl_pb2.IndexedSlicesProto()
100100
serialize_indexed_slices(slices, pb)
101101
return pb

0 commit comments

Comments
 (0)