Skip to content

Commit a726174

Browse files
committed
CreateMode: fixing incorrect create mode TTL including refactor of create flags
BREAKING Client behavior: since this proposes a new parsing of the Create flag integer this will break clients if they rely on CreateContainer as it was creating znodes that may not have been containers. Changes: - Fix FlagTTL value from 4 -> 5 - Adding all known CreateMode values to Flag constants. - Adding a createMode private struct behind the flag integer, replicate CreateMode from ZK java lib. - Create, CreateContainer, CreateTTL methods now parse the flag integer passed in based on the constants defined. - Rewrite tests to better catch CreateContainer and CreateTTL (no assertions on zk behavior since zk does not expose znode mode) - Added Change-Detector unit tests for create mode values.
1 parent 6131812 commit a726174

File tree

6 files changed

+242
-79
lines changed

6 files changed

+242
-79
lines changed

conn.go

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,43 +1055,63 @@ func (c *Conn) Set(path string, data []byte, version int32) (*Stat, error) {
10551055
// same as the input, for example when creating a sequence znode the returned path
10561056
// will be the input path with a sequence number appended.
10571057
func (c *Conn) Create(path string, data []byte, flags int32, acl []ACL) (string, error) {
1058-
if err := validatePath(path, flags&FlagSequence == FlagSequence); err != nil {
1058+
createMode, err := parseCreateMode(flags)
1059+
if err != nil {
1060+
return "", err
1061+
}
1062+
1063+
if err := validatePath(path, createMode.sequential); err != nil {
10591064
return "", err
10601065
}
10611066

10621067
res := &createResponse{}
1063-
_, err := c.request(opCreate, &CreateRequest{path, data, acl, flags}, res, nil)
1068+
_, err = c.request(opCreate, &CreateRequest{path, data, acl, createMode.toFlag()}, res, nil)
10641069
if err == ErrConnectionClosed {
10651070
return "", err
10661071
}
10671072
return res.Path, err
10681073
}
10691074

10701075
// CreateContainer creates a container znode and returns the path.
1071-
func (c *Conn) CreateContainer(path string, data []byte, flags int32, acl []ACL) (string, error) {
1072-
if err := validatePath(path, flags&FlagSequence == FlagSequence); err != nil {
1076+
//
1077+
// Containers cannot be ephemeral or sequential, or have TTLs.
1078+
// Ensure that we reject flags for TTL, Sequence, and Ephemeral.
1079+
func (c *Conn) CreateContainer(path string, data []byte, flag int32, acl []ACL) (string, error) {
1080+
createMode, err := parseCreateMode(flag)
1081+
if err != nil {
10731082
return "", err
10741083
}
1075-
if flags&FlagTTL != FlagTTL {
1084+
1085+
if err := validatePath(path, createMode.sequential); err != nil {
1086+
return "", err
1087+
}
1088+
1089+
if !createMode.isContainer {
10761090
return "", ErrInvalidFlags
10771091
}
10781092

10791093
res := &createResponse{}
1080-
_, err := c.request(opCreateContainer, &CreateContainerRequest{path, data, acl, flags}, res, nil)
1094+
_, err = c.request(opCreateContainer, &CreateRequest{path, data, acl, createMode.toFlag()}, res, nil)
10811095
return res.Path, err
10821096
}
10831097

10841098
// CreateTTL creates a TTL znode, which will be automatically deleted by server after the TTL.
1085-
func (c *Conn) CreateTTL(path string, data []byte, flags int32, acl []ACL, ttl time.Duration) (string, error) {
1086-
if err := validatePath(path, flags&FlagSequence == FlagSequence); err != nil {
1099+
func (c *Conn) CreateTTL(path string, data []byte, flag int32, acl []ACL, ttl time.Duration) (string, error) {
1100+
createMode, err := parseCreateMode(flag)
1101+
if err != nil {
10871102
return "", err
10881103
}
1089-
if flags&FlagTTL != FlagTTL {
1104+
1105+
if err := validatePath(path, createMode.sequential); err != nil {
1106+
return "", err
1107+
}
1108+
1109+
if !createMode.isTTL {
10901110
return "", ErrInvalidFlags
10911111
}
10921112

10931113
res := &createResponse{}
1094-
_, err := c.request(opCreateTTL, &CreateTTLRequest{path, data, acl, flags, ttl.Milliseconds()}, res, nil)
1114+
_, err = c.request(opCreateTTL, &CreateTTLRequest{path, data, acl, createMode.toFlag(), ttl.Milliseconds()}, res, nil)
10951115
return res.Path, err
10961116
}
10971117

constants.go

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,6 @@ const (
7575
StateHasSession = State(101)
7676
)
7777

78-
const (
79-
// FlagEphemeral means the node is ephemeral.
80-
FlagEphemeral = 1
81-
FlagSequence = 2
82-
FlagTTL = 4
83-
)
84-
8578
var (
8679
stateNames = map[State]string{
8780
StateUnknown: "StateUnknown",

create_mode.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package zk
2+
3+
import "fmt"
4+
5+
const (
6+
FlagPersistent = 0
7+
FlagEphemeral = 1
8+
FlagSequence = 2
9+
FlagEphemeralSequential = 3
10+
FlagContainer = 4
11+
FlagTTL = 5
12+
FlagPersistentSequentialWithTTL = 6
13+
)
14+
15+
type createMode struct {
16+
flag int32
17+
ephemeral bool
18+
sequential bool
19+
isContainer bool
20+
isTTL bool
21+
}
22+
23+
func (cm *createMode) toFlag() int32 {
24+
return cm.flag
25+
}
26+
27+
// parsing a flag integer into the CreateMode needed to call the correct
28+
// Create RPC to Zookeeper.
29+
//
30+
// NOTE: This parse method is designed to be able to copy and paste the same
31+
// CreateMode ENUM constructors from Java:
32+
// https://github.com/apache/zookeeper/blob/master/zookeeper-server/src/main/java/org/apache/zookeeper/CreateMode.java
33+
func parseCreateMode(flag int32) (createMode, error) {
34+
switch flag {
35+
case FlagPersistent:
36+
return createMode{0, false, false, false, false}, nil
37+
case FlagEphemeral:
38+
return createMode{1, true, false, false, false}, nil
39+
case FlagSequence:
40+
return createMode{2, false, true, false, false}, nil
41+
case FlagEphemeralSequential:
42+
return createMode{3, true, true, false, false}, nil
43+
case FlagContainer:
44+
return createMode{4, false, false, true, false}, nil
45+
case FlagTTL:
46+
return createMode{5, false, false, false, true}, nil
47+
case FlagPersistentSequentialWithTTL:
48+
return createMode{6, false, true, false, true}, nil
49+
default:
50+
return createMode{}, fmt.Errorf("invalid flag value: [%v]", flag)
51+
}
52+
}

create_mode_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package zk
2+
3+
import (
4+
"strings"
5+
"testing"
6+
)
7+
8+
func TestParseCreateMode(t *testing.T) {
9+
changeDetectorTests := []struct {
10+
name string
11+
flag int32
12+
wantIntValue int32
13+
}{
14+
{"valid flag createmode 0 persistant", FlagPersistent, 0},
15+
{"ephemeral", FlagEphemeral, 1},
16+
{"sequential", FlagSequence, 2},
17+
{"ephemeral sequential", FlagEphemeralSequential, 3},
18+
{"container", FlagContainer, 4},
19+
{"ttl", FlagTTL, 5},
20+
{"persistentSequential w/TTL", FlagPersistentSequentialWithTTL, 6},
21+
}
22+
for _, tt := range changeDetectorTests {
23+
t.Run(tt.name, func(t *testing.T) {
24+
cm, err := parseCreateMode(tt.flag)
25+
requireNoError(t, err)
26+
if cm.toFlag() != tt.wantIntValue {
27+
// change detector test for enum values.
28+
t.Fatalf("createmode value of flag; want: %v, got: %v", cm.toFlag(), tt.wantIntValue)
29+
}
30+
})
31+
}
32+
33+
t.Run("failed to parse", func(t *testing.T) {
34+
cm, err := parseCreateMode(-123)
35+
if err == nil {
36+
t.Fatalf("error expected, got: %v", cm)
37+
}
38+
if !strings.Contains(err.Error(), "invalid flag value") {
39+
t.Fatalf("unexpected error value: %v", err)
40+
}
41+
})
42+
43+
}

structs.go

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,6 @@ type CreateRequest struct {
166166
Flags int32
167167
}
168168

169-
type CreateContainerRequest CreateRequest
170-
171169
type CreateTTLRequest struct {
172170
Path string
173171
Data []byte
@@ -598,10 +596,8 @@ func requestStructForOp(op int32) interface{} {
598596
switch op {
599597
case opClose:
600598
return &closeRequest{}
601-
case opCreate:
599+
case opCreate, opCreateContainer:
602600
return &CreateRequest{}
603-
case opCreateContainer:
604-
return &CreateContainerRequest{}
605601
case opCreateTTL:
606602
return &CreateTTLRequest{}
607603
case opDelete:

zk_test.go

Lines changed: 116 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -109,46 +109,59 @@ func TestIntegration_CreateTTL(t *testing.T) {
109109
}
110110
defer zk.Close()
111111

112-
path := "/gozk-test"
113-
114-
if err := zk.Delete(path, -1); err != nil && err != ErrNoNode {
115-
t.Fatalf("Delete returned error: %+v", err)
116-
}
117-
if _, err := zk.CreateTTL("", []byte{1, 2, 3, 4}, FlagTTL|FlagEphemeral, WorldACL(PermAll), 60*time.Second); err != ErrInvalidPath {
118-
t.Fatalf("Create path check failed")
119-
}
120-
if _, err := zk.CreateTTL(path, []byte{1, 2, 3, 4}, 0, WorldACL(PermAll), 60*time.Second); err != ErrInvalidFlags {
121-
t.Fatalf("Create flags check failed")
122-
}
123-
if p, err := zk.CreateTTL(path, []byte{1, 2, 3, 4}, FlagTTL|FlagEphemeral, WorldACL(PermAll), 60*time.Second); err != nil {
124-
t.Fatalf("Create returned error: %+v", err)
125-
} else if p != path {
126-
t.Fatalf("Create returned different path '%s' != '%s'", p, path)
127-
}
128-
if data, stat, err := zk.Get(path); err != nil {
129-
t.Fatalf("Get returned error: %+v", err)
130-
} else if stat == nil {
131-
t.Fatal("Get returned nil stat")
132-
} else if len(data) < 4 {
133-
t.Fatal("Get returned wrong size data")
134-
}
112+
tests := []struct {
113+
name string
114+
createFlags int32
115+
giveDuration time.Duration
116+
wantErr string
117+
}{
118+
{
119+
name: "valid create ttl",
120+
createFlags: FlagTTL,
121+
giveDuration: time.Minute,
122+
},
123+
{
124+
name: "valid change detector",
125+
createFlags: 5,
126+
giveDuration: time.Minute,
127+
},
128+
{
129+
name: "invalid flag for create mode",
130+
createFlags: 999,
131+
giveDuration: time.Minute,
132+
wantErr: "invalid flag value: [999]",
133+
},
134+
}
135+
136+
const testPath = "/ttl_znode_tests"
137+
// create sub node to create per test in avoiding using the root path.
138+
_, err = zk.Create(testPath, nil /* data */, FlagPersistent, WorldACL(PermAll))
139+
requireNoError(t, err)
140+
141+
for idx, tt := range tests {
142+
t.Run(tt.name, func(t *testing.T) {
143+
path := filepath.Join(testPath, fmt.Sprint(idx))
144+
_, err := zk.CreateTTL(path, []byte{12}, tt.createFlags, WorldACL(PermAll), tt.giveDuration)
145+
if tt.wantErr == "" {
146+
requireNoError(t, err, fmt.Sprintf("error not expected: path; %q; flags %v", path, tt.createFlags))
147+
return
148+
}
135149

136-
if err := zk.Delete(path, -1); err != nil && err != ErrNoNode {
137-
t.Fatalf("Delete returned error: %+v", err)
138-
}
139-
if p, err := zk.CreateTTL(path, []byte{1, 2, 3, 4}, FlagTTL|FlagSequence, WorldACL(PermAll), 60*time.Second); err != nil {
140-
t.Fatalf("Create returned error: %+v", err)
141-
} else if !strings.HasPrefix(p, path) {
142-
t.Fatalf("Create returned invalid path '%s' are not '%s' with sequence", p, path)
143-
} else if data, stat, err := zk.Get(p); err != nil {
144-
t.Fatalf("Get returned error: %+v", err)
145-
} else if stat == nil {
146-
t.Fatal("Get returned nil stat")
147-
} else if len(data) < 4 {
148-
t.Fatal("Get returned wrong size data")
150+
// want an error
151+
if err == nil {
152+
t.Fatalf("did not get expected error: %q", tt.wantErr)
153+
}
154+
if !strings.Contains(err.Error(), tt.wantErr) {
155+
t.Fatalf("wanted error not found: %v; got: %v", tt.wantErr, err.Error())
156+
}
157+
})
149158
}
150159
}
151160

161+
// NOTE: Currently there is not a way to get the znode after creating and
162+
// asserting it is once mode or another. This means these tests are only testing the
163+
// path of creation, but is not asserting that the resulting znode is the
164+
// mode we set with flags.
152165
func TestIntegration_CreateContainer(t *testing.T) {
153166
ts, err := StartTestCluster(t, 1, nil, logWriter{t: t, p: "[ZKERR] "})
154167
if err != nil {
@@ -161,28 +174,74 @@ func TestIntegration_CreateContainer(t *testing.T) {
161174
}
162175
defer zk.Close()
163176

164-
path := "/gozk-test"
177+
tests := []struct {
178+
name string
179+
createFlags int32
180+
wantErr string
181+
}{
182+
{
183+
name: "valid create container",
184+
createFlags: FlagContainer,
185+
},
186+
{
187+
name: "valid create container hard coded flag int",
188+
createFlags: 4,
189+
// container flag, ensure matches ZK Create Mode (change detector test)
190+
},
191+
{
192+
name: "invalid create mode",
193+
createFlags: 999,
194+
wantErr: "invalid flag value: [999]",
195+
},
196+
{
197+
name: "invalid containers cannot be persistant",
198+
createFlags: FlagPersistent,
199+
wantErr: ErrInvalidFlags.Error(),
200+
},
201+
{
202+
name: "invalid containers cannot be ephemeral",
203+
createFlags: FlagEphemeral,
204+
wantErr: ErrInvalidFlags.Error(),
205+
},
206+
{
207+
name: "invalid containers cannot be sequential",
208+
createFlags: FlagSequence,
209+
wantErr: ErrInvalidFlags.Error(),
210+
},
211+
{
212+
name: "invalid container and sequential",
213+
createFlags: FlagContainer | FlagSequence,
214+
wantErr: ErrInvalidFlags.Error(),
215+
},
216+
{
217+
name: "invliad TTLs cannot be used with Container znodes",
218+
createFlags: FlagTTL,
219+
wantErr: ErrInvalidFlags.Error(),
220+
},
221+
}
222+
223+
const testPath = "/container_test_znode"
224+
// create sub node to create per test in avoiding using the root path.
225+
_, err = zk.Create(testPath, nil /* data */, FlagPersistent, WorldACL(PermAll))
226+
requireNoError(t, err)
227+
228+
for idx, tt := range tests {
229+
t.Run(tt.name, func(t *testing.T) {
230+
path := filepath.Join(testPath, fmt.Sprint(idx))
231+
_, err := zk.CreateContainer(path, []byte{12}, tt.createFlags, WorldACL(PermAll))
232+
if tt.wantErr == "" {
233+
requireNoError(t, err, fmt.Sprintf("error not expected: path; %q; flags %v", path, tt.createFlags))
234+
return
235+
}
165236

166-
if err := zk.Delete(path, -1); err != nil && err != ErrNoNode {
167-
t.Fatalf("Delete returned error: %+v", err)
168-
}
169-
if _, err := zk.CreateContainer("", []byte{1, 2, 3, 4}, FlagTTL, WorldACL(PermAll)); err != ErrInvalidPath {
170-
t.Fatalf("Create path check failed")
171-
}
172-
if _, err := zk.CreateContainer(path, []byte{1, 2, 3, 4}, 0, WorldACL(PermAll)); err != ErrInvalidFlags {
173-
t.Fatalf("Create flags check failed")
174-
}
175-
if p, err := zk.CreateContainer(path, []byte{1, 2, 3, 4}, FlagTTL, WorldACL(PermAll)); err != nil {
176-
t.Fatalf("Create returned error: %+v", err)
177-
} else if p != path {
178-
t.Fatalf("Create returned different path '%s' != '%s'", p, path)
179-
}
180-
if data, stat, err := zk.Get(path); err != nil {
181-
t.Fatalf("Get returned error: %+v", err)
182-
} else if stat == nil {
183-
t.Fatal("Get returned nil stat")
184-
} else if len(data) < 4 {
185-
t.Fatal("Get returned wrong size data")
237+
// want an error
238+
if err == nil {
239+
t.Fatalf("did not get expected error: %q", tt.wantErr)
240+
}
241+
if !strings.Contains(err.Error(), tt.wantErr) {
242+
t.Fatalf("wanted error not found: %v; got: %v", tt.wantErr, err.Error())
243+
}
244+
})
186245
}
187246
}
188247

0 commit comments

Comments
 (0)