Skip to content

Commit d5c8b38

Browse files
GODRIVER-3444 Adjust getMore maxTimeMS Calculation for tailable awaitData Cursors (#1925)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent bb68567 commit d5c8b38

File tree

7 files changed

+287
-28
lines changed

7 files changed

+287
-28
lines changed

internal/driverutil/operation.go

+37
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@
66

77
package driverutil
88

9+
import (
10+
"context"
11+
"math"
12+
"time"
13+
)
14+
915
// Operation Names should be sourced from the command reference documentation:
1016
// https://www.mongodb.com/docs/manual/reference/command/
1117
const (
@@ -30,3 +36,34 @@ const (
3036
UpdateOp = "update" // UpdateOp is the name for updating
3137
BulkWriteOp = "bulkWrite" // BulkWriteOp is the name for client-level bulk write
3238
)
39+
40+
// CalculateMaxTimeMS calculates the maxTimeMS value to send to the server
41+
// based on the context deadline and the minimum round trip time. If the
42+
// calculated maxTimeMS is likely to cause a socket timeout, then this function
43+
// will return 0 and false.
44+
func CalculateMaxTimeMS(ctx context.Context, rttMin time.Duration) (int64, bool) {
45+
deadline, ok := ctx.Deadline()
46+
if !ok {
47+
return 0, true
48+
}
49+
50+
remainingTimeout := time.Until(deadline)
51+
52+
// Always round up to the next millisecond value so we never truncate the calculated
53+
// maxTimeMS value (e.g. 400 microseconds evaluates to 1ms, not 0ms).
54+
maxTimeMS := int64((remainingTimeout - rttMin + time.Millisecond - 1) / time.Millisecond)
55+
if maxTimeMS <= 0 {
56+
return 0, false
57+
}
58+
59+
// The server will return a "BadValue" error if maxTimeMS is greater
60+
// than the maximum positive int32 value (about 24.9 days). If the
61+
// user specified a timeout value greater than that, omit maxTimeMS
62+
// and let the client-side timeout handle cancelling the op if the
63+
// timeout is ever reached.
64+
if maxTimeMS > math.MaxInt32 {
65+
return 0, true
66+
}
67+
68+
return maxTimeMS, true
69+
}

internal/driverutil/operation_test.go

+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// Copyright (C) MongoDB, Inc. 2025-present.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License. You may obtain
5+
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6+
7+
package driverutil
8+
9+
import (
10+
"context"
11+
"math"
12+
"testing"
13+
"time"
14+
15+
"go.mongodb.org/mongo-driver/v2/internal/assert"
16+
)
17+
18+
func TestCalculateMaxTimeMS(t *testing.T) {
19+
tests := []struct {
20+
name string
21+
ctx context.Context
22+
rttMin time.Duration
23+
wantZero bool
24+
wantOk bool
25+
wantPositive bool
26+
wantExact int64
27+
}{
28+
{
29+
name: "no deadline",
30+
ctx: context.Background(),
31+
rttMin: 10 * time.Millisecond,
32+
wantZero: true,
33+
wantOk: true,
34+
wantPositive: false,
35+
},
36+
{
37+
name: "deadline expired",
38+
ctx: func() context.Context {
39+
ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(-1*time.Second)) //nolint:govet
40+
return ctx
41+
}(),
42+
wantZero: true,
43+
wantOk: false,
44+
wantPositive: false,
45+
},
46+
{
47+
name: "remaining timeout < rttMin",
48+
ctx: func() context.Context {
49+
ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(1*time.Millisecond)) //nolint:govet
50+
return ctx
51+
}(),
52+
rttMin: 10 * time.Millisecond,
53+
wantZero: true,
54+
wantOk: false,
55+
wantPositive: false,
56+
},
57+
{
58+
name: "normal positive result",
59+
ctx: func() context.Context {
60+
ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond)) //nolint:govet
61+
return ctx
62+
}(),
63+
wantZero: false,
64+
wantOk: true,
65+
wantPositive: true,
66+
},
67+
{
68+
name: "beyond maxInt32",
69+
ctx: func() context.Context {
70+
dur := time.Now().Add(time.Duration(math.MaxInt32+1000) * time.Millisecond)
71+
ctx, _ := context.WithDeadline(context.Background(), dur) //nolint:govet
72+
return ctx
73+
}(),
74+
wantZero: true,
75+
wantOk: true,
76+
wantPositive: false,
77+
},
78+
{
79+
name: "round up to 1ms",
80+
ctx: func() context.Context {
81+
ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(999*time.Microsecond)) //nolint:govet
82+
return ctx
83+
}(),
84+
wantOk: true,
85+
wantExact: 1,
86+
},
87+
}
88+
89+
for _, tt := range tests {
90+
t.Run(tt.name, func(t *testing.T) {
91+
got, got1 := CalculateMaxTimeMS(tt.ctx, tt.rttMin)
92+
93+
assert.Equal(t, tt.wantOk, got1)
94+
95+
if tt.wantExact > 0 && got != tt.wantExact {
96+
t.Errorf("CalculateMaxTimeMS() got = %v, want %v", got, tt.wantExact)
97+
}
98+
99+
if tt.wantZero && got != 0 {
100+
t.Errorf("CalculateMaxTimeMS() got = %v, want 0", got)
101+
}
102+
103+
if !tt.wantZero && got == 0 {
104+
t.Errorf("CalculateMaxTimeMS() got = %v, want > 0", got)
105+
}
106+
107+
if !tt.wantZero && tt.wantPositive && got <= 0 {
108+
t.Errorf("CalculateMaxTimeMS() got = %v, want > 0", got)
109+
}
110+
})
111+
}
112+
113+
}

