kmsv2: ignore cache key expiration on reads
Signed-off-by: Monis Khan <mok@microsoft.com>
This commit is contained in:
		@@ -40,6 +40,13 @@ func NewExpiringWithClock(clock clock.Clock) *Expiring {
 | 
			
		||||
 | 
			
		||||
// Expiring is a map whose entries expire after a per-entry timeout.
 | 
			
		||||
type Expiring struct {
 | 
			
		||||
	// AllowExpiredGet causes the expiration check to be skipped on Get.
 | 
			
		||||
	// It should only be used when a key always corresponds to the exact same value.
 | 
			
		||||
	// Thus when this field is true, expired keys are considered valid
 | 
			
		||||
	// until the next call to Set (which causes the GC to run).
 | 
			
		||||
	// It may not be changed concurrently with calls to Get.
 | 
			
		||||
	AllowExpiredGet bool
 | 
			
		||||
 | 
			
		||||
	clock clock.Clock
 | 
			
		||||
 | 
			
		||||
	// mu protects the below fields
 | 
			
		||||
@@ -70,7 +77,10 @@ func (c *Expiring) Get(key interface{}) (val interface{}, ok bool) {
 | 
			
		||||
	c.mu.RLock()
 | 
			
		||||
	defer c.mu.RUnlock()
 | 
			
		||||
	e, ok := c.cache[key]
 | 
			
		||||
	if !ok || !c.clock.Now().Before(e.expiry) {
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return nil, false
 | 
			
		||||
	}
 | 
			
		||||
	if !c.AllowExpiredGet && !c.clock.Now().Before(e.expiry) {
 | 
			
		||||
		return nil, false
 | 
			
		||||
	}
 | 
			
		||||
	return e.val, true
 | 
			
		||||
 
 | 
			
		||||
@@ -101,6 +101,36 @@ func TestExpiration(t *testing.T) {
 | 
			
		||||
	if _, ok := c.Get("a"); !ok {
 | 
			
		||||
		t.Fatalf("we should have found a key")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Check getting an expired key with and without AllowExpiredGet
 | 
			
		||||
	c.Set("b", "b", time.Second)
 | 
			
		||||
	fc.Step(2 * time.Second)
 | 
			
		||||
	if _, ok := c.Get("b"); ok {
 | 
			
		||||
		t.Fatalf("we should not have found b key")
 | 
			
		||||
	}
 | 
			
		||||
	if count := c.Len(); count != 2 { // b is still in the cache
 | 
			
		||||
		t.Errorf("expected two items got: %d", count)
 | 
			
		||||
	}
 | 
			
		||||
	c.AllowExpiredGet = true
 | 
			
		||||
	if _, ok := c.Get("b"); !ok {
 | 
			
		||||
		t.Fatalf("we should have found b key")
 | 
			
		||||
	}
 | 
			
		||||
	if count := c.Len(); count != 2 { // b is still in the cache
 | 
			
		||||
		t.Errorf("expected two items got: %d", count)
 | 
			
		||||
	}
 | 
			
		||||
	c.Set("c", "c", time.Second)      // set some unrelated key to run gc
 | 
			
		||||
	if count := c.Len(); count != 2 { // only a and c in the cache now
 | 
			
		||||
		t.Errorf("expected two items got: %d", count)
 | 
			
		||||
	}
 | 
			
		||||
	if _, ok := c.Get("b"); ok {
 | 
			
		||||
		t.Fatalf("we should not have found b key")
 | 
			
		||||
	}
 | 
			
		||||
	if _, ok := c.Get("a"); !ok {
 | 
			
		||||
		t.Fatalf("we should have found a key")
 | 
			
		||||
	}
 | 
			
		||||
	if _, ok := c.Get("c"); !ok {
 | 
			
		||||
		t.Fatalf("we should have found c key")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestGarbageCollection(t *testing.T) {
 | 
			
		||||
 
 | 
			
		||||
@@ -50,8 +50,10 @@ type simpleCache struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newSimpleCache(clock clock.Clock, ttl time.Duration) *simpleCache {
 | 
			
		||||
	cache := utilcache.NewExpiringWithClock(clock)
 | 
			
		||||
	cache.AllowExpiredGet = true // for a given key, the value (the decryptTransformer) is always the same
 | 
			
		||||
	return &simpleCache{
 | 
			
		||||
		cache: utilcache.NewExpiringWithClock(clock),
 | 
			
		||||
		cache: cache,
 | 
			
		||||
		ttl:   ttl,
 | 
			
		||||
		hashPool: &sync.Pool{
 | 
			
		||||
			New: func() interface{} {
 | 
			
		||||
 
 | 
			
		||||
@@ -126,6 +126,18 @@ func TestSimpleCache(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
	// Wait for the cache to expire
 | 
			
		||||
	fakeClock.Step(6 * time.Second)
 | 
			
		||||
 | 
			
		||||
	// expired reads still work until GC runs on write
 | 
			
		||||
	for i := 0; i < 10; i++ {
 | 
			
		||||
		k := fmt.Sprintf("key-%d", i)
 | 
			
		||||
		if cache.get([]byte(k)) != transformer {
 | 
			
		||||
			t.Fatalf("Expected to get the transformer for key %v", k)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// run GC by performing a write
 | 
			
		||||
	cache.set([]byte("some-other-unrelated-key"), transformer)
 | 
			
		||||
 | 
			
		||||
	for i := 0; i < 10; i++ {
 | 
			
		||||
		k := fmt.Sprintf("key-%d", i)
 | 
			
		||||
		if cache.get([]byte(k)) != nil {
 | 
			
		||||
 
 | 
			
		||||
@@ -191,7 +191,11 @@ func TestEnvelopeCaching(t *testing.T) {
 | 
			
		||||
				t.Fatalf("envelopeTransformer transformed data incorrectly. Expected: %v, got %v", originalText, untransformedData)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// advance the clock to allow cache entries to expire depending on TTL
 | 
			
		||||
			fakeClock.Step(2 * time.Minute)
 | 
			
		||||
			// force GC to run by performing a write
 | 
			
		||||
			transformer.(*envelopeTransformer).cache.set([]byte("some-other-unrelated-key"), &envelopeTransformer{})
 | 
			
		||||
 | 
			
		||||
			state, err = testStateFunc(ctx, envelopeService, fakeClock)()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				t.Fatal(err)
 | 
			
		||||
@@ -867,6 +871,8 @@ func TestEnvelopeLogging(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
			// advance the clock to trigger cache to expire, so we make a decrypt call that will log
 | 
			
		||||
			fakeClock.Step(2 * time.Second)
 | 
			
		||||
			// force GC to run by performing a write
 | 
			
		||||
			transformer.(*envelopeTransformer).cache.set([]byte("some-other-unrelated-key"), &envelopeTransformer{})
 | 
			
		||||
 | 
			
		||||
			_, _, err = transformer.TransformFromStorage(tc.ctx, transformedData, dataCtx)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user