Skip to content

Commit

Permalink
Adds support for SSH ProxyCommand to be able to use "Bastion" servers
Browse files Browse the repository at this point in the history
  • Loading branch information
vo-va committed Dec 12, 2024
1 parent 42e5637 commit 163f4e4
Show file tree
Hide file tree
Showing 7 changed files with 294 additions and 43 deletions.
11 changes: 6 additions & 5 deletions pkg/config/playbook.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,12 @@ type Target struct {

// Destination defines destination info
type Destination struct {
Name string `yaml:"name" toml:"name"`
Host string `yaml:"host" toml:"host"`
Port int `yaml:"port" toml:"port"`
User string `yaml:"user" toml:"user"`
Tags []string `yaml:"tags" toml:"tags"`
Name string `yaml:"name" toml:"name"`
Host string `yaml:"host" toml:"host"`
Port int `yaml:"port" toml:"port"`
User string `yaml:"user" toml:"user"`
Tags []string `yaml:"tags" toml:"tags"`
ProxyCommand []string `yaml:"proxy_command" toml:"proxy_command"`
}

// Overrides defines override for task passed from cli
Expand Down
88 changes: 77 additions & 11 deletions pkg/executor/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"log"
"net"
"os"
"os/exec"
"strings"
"time"

Expand All @@ -22,6 +23,31 @@ type Connector struct {
logs Logs
}

// In the ProxyCommand variables can be used %h, %p, %r (%r - username)

Check failure on line 26 in pkg/executor/connector.go

View workflow job for this annotation

GitHub Actions / build

exported: comment on exported function SubstituteProxyCommand should be of the form "SubstituteProxyCommand ..." (revive)
// before executing the command they needs to be replaced with the actual values
func SubstituteProxyCommand(username, address string, proxyCommand []string) ([]string, error) {
if len(proxyCommand) == 0 {
return []string{}, nil
}

host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, fmt.Errorf("failed to split hostAddr and port: %w", err)
}

cmdArgs := make([]string, len(proxyCommand))

for i, arg := range proxyCommand {
arg = strings.Replace(arg, "%h", host, -1)
if port != "" {
arg = strings.Replace(arg, "%p", port, -1)
}
arg = strings.Replace(arg, "%r", username, -1)
cmdArgs[i] = arg
}
return cmdArgs, nil
}