internal/integration/cursor_test.go

+70
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"go.mongodb.org/mongo-driver/v2/internal/assert"
1818
"go.mongodb.org/mongo-driver/v2/internal/failpoint"
1919
"go.mongodb.org/mongo-driver/v2/internal/integration/mtest"
20+
"go.mongodb.org/mongo-driver/v2/internal/require"
2021
"go.mongodb.org/mongo-driver/v2/mongo"
2122
"go.mongodb.org/mongo-driver/v2/mongo/options"
2223
)
@@ -303,6 +304,75 @@ func TestCursor(t *testing.T) {
303304
batchSize = sizeVal.Int32()
304305
assert.Equal(mt, int32(4), batchSize, "expected batchSize 4, got %v", batchSize)
305306
})
307+
308+
tailableAwaitDataCursorOpts := mtest.NewOptions().MinServerVersion("4.4").
309+
Topologies(mtest.ReplicaSet, mtest.Sharded, mtest.LoadBalanced, mtest.Single)
310+
311+
mt.RunOpts("tailable awaitData cursor", tailableAwaitDataCursorOpts, func(mt *mtest.T) {
312+
mt.Run("apply remaining timeoutMS if less than maxAwaitTimeMS", func(mt *mtest.T) {
313+
initCollection(mt, mt.Coll)
314+
mt.ClearEvents()
315+
316+
// Create a find cursor
317+
opts := options.Find().SetBatchSize(1).SetMaxAwaitTime(100 * time.Millisecond)
318+
319+
cursor, err := mt.Coll.Find(context.Background(), bson.D{}, opts)
320+
require.NoError(mt, err)
321+
322+
_ = mt.GetStartedEvent() // Empty find from started list.
323+
324+
defer cursor.Close(context.Background())
325+
326+
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
327+
defer cancel()
328+
329+
// Iterate twice to force a getMore
330+
cursor.Next(ctx)
331+
cursor.Next(ctx)
332+
333+
cmd := mt.GetStartedEvent().Command
334+
335+
maxTimeMSRaw, err := cmd.LookupErr("maxTimeMS")
336+
require.NoError(mt, err)
337+
338+
got, ok := maxTimeMSRaw.AsInt64OK()
339+
require.True(mt, ok)
340+
341+
assert.LessOrEqual(mt, got, int64(50))
342+
})
343+
344+
mt.RunOpts("apply maxAwaitTimeMS if less than remaining timeout", tailableAwaitDataCursorOpts, func(mt *mtest.T) {
345+
initCollection(mt, mt.Coll)
346+
mt.ClearEvents()
347+
348+
// Create a find cursor
349+
opts := options.Find().SetBatchSize(1).SetMaxAwaitTime(50 * time.Millisecond)
350+
351+
cursor, err := mt.Coll.Find(context.Background(), bson.D{}, opts)
352+
require.NoError(mt, err)
353+
354+
_ = mt.GetStartedEvent() // Empty find from started list.
355+
356+
defer cursor.Close(context.Background())
357+
358+
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
359+
defer cancel()
360+
361+
// Iterate twice to force a getMore
362+
cursor.Next(ctx)
363+
cursor.Next(ctx)
364+
365+
cmd := mt.GetStartedEvent().Command
366+
367+
maxTimeMSRaw, err := cmd.LookupErr("maxTimeMS")
368+
require.NoError(mt, err)
369+
370+
got, ok := maxTimeMSRaw.AsInt64OK()
371+
require.True(mt, ok)
372+
373+
assert.LessOrEqual(mt, got, int64(50))
374+
})
375+
})
306376
}
307377

