Skip to content

fix(mongodb): replica set initialization & connection handling #2984

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 36 additions & 7 deletions modules/mongodb/mongodb.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
_ "embed"
"errors"
"fmt"
"net"
"net/url"
"time"

"github.com/testcontainers/testcontainers-go"
Expand Down Expand Up @@ -125,10 +127,23 @@ func (c *MongoDBContainer) ConnectionString(ctx context.Context) (string, error)
if err != nil {
return "", err
}
u := url.URL{
Scheme: "mongodb",
Host: net.JoinHostPort(host, port.Port()),
Path: "/",
}

if c.username != "" && c.password != "" {
return fmt.Sprintf("mongodb://%s:%s@%s:%s", c.username, c.password, host, port.Port()), nil
u.User = url.UserPassword(c.username, c.password)
}

if c.replicaSet != "" {
q := url.Values{}
q.Add("replicaSet", c.replicaSet)
u.RawQuery = q.Encode()
}
return c.Endpoint(ctx, "mongodb")

return u.String(), nil
}

func setupEntrypointForAuth(req *testcontainers.GenericContainerRequest) {
Expand Down Expand Up @@ -176,17 +191,27 @@ func initiateReplicaSet(req *testcontainers.GenericContainerRequest, cli mongoCl
req.LifecycleHooks, testcontainers.ContainerLifecycleHooks{
PostStarts: []testcontainers.ContainerHook{
func(ctx context.Context, c testcontainers.Container) error {
ip, err := c.ContainerIP(ctx)
// Wait for MongoDB to be ready
if err := waitForMongoReady(ctx, c, cli); err != nil {
return fmt.Errorf("wait for mongo: %w", err)
}

// Initiate replica set
host, err := c.Host(ctx)
if err != nil {
return fmt.Errorf("get host: %w", err)
}
mappedPort, err := c.MappedPort(ctx, "27017/tcp")
if err != nil {
return fmt.Errorf("container ip: %w", err)
return fmt.Errorf("get mapped port: %w", err)
}

cmd := cli.eval(
"rs.initiate({ _id: '%s', members: [ { _id: 0, host: '%s:27017' } ] })",
"rs.initiate({ _id: '%s', members: [ { _id: 0, host: '%s:%s' } ] })",
replSetName,
ip,
host,
mappedPort.Port(),
)

return wait.ForExec(cmd).WaitUntilReady(ctx, c)
},
},
Expand All @@ -208,3 +233,7 @@ func withAuthReplicaset(
return nil
}
}

func waitForMongoReady(ctx context.Context, c testcontainers.Container, cli mongoCli) error {
return wait.ForExec(cli.eval("db.runCommand({ ping: 1 })")).WaitUntilReady(ctx, c)
}
93 changes: 88 additions & 5 deletions modules/mongodb/mongodb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ package mongodb_test

import (
"context"
"errors"
"fmt"
"net"
"net/url"
"testing"

"github.com/stretchr/testify/require"
Expand All @@ -13,7 +17,52 @@ import (
"github.com/testcontainers/testcontainers-go/modules/mongodb"
)

func localNonLoopbackIP() (string, error) {
interfaces, err := net.Interfaces()
if err != nil {
return "", fmt.Errorf("list network interfaces: %w", err)
}

for _, iface := range interfaces {
// Skip down or loopback interfaces.
if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagLoopback != 0 {
continue
}

addrs, err := iface.Addrs()
if err != nil {
continue // try next interface
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
default:
continue
}
// Check if it's a valid IPv4 and not loopback.
if ip.IsLoopback() {
continue
}
ip = ip.To4()
if ip == nil {
continue // not IPv4
}
return ip.String(), nil
}
}
return "", errors.New("no non-loopback IPv4 address found")
}

func TestMongoDB(t *testing.T) {
host, err := localNonLoopbackIP()
if err != nil {
host = "host.docker.internal"
}
t.Setenv("TESTCONTAINERS_HOST_OVERRIDE", host)
type tests struct {
name string
img string
Expand Down Expand Up @@ -125,18 +174,52 @@ func TestMongoDB(t *testing.T) {
endpoint, err := mongodbContainer.ConnectionString(ctx)
require.NoError(tt, err)

// Force direct connection to the container to avoid the replica set
// connection string that is returned by the container itself when
// using the replica set option.
// Force direct connection to the container.
mongoClient, err := mongo.Connect(ctx, options.Client().ApplyURI(endpoint).SetDirect(true))
require.NoError(tt, err)

err = mongoClient.Ping(ctx, nil)
require.NoError(tt, err)
require.Equal(t, "test", mongoClient.Database("test").Name())
require.Equal(tt, "test", mongoClient.Database("test").Name())

_, err = mongoClient.Database("testcontainer").Collection("test").InsertOne(context.Background(), bson.M{})
// Basic insert test.
_, err = mongoClient.Database("testcontainer").Collection("test").InsertOne(ctx, bson.M{})
require.NoError(tt, err)

// If the container is configured with a replica set, run the change stream test.
if hasReplica, _ := hasReplicaSet(endpoint); hasReplica {
coll := mongoClient.Database("test").Collection("changes")
stream, err := coll.Watch(ctx, mongo.Pipeline{})
require.NoError(tt, err)
defer stream.Close(ctx)

doc := bson.M{"message": "hello change streams"}
_, err = coll.InsertOne(ctx, doc)
require.NoError(tt, err)

require.True(tt, stream.Next(ctx))
var changeEvent bson.M
err = stream.Decode(&changeEvent)
require.NoError(tt, err)

opType, ok := changeEvent["operationType"].(string)
require.True(tt, ok, "Expected operationType field")
require.Equal(tt, "insert", opType, "Expected operationType to be 'insert'")

fullDoc, ok := changeEvent["fullDocument"].(bson.M)
require.True(tt, ok, "Expected fullDocument field")
require.Equal(tt, "hello change streams", fullDoc["message"])
}
})
}
}

// hasReplicaSet checks if the connection string includes a replicaSet query parameter.
func hasReplicaSet(connStr string) (bool, error) {
u, err := url.Parse(connStr)
if err != nil {
return false, fmt.Errorf("parse connection string: %w", err)
}
q := u.Query()
return q.Get("replicaSet") != "", nil
}
Loading