aboutsummaryrefslogblamecommitdiff
path: root/securecookie_test.go
blob: 33ce4fc2bc57452331ec02e2916a002dfa456c0c (plain) (tree)
1
2
3
4
5
6
7
8
9
10









                                                           
                         
             
                 
                 


                 



                                                                     








                                                 



















                                                                                
                                                   






                                                                    














                                                                                     


         











                                                     













                                                                              











                                                                                     



                 












                                                             
                                   


















                                                                                          
                                         
             
                                       




                                              
                                                                      


                                                              






















                                                                                        




















                                                                                        








                                                                                         









                                                                                               


         












                                                                          







                                                                        


                                                                                   

 


















                                                                               
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package securecookie

import (
	"crypto/aes"
	"crypto/hmac"
	"crypto/sha256"
	"encoding/base64"
	"fmt"
	"reflect"
	"strings"
	"testing"
)

// Asserts that cookieError and MultiError are Error implementations.
var _ Error = cookieError{}
var _ Error = MultiError{}

var testCookies = []interface{}{
	map[string]string{"foo": "bar"},
	map[string]string{"baz": "ding"},
}

var testStrings = []string{"foo", "bar", "baz"}

func TestSecureCookie(t *testing.T) {
	// TODO test too old / too new timestamps
	s1 := New([]byte("12345"), []byte("1234567890123456"))
	s2 := New([]byte("54321"), []byte("6543210987654321"))
	value := map[string]interface{}{
		"foo": "bar",
		"baz": 128,
	}

	for i := 0; i < 50; i++ {
		// Running this multiple times to check if any special character
		// breaks encoding/decoding.
		encoded, err1 := s1.Encode("sid", value)
		if err1 != nil {
			t.Error(err1)
			continue
		}
		dst := make(map[string]interface{})
		err2 := s1.Decode("sid", encoded, &dst)
		if err2 != nil {
			t.Fatalf("%v: %v", err2, encoded)
		}
		if !reflect.DeepEqual(dst, value) {
			t.Fatalf("Expected %v, got %v.", value, dst)
		}
		dst2 := make(map[string]interface{})
		err3 := s2.Decode("sid", encoded, &dst2)
		if err3 == nil {
			t.Fatalf("Expected failure decoding.")
		}
		err4, ok := err3.(Error)
		if !ok {
			t.Fatalf("Expected error to implement Error, got: %#v", err3)
		}
		if !err4.IsDecode() {
			t.Fatalf("Expected DecodeError, got: %#v", err4)
		}

		// Test other error type flags.
		if err4.IsUsage() {
			t.Fatalf("Expected IsUsage() == false, got: %#v", err4)
		}
		if err4.IsInternal() {
			t.Fatalf("Expected IsInternal() == false, got: %#v", err4)
		}
	}
}

func TestSecureCookieNilKey(t *testing.T) {
	s1 := New(nil, nil)
	value := map[string]interface{}{
		"foo": "bar",
		"baz": 128,
	}
	_, err := s1.Encode("sid", value)
	if err != errHashKeyNotSet {
		t.Fatal("Wrong error returned:", err)
	}
}

func TestDecodeInvalid(t *testing.T) {
	// List of invalid cookies, which must not be accepted, base64-decoded
	// (they will be encoded before passing to Decode).
	invalidCookies := []string{
		"",
		" ",
		"\n",
		"||",
		"|||",
		"cookie",
	}
	s := New([]byte("12345"), nil)
	var dst string
	for i, v := range invalidCookies {
		for _, enc := range []*base64.Encoding{
			base64.StdEncoding,
			base64.URLEncoding,
		} {
			err := s.Decode("name", enc.EncodeToString([]byte(v)), &dst)
			if err == nil {
				t.Fatalf("%d: expected failure decoding", i)
			}
			err2, ok := err.(Error)
			if !ok || !err2.IsDecode() {
				t.Fatalf("%d: Expected IsDecode(), got: %#v", i, err)
			}
		}
	}
}

func TestAuthentication(t *testing.T) {
	hash := hmac.New(sha256.New, []byte("secret-key"))
	for _, value := range testStrings {
		hash.Reset()
		signed := createMac(hash, []byte(value))
		hash.Reset()
		err := verifyMac(hash, []byte(value), signed)
		if err != nil {
			t.Error(err)
		}
	}
}

