kubernetes/vendor/github.com/skynetservices/skydns/cache/cache.go
2016-05-20 15:08:11 -07:00

168 lines
4.1 KiB
Go

// Copyright (c) 2014 The SkyDNS Authors. All rights reserved.
// Use of this source code is governed by The MIT License (MIT) that can be
// found in the LICENSE file.
package cache
// Cache that holds RRs and for DNSSEC an RRSIG.
// TODO(miek): there is a lot of copying going on to copy myself out of data
// races. This should be optimized.
import (
"crypto/sha1"
"sync"
"time"
"github.com/miekg/dns"
)
// Elem hold an answer and additional section that returned from the cache.
// The signature is put in answer, extra is empty there. This wastes some memory.
type elem struct {
expiration time.Time // time added + TTL, after this the elem is invalid
msg *dns.Msg
}
// Cache is a cache that holds on the a number of RRs or DNS messages. The cache
// eviction is randomized.
type Cache struct {
sync.RWMutex
capacity int
m map[string]*elem
ttl time.Duration
}
// New returns a new cache with the capacity and the ttl specified.
func New(capacity, ttl int) *Cache {
c := new(Cache)
c.m = make(map[string]*elem)
c.capacity = capacity
c.ttl = time.Duration(ttl) * time.Second
return c
}
func (c *Cache) Capacity() int { return c.capacity }
func (c *Cache) Remove(s string) {
c.Lock()
delete(c.m, s)
c.Unlock()
}
// EvictRandom removes a random member a the cache.
// Must be called under a write lock.
func (c *Cache) EvictRandom() {
clen := len(c.m)
if clen < c.capacity {
return
}
i := c.capacity - clen
for k, _ := range c.m {
delete(c.m, k)
i--
if i == 0 {
break
}
}
}
// InsertMessage inserts a message in the Cache. We will cache it for ttl seconds, which
// should be a small (60...300) integer.
func (c *Cache) InsertMessage(s string, msg *dns.Msg) {
if c.capacity <= 0 {
return
}
c.Lock()
if _, ok := c.m[s]; !ok {
c.m[s] = &elem{time.Now().UTC().Add(c.ttl), msg.Copy()}
}
c.EvictRandom()
c.Unlock()
}
// InsertSignature inserts a signature, the expiration time is used as the cache ttl.
func (c *Cache) InsertSignature(s string, sig *dns.RRSIG) {
if c.capacity <= 0 {
return
}
c.Lock()
if _, ok := c.m[s]; !ok {
m := ((int64(sig.Expiration) - time.Now().Unix()) / (1 << 31)) - 1
if m < 0 {
m = 0
}
t := time.Unix(int64(sig.Expiration)-(m*(1<<31)), 0).UTC()
c.m[s] = &elem{t, &dns.Msg{Answer: []dns.RR{dns.Copy(sig)}}}
}
c.EvictRandom()
c.Unlock()
}
// Search returns a dns.Msg, the expiration time and a boolean indicating if we found something
// in the cache.
func (c *Cache) Search(s string) (*dns.Msg, time.Time, bool) {
if c.capacity <= 0 {
return nil, time.Time{}, false
}
c.RLock()
if e, ok := c.m[s]; ok {
e1 := e.msg.Copy()
c.RUnlock()
return e1, e.expiration, true
}
c.RUnlock()
return nil, time.Time{}, false
}
// Key creates a hash key from a question section. It creates a different key
// for requests with DNSSEC.
func Key(q dns.Question, dnssec, tcp bool) string {
h := sha1.New()
i := append([]byte(q.Name), packUint16(q.Qtype)...)
if dnssec {
i = append(i, byte(255))
}
if tcp {
i = append(i, byte(254))
}
return string(h.Sum(i))
}
// Key uses the name, type and rdata, which is serialized and then hashed as the key for the lookup.
func KeyRRset(rrs []dns.RR) string {
h := sha1.New()
i := []byte(rrs[0].Header().Name)
i = append(i, packUint16(rrs[0].Header().Rrtype)...)
for _, r := range rrs {
switch t := r.(type) { // we only do a few type, serialize these manually
case *dns.SOA:
// We only fiddle with the serial so store that.
i = append(i, packUint32(t.Serial)...)
case *dns.SRV:
i = append(i, packUint16(t.Priority)...)
i = append(i, packUint16(t.Weight)...)
i = append(i, packUint16(t.Weight)...)
i = append(i, []byte(t.Target)...)
case *dns.A:
i = append(i, []byte(t.A)...)
case *dns.AAAA:
i = append(i, []byte(t.AAAA)...)
case *dns.NSEC3:
i = append(i, []byte(t.NextDomain)...)
// Bitmap does not differentiate in SkyDNS.
case *dns.DNSKEY:
case *dns.NS:
case *dns.TXT:
}
}
return string(h.Sum(i))
}
func packUint16(i uint16) []byte { return []byte{byte(i >> 8), byte(i)} }
func packUint32(i uint32) []byte { return []byte{byte(i >> 24), byte(i >> 16), byte(i >> 8), byte(i)} }