From 160edf7823ce6376a03767ff5fb04e26ada3a9fd Mon Sep 17 00:00:00 2001 From: Doug Fawley Date: Mon, 17 Mar 2025 16:45:07 -0700 Subject: [PATCH 1/3] grpc: fix bug causing an extra Read if a compressed message is the same size as the limit --- rpc_util.go | 46 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 38 insertions(+), 8 deletions(-) diff --git a/rpc_util.go b/rpc_util.go index a8ddb0af5285..b9ebfba75c97 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -869,14 +869,15 @@ func decompress(compressor encoding.Compressor, d mem.BufferSlice, dc Decompress if err != nil { return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the message: %v", err) } + limitedReader := limitReader(dcReader, int64(maxReceiveMessageSize)) - out, err := mem.ReadAll(io.LimitReader(dcReader, int64(maxReceiveMessageSize)), pool) + out, err := mem.ReadAll(limitedReader, pool) if err != nil { out.Free() return nil, status.Errorf(codes.Internal, "grpc: failed to read decompressed data: %v", err) } - if out.Len() == maxReceiveMessageSize && !atEOF(dcReader) { + if limitedReader.exceeded() { out.Free() return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max %d", maxReceiveMessageSize) } @@ -885,12 +886,6 @@ func decompress(compressor encoding.Compressor, d mem.BufferSlice, dc Decompress return nil, status.Errorf(codes.Internal, "grpc: no decompressor available for compressed payload") } -// atEOF reads data from r and returns true if zero bytes could be read and r.Read returns EOF. -func atEOF(dcReader io.Reader) bool { - n, err := dcReader.Read(make([]byte, 1)) - return n == 0 && err == io.EOF -} - type recvCompressor interface { RecvCompress() string } @@ -1035,6 +1030,41 @@ func setCallInfoCodec(c *callInfo) error { return nil } +type limitedReader struct { + r io.Reader // the underlying reader + n int64 // how many bytes remain before the limit +} + +// limitReader returns a wrapper around r that may read at most one byte more +// than limit to determine if r contains more data than the limit. +func limitReader(r io.Reader, limit int64) *limitedReader { + return &limitedReader{r: r, n: limit} +} + +func (l *limitedReader) Read(p []byte) (n int, err error) { + if l.n < int64(len(p)) { + // We have space in the input buffer to read past the limit remaining in + // l.n. Truncate the input buffer, but read one extra byte to determine + // overflow. + p = p[0 : l.n+1] + } + n, err = l.r.Read(p) + l.n -= int64(n) + if l.n < 0 && err == nil { + // We read more bytes from r than the limit allowed. Convert to io.EOF, + // and exceeded() will return true. + n-- + err = io.EOF + } + return n, err +} + +// exceeded returns true iff l's reader was read beyond the limit specified at +// creation. +func (l *limitedReader) exceeded() bool { + return l.n < 0 +} + // The SupportPackageIsVersion variables are referenced from generated protocol // buffer files to ensure compatibility with the gRPC version used. The latest // support package version is 9. From 4dff8d1df4eca2b836c2c8d362370c8524a03e38 Mon Sep 17 00:00:00 2001 From: Doug Fawley Date: Tue, 18 Mar 2025 13:48:17 -0700 Subject: [PATCH 2/3] Switch back to io.LimitReader with limit+1. Add test case. --- rpc_util.go | 46 +++++++------------------------------- rpc_util_test.go | 57 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 38 deletions(-) diff --git a/rpc_util.go b/rpc_util.go index b9ebfba75c97..ad20e9dff206 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -869,15 +869,20 @@ func decompress(compressor encoding.Compressor, d mem.BufferSlice, dc Decompress if err != nil { return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the message: %v", err) } - limitedReader := limitReader(dcReader, int64(maxReceiveMessageSize)) - out, err := mem.ReadAll(limitedReader, pool) + // Read at most one byte more than the limit from the decompressor. + // Unless the limit is MaxInt64, in which case, that's impossible, so + // apply no limit. + if limit := int64(maxReceiveMessageSize); limit < math.MaxInt64 { + dcReader = io.LimitReader(dcReader, limit+1) + } + out, err := mem.ReadAll(dcReader, pool) if err != nil { out.Free() return nil, status.Errorf(codes.Internal, "grpc: failed to read decompressed data: %v", err) } - if limitedReader.exceeded() { + if out.Len() > maxReceiveMessageSize { out.Free() return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max %d", maxReceiveMessageSize) } @@ -1030,41 +1035,6 @@ func setCallInfoCodec(c *callInfo) error { return nil } -type limitedReader struct { - r io.Reader // the underlying reader - n int64 // how many bytes remain before the limit -} - -// limitReader returns a wrapper around r that may read at most one byte more -// than limit to determine if r contains more data than the limit. -func limitReader(r io.Reader, limit int64) *limitedReader { - return &limitedReader{r: r, n: limit} -} - -func (l *limitedReader) Read(p []byte) (n int, err error) { - if l.n < int64(len(p)) { - // We have space in the input buffer to read past the limit remaining in - // l.n. Truncate the input buffer, but read one extra byte to determine - // overflow. - p = p[0 : l.n+1] - } - n, err = l.r.Read(p) - l.n -= int64(n) - if l.n < 0 && err == nil { - // We read more bytes from r than the limit allowed. Convert to io.EOF, - // and exceeded() will return true. - n-- - err = io.EOF - } - return n, err -} - -// exceeded returns true iff l's reader was read beyond the limit specified at -// creation. -func (l *limitedReader) exceeded() bool { - return l.n < 0 -} - // The SupportPackageIsVersion variables are referenced from generated protocol // buffer files to ensure compatibility with the gRPC version used. The latest // support package version is 9. diff --git a/rpc_util_test.go b/rpc_util_test.go index 608cc1002471..0da2bec5bce7 100644 --- a/rpc_util_test.go +++ b/rpc_util_test.go @@ -21,11 +21,14 @@ package grpc import ( "bytes" "compress/gzip" + "context" "errors" "io" "math" "reflect" + "sync" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" @@ -421,3 +424,57 @@ func (s) TestDecompress(t *testing.T) { }) } } + +type mockCompressor struct { + // Written to by the io.Reader on every call to Read. + ch chan<- struct{} +} + +func (m *mockCompressor) Compress(io.Writer) (io.WriteCloser, error) { + panic("unimplemented") +} + +func (m *mockCompressor) Decompress(io.Reader) (io.Reader, error) { + return m, nil +} + +func (m *mockCompressor) Read(p []byte) (int, error) { + m.ch <- struct{}{} + return 1, io.EOF +} + +func (m *mockCompressor) Name() string { return "" } + +// Tests that the decompressor's Read method is not called after it returns EOF. +func (s) TestDecompress_NoReadAfterEOF(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + ch := make(chan struct{}, 10) + mc := &mockCompressor{ch: ch} + in := mem.BufferSlice{mem.NewBuffer(&[]byte{1, 2, 3}, nil)} + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + out, err := decompress(mc, in, nil, 1, mem.DefaultBufferPool()) + if err != nil { + t.Errorf("Unexpected error from decompress: %v", err) + return + } + out.Free() + }() + select { + case <-ch: + case <-ctx.Done(): + t.Fatalf("Timed out waiting for call to compressor") + } + ctx, cancel = context.WithTimeout(ctx, 10*time.Millisecond) + defer cancel() + select { + case <-ch: + t.Fatalf("Unexpected second compressor.Read call detected") + case <-ctx.Done(): + } + wg.Wait() +} From 03ac4588d110b45ff0c36374473731d8f46e6253 Mon Sep 17 00:00:00 2001 From: Doug Fawley Date: Tue, 18 Mar 2025 14:53:32 -0700 Subject: [PATCH 3/3] review comments --- clientconn_test.go | 1 + rpc_util_test.go | 5 ++--- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/clientconn_test.go b/clientconn_test.go index 691c007f5e9a..9cca5a8eb7b7 100644 --- a/clientconn_test.go +++ b/clientconn_test.go @@ -50,6 +50,7 @@ import ( const ( defaultTestTimeout = 10 * time.Second + defaultTestShortTimeout = 10 * time.Millisecond stateRecordingBalancerName = "state_recording_balancer" grpclbServiceConfig = `{"loadBalancingConfig": [{"grpclb": {}}]}` rrServiceConfig = `{"loadBalancingPolicy": [{"round_robin": {}}]}` diff --git a/rpc_util_test.go b/rpc_util_test.go index 0da2bec5bce7..a5c5cb8b17e2 100644 --- a/rpc_util_test.go +++ b/rpc_util_test.go @@ -28,7 +28,6 @@ import ( "reflect" "sync" "testing" - "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" @@ -438,7 +437,7 @@ func (m *mockCompressor) Decompress(io.Reader) (io.Reader, error) { return m, nil } -func (m *mockCompressor) Read(p []byte) (int, error) { +func (m *mockCompressor) Read([]byte) (int, error) { m.ch <- struct{}{} return 1, io.EOF } @@ -469,7 +468,7 @@ func (s) TestDecompress_NoReadAfterEOF(t *testing.T) { case <-ctx.Done(): t.Fatalf("Timed out waiting for call to compressor") } - ctx, cancel = context.WithTimeout(ctx, 10*time.Millisecond) + ctx, cancel = context.WithTimeout(ctx, defaultTestShortTimeout) defer cancel() select { case <-ch: