// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package securecookie
import (
"crypto/aes"
"crypto/hmac"
"crypto/sha256"
"errors"
"fmt"
"strings"
"testing"
)
var testCookies = []interface{}{
map[string]string{"foo": "bar"},
map[string]string{"baz": "ding"},
}
var testStrings = []string{"foo", "bar", "baz"}
func TestSecureCookie(t *testing.T) {
// TODO test too old / too new timestamps
compareMaps := func(m1, m2 map[string]interface{}) error {
if len(m1) != len(m2) {
return errors.New("different maps")
}
for k, v := range m1 {
if m2[k] != v {
return fmt.Errorf("Different value for key %v: expected %v, got %v", k, m2[k], v)
}
}
return nil
}
s1 := New([]byte("12345"), []byte("1234567890123456"))
s2 := New([]byte("54321"), []byte("6543210987654321"))
value := map[string]interface{}{
"foo": "bar",
"baz": 128,
}
for i := 0; i < 50; i++ {
// Running this multiple times to check if any special character
// breaks encoding/decoding.
encoded, err1 := s1.Encode("sid", value)
if err1 != nil {
t.Error(err1)
continue
}
dst := make(map[string]interface{})
err2 := s1.Decode("sid", encoded, &dst)
if err2 != nil {
t.Fatalf("%v: %v", err2, encoded)
}
if err := compareMaps(dst, value); err != nil {
t.Fatalf("Expected %v, got %v.", value, dst)
}
dst2 := make(map[string]interface{})
err3 := s2.Decode("sid", encoded, &dst2)
if err3 == nil {
t.Fatalf("Expected failure decoding.")
}
}
}
func TestAuthentication(t *testing.T) {
hash := hmac.New(sha256.New, []byte("secret-key"))
for _, value := range testStrings {
hash.Reset()
signed := createMac(hash, []byte(value))
hash.Reset()
err := verifyMac(hash, []byte(value), signed)
if err != nil {
t.Error(err)
}
}
}
func TestEncription(t *testing.T) {
block, err := aes.NewCipher([]byte("1234567890123456"))
if err != nil {
t.Fatalf("Block could not be created")
}
var encrypted, decrypted []byte
for _, value := range testStrings {
if encrypted, err = encrypt(block, []byte(value)); err != nil {
t.Error(err)
} else {
if decrypted, err = decrypt(block, encrypted); err != nil {
t.Error(err)
}
if string(decrypted) != value {
t.Errorf("Expected %v, got %v.", value, string(decrypted))
}
}
}
}
func TestSerialization(t *testing.T) {
var (
serialized []byte
deserialized map[string]string
err error
)
for _, value := range testCookies {
if serialized, err = serialize(value); err != nil {
t.Error(err)
} else {
deserialized = make(map[string]string)
if err = deserialize(serialized, &deserialized); err != nil {
t.Error(err)
}
if fmt.Sprintf("%v", deserialized) != fmt.Sprintf("%v", value) {
t.Errorf("Expected %v, got %v.", value, deserialized)
}
}
}
}
func TestEncoding(t *testing.T) {
for _, value := range testStrings {
encoded := encode([]byte(value))
decoded, err := decode(encoded)
if err != nil {
t.Error(err)
} else if string(decoded) != value {
t.Errorf("Expected %v, got %s.", value, string(decoded))
}
}
}
func TestMultiError(t *testing.T) {
s1, s2 := New(nil, nil), New(nil, nil)
_, err := EncodeMulti("sid", "value", s1, s2)
if len(err.(MultiError)) != 2 {
t.Errorf("Expected 2 errors, got %s.", err)
} else {
if strings.Index(err.Error(), "hash key is not set") == -1 {
t.Errorf("Expected missing hash key error, got %s.", err.Error())
}
}
}
// ----------------------------------------------------------------------------
type FooBar struct {
Foo int
Bar string
}
func TestCustomType(t *testing.T) {
s1 := New([]byte("12345"), []byte("1234567890123456"))
// Type is not registered in gob. (!!!)
src := &FooBar{42, "bar"}
encoded, _ := s1.Encode("sid", src)
dst := &FooBar{}
_ = s1.Decode("sid", encoded, dst)
if dst.Foo != 42 || dst.Bar != "bar" {
t.Fatalf("Expected %#v, got %#v", src, dst)
}
}