Skip to content
This repository has been archived by the owner on Oct 23, 2024. It is now read-only.

Commit

Permalink
Transparent forwarding
Browse files Browse the repository at this point in the history
This change set follows RFC5625 and implements protocol aware
transparent forwarding of DNS messages to external servers.

With this we get rid of the poorly implemented recursive semantics of
external resolution which were causing issues.
  • Loading branch information
Tomás Senart committed Oct 7, 2015
1 parent f3e7996 commit 02049c4
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 295 deletions.
69 changes: 9 additions & 60 deletions exchanger/exchanger.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package exchanger

import (
"log"
"net"
"time"

"github.com/mesosphere/mesos-dns/logging"
Expand Down Expand Up @@ -36,24 +35,6 @@ func Decorate(ex Exchanger, ds ...Decorator) Exchanger {
return decorated
}

// Pred is a predicate function type for dns.Msgs.
type Pred func(*dns.Msg) bool

// While returns an Exchanger which attempts the given Exchangers while the given
// predicate function returns true for the returned dns.Msg, an error is returned,
// or all Exchangers are attempted, in which case the return values of the last
// one are returned.
func While(p Pred, exs ...Exchanger) Exchanger {
return Func(func(m *dns.Msg, a string) (r *dns.Msg, rtt time.Duration, err error) {
for _, ex := range exs {
if r, rtt, err = ex.Exchange(m, a); err != nil || !p(r) {
break
}
}
return
})
}

// ErrorLogging returns a Decorator which logs an Exchanger's errors to the given
// logger.
func ErrorLogging(l *log.Logger) Decorator {
Expand All @@ -70,50 +51,18 @@ func ErrorLogging(l *log.Logger) Decorator {
}

// Instrumentation returns a Decorator which instruments an Exchanger with the given
// counter.
func Instrumentation(c logging.Counter) Decorator {
return func(ex Exchanger) Exchanger {
return Func(func(m *dns.Msg, a string) (*dns.Msg, time.Duration, error) {
defer c.Inc()
return ex.Exchange(m, a)
})
}
}

// A Recurser returns the addr (host:port) of the next DNS server to recurse a
// Msg to. Empty returns signal that further recursion isn't possible or needed.
type Recurser func(*dns.Msg) string

// Recurse is the default Mesos-DNS Recurser which returns an addr (host:port)
// only when the given dns.Msg doesn't contain authoritative answers and has at
// least one SOA record in its NS section.
func Recurse(r *dns.Msg) string {
if r.Authoritative && len(r.Answer) > 0 {
return ""
}

for _, ns := range r.Ns {
if soa, ok := ns.(*dns.SOA); ok {
return net.JoinHostPort(soa.Ns, "53")
}
}

return ""
}

// Recursion returns a Decorator which recurses until the given Recurser returns
// an empty string or max attempts have been reached.
func Recursion(max int, rec Recurser) Decorator {
// counters.
func Instrumentation(total, success, failure logging.Counter) Decorator {
return func(ex Exchanger) Exchanger {
return Func(func(m *dns.Msg, a string) (r *dns.Msg, rtt time.Duration, err error) {
for i := 0; i <= max; i++ {
if r, rtt, err = ex.Exchange(m, a); err != nil {
break
} else if a = rec(r); a == "" {
break
defer func() {
if total.Inc(); err != nil {
failure.Inc()
} else {
success.Inc()
}
}
return r, rtt, err
}()
return ex.Exchange(m, a)
})
}
}
186 changes: 40 additions & 146 deletions exchanger/exchanger_test.go
Original file line number Diff line number Diff line change
@@ -1,169 +1,63 @@
package exchanger

import (
"bytes"
"errors"
"net"
"reflect"
"log"
"testing"
"time"

. "github.com/mesosphere/mesos-dns/dnstest"
"github.com/mesosphere/mesos-dns/logging"
"github.com/miekg/dns"
)

func TestWhile(t *testing.T) {
for i, tt := range []struct {
pred Pred
exs []Exchanger
want exchanged
}{
{ // error
nil,
stubs(exchanged{nil, 0, errors.New("foo")}),
exchanged{nil, 0, errors.New("foo")},
},
{ // always true predicate
func(*dns.Msg) bool { return true },
stubs(exchanged{nil, 0, nil}, exchanged{nil, 1, nil}),
exchanged{nil, 1, nil},
},
{ // nil exchangers
nil,
nil,
exchanged{nil, 0, nil},
},
{ // empty exchangers
nil,
stubs(),
exchanged{nil, 0, nil},
},
{ // false predicate
func(calls int) Pred {
return func(*dns.Msg) bool {
calls++
return calls != 2
}
}(0),
stubs(exchanged{nil, 0, nil}, exchanged{nil, 1, nil}, exchanged{nil, 2, nil}),
exchanged{nil, 1, nil},
},
} {
var got exchanged
got.m, got.rtt, got.err = While(tt.pred, tt.exs...).Exchange(nil, "")
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("test #%d: got: %v, want: %v", i, got, tt.want)
func TestErrorLogging(t *testing.T) {
{ // with error
var buf bytes.Buffer
_, _, _ = ErrorLogging(log.New(&buf, "", 0))(
stub(exchanged{err: errors.New("timeout")})).Exchange(nil, "1.2.3.4")

want := "timeout: exchanging (*dns.Msg)(nil) with \"1.2.3.4\"\n"
if got := buf.String(); got != want {
t.Errorf("got %q, want %q", got, want)
}
}
}
{ // no error
var buf bytes.Buffer
_, _, _ = ErrorLogging(log.New(&buf, "", 0))(
stub(exchanged{})).Exchange(nil, "1.2.3.4")

func TestRecurse(t *testing.T) {
for i, tt := range []struct {
*dns.Msg
want string
}{
{ // Authoritative with answers
Message(
Header(true, 0),
Answers(
A(RRHeader("localhost", dns.TypeA, 0), net.IPv6loopback.To4()),
),
NSs(
SOA(RRHeader("", dns.TypeSOA, 0), "next", "", 0),
),
),
"",
},
{ // Authoritative, empty answers, no SOA records
Message(
Header(true, 0),
NSs(
NS(RRHeader("", dns.TypeNS, 0), "next"),
),
),
"",
},
{ // Not authoritative, no SOA record
Message(Header(false, 0)),
"",
},
{ // Not authoritative, one SOA record
Message(
Header(false, 0),
NSs(SOA(RRHeader("", dns.TypeSOA, 0), "next", "", 0)),
),
"next:53",
},
{ // Authoritative, empty answers, one SOA record
Message(
Header(true, 0),
NSs(
NS(RRHeader("", dns.TypeNS, 0), "foo"),
SOA(RRHeader("", dns.TypeSOA, 0), "next", "", 0),
),
),
"next:53",
},
} {
if got := Recurse(tt.Msg); got != tt.want {
t.Errorf("test #%d: got: %v, want: %v", i, got, tt.want)
if got, want := buf.String(), ""; got != want {
t.Errorf("got %q, want %q", got, want)
}
}
}

func TestRecursion(t *testing.T) {
for i, tt := range []struct {
max int
rec Recurser
ex Exchanger
want exchanged
}{
{
0,
func(*dns.Msg) string { return "next" },
seq(stubs(exchanged{rtt: 1})...),
exchanged{rtt: 1},
},
{
1,
func(*dns.Msg) string { return "next" },
seq(stubs(exchanged{rtt: 0}, exchanged{rtt: 1}, exchanged{rtt: 2})...),
exchanged{rtt: 1},
},
{
0,
nil,
seq(stubs(exchanged{err: errors.New("foo")})...),
exchanged{err: errors.New("foo")},
},
{
2,
func(calls int) Recurser {
return func(*dns.Msg) string {
if calls++; calls <= 1 {
return "next"
}
return ""
}
}(0),
seq(stubs(exchanged{rtt: 0}, exchanged{rtt: 1}, exchanged{rtt: 2})...),
exchanged{rtt: 1},
},
} {
var got exchanged
got.m, got.rtt, got.err = Recursion(tt.max, tt.rec)(tt.ex).Exchange(nil, "")
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("test #%d: got: %v, want: %v", i, got, tt.want)
func TestInstrumentation(t *testing.T) {
{ // with error
var total, success, failure logging.LogCounter
_, _, _ = Instrumentation(&total, &success, &failure)(
stub(exchanged{err: errors.New("timeout")})).Exchange(nil, "1.2.3.4")

want := []string{"1", "0", "1"}
for i, c := range []*logging.LogCounter{&total, &success, &failure} {
if got, want := c.String(), want[i]; got != want {
t.Errorf("test #%d: got %q, want %q", i, got, want)
}
}
}
}
{ // no error
var total, success, failure logging.LogCounter
_, _, _ = Instrumentation(&total, &success, &failure)(
stub(exchanged{})).Exchange(nil, "1.2.3.4")

func seq(exs ...Exchanger) Exchanger {
var i int
return Func(func(m *dns.Msg, a string) (*dns.Msg, time.Duration, error) {
ex := exs[i]
i++
return ex.Exchange(m, a)
})
want := []string{"1", "1", "0"}
for i, c := range []*logging.LogCounter{&total, &success, &failure} {
if got, want := c.String(), want[i]; got != want {
t.Errorf("test #%d: got %q, want %q", i, got, want)
}
}
}
}

func stubs(ed ...exchanged) []Exchanger {
Expand Down
49 changes: 49 additions & 0 deletions exchanger/forwarder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package exchanger

import (
"fmt"
"net"

"github.com/miekg/dns"
)

// A Forwarder is a DNS message forwarder that transparently proxies messages
// to DNS servers.
type Forwarder func(*dns.Msg, string) (*dns.Msg, error)

// Forward is an utility method that calls f itself.
func (f Forwarder) Forward(m *dns.Msg, proto string) (*dns.Msg, error) {
return f(m, proto)
}

// NewForwarder returns a new Forwarder for the given addrs with the given
// Exchangers map which maps network protocols to Exchangers.
//
// Every message will be exchanged with each address until no error is returned.
// If no addresses or no matching protocol exchanger exist, a *ForwardError will
// be returned.
func NewForwarder(addrs []string, exs map[string]Exchanger) Forwarder {
return func(m *dns.Msg, proto string) (r *dns.Msg, err error) {
ex, ok := exs[proto]
if !ok || len(addrs) == 0 {
return nil, &ForwardError{Addrs: addrs, Proto: proto}
}
for _, a := range addrs {
if r, _, err = ex.Exchange(m, net.JoinHostPort(a, "53")); err == nil {
break
}
}
return
}
}

// A ForwardError is returned by Forwarders when they can't forward.
type ForwardError struct {
Addrs []string
Proto string
}

// Error implements the error interface.
func (e ForwardError) Error() string {
return fmt.Sprintf("can't forward to %v over %q", e.Addrs, e.Proto)
}
Loading

0 comments on commit 02049c4

Please sign in to comment.