Skip to content

grpc: fix bug causing an extra Read if a compressed message is the same size as the limit #8178

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions clientconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import (

const (
defaultTestTimeout = 10 * time.Second
defaultTestShortTimeout = 10 * time.Millisecond
stateRecordingBalancerName = "state_recording_balancer"
grpclbServiceConfig = `{"loadBalancingConfig": [{"grpclb": {}}]}`
rrServiceConfig = `{"loadBalancingPolicy": [{"round_robin": {}}]}`
Expand Down
16 changes: 8 additions & 8 deletions rpc_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -870,13 +870,19 @@ func decompress(compressor encoding.Compressor, d mem.BufferSlice, dc Decompress
return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the message: %v", err)
}

out, err := mem.ReadAll(io.LimitReader(dcReader, int64(maxReceiveMessageSize)), pool)
// Read at most one byte more than the limit from the decompressor.
// Unless the limit is MaxInt64, in which case, that's impossible, so
// apply no limit.
if limit := int64(maxReceiveMessageSize); limit < math.MaxInt64 {
dcReader = io.LimitReader(dcReader, limit+1)
}
out, err := mem.ReadAll(dcReader, pool)
if err != nil {
out.Free()
return nil, status.Errorf(codes.Internal, "grpc: failed to read decompressed data: %v", err)
}

if out.Len() == maxReceiveMessageSize && !atEOF(dcReader) {
if out.Len() > maxReceiveMessageSize {
out.Free()
return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max %d", maxReceiveMessageSize)
}
Expand All @@ -885,12 +891,6 @@ func decompress(compressor encoding.Compressor, d mem.BufferSlice, dc Decompress
return nil, status.Errorf(codes.Internal, "grpc: no decompressor available for compressed payload")
}

// atEOF reads data from r and returns true if zero bytes could be read and r.Read returns EOF.
func atEOF(dcReader io.Reader) bool {
n, err := dcReader.Read(make([]byte, 1))
return n == 0 && err == io.EOF
}

type recvCompressor interface {
RecvCompress() string
}
Expand Down
56 changes: 56 additions & 0 deletions rpc_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ package grpc
import (
"bytes"
"compress/gzip"
"context"
"errors"
"io"
"math"
"reflect"
"sync"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -421,3 +423,57 @@ func (s) TestDecompress(t *testing.T) {
})
}
}

type mockCompressor struct {
// Written to by the io.Reader on every call to Read.
ch chan<- struct{}
}

func (m *mockCompressor) Compress(io.Writer) (io.WriteCloser, error) {
panic("unimplemented")
}

func (m *mockCompressor) Decompress(io.Reader) (io.Reader, error) {
return m, nil
}

func (m *mockCompressor) Read([]byte) (int, error) {
m.ch <- struct{}{}
return 1, io.EOF
}

func (m *mockCompressor) Name() string { return "" }

// Tests that the decompressor's Read method is not called after it returns EOF.
func (s) TestDecompress_NoReadAfterEOF(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

ch := make(chan struct{}, 10)
mc := &mockCompressor{ch: ch}
in := mem.BufferSlice{mem.NewBuffer(&[]byte{1, 2, 3}, nil)}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
out, err := decompress(mc, in, nil, 1, mem.DefaultBufferPool())
if err != nil {
t.Errorf("Unexpected error from decompress: %v", err)
return
}
out.Free()
}()
select {
case <-ch:
case <-ctx.Done():
t.Fatalf("Timed out waiting for call to compressor")
}
ctx, cancel = context.WithTimeout(ctx, defaultTestShortTimeout)
defer cancel()
select {
case <-ch:
t.Fatalf("Unexpected second compressor.Read call detected")
case <-ctx.Done():
}
wg.Wait()
}