func TestEncryption(t *testing.T) {
	block, err := aes.NewCipher([]byte("1234567890123456"))
	if err != nil {
		t.Fatalf("Block could not be created")
	}
	var encrypted, decrypted []byte
	for _, value := range testStrings {
		if encrypted, err = encrypt(block, []byte(value)); err != nil {
			t.Error(err)
		} else {
			if decrypted, err = decrypt(block, encrypted); err != nil {
				t.Error(err)
			}
			if string(decrypted) != value {
				t.Errorf("Expected %v, got %v.", value, string(decrypted))
			}
		}
	}
}

func TestGobSerialization(t *testing.T) {
	var (
		sz           GobEncoder
		serialized   []byte
		deserialized map[string]string
		err          error
	)
	for _, value := range testCookies {
		if serialized, err = sz.Serialize(value); err != nil {
			t.Error(err)
		} else {
			deserialized = make(map[string]string)
			if err = sz.Deserialize(serialized, &deserialized); err != nil {
				t.Error(err)
			}
			if fmt.Sprintf("%v", deserialized) != fmt.Sprintf("%v", value) {
				t.Errorf("Expected %v, got %v.", value, deserialized)
			}
		}
	}
}

func TestJSONSerialization(t *testing.T) {
	var (
		sz           JSONEncoder
		serialized   []byte
		deserialized map[string]string
		err          error
	)
	for _, value := range testCookies {
		if serialized, err = sz.Serialize(value); err != nil {
			t.Error(err)
		} else {
			deserialized = make(map[string]string)
			if err = sz.Deserialize(serialized, &deserialized); err != nil {
				t.Error(err)
			}
			if fmt.Sprintf("%v", deserialized) != fmt.Sprintf("%v", value) {
				t.Errorf("Expected %v, got %v.", value, deserialized)
			}
		}
	}
}

func TestEncoding(t *testing.T) {
	for _, value := range testStrings {
		encoded := encode([]byte(value))
		decoded, err := decode(encoded)
		if err != nil {
			t.Error(err)
		} else if string(decoded) != value {
			t.Errorf("Expected %v, got %s.", value, string(decoded))
		}
	}
}

func TestMultiError(t *testing.T) {
	s1, s2 := New(nil, nil), New(nil, nil)
	_, err := EncodeMulti("sid", "value", s1, s2)
	if len(err.(MultiError)) != 2 {
		t.Errorf("Expected 2 errors, got %s.", err)
	} else {
		if strings.Index(err.Error(), "hash key is not set") == -1 {
			t.Errorf("Expected missing hash key error, got %s.", err.Error())
		}
		ourErr, ok := err.(Error)
		if !ok || !ourErr.IsUsage() {
			t.Fatalf("Expected error to be a usage error; got %#v", err)
		}
		if ourErr.IsDecode() {
			t.Errorf("Expected error NOT to be a decode error; got %#v", ourErr)
		}
		if ourErr.IsInternal() {
			t.Errorf("Expected error NOT to be an internal error; got %#v", ourErr)
		}
	}
}

func TestMultiNoCodecs(t *testing.T) {
	_, err := EncodeMulti("foo", "bar")
	if err != errNoCodecs {
		t.Errorf("EncodeMulti: bad value for error, got: %v", err)
	}

	var dst []byte
	err = DecodeMulti("foo", "bar", &dst)
	if err != errNoCodecs {
		t.Errorf("DecodeMulti: bad value for error, got: %v", err)
	}
}

func TestMissingKey(t *testing.T) {
	s1 := New(nil, nil)

	var dst []byte
	err := s1.Decode("sid", "value", &dst)
	if err != errHashKeyNotSet {
		t.Fatalf("Expected %#v, got %#v", errHashKeyNotSet, err)
	}
	if err2, ok := err.(Error); !ok || !err2.IsUsage() {
		t.Errorf("Expected missing hash key to be IsUsage(); was %#v", err)
	}
}

// ----------------------------------------------------------------------------

type FooBar struct {
	Foo int
	Bar string
}

func TestCustomType(t *testing.T) {
	s1 := New([]byte("12345"), []byte("1234567890123456"))
	// Type is not registered in gob. (!!!)
	src := &FooBar{42, "bar"}
	encoded, _ := s1.Encode("sid", src)

	dst := &FooBar{}
	_ = s1.Decode("sid", encoded, dst)
	if dst.Foo != 42 || dst.Bar != "bar" {
		t.Fatalf("Expected %#v, got %#v", src, dst)
	}
}