diff --git a/README.md b/README.md index 2847f83..ba0dfbc 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Stegano: The fastest Steganography Library for Go -[![Tests](https://github.com/scott-mescudi/stegano/actions/workflows/go.yml/badge.svg?event=push)](https://github.com/scott-mescudi/stegano/actions/workflows/go.yml) ![GitHub License](https://img.shields.io/github/license/scott-mescudi/stegano) [![Go Reference](https://pkg.go.dev/badge/github.com/scott-mescudi/stegano.svg)](https://pkg.go.dev/github.com/scott-mescudi/stegano) +[![Tests](https://github.com/scott-mescudi/stegano/actions/workflows/go.yml/badge.svg?event=push)](https://github.com/scott-mescudi/stegano/actions/workflows/go.yml) ![GitHub License](https://img.shields.io/github/license/scott-mescudi/stegano) [![Go Reference](https://pkg.go.dev/badge/github.com/scott-mescudi/stegano.svg)](https://pkg.go.dev/github.com/scott-mescudi/stegano)[![Go Report Card](https://goreportcard.com/badge/github.com/scott-mescudi/stegano)](https://goreportcard.com/report/github.com/scott-mescudi/stegano) ## Table of Contents diff --git a/audio.go b/audio.go index 97cd9a0..ede2362 100644 --- a/audio.go +++ b/audio.go @@ -24,7 +24,7 @@ func (s *AudioEmbedHandler) EmbedDataIntoWAVWithDepth(audioFilename, outputFilen return err } - buffer = u.EmbedDataWithDepthAudio(buffer, nd, bitDepth) + buffer, err = u.EmbedDataWithDepthAudio(buffer, nd, bitDepth) err = WriteAudioFile(outputFilename, decoder, buffer) if err != nil { @@ -88,7 +88,10 @@ func (s *AudioEmbedHandler) EmbedDataIntoWAVAtDepth(audioFilename, outputFilenam return err } - buffer = u.EmbedDataAtDepthAudio(buffer, data, bitDepth) + buffer, err = u.EmbedDataAtDepthAudio(buffer, data, bitDepth) + if err != nil { + return ErrInvalidData + } err = WriteAudioFile(outputFilename, decoder, buffer) if err != nil { diff --git a/pkg/AES_test.go b/pkg/AES_test.go new file mode 100644 index 0000000..3dc42cb --- /dev/null +++ b/pkg/AES_test.go @@ -0,0 +1,114 @@ +package pkg + +import ( + "bytes" + "fmt" + "testing" +) + +func TestEncryptDecrypt(t *testing.T) { + tests := []struct { + password string + plaintext []byte + }{ + { + password: "correcthorsebatterystaple", + plaintext: []byte("This is a test message."), + }, + { + password: "password123", + plaintext: []byte("Another test message!"), + }, + { + password: "password123", + plaintext: []byte("Short"), + }, + { + password: "password123", + plaintext: []byte(""), + }, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("password=%s", tt.password), func(t *testing.T) { + // Encrypt + ciphertext, err := Encrypt(tt.password, tt.plaintext) + if err != nil { + t.Fatalf("Encryption failed: %v", err) + } + + // Decrypt + decryptedText, err := Decrypt(tt.password, ciphertext) + if err != nil { + t.Fatalf("Decryption failed: %v", err) + } + + // Check if decrypted text matches original plaintext + if !bytes.Equal(decryptedText, tt.plaintext) { + t.Errorf("Decrypted text does not match original plaintext. Expected: %s, got: %s", tt.plaintext, decryptedText) + } + }) + } +} + +func TestDecryptWithIncorrectPassword(t *testing.T) { + plaintext := []byte("Test for incorrect password") + password := "correctpassword" + incorrectPassword := "wrongpassword" + + // Encrypt with the correct password + ciphertext, err := Encrypt(password, plaintext) + if err != nil { + t.Fatalf("Encryption failed: %v", err) + } + + // Decrypt with incorrect password + _, err = Decrypt(incorrectPassword, ciphertext) + if err == nil { + t.Errorf("Expected decryption to fail with incorrect password") + } +} + +func TestDecryptWithCorruptedCiphertext(t *testing.T) { + plaintext := []byte("Test for corrupted ciphertext") + password := "correctpassword" + + // Encrypt with the correct password + ciphertext, err := Encrypt(password, plaintext) + if err != nil { + t.Fatalf("Encryption failed: %v", err) + } + + // Corrupt the ciphertext by changing some bytes + corruptedCiphertext := append([]byte{}, ciphertext...) + corruptedCiphertext[10] = 0xFF + + // Decrypt the corrupted ciphertext + _, err = Decrypt(password, corruptedCiphertext) + if err == nil { + t.Errorf("Expected decryption to fail with corrupted ciphertext") + } +} + +func TestEmptyPlaintext(t *testing.T) { + // Test with empty plaintext + plaintext := []byte("") + password := "password123" + + // Encrypt with the correct password + ciphertext, err := Encrypt(password, plaintext) + if err != nil { + t.Fatalf("Encryption failed: %v", err) + } + + // Decrypt the ciphertext + decryptedText, err := Decrypt(password, ciphertext) + if err != nil { + t.Fatalf("Decryption failed: %v", err) + } + + // Check if decrypted text is also empty + if len(decryptedText) != 0 { + t.Errorf("Expected empty decrypted text, got: %s", decryptedText) + } +} diff --git a/pkg/audio.go b/pkg/audio.go index 1757bbf..6ab7af6 100644 --- a/pkg/audio.go +++ b/pkg/audio.go @@ -1,10 +1,16 @@ package pkg import ( + "fmt" + "github.com/go-audio/audio" ) -func EmbedDataAtDepthAudio(buffer *audio.IntBuffer, data []byte, depth uint8) *audio.IntBuffer { +func EmbedDataAtDepthAudio(buffer *audio.IntBuffer, data []byte, depth uint8) (*audio.IntBuffer, error) { + if len(data) == 0 { + return nil, fmt.Errorf("Data is empty") + } + dataBits := BytesToBinary(data) lenBits := Int32ToBinary(int32(len(data))) lenBits = append(lenBits, dataBits...) @@ -15,7 +21,7 @@ func EmbedDataAtDepthAudio(buffer *audio.IntBuffer, data []byte, depth uint8) *a } } - return buffer + return buffer, nil } func ExtractDataAtDepthAudio(buffer *audio.IntBuffer, depth uint8) []byte { @@ -38,7 +44,11 @@ func ExtractDataAtDepthAudio(buffer *audio.IntBuffer, depth uint8) []byte { return data } -func EmbedDataWithDepthAudio(buffer *audio.IntBuffer, data []byte, bitDepth uint8) *audio.IntBuffer { +func EmbedDataWithDepthAudio(buffer *audio.IntBuffer, data []byte, bitDepth uint8) (*audio.IntBuffer, error) { + if len(data) == 0 { + return nil, fmt.Errorf("Data is empty") + } + dataBits := BytesToBinary(data) lenBits := Int32ToBinary(int32(len(data))) lenBits = append(lenBits, dataBits...) @@ -59,7 +69,7 @@ func EmbedDataWithDepthAudio(buffer *audio.IntBuffer, data []byte, bitDepth uint } } - return buffer + return buffer, nil } func ExtractDataWithDepthAudio(buffer *audio.IntBuffer, depth uint8) []byte { diff --git a/pkg/audio_test.go b/pkg/audio_test.go new file mode 100644 index 0000000..75b3757 --- /dev/null +++ b/pkg/audio_test.go @@ -0,0 +1,95 @@ +package pkg + +import ( + "bytes" + "fmt" + "github.com/go-audio/audio" + "testing" +) + +func TestEmbedDataWithDepthAudio(t *testing.T) { + tests := []struct { + data []byte + bitDepth uint8 + }{ + { + data: []byte("Hello, world!"), + bitDepth: 1, + }, + { + data: []byte("Another test data"), + bitDepth: 2, + }, + { + data: []byte("Small"), + bitDepth: 3, + }, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("bitDepth=%d", tt.bitDepth), func(t *testing.T) { + // Create a dummy audio buffer with arbitrary values + buffer := &audio.IntBuffer{ + Data: make([]int, len(tt.data)*8000), + Format: &audio.Format{ + SampleRate: 44100, + NumChannels: 1, // Adjusted to use NumChannels instead of SampleWidth + }, + } + + // Embed data in audio buffer with bit depth + embededBuff, _ := EmbedDataWithDepthAudio(buffer, tt.data, tt.bitDepth) + + // Extract data from the audio buffer to verify embedding worked + extractedData := ExtractDataWithDepthAudio(embededBuff, tt.bitDepth) + + // Compare extracted data with original data + if !bytes.Contains(extractedData, tt.data) { + t.Errorf("Extracted data does not match original data. Expected: %s, got: %s", tt.data, extractedData) + } + }) + } +} + +func TestExtractDataWithDepthAudio(t *testing.T) { + // Create a dummy audio buffer with embedded data + data := []byte("Embed this data with bit depth") + bitDepth := uint8(2) + buffer := &audio.IntBuffer{ + Data: make([]int, len(data)*8), + Format: &audio.Format{ + SampleRate: 44100, + NumChannels: 1, // Adjusted to use NumChannels instead of SampleWidth + }, + } + + // Embed data in audio buffer with bit depth + embededBuffer, _ := EmbedDataWithDepthAudio(buffer, data, bitDepth) + + // Extract data from the audio buffer with specific bit depth + extractedData := ExtractDataWithDepthAudio(embededBuffer, bitDepth) + + // Compare extracted data with original data + if !bytes.Contains(extractedData, data) { + t.Errorf("Extracted data does not match original data. Expected: %s, got: %s", data, extractedData) + } +} + +func TestEmptyDataEmbedding(t *testing.T) { + // Test with empty data input + data := []byte("") + depth := uint8(2) + buffer := &audio.IntBuffer{ + Data: make([]int, 100), + Format: &audio.Format{ + SampleRate: 44100, + NumChannels: 1, // Adjusted to use NumChannels instead of SampleWidth + }, + } + + // Embed empty data at depth + _, err := EmbedDataAtDepthAudio(buffer, data, depth) + if err == nil { + t.Error("Embedded empty data instead of returning error") + } +}