Skip to content

Commit 6945025

Browse files
fix: interrupt/skip retry on context cancel (#144)
* fix: interrupt retry sleep if context is cancelled * add unit tests for sleep interrupt * Early exit for context cancellation/deadline --------- Co-authored-by: Anmol Chopra <anmol.chopra@gojek.com>
1 parent f59885b commit 6945025

8 files changed

Lines changed: 254 additions & 4 deletions

File tree

httpclient/client.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,12 @@ func (c *Client) Do(request *http.Request) (*http.Response, error) {
147147
_ = response.Body.Close()
148148
}
149149
if i > 0 {
150-
time.Sleep(c.retrier.NextInterval(i - 1)) // sleep after closing the previous response body
150+
if err := internal.SleepInterruptible(request.Context(), c.retrier.NextInterval(i-1)); err != nil {
151+
multiErr.Push(err.Error())
152+
c.reportError(request, err)
153+
// no point of retrying after context has been cancelled
154+
break
155+
}
151156

152157
request, err = internal.CloneRequest(request, reqGetBody) // Clone the request to reset the body for retry
153158
if err != nil {
@@ -162,11 +167,18 @@ func (c *Client) Do(request *http.Request) (*http.Response, error) {
162167
if err != nil {
163168
multiErr.Push(err.Error())
164169
c.reportError(request, err)
170+
if internal.IsCtxDone(request.Context()) {
171+
break
172+
}
165173
continue
166174
}
167175
c.reportRequestEnd(request, response)
168176

169177
if response.StatusCode >= http.StatusInternalServerError {
178+
if internal.IsCtxDone(request.Context()) {
179+
break
180+
}
181+
170182
continue
171183
}
172184

httpclient/client_test.go

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@ package httpclient
22

33
import (
44
"bytes"
5+
"context"
56
"io"
67
"net/http"
78
"net/http/httptest"
89
"strings"
10+
"sync/atomic"
911
"testing"
1012
"time"
1113

@@ -518,3 +520,107 @@ func respBody(t *testing.T, response *http.Response) string {
518520

519521
return string(respBody)
520522
}
523+
524+
func TestHTTPClientDoContextCancelledDuringRetry(t *testing.T) {
525+
noOfRetries := 3
526+
backoffInterval := 100 * time.Millisecond
527+
maximumJitterInterval := 10 * time.Millisecond
528+
529+
client := NewClient(
530+
WithHTTPTimeout(10*time.Millisecond),
531+
WithRetryCount(noOfRetries),
532+
WithRetrier(heimdall.NewRetrier(heimdall.NewConstantBackoff(backoffInterval, maximumJitterInterval))),
533+
)
534+
535+
count := atomic.Int32{}
536+
count.Store(0)
537+
dummyHandler := func(w http.ResponseWriter, r *http.Request) {
538+
w.WriteHeader(http.StatusInternalServerError)
539+
count.Add(1)
540+
}
541+
542+
server := httptest.NewServer(http.HandlerFunc(dummyHandler))
543+
defer server.Close()
544+
545+
ctx, cancel := context.WithCancel(context.Background())
546+
req, err := http.NewRequest(http.MethodGet, server.URL, nil)
547+
require.NoError(t, err)
548+
req = req.WithContext(ctx)
549+
550+
// Cancel the context after a short delay to simulate context cancellation during sleep
551+
go func() {
552+
time.Sleep(10 * time.Millisecond)
553+
cancel()
554+
}()
555+
556+
_, err = client.Do(req)
557+
require.Error(t, err)
558+
assert.Contains(t, err.Error(), context.Canceled.Error())
559+
assert.Less(t, count.Load(), int32(noOfRetries+1), "should not have completed all retries due to context cancellation")
560+
}
561+
562+
func TestHTTPClientDoContextCancelledBeforeRetry(t *testing.T) {
563+
client := NewClient(
564+
WithHTTPTimeout(10*time.Millisecond),
565+
WithRetryCount(3),
566+
WithRetrier(heimdall.NewRetrierFunc(func(retry int) time.Duration {
567+
assert.Fail(t, "should not have retrier func due to context cancellation")
568+
return 0
569+
})),
570+
)
571+
ctx, cancel := context.WithCancel(context.Background())
572+
573+
count := atomic.Int32{}
574+
count.Store(0)
575+
dummyHandler := func(w http.ResponseWriter, r *http.Request) {
576+
cancel() // Cancel immediately
577+
count.Add(1)
578+
w.WriteHeader(http.StatusInternalServerError)
579+
}
580+
581+
server := httptest.NewServer(http.HandlerFunc(dummyHandler))
582+
defer server.Close()
583+
584+
req, err := http.NewRequest(http.MethodGet, server.URL, nil)
585+
require.NoError(t, err)
586+
req = req.WithContext(ctx)
587+
588+
_, err = client.Do(req)
589+
require.Error(t, err)
590+
assert.Contains(t, err.Error(), context.Canceled.Error())
591+
assert.Equal(t, int32(1), count.Load())
592+
}
593+
594+
func TestHTTPClientDoContextTimeoutDuringRetry(t *testing.T) {
595+
noOfRetries := 3
596+
backoffInterval := 100 * time.Millisecond
597+
maximumJitterInterval := 10 * time.Millisecond
598+
599+
client := NewClient(
600+
WithHTTPTimeout(10*time.Millisecond),
601+
WithRetryCount(noOfRetries),
602+
WithRetrier(heimdall.NewRetrier(heimdall.NewConstantBackoff(backoffInterval, maximumJitterInterval))),
603+
)
604+
605+
count := atomic.Int32{}
606+
count.Store(0)
607+
dummyHandler := func(w http.ResponseWriter, r *http.Request) {
608+
w.WriteHeader(http.StatusInternalServerError)
609+
count.Add(1)
610+
}
611+
612+
server := httptest.NewServer(http.HandlerFunc(dummyHandler))
613+
defer server.Close()
614+
615+
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Millisecond)
616+
defer cancel()
617+
618+
req, err := http.NewRequest(http.MethodGet, server.URL, nil)
619+
require.NoError(t, err)
620+
req = req.WithContext(ctx)
621+
622+
_, err = client.Do(req)
623+
require.Error(t, err)
624+
assert.Contains(t, err.Error(), context.DeadlineExceeded.Error())
625+
assert.Less(t, count.Load(), int32(noOfRetries+1), "should not have completed all retries due to context timeout")
626+
}

hystrix/hystrix_client.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,10 @@ func (hhc *Client) Do(request *http.Request) (*http.Response, error) {
194194
}
195195

196196
if i > 0 {
197-
time.Sleep(hhc.retrier.NextInterval(i - 1)) // sleep after closing the previous response body
197+
err = internal.SleepInterruptible(request.Context(), hhc.retrier.NextInterval(i-1))
198+
if err != nil {
199+
return nil, err
200+
}
198201

199202
request, err = internal.CloneRequest(request, reqGetBody) // Clone the request to reset the body for retry
200203
if err != nil {
@@ -203,7 +206,7 @@ func (hhc *Client) Do(request *http.Request) (*http.Response, error) {
203206
}
204207

205208
response, err = hhc.hystrixDo(request)
206-
if err == nil {
209+
if err == nil || internal.IsCtxDone(request.Context()) {
207210
break
208211
}
209212
}

hystrix/hystrix_client_test.go

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,13 @@ func TestHystrixHTTPClientDoContextCancelled(t *testing.T) {
606606
const cmdName = "some_command_name_for_ctx_cncl"
607607
r := newSimpleMetricRegistry()
608608

609-
client := NewClient(WithCommandName(cmdName))
609+
client := NewClient(
610+
WithCommandName(cmdName),
611+
WithRetryCount(3),
612+
WithRetrier(heimdall.NewRetrierFunc(func(retry int) time.Duration {
613+
assert.Fail(t, "should not invoke retrier func due to context cancellation")
614+
return 0
615+
})))
610616

611617
req, err := http.NewRequest(http.MethodGet, "http://localhost/test", nil)
612618
require.NoError(t, err)
@@ -632,6 +638,44 @@ func TestHystrixHTTPClientDoContextCancelled(t *testing.T) {
632638
assert.Zero(t, m.ContextDeadlineExceeded)
633639
}
634640

641+
func TestHystrixHTTPClientDoContextCancelledDuringRetryBackoff(t *testing.T) {
642+
const cmdName = "some_command_name_for_ctx_cncl"
643+
ctx, cnclFN := context.WithCancel(context.Background())
644+
645+
var retrierFuncCount int
646+
var retrierFuncStartTime time.Time
647+
648+
client := NewClient(
649+
WithCommandName(cmdName),
650+
WithRetryCount(3),
651+
WithRetrier(heimdall.NewRetrierFunc(func(retry int) time.Duration {
652+
retrierFuncStartTime = time.Now()
653+
retrierFuncCount++
654+
cnclFN()
655+
return time.Hour
656+
})))
657+
658+
reqCount := atomic.Int32{}
659+
reqCount.Store(0)
660+
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
661+
reqCount.Add(1)
662+
rw.WriteHeader(http.StatusInternalServerError)
663+
}))
664+
defer server.Close()
665+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
666+
require.NoError(t, err)
667+
668+
response, err := client.Do(req)
669+
endTime := time.Now()
670+
require.Contains(t, err.Error(), context.Canceled.Error())
671+
require.Nil(t, response)
672+
assert.Equal(t, int32(1), reqCount.Load())
673+
require.Equal(t, 1, retrierFuncCount)
674+
assert.Less(t, endTime.Sub(retrierFuncStartTime), time.Millisecond)
675+
676+
time.Sleep(time.Second)
677+
}
678+
635679
func TestResponseBodyStreaming(t *testing.T) {
636680
const cmdName = "response_body_streamig"
637681
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

internal/ctx.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package internal
2+
3+
import "context"
4+
5+
func IsCtxDone(ctx context.Context) bool {
6+
select {
7+
case <-ctx.Done():
8+
return true
9+
default:
10+
return false
11+
}
12+
}

internal/ctx_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package internal_test
2+
3+
import (
4+
"context"
5+
"testing"
6+
"time"
7+
8+
"github.com/gojek/heimdall/v7/internal"
9+
"github.com/stretchr/testify/assert"
10+
)
11+
12+
func TestIsCtxDone(t *testing.T) {
13+
assert.False(t, internal.IsCtxDone(context.Background()))
14+
15+
ctx, cancel := context.WithCancel(context.Background())
16+
assert.False(t, internal.IsCtxDone(ctx))
17+
cancel()
18+
assert.True(t, internal.IsCtxDone(ctx))
19+
20+
ctx, cancel = context.WithTimeout(context.Background(), 10*time.Millisecond)
21+
defer cancel()
22+
assert.False(t, internal.IsCtxDone(ctx))
23+
time.Sleep(12 * time.Millisecond)
24+
assert.True(t, internal.IsCtxDone(ctx))
25+
}

internal/sleep.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package internal
2+
3+
import (
4+
"context"
5+
"time"
6+
)
7+
8+
// SleepInterruptible sleeps until either the timer triggers or context is cancelled
9+
func SleepInterruptible(ctx context.Context, d time.Duration) error {
10+
select {
11+
case <-ctx.Done():
12+
return ctx.Err()
13+
case <-time.After(d):
14+
}
15+
return nil
16+
}

internal/sleep_test.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package internal_test
2+
3+
import (
4+
"context"
5+
"testing"
6+
"time"
7+
8+
"github.com/gojek/heimdall/v7/internal"
9+
"github.com/stretchr/testify/assert"
10+
)
11+
12+
func TestSleepInterruptible_CancelledContext(t *testing.T) {
13+
ctx, cancel := context.WithCancel(context.Background())
14+
15+
// Cancel the context immediately
16+
cancel()
17+
18+
err := internal.SleepInterruptible(ctx, 10*time.Second) // Long duration to ensure cancellation is what stops it
19+
assert.Error(t, err)
20+
assert.Equal(t, context.Canceled, err)
21+
}
22+
23+
func TestSleepInterruptible_CompletesWithoutCancel(t *testing.T) {
24+
ctx := context.Background()
25+
start := time.Now()
26+
27+
err := internal.SleepInterruptible(ctx, 50*time.Millisecond) // Short sleep time
28+
elapsed := time.Since(start)
29+
30+
assert.NoError(t, err)
31+
assert.True(t, elapsed.Milliseconds() >= int64(50), "Sleep duration should be at least 50ms") // Ensure it slept approximately 50ms
32+
}

0 commit comments

Comments
 (0)