From 02049c45d219cfb38adc15bd94b6b8a7775e9b53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Senart?= Date: Tue, 6 Oct 2015 19:38:14 +0200 Subject: [PATCH] Transparent forwarding This change set follows RFC5625 and implements protocol aware transparent forwarding of DNS messages to external servers. With this we get rid of the poorly implemented recursive semantics of external resolution which were causing issues. --- exchanger/exchanger.go | 69 ++----------- exchanger/exchanger_test.go | 186 ++++++++---------------------------- exchanger/forwarder.go | 49 ++++++++++ exchanger/forwarder_test.go | 95 ++++++++++++++++++ logging/logging.go | 36 +++---- resolver/resolver.go | 103 ++++++++------------ resolver/resolver_test.go | 12 +-- 7 files changed, 255 insertions(+), 295 deletions(-) create mode 100644 exchanger/forwarder.go create mode 100644 exchanger/forwarder_test.go diff --git a/exchanger/exchanger.go b/exchanger/exchanger.go index b4c642e1..2aa582c6 100644 --- a/exchanger/exchanger.go +++ b/exchanger/exchanger.go @@ -2,7 +2,6 @@ package exchanger import ( "log" - "net" "time" "github.com/mesosphere/mesos-dns/logging" @@ -36,24 +35,6 @@ func Decorate(ex Exchanger, ds ...Decorator) Exchanger { return decorated } -// Pred is a predicate function type for dns.Msgs. -type Pred func(*dns.Msg) bool - -// While returns an Exchanger which attempts the given Exchangers while the given -// predicate function returns true for the returned dns.Msg, an error is returned, -// or all Exchangers are attempted, in which case the return values of the last -// one are returned. -func While(p Pred, exs ...Exchanger) Exchanger { - return Func(func(m *dns.Msg, a string) (r *dns.Msg, rtt time.Duration, err error) { - for _, ex := range exs { - if r, rtt, err = ex.Exchange(m, a); err != nil || !p(r) { - break - } - } - return - }) -} - // ErrorLogging returns a Decorator which logs an Exchanger's errors to the given // logger. func ErrorLogging(l *log.Logger) Decorator { @@ -70,50 +51,18 @@ func ErrorLogging(l *log.Logger) Decorator { } // Instrumentation returns a Decorator which instruments an Exchanger with the given -// counter. -func Instrumentation(c logging.Counter) Decorator { - return func(ex Exchanger) Exchanger { - return Func(func(m *dns.Msg, a string) (*dns.Msg, time.Duration, error) { - defer c.Inc() - return ex.Exchange(m, a) - }) - } -} - -// A Recurser returns the addr (host:port) of the next DNS server to recurse a -// Msg to. Empty returns signal that further recursion isn't possible or needed. -type Recurser func(*dns.Msg) string - -// Recurse is the default Mesos-DNS Recurser which returns an addr (host:port) -// only when the given dns.Msg doesn't contain authoritative answers and has at -// least one SOA record in its NS section. -func Recurse(r *dns.Msg) string { - if r.Authoritative && len(r.Answer) > 0 { - return "" - } - - for _, ns := range r.Ns { - if soa, ok := ns.(*dns.SOA); ok { - return net.JoinHostPort(soa.Ns, "53") - } - } - - return "" -} - -// Recursion returns a Decorator which recurses until the given Recurser returns -// an empty string or max attempts have been reached. -func Recursion(max int, rec Recurser) Decorator { +// counters. +func Instrumentation(total, success, failure logging.Counter) Decorator { return func(ex Exchanger) Exchanger { return Func(func(m *dns.Msg, a string) (r *dns.Msg, rtt time.Duration, err error) { - for i := 0; i <= max; i++ { - if r, rtt, err = ex.Exchange(m, a); err != nil { - break - } else if a = rec(r); a == "" { - break + defer func() { + if total.Inc(); err != nil { + failure.Inc() + } else { + success.Inc() } - } - return r, rtt, err + }() + return ex.Exchange(m, a) }) } } diff --git a/exchanger/exchanger_test.go b/exchanger/exchanger_test.go index 36af3ebb..1c974327 100644 --- a/exchanger/exchanger_test.go +++ b/exchanger/exchanger_test.go @@ -1,169 +1,63 @@ package exchanger import ( + "bytes" "errors" - "net" - "reflect" + "log" "testing" "time" - . "github.com/mesosphere/mesos-dns/dnstest" + "github.com/mesosphere/mesos-dns/logging" "github.com/miekg/dns" ) -func TestWhile(t *testing.T) { - for i, tt := range []struct { - pred Pred - exs []Exchanger - want exchanged - }{ - { // error - nil, - stubs(exchanged{nil, 0, errors.New("foo")}), - exchanged{nil, 0, errors.New("foo")}, - }, - { // always true predicate - func(*dns.Msg) bool { return true }, - stubs(exchanged{nil, 0, nil}, exchanged{nil, 1, nil}), - exchanged{nil, 1, nil}, - }, - { // nil exchangers - nil, - nil, - exchanged{nil, 0, nil}, - }, - { // empty exchangers - nil, - stubs(), - exchanged{nil, 0, nil}, - }, - { // false predicate - func(calls int) Pred { - return func(*dns.Msg) bool { - calls++ - return calls != 2 - } - }(0), - stubs(exchanged{nil, 0, nil}, exchanged{nil, 1, nil}, exchanged{nil, 2, nil}), - exchanged{nil, 1, nil}, - }, - } { - var got exchanged - got.m, got.rtt, got.err = While(tt.pred, tt.exs...).Exchange(nil, "") - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("test #%d: got: %v, want: %v", i, got, tt.want) +func TestErrorLogging(t *testing.T) { + { // with error + var buf bytes.Buffer + _, _, _ = ErrorLogging(log.New(&buf, "", 0))( + stub(exchanged{err: errors.New("timeout")})).Exchange(nil, "1.2.3.4") + + want := "timeout: exchanging (*dns.Msg)(nil) with \"1.2.3.4\"\n" + if got := buf.String(); got != want { + t.Errorf("got %q, want %q", got, want) } } -} + { // no error + var buf bytes.Buffer + _, _, _ = ErrorLogging(log.New(&buf, "", 0))( + stub(exchanged{})).Exchange(nil, "1.2.3.4") -func TestRecurse(t *testing.T) { - for i, tt := range []struct { - *dns.Msg - want string - }{ - { // Authoritative with answers - Message( - Header(true, 0), - Answers( - A(RRHeader("localhost", dns.TypeA, 0), net.IPv6loopback.To4()), - ), - NSs( - SOA(RRHeader("", dns.TypeSOA, 0), "next", "", 0), - ), - ), - "", - }, - { // Authoritative, empty answers, no SOA records - Message( - Header(true, 0), - NSs( - NS(RRHeader("", dns.TypeNS, 0), "next"), - ), - ), - "", - }, - { // Not authoritative, no SOA record - Message(Header(false, 0)), - "", - }, - { // Not authoritative, one SOA record - Message( - Header(false, 0), - NSs(SOA(RRHeader("", dns.TypeSOA, 0), "next", "", 0)), - ), - "next:53", - }, - { // Authoritative, empty answers, one SOA record - Message( - Header(true, 0), - NSs( - NS(RRHeader("", dns.TypeNS, 0), "foo"), - SOA(RRHeader("", dns.TypeSOA, 0), "next", "", 0), - ), - ), - "next:53", - }, - } { - if got := Recurse(tt.Msg); got != tt.want { - t.Errorf("test #%d: got: %v, want: %v", i, got, tt.want) + if got, want := buf.String(), ""; got != want { + t.Errorf("got %q, want %q", got, want) } } } -func TestRecursion(t *testing.T) { - for i, tt := range []struct { - max int - rec Recurser - ex Exchanger - want exchanged - }{ - { - 0, - func(*dns.Msg) string { return "next" }, - seq(stubs(exchanged{rtt: 1})...), - exchanged{rtt: 1}, - }, - { - 1, - func(*dns.Msg) string { return "next" }, - seq(stubs(exchanged{rtt: 0}, exchanged{rtt: 1}, exchanged{rtt: 2})...), - exchanged{rtt: 1}, - }, - { - 0, - nil, - seq(stubs(exchanged{err: errors.New("foo")})...), - exchanged{err: errors.New("foo")}, - }, - { - 2, - func(calls int) Recurser { - return func(*dns.Msg) string { - if calls++; calls <= 1 { - return "next" - } - return "" - } - }(0), - seq(stubs(exchanged{rtt: 0}, exchanged{rtt: 1}, exchanged{rtt: 2})...), - exchanged{rtt: 1}, - }, - } { - var got exchanged - got.m, got.rtt, got.err = Recursion(tt.max, tt.rec)(tt.ex).Exchange(nil, "") - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("test #%d: got: %v, want: %v", i, got, tt.want) +func TestInstrumentation(t *testing.T) { + { // with error + var total, success, failure logging.LogCounter + _, _, _ = Instrumentation(&total, &success, &failure)( + stub(exchanged{err: errors.New("timeout")})).Exchange(nil, "1.2.3.4") + + want := []string{"1", "0", "1"} + for i, c := range []*logging.LogCounter{&total, &success, &failure} { + if got, want := c.String(), want[i]; got != want { + t.Errorf("test #%d: got %q, want %q", i, got, want) + } } } -} + { // no error + var total, success, failure logging.LogCounter + _, _, _ = Instrumentation(&total, &success, &failure)( + stub(exchanged{})).Exchange(nil, "1.2.3.4") -func seq(exs ...Exchanger) Exchanger { - var i int - return Func(func(m *dns.Msg, a string) (*dns.Msg, time.Duration, error) { - ex := exs[i] - i++ - return ex.Exchange(m, a) - }) + want := []string{"1", "1", "0"} + for i, c := range []*logging.LogCounter{&total, &success, &failure} { + if got, want := c.String(), want[i]; got != want { + t.Errorf("test #%d: got %q, want %q", i, got, want) + } + } + } } func stubs(ed ...exchanged) []Exchanger { diff --git a/exchanger/forwarder.go b/exchanger/forwarder.go new file mode 100644 index 00000000..8cc0dc72 --- /dev/null +++ b/exchanger/forwarder.go @@ -0,0 +1,49 @@ +package exchanger + +import ( + "fmt" + "net" + + "github.com/miekg/dns" +) + +// A Forwarder is a DNS message forwarder that transparently proxies messages +// to DNS servers. +type Forwarder func(*dns.Msg, string) (*dns.Msg, error) + +// Forward is an utility method that calls f itself. +func (f Forwarder) Forward(m *dns.Msg, proto string) (*dns.Msg, error) { + return f(m, proto) +} + +// NewForwarder returns a new Forwarder for the given addrs with the given +// Exchangers map which maps network protocols to Exchangers. +// +// Every message will be exchanged with each address until no error is returned. +// If no addresses or no matching protocol exchanger exist, a *ForwardError will +// be returned. +func NewForwarder(addrs []string, exs map[string]Exchanger) Forwarder { + return func(m *dns.Msg, proto string) (r *dns.Msg, err error) { + ex, ok := exs[proto] + if !ok || len(addrs) == 0 { + return nil, &ForwardError{Addrs: addrs, Proto: proto} + } + for _, a := range addrs { + if r, _, err = ex.Exchange(m, net.JoinHostPort(a, "53")); err == nil { + break + } + } + return + } +} + +// A ForwardError is returned by Forwarders when they can't forward. +type ForwardError struct { + Addrs []string + Proto string +} + +// Error implements the error interface. +func (e ForwardError) Error() string { + return fmt.Sprintf("can't forward to %v over %q", e.Addrs, e.Proto) +} diff --git a/exchanger/forwarder_test.go b/exchanger/forwarder_test.go new file mode 100644 index 00000000..9c3c49f0 --- /dev/null +++ b/exchanger/forwarder_test.go @@ -0,0 +1,95 @@ +package exchanger + +import ( + "errors" + "reflect" + "testing" + "time" + + "github.com/kylelemons/godebug/pretty" + . "github.com/mesosphere/mesos-dns/dnstest" + "github.com/miekg/dns" +) + +func TestForwarder(t *testing.T) { + exs := func(e exchanged, protos ...string) map[string]Exchanger { + es := make(map[string]Exchanger, len(protos)) + for _, proto := range protos { + es[proto] = stub(e) + } + return es + } + + msg := Message(Question("foo.bar", dns.TypeA)) + for i, tt := range []struct { + addrs []string + exs map[string]Exchanger + proto string + r *dns.Msg + err error + }{ + { // no matching protocol + nil, exs(exchanged{}, "udp"), "tcp", nil, &ForwardError{nil, "tcp"}, + }, + { // matching protocol, no addrs + nil, exs(exchanged{}, "udp"), "udp", nil, &ForwardError{nil, "udp"}, + }, + { // matching protocol, no addrs + []string{}, exs(exchanged{}, "udp"), "udp", nil, &ForwardError{[]string{}, "udp"}, + }, + { // matching protocol, one addr, no error exchanging + addrs: []string{"1.2.3.4"}, + exs: exs(exchanged{m: msg}, "udp"), + proto: "udp", + r: msg, + }, + { // matching protocol, one addr, error exchanging + addrs: []string{"1.2.3.4"}, + exs: exs(exchanged{err: errors.New("timeout")}, "udp"), + proto: "udp", + err: errors.New("timeout"), + }, + { // matching protocol, two addrs, error exchanging with the first only + addrs: []string{"1.2.3.4", "2.3.4.5"}, + exs: map[string]Exchanger{ + "udp": Func(func(_ *dns.Msg, a string) (*dns.Msg, time.Duration, error) { + switch a { + case "1.2.3.4": + return nil, 0, errors.New("timeout") + default: + return msg, 0, nil + } + }), + }, + proto: "udp", + r: msg, + }, + { // matching protocol, two addrs, error exchanging with all of them + addrs: []string{"1.2.3.4", "2.3.4.5"}, + exs: map[string]Exchanger{ + "udp": Func(func(_ *dns.Msg, a string) (*dns.Msg, time.Duration, error) { + switch a { + case "1.2.3.4": + return nil, 0, errors.New("timeout") + default: + return nil, 0, errors.New("eof") + } + }), + }, + proto: "udp", + err: errors.New("eof"), + }, + } { + var got forwarded + got.r, got.err = NewForwarder(tt.addrs, tt.exs).Forward(nil, tt.proto) + if want := (forwarded{r: tt.r, err: tt.err}); !reflect.DeepEqual(got, want) { + t.Logf("test #%d\n", i) + t.Error(pretty.Compare(got, want)) + } + } +} + +type forwarded struct { + r *dns.Msg + err error +} diff --git a/logging/logging.go b/logging/logging.go index eabdd848..0abd8f07 100644 --- a/logging/logging.go +++ b/logging/logging.go @@ -46,28 +46,28 @@ func (lc *LogCounter) String() string { // LogOut holds metrics captured in an instrumented runtime. type LogOut struct { - MesosRequests Counter - MesosSuccess Counter - MesosNXDomain Counter - MesosFailed Counter - NonMesosRequests Counter - NonMesosSuccess Counter - NonMesosNXDomain Counter - NonMesosFailed Counter - NonMesosRecursed Counter + MesosRequests Counter + MesosSuccess Counter + MesosNXDomain Counter + MesosFailed Counter + NonMesosRequests Counter + NonMesosSuccess Counter + NonMesosNXDomain Counter + NonMesosFailed Counter + NonMesosForwarded Counter } // CurLog is the default package level LogOut. var CurLog = LogOut{ - MesosRequests: &LogCounter{}, - MesosSuccess: &LogCounter{}, - MesosNXDomain: &LogCounter{}, - MesosFailed: &LogCounter{}, - NonMesosRequests: &LogCounter{}, - NonMesosSuccess: &LogCounter{}, - NonMesosNXDomain: &LogCounter{}, - NonMesosFailed: &LogCounter{}, - NonMesosRecursed: &LogCounter{}, + MesosRequests: &LogCounter{}, + MesosSuccess: &LogCounter{}, + MesosNXDomain: &LogCounter{}, + MesosFailed: &LogCounter{}, + NonMesosRequests: &LogCounter{}, + NonMesosSuccess: &LogCounter{}, + NonMesosNXDomain: &LogCounter{}, + NonMesosFailed: &LogCounter{}, + NonMesosForwarded: &LogCounter{}, } // PrintCurLog prints out the current LogOut and then resets diff --git a/resolver/resolver.go b/resolver/resolver.go index 2e861be7..b1ad5794 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -29,9 +29,7 @@ type Resolver struct { rs *records.RecordGenerator rsLock sync.RWMutex rng *rand.Rand - - // pluggable external DNS resolution, mainly for unit testing - extResolver exchanger.Exchanger + fwd exchanger.Forwarder } // New returns a Resolver with the given version and configuration. @@ -44,40 +42,41 @@ func New(version string, config records.Config) *Resolver { masters: append([]string{""}, config.Masters...), } - if !config.ExternalOn { - return r - } - timeout := 5 * time.Second if config.Timeout != 0 { timeout = time.Duration(config.Timeout) * time.Second } - r.extResolver = newClient(timeout) + rs := config.Resolvers + if !config.ExternalOn { + rs = rs[:0] + } + r.fwd = exchanger.NewForwarder(rs, exchangers(timeout, "udp", "tcp")) return r } -func newClient(timeout time.Duration) exchanger.Exchanger { - clients := make([]exchanger.Exchanger, 2) - for i, proto := range [...]string{"udp", "tcp"} { // See RFC5966 - clients[i] = &dns.Client{ - Net: proto, - DialTimeout: timeout, - ReadTimeout: timeout, - WriteTimeout: timeout, - } +func exchangers(timeout time.Duration, protos ...string) map[string]exchanger.Exchanger { + exs := make(map[string]exchanger.Exchanger, len(protos)) + for _, proto := range protos { + exs[proto] = exchanger.Decorate( + &dns.Client{ + Net: proto, + DialTimeout: timeout, + ReadTimeout: timeout, + WriteTimeout: timeout, + }, + exchanger.ErrorLogging(logging.Error), + exchanger.Instrumentation( + logging.CurLog.NonMesosForwarded, + logging.CurLog.NonMesosSuccess, + logging.CurLog.NonMesosFailed, + ), + ) } - return exchanger.Decorate( - exchanger.While(truncated, clients...), - exchanger.Recursion(3, exchanger.Recurse), - exchanger.ErrorLogging(logging.Error), - exchanger.Instrumentation(logging.CurLog.NonMesosRecursed), - ) + return exs } -func truncated(m *dns.Msg) bool { return m.Truncated } - // return the current (read-only) record set. attempts to write to the returned // object will likely result in a data race. func (res *Resolver) records() *records.RecordGenerator { @@ -246,52 +245,28 @@ func shuffleAnswers(rng *rand.Rand, answers []dns.RR) []dns.RR { return answers } -// HandleNonMesos handles non-mesos queries by recursing to a configured -// external resolver. +// HandleNonMesos handles non-mesos queries by forwarding to configured +// external DNS servers. func (res *Resolver) HandleNonMesos(w dns.ResponseWriter, r *dns.Msg) { - var err error - var m *dns.Msg - - // tracing info logging.CurLog.NonMesosRequests.Inc() - - // If external request are disabled - if res.extResolver == nil { - m = new(dns.Msg) - // set refused - m.SetRcode(r, 5) - } else { - for _, resolver := range res.config.Resolvers { - nameserver := net.JoinHostPort(resolver, "53") - m, _, err = res.extResolver.Exchange(r, nameserver) - if err == nil { - break - } - } - } - - // extResolver returns nil Msg sometimes cause of perf - if m == nil { - m = new(dns.Msg) - m.SetRcode(r, 2) - err = fmt.Errorf("failed external DNS lookup of %q: %v", r.Question[0].Name, err) - } + m, err := res.fwd(r, w.RemoteAddr().Network()) if err != nil { - logging.Error.Println(r.Question[0].Name) - logging.Error.Println(err) - logging.CurLog.NonMesosFailed.Inc() - } else { - // nxdomain - if len(m.Answer) == 0 { - logging.CurLog.NonMesosNXDomain.Inc() - } else { - logging.CurLog.NonMesosSuccess.Inc() - } + m = new(dns.Msg).SetRcode(r, rcode(err)) + } else if len(m.Answer) == 0 { + logging.CurLog.NonMesosNXDomain.Inc() } - reply(w, m) } +func rcode(err error) int { + switch err.(type) { + case *exchanger.ForwardError: + return dns.RcodeRefused + default: + return dns.RcodeServerFailure + } +} + // HandleMesos is a resolver request handler that responds to a resource // question with resource answer(s) // it can handle {A, SRV, ANY} diff --git a/resolver/resolver_test.go b/resolver/resolver_test.go index 0c49f424..f7b74502 100644 --- a/resolver/resolver_test.go +++ b/resolver/resolver_test.go @@ -11,11 +11,9 @@ import ( "reflect" "strconv" "testing" - "time" "github.com/kylelemons/godebug/pretty" . "github.com/mesosphere/mesos-dns/dnstest" - "github.com/mesosphere/mesos-dns/exchanger" "github.com/mesosphere/mesos-dns/logging" "github.com/mesosphere/mesos-dns/records" "github.com/mesosphere/mesos-dns/records/labels" @@ -77,19 +75,19 @@ func TestShuffleAnswers(t *testing.T) { func TestHandlers(t *testing.T) { res := fakeDNS(t) - res.extResolver = exchanger.Func(func(m *dns.Msg, a string) (*dns.Msg, time.Duration, error) { + res.fwd = func(m *dns.Msg, net string) (*dns.Msg, error) { rr1, err := res.formatA("google.com.", "1.1.1.1") if err != nil { - return nil, 0, err + return nil, err } rr2, err := res.formatA("google.com.", "2.2.2.2") if err != nil { - return nil, 0, err + return nil, err } msg := &dns.Msg{Answer: []dns.RR{rr1, rr2}} msg.SetReply(m) - return msg, 0, nil - }) + return msg, nil + } for i, tt := range []struct { dns.HandlerFunc