Skip to content
This repository has been archived by the owner on Oct 23, 2024. It is now read-only.

Commit

Permalink
Merge pull request #307 from mesosphere/rfc5625
Browse files Browse the repository at this point in the history
Transparent proxying of external queries
  • Loading branch information
tsenart committed Oct 7, 2015
2 parents f3e7996 + 02049c4 commit 54c5137
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 295 deletions.
69 changes: 9 additions & 60 deletions exchanger/exchanger.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package exchanger

import (
"log"
"net"
"time"

"github.com/mesosphere/mesos-dns/logging"
Expand Down Expand Up @@ -36,24 +35,6 @@ func Decorate(ex Exchanger, ds ...Decorator) Exchanger {
return decorated
}

// Pred is a predicate function type for dns.Msgs.
type Pred func(*dns.Msg) bool

// While returns an Exchanger which attempts the given Exchangers while the given
// predicate function returns true for the returned dns.Msg, an error is returned,
// or all Exchangers are attempted, in which case the return values of the last
// one are returned.
func While(p Pred, exs ...Exchanger) Exchanger {
return Func(func(m *dns.Msg, a string) (r *dns.Msg, rtt time.Duration, err error) {
for _, ex := range exs {
if r, rtt, err = ex.Exchange(m, a); err != nil || !p(r) {
break
}
}
return
})
}

// ErrorLogging returns a Decorator which logs an Exchanger's errors to the given
// logger.
func ErrorLogging(l *log.Logger) Decorator {
Expand All @@ -70,50 +51,18 @@ func ErrorLogging(l *log.Logger) Decorator {
}

// Instrumentation returns a Decorator which instruments an Exchanger with the given
// counter.
func Instrumentation(c logging.Counter) Decorator {
return func(ex Exchanger) Exchanger {
return Func(func(m *dns.Msg, a string) (*dns.Msg, time.Duration, error) {
defer c.Inc()
return ex.Exchange(m, a)
})
}
}

// A Recurser returns the addr (host:port) of the next DNS server to recurse a
// Msg to. Empty returns signal that further recursion isn't possible or needed.
type Recurser func(*dns.Msg) string

// Recurse is the default Mesos-DNS Recurser which returns an addr (host:port)
// only when the given dns.Msg doesn't contain authoritative answers and has at
// least one SOA record in its NS section.
func Recurse(r *dns.Msg) string {
if r.Authoritative && len(r.Answer) > 0 {
return ""
}

for _, ns := range r.Ns {
if soa, ok := ns.(*dns.SOA); ok {
return net.JoinHostPort(soa.Ns, "53")
}
}

return ""
}

// Recursion returns a Decorator which recurses until the given Recurser returns
// an empty string or max attempts have been reached.
func Recursion(max int, rec Recurser) Decorator {
// counters.
func Instrumentation(total, success, failure logging.Counter) Decorator {
return func(ex Exchanger) Exchanger {
return Func(func(m *dns.Msg, a string) (r *dns.Msg, rtt time.Duration, err error) {
for i := 0; i <= max; i++ {
if r, rtt, err = ex.Exchange(m, a); err != nil {
break
} else if a = rec(r); a == "" {
break
defer func() {
if total.Inc(); err != nil {
failure.Inc()
} else {
success.Inc()
}
}
return r, rtt, err
}()
return ex.Exchange(m, a)
})
}
}
186 changes: 40 additions & 146 deletions exchanger/exchanger_test.go
Original file line number Diff line number Diff line change
@@ -1,169 +1,63 @@
package exchanger

import (
"bytes"
"errors"
"net"
"reflect"
"log"
"testing"
"time"

. "github.com/mesosphere/mesos-dns/dnstest"
"github.com/mesosphere/mesos-dns/logging"
"github.com/miekg/dns"
)

func TestWhile(t *testing.T) {
for i, tt := range []struct {
pred Pred
exs []Exchanger
want exchanged
}{
{ // error
nil,
stubs(exchanged{nil, 0, errors.New("foo")}),
exchanged{nil, 0, errors.New("foo")},
},
{ // always true predicate
func(*dns.Msg) bool { return true },
stubs(exchanged{nil, 0, nil}, exchanged{nil, 1, nil}),
exchanged{nil, 1, nil},
},
{ // nil exchangers
nil,
nil,
exchanged{nil, 0, nil},
},
{ // empty exchangers
nil,
stubs(),
exchanged{nil, 0, nil},
},
{ // false predicate
func(calls int) Pred {
return func(*dns.Msg) bool {
calls++
return calls != 2
}
}(0),
stubs(exchanged{nil, 0, nil}, exchanged{nil, 1, nil}, exchanged{nil, 2, nil}),
exchanged{nil, 1, nil},
},
} {
var got exchanged
got.m, got.rtt, got.err = While(tt.pred, tt.exs...).Exchange(nil, "")
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("test #%d: got: %v, want: %v", i, got, tt.want)
func TestErrorLogging(t *testing.T) {
{ // with error
var buf bytes.Buffer
_, _, _ = ErrorLogging(log.New(&buf, "", 0))(
stub(exchanged{err: errors.New("timeout")})).Exchange(nil, "1.2.3.4")

want := "timeout: exchanging (*dns.Msg)(nil) with \"1.2.3.4\"\n"
if got := buf.String(); got != want {
t.Errorf("got %q, want %q", got, want)
}
}
}
{ // no error
var buf bytes.Buffer
_, _, _ = ErrorLogging(log.New(&buf, "", 0))(
stub(exchanged{})).Exchange(nil, "1.2.3.4")

func TestRecurse(t *testing.T) {
for i, tt := range []struct {
*dns.Msg
want string
}{
{ // Authoritative with answers
Message(
Header(true, 0),
Answers(
A(RRHeader("localhost", dns.TypeA, 0), net.IPv6loopback.To4()),
),
NSs(
SOA(RRHeader("", dns.TypeSOA, 0), "next", "", 0),
),
),
"",
},
{ // Authoritative, empty answers, no SOA records
Message(
Header(true, 0),
NSs(
NS(RRHeader("", dns.TypeNS, 0), "next"),
),
),
"",
},
{ // Not authoritative, no SOA record
Message(Header(false, 0)),
"",
},
{ // Not authoritative, one SOA record
Message(
Header(false, 0),
NSs(SOA(RRHeader("", dns.TypeSOA, 0), "next", "", 0)),
),
"next:53",
},
{ // Authoritative, empty answers, one SOA record
Message(
Header(true, 0),
NSs(
NS(RRHeader("", dns.TypeNS, 0), "foo"),
SOA(RRHeader("", dns.TypeSOA, 0), "next", "", 0),
),
),
"next:53",
},
} {
if got := Recurse(tt.Msg); got != tt.want {
t.Errorf("test #%d: got: %v, want: %v", i, got, tt.want)
if got, want := buf.String(), ""; got != want {
t.Errorf("got %q, want %q", got, want)
}
}
}

func TestRecursion(t *testing.T) {
for i, tt := range []struct {
max int
rec Recurser
ex Exchanger
want exchanged
}{
{
0,
func(*dns.Msg) string { return "next" },
seq(stubs(exchanged{rtt: 1})...),
exchanged{rtt: 1},
},
{
1,
func(*dns.Msg) string { return "next" },
seq(stubs(exchanged{rtt: 0}, exchanged{rtt: 1}, exchanged{rtt: 2})...),
exchanged{rtt: 1},
},
{
0,
nil,
seq(stubs(exchanged{err: errors.New("foo")})...),
exchanged{err: errors.New("foo")},
},
{
2,
func(calls int) Recurser {
return func(*dns.Msg) string {
if calls++; calls <= 1 {
return "next"
}
return ""
}
}(0),
seq(stubs(exchanged{rtt: 0}, exchanged{rtt: 1}, exchanged{rtt: 2})...),
exchanged{rtt: 1},
},
} {
var got exchanged
got.m, got.rtt, got.err = Recursion(tt.max, tt.rec)(tt.ex).Exchange(nil, "")
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("test #%d: got: %v, want: %v", i, got, tt.want)
func TestInstrumentation(t *testing.T) {
{ // with error
var total, success, failure logging.LogCounter
_, _, _ = Instrumentation(&total, &success, &failure)(
stub(exchanged{err: errors.New("timeout")})).Exchange(nil, "1.2.3.4")

want := []string{"1", "0", "1"}
for i, c := range []*logging.LogCounter{&total, &success, &failure} {
if got, want := c.String(), want[i]; got != want {
t.Errorf("test #%d: got %q, want %q", i, got, want)
}
}
}
}
{ // no error
var total, success, failure logging.LogCounter
_, _, _ = Instrumentation(&total, &success, &failure)(
stub(exchanged{})).Exchange(nil, "1.2.3.4")

func seq(exs ...Exchanger) Exchanger {
var i int
return Func(func(m *dns.Msg, a string) (*dns.Msg, time.Duration, error) {
ex := exs[i]
i++
return ex.Exchange(m, a)
})
want := []string{"1", "1", "0"}
for i, c := range []*logging.LogCounter{&total, &success, &failure} {
if got, want := c.String(), want[i]; got != want {
t.Errorf("test #%d: got %q, want %q", i, got, want)
}
}
}
}

func stubs(ed ...exchanged) []Exchanger {
Expand Down
49 changes: 49 additions & 0 deletions exchanger/forwarder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package exchanger

import (
"fmt"
"net"

"github.com/miekg/dns"
)

// A Forwarder is a DNS message forwarder that transparently proxies messages
// to DNS servers.
type Forwarder func(*dns.Msg, string) (*dns.Msg, error)

// Forward is an utility method that calls f itself.
func (f Forwarder) Forward(m *dns.Msg, proto string) (*dns.Msg, error) {
return f(m, proto)
}

// NewForwarder returns a new Forwarder for the given addrs with the given
// Exchangers map which maps network protocols to Exchangers.
//
// Every message will be exchanged with each address until no error is returned.
// If no addresses or no matching protocol exchanger exist, a *ForwardError will
// be returned.
func NewForwarder(addrs []string, exs map[string]Exchanger) Forwarder {
return func(m *dns.Msg, proto string) (r *dns.Msg, err error) {
ex, ok := exs[proto]
if !ok || len(addrs) == 0 {
return nil, &ForwardError{Addrs: addrs, Proto: proto}
}
for _, a := range addrs {
if r, _, err = ex.Exchange(m, net.JoinHostPort(a, "53")); err == nil {
break
}
}
return
}
}

// A ForwardError is returned by Forwarders when they can't forward.
type ForwardError struct {
Addrs []string
Proto string
}

// Error implements the error interface.
func (e ForwardError) Error() string {
return fmt.Sprintf("can't forward to %v over %q", e.Addrs, e.Proto)
}
Loading

0 comments on commit 54c5137

Please sign in to comment.