308378
type tryNextCursor interface {

internal/integration/unified/collection_operation_execution.go

+15
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"context"
1111
"errors"
1212
"fmt"
13+
"strings"
1314
"time"
1415

1516
"go.mongodb.org/mongo-driver/v2/bson"
@@ -1485,6 +1486,20 @@ func createFindCursor(ctx context.Context, operation *operation) (*cursorResult,
14851486
opts.SetSkip(int64(val.Int32()))
14861487
case "sort":
14871488
opts.SetSort(val.Document())
1489+
case "timeoutMode":
1490+
return nil, newSkipTestError("timeoutMode is not supported")
1491+
case "cursorType":
1492+
switch strings.ToLower(val.StringValue()) {
1493+
case "tailable":
1494+
opts.SetCursorType(options.Tailable)
1495+
case "tailableawait":
1496+
opts.SetCursorType(options.TailableAwait)
1497+
case "nontailable":
1498+
opts.SetCursorType(options.NonTailable)
1499+
}
1500+
case "maxAwaitTimeMS":
1501+
maxAwaitTimeMS := time.Duration(val.Int32()) * time.Millisecond
1502+
opts.SetMaxAwaitTime(maxAwaitTimeMS)
14881503
default:
14891504
return nil, fmt.Errorf("unrecognized find option %q", key)
14901505
}

internal/spectest/skip.go

+19-3
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,10 @@ var skipTests = map[string][]string{
346346
"TestUnifiedSpec/client-side-operations-timeout/tests/retryability-timeoutMS.json/operation_is_retried_multiple_times_for_non-zero_timeoutMS_-_aggregate_on_collection",
347347
"TestUnifiedSpec/client-side-operations-timeout/tests/retryability-timeoutMS.json/operation_is_retried_multiple_times_for_non-zero_timeoutMS_-_aggregate_on_database",
348348
"TestUnifiedSpec/client-side-operations-timeout/tests/gridfs-find.json/timeoutMS_applied_to_find_command",
349+
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_applied_to_find",
350+
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_is_refreshed_for_getMore_if_maxAwaitTimeMS_is_not_set",
351+
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_is_refreshed_for_getMore_if_maxAwaitTimeMS_is_set",
352+
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_is_refreshed_for_getMore_-_failure",
349353
},
350354

351355
// TODO(GODRIVER-3411): Tests require "getMore" with "maxTimeMS" settings. Not
@@ -448,7 +452,6 @@ var skipTests = map[string][]string{
448452
"TestUnifiedSpec/client-side-operations-timeout/tests/change-streams.json/change_stream_can_be_iterated_again_if_previous_iteration_times_out",
449453
"TestUnifiedSpec/client-side-operations-timeout/tests/change-streams.json/timeoutMS_is_refreshed_for_getMore_-_failure",
450454
"TestUnifiedSpec/client-side-operations-timeout/tests/change-streams.json/error_if_maxAwaitTimeMS_is_greater_than_timeoutMS",
451-
"TestUnifiedSpec/client-side-operations-timeout/tests/change-streams.json/error_if_maxAwaitTimeMS_is_equal_to_timeoutMS",
452455
},
453456

454457
// Unknown CSOT:
@@ -584,12 +587,10 @@ var skipTests = map[string][]string{
584587
"TestUnifiedSpec/client-side-operations-timeout/tests/sessions-override-timeoutMS.json",
585588
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/error_if_timeoutMode_is_cursor_lifetime",
586589
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/error_if_maxAwaitTimeMS_is_greater_than_timeoutMS",
587-
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/error_if_maxAwaitTimeMS_is_equal_to_timeoutMS",
588590
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_applied_to_find",
589591
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_is_refreshed_for_getMore_if_maxAwaitTimeMS_is_not_set",
590592
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_is_refreshed_for_getMore_if_maxAwaitTimeMS_is_set",
591593
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_is_refreshed_for_getMore_-_failure",
592-
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/apply_remaining_timeoutMS_if_less_than_maxAwaitTimeMS",
593594
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/apply_maxAwaitTimeMS_if_less_than_remaining_timeout",
594595
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-non-awaitData.json/error_if_timeoutMode_is_cursor_lifetime",
595596
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-non-awaitData.json/timeoutMS_applied_to_find",
@@ -819,6 +820,21 @@ var skipTests = map[string][]string{
819820
"TestUnifiedSpec/transactions-convenient-api/tests/unified/transaction-options.json/withTransaction_explicit_transaction_options_override_client_options",
820821
"TestUnifiedSpec/transactions-convenient-api/tests/unified/commit.json/withTransaction_commits_after_callback_returns",
821822
},
823+
824+
// GODRIVER-3473: the implementation of DRIVERS-2868 makes it clear that the
825+
// Go Driver does not correctly implement the following validation for
826+
// tailable awaitData cursors:
827+
//
828+
// Drivers MUST error if this option is set, timeoutMS is set to a
829+
// non-zero value, and maxAwaitTimeMS is greater than or equal to
830+
// timeoutMS.
831+
//
832+
// Once GODRIVER-3473 is completed, we can continue running these tests.
833+
"When constructing tailable awaitData cusors must validate, timeoutMS is set to a non-zero value, and maxAwaitTimeMS is greater than or equal to timeoutMS (GODRIVER-3473)": {
834+
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/apply_remaining_timeoutMS_if_less_than_maxAwaitTimeMS",
835+
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/error_if_maxAwaitTimeMS_is_equal_to_timeoutMS",
836+
"TestUnifiedSpec/client-side-operations-timeout/tests/change-streams.json/error_if_maxAwaitTimeMS_is_equal_to_timeoutMS",
837+
},
822838
}
823839

824840
// CheckSkip checks if the fully-qualified test name matches a list of skipped test names for a given reason.

x/mongo/driver/batch_cursor.go

+28-2
Original file line numberDiff line numberDiff line change
@@ -381,14 +381,40 @@ func (bc *BatchCursor) getMore(ctx context.Context) {
381381

382382
bc.err = Operation{
383383
CommandFn: func(dst []byte, _ description.SelectedServer) ([]byte, error) {
384+
// If maxAwaitTime > remaining timeoutMS - minRoundTripTime, then use
385+
// send remaining TimeoutMS - minRoundTripTime allowing the server an
386+
// opportunity to respond with an empty batch.
387+
var maxTimeMS int64
388+
if bc.maxAwaitTime != nil {
389+
_, ctxDeadlineSet := ctx.Deadline()
390+
391+
if ctxDeadlineSet {
392+
rttMonitor := bc.Server().RTTMonitor()
393+
394+
var ok bool
395+
maxTimeMS, ok = driverutil.CalculateMaxTimeMS(ctx, rttMonitor.Min())
396+
if !ok && maxTimeMS <= 0 {
397+
return nil, fmt.Errorf(
398+
"calculated server-side timeout (%v ms) is less than or equal to 0 (%v): %w",
399+
maxTimeMS,
400+
rttMonitor.Stats(),
401+
ErrDeadlineWouldBeExceeded)
402+
}
403+
}
404+
405+
if !ctxDeadlineSet || bc.maxAwaitTime.Milliseconds() < maxTimeMS {
406+
maxTimeMS = bc.maxAwaitTime.Milliseconds()
407+
}
408+
}
409+
384410
dst = bsoncore.AppendInt64Element(dst, "getMore", bc.id)
385411
dst = bsoncore.AppendStringElement(dst, "collection", bc.collection)
386412
if numToReturn > 0 {
387413
dst = bsoncore.AppendInt32Element(dst, "batchSize", numToReturn)
388414
}
389415

390-
if bc.maxAwaitTime != nil && *bc.maxAwaitTime > 0 {
391-
dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", int64(*bc.maxAwaitTime)/int64(time.Millisecond))
416+
if maxTimeMS > 0 {
417+
dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", maxTimeMS)
392418
}
393419

394420
comment, err := codecutil.MarshalValue(bc.comment, bc.encoderFn)

0 commit comments

Comments
 (0)