diff --git a/cbreaker/cbreaker.go b/cbreaker/cbreaker.go index b330bf01..d439be6d 100644 --- a/cbreaker/cbreaker.go +++ b/cbreaker/cbreaker.go @@ -8,6 +8,8 @@ // // Once the Circuit breaker condition is met, it enters the "Tripped" state, where it activates fallback scenario // for all requests during the FallbackDuration time period and reset the stats for the location. +// RequestVolumeThreshold can be used to require a minimum number of requests in the rolling window +// before the condition is evaluated. // // After FallbackDuration time period passes, Circuit breaker enters "Recovering" state, during that state it will // start passing some traffic back to the endpoints, increasing the amount of passed requests using linear function: @@ -56,6 +58,8 @@ type CircuitBreaker struct { checkPeriod time.Duration lastCheck clock.Time + requestVolumeThreshold int64 + fallback http.Handler next http.Handler @@ -251,6 +255,10 @@ func (c *CircuitBreaker) checkAndSet() { return } + if c.metrics.TotalCount() < c.requestVolumeThreshold { + return + } + if !c.condition(c) { return } diff --git a/cbreaker/cbreaker_test.go b/cbreaker/cbreaker_test.go index a9cea71f..e5db1bfd 100644 --- a/cbreaker/cbreaker_test.go +++ b/cbreaker/cbreaker_test.go @@ -207,6 +207,42 @@ func TestCircuitBreaker_triggerDuringRecovery(t *testing.T) { assert.Equal(t, cbState(stateTripped), cb.state) } +func TestCircuitBreaker_requestVolumeThreshold(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusGatewayTimeout) + }) + + testutils.FreezeTime(t) + + cb, err := New(handler, triggerNetRatio, RequestVolumeThreshold(20)) + require.NoError(t, err) + + srv := httptest.NewServer(cb) + t.Cleanup(srv.Close) + + cb.metrics = statsResponseCodes(statusCode{Code: http.StatusGatewayTimeout, Count: 18}) + + clock.Advance(defaultCheckPeriod + clock.Millisecond) + + re, _, err := testutils.Get(srv.URL) + require.NoError(t, err) + assert.Equal(t, http.StatusGatewayTimeout, re.StatusCode) + assert.Equal(t, cbState(stateStandby), cb.state) + + cb.metrics = statsResponseCodes(statusCode{Code: http.StatusGatewayTimeout, Count: 19}) + + clock.Advance(defaultCheckPeriod + clock.Millisecond) + + re, _, err = testutils.Get(srv.URL) + require.NoError(t, err) + assert.Equal(t, http.StatusGatewayTimeout, re.StatusCode) + assert.Equal(t, cbState(stateTripped), cb.state) + + re, _, err = testutils.Get(srv.URL) + require.NoError(t, err) + assert.Equal(t, http.StatusServiceUnavailable, re.StatusCode) +} + func TestCircuitBreaker_sideEffects(t *testing.T) { srv1Chan := make(chan *http.Request, 1) diff --git a/cbreaker/options.go b/cbreaker/options.go index 5c395a99..949046af 100644 --- a/cbreaker/options.go +++ b/cbreaker/options.go @@ -53,6 +53,15 @@ func CheckPeriod(d time.Duration) Option { } } +// RequestVolumeThreshold sets the minimum number of requests in the rolling +// window before the CircuitBreaker can trip. Defaults to 0. +func RequestVolumeThreshold(n int64) Option { + return func(c *CircuitBreaker) error { + c.requestVolumeThreshold = n + return nil + } +} + // OnTripped sets a SideEffect to run when entering the Tripped state. // Only one SideEffect can be set for this hook. func OnTripped(s SideEffect) Option {