// NewConnector creates a new Connector for a given user and private key.
func NewConnector(privateKey string, timeout time.Duration, logs Logs) (res *Connector, err error) {
res = &Connector{privateKey: privateKey, timeout: timeout, logs: logs}
Expand Down Expand Up @@ -52,9 +78,9 @@ func (c *Connector) WithAgentForwarding() *Connector {
}

// Connect connects to a remote hostAddr and returns a remote executer, caller must close.
func (c *Connector) Connect(ctx context.Context, hostAddr, hostName, user string) (*Remote, error) {
log.Printf("[DEBUG] connect to %q (%s), user %q", hostAddr, hostName, user)
client, err := c.sshClient(ctx, hostAddr, user)
func (c *Connector) Connect(ctx context.Context, hostAddr, hostName, user string, proxyCommand []string) (*Remote, error) {
log.Printf("[DEBUG] connect to %q (%s), user %q, proxy command: %s", hostAddr, hostName, user, proxyCommand)
client, err := c.sshClient(ctx, hostAddr, user, proxyCommand)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -89,27 +115,67 @@ func (c *Connector) forwardAgent(client *ssh.Client) error {
return nil
}

func (c *Connector) sshClient(ctx context.Context, host, user string) (session *ssh.Client, err error) {
func sshDialWithProxy(ctx context.Context, host string, cmdArgs []string, config *ssh.ClientConfig) (*ssh.Client, error) {

client, server := net.Pipe()

cmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)

Check failure on line 122 in pkg/executor/connector.go

View workflow job for this annotation

GitHub Actions / build

G204: Subprocess launched with a potential tainted input or cmd arguments (gosec)
cmd.Stdin = server
cmd.Stdout = server
cmd.Stderr = os.Stderr

if err := cmd.Start(); err != nil {
return nil, err
}

ncc, chans, reqs, err := ssh.NewClientConn(client, host, config)
if err != nil {
return nil, err
}

return ssh.NewClient(ncc, chans, reqs), nil

}

func (c *Connector) sshClient(ctx context.Context, host, user string, proxyCommand []string) (session *ssh.Client, err error) {
var conn net.Conn
var client *ssh.Client

log.Printf("[DEBUG] create ssh session to %s, user %s", host, user)
log.Printf("[DEBUG] ProxyCommand %s ", proxyCommand)
if !strings.Contains(host, ":") {
host += ":22"
}

dialer := net.Dialer{Timeout: c.timeout}
conn, err := dialer.DialContext(ctx, "tcp", host)
cmdArgs, err := SubstituteProxyCommand(user, host, proxyCommand)
if err != nil {
return nil, fmt.Errorf("failed to dial: %w", err)
return nil, fmt.Errorf("failed to parse proxy command: %w", err)
}

conf, err := c.sshConfig(user, c.privateKey)
if err != nil {
return nil, fmt.Errorf("failed to create ssh config: %w", err)
}
ncc, chans, reqs, err := ssh.NewClientConn(conn, host, conf)
if err != nil {
return nil, fmt.Errorf("failed to create client connection to %s: %v", host, err)

if len(proxyCommand) == 0 {
dialer := net.Dialer{Timeout: c.timeout}
conn, err = dialer.DialContext(ctx, "tcp", host)
if err != nil {
return nil, fmt.Errorf("failed to dial: %w", err)
}

ncc, chans, reqs, err := ssh.NewClientConn(conn, host, conf)

Check failure on line 167 in pkg/executor/connector.go

View workflow job for this annotation

GitHub Actions / build

shadow: declaration of "err" shadows declaration at line 140 (govet)
if err != nil {
return nil, fmt.Errorf("failed to create client connection to %s: %v", host, err)
}
client = ssh.NewClient(ncc, chans, reqs)

} else {
client, err = sshDialWithProxy(ctx, host, cmdArgs, conf)
if err != nil {
return nil, fmt.Errorf("failed to create client connection wtth proxy command %s, to %s: %v", proxyCommand, host, err)
}
}
client := ssh.NewClient(ncc, chans, reqs)

if err := c.forwardAgent(client); err != nil {
return nil, fmt.Errorf("failed to forward agent to %s: %v", host, err)
Expand Down
119 changes: 114 additions & 5 deletions pkg/executor/connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package executor

import (
"context"
"strings"
"testing"
"time"

Expand All @@ -16,15 +17,15 @@ func TestConnector_Connect(t *testing.T) {
t.Run("good connection", func(t *testing.T) {
c, err := NewConnector("testdata/test_ssh_key", time.Second*10, MakeLogs(true, false, nil))
require.NoError(t, err)
sess, err := c.Connect(ctx, hostAndPort, "h1", "test")
sess, err := c.Connect(ctx, hostAndPort, "h1", "test", []string{})
require.NoError(t, err)
defer sess.Close()
})

t.Run("bad user", func(t *testing.T) {
c, err := NewConnector("testdata/test_ssh_key", time.Second*10, MakeLogs(true, false, nil))
require.NoError(t, err)
_, err = c.Connect(ctx, hostAndPort, "h1", "test33")
_, err = c.Connect(ctx, hostAndPort, "h1", "test33", []string{})
require.ErrorContains(t, err, "ssh: unable to authenticate")
})

Expand All @@ -36,21 +37,129 @@ func TestConnector_Connect(t *testing.T) {
t.Run("wrong port", func(t *testing.T) {
c, err := NewConnector("testdata/test_ssh_key", time.Second*10, MakeLogs(true, false, nil))
require.NoError(t, err)
_, err = c.Connect(ctx, "127.0.0.1:12345", "h1", "test")
_, err = c.Connect(ctx, "127.0.0.1:12345", "h1", "test", []string{})
require.ErrorContains(t, err, "failed to dial: dial tcp 127.0.0.1:12345")
})

t.Run("timeout", func(t *testing.T) {
c, err := NewConnector("testdata/test_ssh_key", time.Nanosecond, MakeLogs(true, false, nil))
require.NoError(t, err)
_, err = c.Connect(ctx, hostAndPort, "h1", "test")
_, err = c.Connect(ctx, hostAndPort, "h1", "test", []string{})
require.ErrorContains(t, err, "i/o timeout")
})

t.Run("unreachable host", func(t *testing.T) {
c, err := NewConnector("testdata/test_ssh_key", time.Second, MakeLogs(true, false, nil))
require.NoError(t, err)
_, err = c.Connect(ctx, "10.255.255.1:22", "h1", "test")
_, err = c.Connect(ctx, "10.255.255.1:22", "h1", "test", []string{})
require.ErrorContains(t, err, "failed to dial: dial tcp 10.255.255.1:22: i/o timeout")
})
}

func TestConnector_ConnectWithProxy(t *testing.T) {
ctx := context.Background()

bastionHostAndPort, _, teardown := start2TestContainers(t)
defer teardown()

t.Run("good connection", func(t *testing.T) {
c, err := NewConnector("testdata/test_ssh_key", time.Second*10, MakeLogs(true, false, nil))
require.NoError(t, err)
sess, err := c.Connect(ctx, bastionHostAndPort, "bastion-host", "test", []string{})
require.NoError(t, err)
defer sess.Close()
})

// To test proxy command, the chain of connection will be next:
// localhost -> localhost:<random_port> (this is also the bastion host) -> target-host:2222
// In a real-world application, "target-host:2222" will be replaced with "%h:%p", but since
// testcontainers returns "localhost:<random_port>" manually, overriding it.

// "ssh -W" requires enabling AllowTcpForwarding, to enable it, modification was applied:
// see pkg/executor/remote_test.go, env variable DOCKER_MODS on test container.
// The "bastion-host" is a local host, and we are using a standard SSH client which tries to verify the host key;
// to bypass this check, "-o StrictHostKeyChecking=no” was added to the proxy command.

// There is a situation that I am not sure if it is a bug or should be handled on client/spot side.
// If ssh server on proxy server works, but forbid TCP forwarding, go ssh client will connect but will not abort
// the connection or return error, it will just print to terminal
// "channel open failed: open failed: administratively prohibited: open failed".

t.Run("good connection with proxy", func(t *testing.T) {
c, err := NewConnector("testdata/test_ssh_key", time.Second*10, MakeLogs(true, false, nil))
require.NoError(t, err)

bastionAddr := strings.Split(bastionHostAndPort, ":")

proxyCommand := []string{
"ssh",
"-W",
"target-host:2222",
"test@localhost",
"-p",
bastionAddr[1],
"-i",
"testdata/test_ssh_key",
"-o",
"StrictHostKeyChecking=no",
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

sess, err := c.Connect(ctx, "target-host:2222", "target-host", "test", proxyCommand)
require.NoError(t, err)
defer sess.Close()
})

}

func TestSubstituteProxyCommand(t *testing.T) {
tests := []struct {
username string
address string
proxyCommand []string
expected []string
expectError bool
}{
{
username: "user",
address: "example.com:22",
proxyCommand: []string{"ssh", "-W", "%h:%p", "%[email protected]"},
expected: []string{"ssh", "-W", "example.com:22", "[email protected]"},
expectError: false,
},
{
username: "user",
address: "example.com:22",
proxyCommand: []string{"ssh", "-W", "%h:%p", "%[email protected]", "random arg with spaces"},
expected: []string{"ssh", "-W", "example.com:22", "[email protected]", "random arg with spaces"},
expectError: false,
},
{
username: "user",
address: "example.com",
proxyCommand: []string{"ssh", "-W", "%h:%p", "%[email protected]"},
expected: nil,
expectError: true,
},
{
username: "user",
address: "example.com:22",
proxyCommand: []string{},
expected: []string{},
expectError: false,
},
}

for _, test := range tests {
t.Run(test.address, func(t *testing.T) {
result, err := SubstituteProxyCommand(test.username, test.address, test.proxyCommand)
if test.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, test.expected, result)
}
})
}
}
Loading

0 comments on commit 163f4e4

Please sign in to comment.