kmsv2: ignore cache key expiration on reads
Signed-off-by: Monis Khan <mok@microsoft.com>
This commit is contained in:
		@@ -40,6 +40,13 @@ func NewExpiringWithClock(clock clock.Clock) *Expiring {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// Expiring is a map whose entries expire after a per-entry timeout.
 | 
					// Expiring is a map whose entries expire after a per-entry timeout.
 | 
				
			||||||
type Expiring struct {
 | 
					type Expiring struct {
 | 
				
			||||||
 | 
						// AllowExpiredGet causes the expiration check to be skipped on Get.
 | 
				
			||||||
 | 
						// It should only be used when a key always corresponds to the exact same value.
 | 
				
			||||||
 | 
						// Thus when this field is true, expired keys are considered valid
 | 
				
			||||||
 | 
						// until the next call to Set (which causes the GC to run).
 | 
				
			||||||
 | 
						// It may not be changed concurrently with calls to Get.
 | 
				
			||||||
 | 
						AllowExpiredGet bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	clock clock.Clock
 | 
						clock clock.Clock
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// mu protects the below fields
 | 
						// mu protects the below fields
 | 
				
			||||||
@@ -70,7 +77,10 @@ func (c *Expiring) Get(key interface{}) (val interface{}, ok bool) {
 | 
				
			|||||||
	c.mu.RLock()
 | 
						c.mu.RLock()
 | 
				
			||||||
	defer c.mu.RUnlock()
 | 
						defer c.mu.RUnlock()
 | 
				
			||||||
	e, ok := c.cache[key]
 | 
						e, ok := c.cache[key]
 | 
				
			||||||
	if !ok || !c.clock.Now().Before(e.expiry) {
 | 
						if !ok {
 | 
				
			||||||
 | 
							return nil, false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if !c.AllowExpiredGet && !c.clock.Now().Before(e.expiry) {
 | 
				
			||||||
		return nil, false
 | 
							return nil, false
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return e.val, true
 | 
						return e.val, true
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -101,6 +101,36 @@ func TestExpiration(t *testing.T) {
 | 
				
			|||||||
	if _, ok := c.Get("a"); !ok {
 | 
						if _, ok := c.Get("a"); !ok {
 | 
				
			||||||
		t.Fatalf("we should have found a key")
 | 
							t.Fatalf("we should have found a key")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check getting an expired key with and without AllowExpiredGet
 | 
				
			||||||
 | 
						c.Set("b", "b", time.Second)
 | 
				
			||||||
 | 
						fc.Step(2 * time.Second)
 | 
				
			||||||
 | 
						if _, ok := c.Get("b"); ok {
 | 
				
			||||||
 | 
							t.Fatalf("we should not have found b key")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if count := c.Len(); count != 2 { // b is still in the cache
 | 
				
			||||||
 | 
							t.Errorf("expected two items got: %d", count)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						c.AllowExpiredGet = true
 | 
				
			||||||
 | 
						if _, ok := c.Get("b"); !ok {
 | 
				
			||||||
 | 
							t.Fatalf("we should have found b key")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if count := c.Len(); count != 2 { // b is still in the cache
 | 
				
			||||||
 | 
							t.Errorf("expected two items got: %d", count)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						c.Set("c", "c", time.Second)      // set some unrelated key to run gc
 | 
				
			||||||
 | 
						if count := c.Len(); count != 2 { // only a and c in the cache now
 | 
				
			||||||
 | 
							t.Errorf("expected two items got: %d", count)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if _, ok := c.Get("b"); ok {
 | 
				
			||||||
 | 
							t.Fatalf("we should not have found b key")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if _, ok := c.Get("a"); !ok {
 | 
				
			||||||
 | 
							t.Fatalf("we should have found a key")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if _, ok := c.Get("c"); !ok {
 | 
				
			||||||
 | 
							t.Fatalf("we should have found c key")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestGarbageCollection(t *testing.T) {
 | 
					func TestGarbageCollection(t *testing.T) {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -50,8 +50,10 @@ type simpleCache struct {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func newSimpleCache(clock clock.Clock, ttl time.Duration) *simpleCache {
 | 
					func newSimpleCache(clock clock.Clock, ttl time.Duration) *simpleCache {
 | 
				
			||||||
 | 
						cache := utilcache.NewExpiringWithClock(clock)
 | 
				
			||||||
 | 
						cache.AllowExpiredGet = true // for a given key, the value (the decryptTransformer) is always the same
 | 
				
			||||||
	return &simpleCache{
 | 
						return &simpleCache{
 | 
				
			||||||
		cache: utilcache.NewExpiringWithClock(clock),
 | 
							cache: cache,
 | 
				
			||||||
		ttl:   ttl,
 | 
							ttl:   ttl,
 | 
				
			||||||
		hashPool: &sync.Pool{
 | 
							hashPool: &sync.Pool{
 | 
				
			||||||
			New: func() interface{} {
 | 
								New: func() interface{} {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -126,6 +126,18 @@ func TestSimpleCache(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// Wait for the cache to expire
 | 
						// Wait for the cache to expire
 | 
				
			||||||
	fakeClock.Step(6 * time.Second)
 | 
						fakeClock.Step(6 * time.Second)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// expired reads still work until GC runs on write
 | 
				
			||||||
 | 
						for i := 0; i < 10; i++ {
 | 
				
			||||||
 | 
							k := fmt.Sprintf("key-%d", i)
 | 
				
			||||||
 | 
							if cache.get([]byte(k)) != transformer {
 | 
				
			||||||
 | 
								t.Fatalf("Expected to get the transformer for key %v", k)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// run GC by performing a write
 | 
				
			||||||
 | 
						cache.set([]byte("some-other-unrelated-key"), transformer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for i := 0; i < 10; i++ {
 | 
						for i := 0; i < 10; i++ {
 | 
				
			||||||
		k := fmt.Sprintf("key-%d", i)
 | 
							k := fmt.Sprintf("key-%d", i)
 | 
				
			||||||
		if cache.get([]byte(k)) != nil {
 | 
							if cache.get([]byte(k)) != nil {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -191,7 +191,11 @@ func TestEnvelopeCaching(t *testing.T) {
 | 
				
			|||||||
				t.Fatalf("envelopeTransformer transformed data incorrectly. Expected: %v, got %v", originalText, untransformedData)
 | 
									t.Fatalf("envelopeTransformer transformed data incorrectly. Expected: %v, got %v", originalText, untransformedData)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// advance the clock to allow cache entries to expire depending on TTL
 | 
				
			||||||
			fakeClock.Step(2 * time.Minute)
 | 
								fakeClock.Step(2 * time.Minute)
 | 
				
			||||||
 | 
								// force GC to run by performing a write
 | 
				
			||||||
 | 
								transformer.(*envelopeTransformer).cache.set([]byte("some-other-unrelated-key"), &envelopeTransformer{})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			state, err = testStateFunc(ctx, envelopeService, fakeClock)()
 | 
								state, err = testStateFunc(ctx, envelopeService, fakeClock)()
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				t.Fatal(err)
 | 
									t.Fatal(err)
 | 
				
			||||||
@@ -867,6 +871,8 @@ func TestEnvelopeLogging(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
			// advance the clock to trigger cache to expire, so we make a decrypt call that will log
 | 
								// advance the clock to trigger cache to expire, so we make a decrypt call that will log
 | 
				
			||||||
			fakeClock.Step(2 * time.Second)
 | 
								fakeClock.Step(2 * time.Second)
 | 
				
			||||||
 | 
								// force GC to run by performing a write
 | 
				
			||||||
 | 
								transformer.(*envelopeTransformer).cache.set([]byte("some-other-unrelated-key"), &envelopeTransformer{})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			_, _, err = transformer.TransformFromStorage(tc.ctx, transformedData, dataCtx)
 | 
								_, _, err = transformer.TransformFromStorage(tc.ctx, transformedData, dataCtx)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user