Skip to content

Commit

Permalink
Session's contexts can now be updated
Browse files Browse the repository at this point in the history
This is needed for DTLS restart, as mentioned in pion/webrtc#1636.
  • Loading branch information
Antonito committed Jun 26, 2021
1 parent d9aae44 commit 98dcba4
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 6 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ Check out the **[contributing wiki](https://github.com/pion/webrtc/wiki/Contribu
* [Mission Liao](https://github.com/mission-liao)
* [Orlando](https://github.com/OrlandoCo)
* [Tarrence van As](https://github.com/tarrencev)
* [Antoine Baché](https://github.com/Antonito)

### License
MIT License - see [LICENSE](LICENSE) for full text
32 changes: 28 additions & 4 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"io"
"net"
"sync"
"sync/atomic"

"github.com/pion/logging"
"github.com/pion/transport/packetio"
Expand All @@ -17,7 +18,8 @@ type streamSession interface {

type session struct {
localContextMutex sync.Mutex
localContext, remoteContext *Context
localContext *Context
remoteContext atomic.Value // *Context
localOptions, remoteOptions []ContextOption

newStream chan readStream
Expand Down Expand Up @@ -106,17 +108,19 @@ func (s *session) close() error {
}

func (s *session) start(localMasterKey, localMasterSalt, remoteMasterKey, remoteMasterSalt []byte, profile ProtectionProfile, child streamSession) error {
var err error
s.localContext, err = CreateContext(localMasterKey, localMasterSalt, profile, s.localOptions...)
localContext, err := CreateContext(localMasterKey, localMasterSalt, profile, s.localOptions...)
if err != nil {
return err
}

s.remoteContext, err = CreateContext(remoteMasterKey, remoteMasterSalt, profile, s.remoteOptions...)
remoteContext, err := CreateContext(remoteMasterKey, remoteMasterSalt, profile, s.remoteOptions...)
if err != nil {
return err
}

s.localContext = localContext
s.remoteContext.Store(remoteContext)

go func() {
defer func() {
close(s.newStream)
Expand Down Expand Up @@ -148,3 +152,23 @@ func (s *session) start(localMasterKey, localMasterSalt, remoteMasterKey, remote

return nil
}

// UpdateContext updates the local and remote context of the session.
func (s *session) UpdateContext(config *Config) error {
localContext, err := CreateContext(config.Keys.LocalMasterKey, config.Keys.LocalMasterSalt, config.Profile, s.localOptions...)
if err != nil {
return err
}
remoteContext, err := CreateContext(config.Keys.RemoteMasterKey, config.Keys.RemoteMasterSalt, config.Profile, s.remoteOptions...)
if err != nil {
return err
}

s.localContextMutex.Lock()
s.localContext = localContext
s.localContextMutex.Unlock()

s.remoteContext.Store(remoteContext)

return nil
}
4 changes: 3 additions & 1 deletion session_srtcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ func destinationSSRC(pkts []rtcp.Packet) []uint32 {
}

func (s *SessionSRTCP) decrypt(buf []byte) error {
decrypted, err := s.remoteContext.DecryptRTCP(buf, buf, nil)
// Safe since remoteContext always contains a *Context.
remoteContext := s.remoteContext.Load().(*Context)
decrypted, err := remoteContext.DecryptRTCP(buf, buf, nil)
if err != nil {
return err
}
Expand Down
4 changes: 3 additions & 1 deletion session_srtp.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ func (s *SessionSRTP) decrypt(buf []byte) error {
return errFailedTypeAssertion
}

decrypted, err := s.remoteContext.decryptRTP(buf, buf, h)
// Safe since remoteContext always contains a *Context.
remoteContext := s.remoteContext.Load().(*Context)
decrypted, err := remoteContext.decryptRTP(buf, buf, h)
if err != nil {
return err
}
Expand Down

0 comments on commit 98dcba4

Please sign in to comment.