aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--securecookie.go174
-rw-r--r--securecookie_test.go43
2 files changed, 186 insertions, 31 deletions
diff --git a/securecookie.go b/securecookie.go
index cfee1d8..ed2db69 100644
--- a/securecookie.go
+++ b/securecookie.go
@@ -15,19 +15,102 @@ import (
"encoding/base64"
"encoding/gob"
"encoding/json"
- "errors"
"fmt"
"hash"
"io"
"strconv"
+ "strings"
"time"
)
-var (
- errNoCodecs = errors.New("securecookie: no codecs provided")
- errHashKeyNotSet = errors.New("securecookie: hash key is not set")
+// Error is the interface of all errors returned by functions in this library.
+type Error interface {
+ error
+
+ // IsUsage returns true for errors indicating the client code probably
+ // uses this library incorrectly. For example, the client may have
+ // failed to provide a valid hash key, or may have failed to configure
+ // the Serializer adequately for encoding value.
+ IsUsage() bool
+
+ // IsDecode returns true for errors indicating that a cookie could not
+ // be decoded and validated. Since cookies are usually untrusted
+ // user-provided input, errors of this type should be expected.
+ // Usually, the proper action is simply to reject the request.
+ IsDecode() bool
+
+ // IsInternal returns true for unexpected errors occurring in the
+ // securecookie implementation.
+ IsInternal() bool
+
+ // Cause, if it returns a non-nil value, indicates that this error was
+ // propagated from some underlying library. If this method returns nil,
+ // this error was raised directly by this library.
+ //
+ // Cause is provided principally for debugging/logging purposes; it is
+ // rare that application logic should perform meaningfully different
+ // logic based on Cause. See, for example, the caveats described on
+ // (MultiError).Cause().
+ Cause() error
+}
+
+// errorType is a bitmask giving the error type(s) of an errorImpl value.
+type errorType int
+
+const (
+ usageError = errorType(1 << iota)
+ decodeError
+ internalError
+)
+
+type errorImpl struct {
+ typ errorType
+ msg string
+ cause error
+}
+
+func (e errorImpl) IsUsage() bool { return (e.typ & usageError) != 0 }
+func (e errorImpl) IsDecode() bool { return (e.typ & decodeError) != 0 }
+func (e errorImpl) IsInternal() bool { return (e.typ & internalError) != 0 }
- ErrMacInvalid = errors.New("securecookie: the value is not valid")
+func (e errorImpl) Cause() error { return e.cause }
+
+func (e errorImpl) Error() string {
+ parts := []string{"securecookie: "}
+ if e.msg == "" {
+ parts = append(parts, "error")
+ } else {
+ parts = append(parts, e.msg)
+ }
+ if c := e.Cause(); c != nil {
+ parts = append(parts, " - caused by: ", c.Error())
+ }
+ return strings.Join(parts, "")
+}
+
+// Asserts that errorImpl is an Error implementation.
+var _ Error = errorImpl{}
+
+var (
+ errGeneratingIV = errorImpl{typ: internalError, msg: "failed to generate random iv"}
+
+ errNoCodecs = errorImpl{typ: usageError, msg: "no codecs provided"}
+ errHashKeyNotSet = errorImpl{typ: usageError, msg: "hash key is not set"}
+ errBlockKeyNotSet = errorImpl{typ: usageError, msg: "block key is not set"}
+ errEncodedValueTooLong = errorImpl{typ: usageError, msg: "the value is too long"}
+
+ errValueToDecodeTooLong = errorImpl{typ: decodeError, msg: "the value is too long"}
+ errTimestampInvalid = errorImpl{typ: decodeError, msg: "invalid timestamp"}
+ errTimestampTooNew = errorImpl{typ: decodeError, msg: "timestamp is too new"}
+ errTimestampExpired = errorImpl{typ: decodeError, msg: "expired timestamp"}
+ errDecryptionFailed = errorImpl{typ: decodeError, msg: "the value could not be decrypted"}
+
+ // ErrMacInvalid indicates that cookie decoding failed because the HMAC
+ // could not be extracted and verified. Direct use of this error
+ // variable is deprecated; it is public only for legacy compatibility,
+ // and may be privatized in the future, as it is rarely useful to
+ // distinguish between this error and other Error implementations.
+ ErrMacInvalid = errorImpl{typ: decodeError, msg: "the value is not valid"}
)
// Codec defines an interface to encode and decode cookie values.
@@ -134,18 +217,19 @@ func (s *SecureCookie) HashFunc(f func() hash.Hash) *SecureCookie {
// Default is crypto/aes.New.
func (s *SecureCookie) BlockFunc(f func([]byte) (cipher.Block, error)) *SecureCookie {
if s.blockKey == nil {
- s.err = errors.New("securecookie: block key is not set")
+ s.err = errBlockKeyNotSet
} else if block, err := f(s.blockKey); err == nil {
s.block = block
} else {
- s.err = err
+ s.err = errorImpl{cause: err, typ: usageError}
}
return s
}
// Encoding sets the encoding/serialization method for cookies.
//
-// Default is encoding/gob.
+// Default is encoding/gob. To encode special structures using encoding/gob,
+// they must be registered first using gob.Register().
func (s *SecureCookie) SetSerializer(sz Serializer) *SecureCookie {
s.sz = sz
@@ -154,13 +238,16 @@ func (s *SecureCookie) SetSerializer(sz Serializer) *SecureCookie {
// Encode encodes a cookie value.
//
-// It serializes, optionally encrypts, signs with a message authentication code, and
-// finally encodes the value.
+// It serializes, optionally encrypts, signs with a message authentication code,
+// and finally encodes the value.
//
// The name argument is the cookie name. It is stored with the encoded value.
// The value argument is the value to be encoded. It can be any value that can
-// be encoded using encoding/gob. To store special structures, they must be
-// registered first using gob.Register().
+// be encoded using the currently selected serializer; see SetSerializer().
+//
+// It is the client's responsibility to ensure that value, when encoded using
+// the current serialization/encryption settings on s and then base64-encoded,
+// is shorter than the maximum permissible length.
func (s *SecureCookie) Encode(name string, value interface{}) (string, error) {
if s.err != nil {
return "", s.err
@@ -173,12 +260,12 @@ func (s *SecureCookie) Encode(name string, value interface{}) (string, error) {
var b []byte
// 1. Serialize.
if b, err = s.sz.Serialize(value); err != nil {
- return "", err
+ return "", errorImpl{cause: err, typ: usageError}
}
// 2. Encrypt (optional).
if s.block != nil {
if b, err = encrypt(s.block, b); err != nil {
- return "", err
+ return "", errorImpl{cause: err, typ: usageError}
}
}
b = encode(b)
@@ -191,7 +278,7 @@ func (s *SecureCookie) Encode(name string, value interface{}) (string, error) {
b = encode(b)
// 5. Check length.
if s.maxLength != 0 && len(b) > s.maxLength {
- return "", errors.New("securecookie: the value is too long")
+ return "", errEncodedValueTooLong
}
// Done.
return string(b), nil
@@ -215,7 +302,7 @@ func (s *SecureCookie) Decode(name, value string, dst interface{}) error {
}
// 1. Check length.
if s.maxLength != 0 && len(value) > s.maxLength {
- return errors.New("securecookie: the value is too long")
+ return errValueToDecodeTooLong
}
// 2. Decode from base64.
b, err := decode([]byte(value))
@@ -235,14 +322,14 @@ func (s *SecureCookie) Decode(name, value string, dst interface{}) error {
// 4. Verify date ranges.
var t1 int64
if t1, err = strconv.ParseInt(string(parts[0]), 10, 64); err != nil {
- return errors.New("securecookie: invalid timestamp")
+ return errTimestampInvalid
}
t2 := s.timestamp()
if s.minAge != 0 && t1 > t2-s.minAge {
- return errors.New("securecookie: timestamp is too new")
+ return errTimestampTooNew
}
if s.maxAge != 0 && t1 < t2-s.maxAge {
- return errors.New("securecookie: expired timestamp")
+ return errTimestampExpired
}
// 5. Decrypt (optional).
b, err = decode(parts[1])
@@ -256,7 +343,7 @@ func (s *SecureCookie) Decode(name, value string, dst interface{}) error {
}
// 6. Deserialize.
if err = s.sz.Deserialize(b, dst); err != nil {
- return err
+ return errorImpl{cause: err, typ: decodeError}
}
// Done.
return nil
@@ -299,7 +386,7 @@ func verifyMac(h hash.Hash, value []byte, mac []byte) error {
func encrypt(block cipher.Block, value []byte) ([]byte, error) {
iv := GenerateRandomKey(block.BlockSize())
if iv == nil {
- return nil, errors.New("securecookie: failed to generate random iv")
+ return nil, errGeneratingIV
}
// Encrypt it.
stream := cipher.NewCTR(block, iv)
@@ -324,7 +411,7 @@ func decrypt(block cipher.Block, value []byte) ([]byte, error) {
stream.XORKeyStream(value, value)
return value, nil
}
- return nil, errors.New("securecookie: the value could not be decrypted")
+ return nil, errDecryptionFailed
}
// Serialization --------------------------------------------------------------
@@ -334,7 +421,7 @@ func (e GobEncoder) Serialize(src interface{}) ([]byte, error) {
buf := new(bytes.Buffer)
enc := gob.NewEncoder(buf)
if err := enc.Encode(src); err != nil {
- return nil, err
+ return nil, errorImpl{cause: err, typ: usageError}
}
return buf.Bytes(), nil
}
@@ -343,7 +430,7 @@ func (e GobEncoder) Serialize(src interface{}) ([]byte, error) {
func (e GobEncoder) Deserialize(src []byte, dst interface{}) error {
dec := gob.NewDecoder(bytes.NewBuffer(src))
if err := dec.Decode(dst); err != nil {
- return err
+ return errorImpl{cause: err, typ: decodeError}
}
return nil
}
@@ -353,9 +440,8 @@ func (e JSONEncoder) Serialize(src interface{}) ([]byte, error) {
buf := new(bytes.Buffer)
enc := json.NewEncoder(buf)
if err := enc.Encode(src); err != nil {
- return nil, err
+ return nil, errorImpl{cause: err, typ: usageError}
}
-
return buf.Bytes(), nil
}
@@ -363,7 +449,7 @@ func (e JSONEncoder) Serialize(src interface{}) ([]byte, error) {
func (e JSONEncoder) Deserialize(src []byte, dst interface{}) error {
dec := json.NewDecoder(bytes.NewReader(src))
if err := dec.Decode(dst); err != nil {
- return err
+ return errorImpl{cause: err, typ: decodeError}
}
return nil
}
@@ -382,7 +468,7 @@ func decode(value []byte) ([]byte, error) {
decoded := make([]byte, base64.URLEncoding.DecodedLen(len(value)))
b, err := base64.URLEncoding.Decode(decoded, value)
if err != nil {
- return nil, err
+ return nil, errorImpl{cause: err, typ: decodeError, msg: "base64 decode failed"}
}
return decoded[:b], nil
}
@@ -390,6 +476,7 @@ func decode(value []byte) ([]byte, error) {
// Helpers --------------------------------------------------------------------
// GenerateRandomKey creates a random key with the given length in bytes.
+// On failure, returns nil.
func GenerateRandomKey(length int) []byte {
k := make([]byte, length)
if _, err := io.ReadFull(rand.Reader, k); err != nil {
@@ -417,6 +504,8 @@ func CodecsFromPairs(keyPairs ...[]byte) []Codec {
//
// The codecs are tried in order. Multiple codecs are accepted to allow
// key rotation.
+//
+// On error, may return a MultiError.
func EncodeMulti(name string, value interface{}, codecs ...Codec) (string, error) {
if len(codecs) == 0 {
return "", errNoCodecs
@@ -437,6 +526,8 @@ func EncodeMulti(name string, value interface{}, codecs ...Codec) (string, error
//
// The codecs are tried in order. Multiple codecs are accepted to allow
// key rotation.
+//
+// On error, may return a MultiError.
func DecodeMulti(name string, value string, dst interface{}, codecs ...Codec) error {
if len(codecs) == 0 {
return errNoCodecs
@@ -456,6 +547,20 @@ func DecodeMulti(name string, value string, dst interface{}, codecs ...Codec) er
// MultiError groups multiple errors.
type MultiError []error
+func (m MultiError) IsUsage() bool { return m.any(func(e Error) bool { return e.IsUsage() }) }
+func (m MultiError) IsDecode() bool { return m.any(func(e Error) bool { return e.IsDecode() }) }
+func (m MultiError) IsInternal() bool { return m.any(func(e Error) bool { return e.IsInternal() }) }
+
+// Cause returns nil for MultiError; there is no unique underlying cause in the
+// general case.
+//
+// Note: we could conceivably return a non-nil Cause only when there is exactly
+// one child error with a Cause. However, it would be brittle for client code
+// to rely on the arity of causes inside a MultiError, so we have opted not to
+// provide this functionality. Clients which really wish to access the Causes
+// of the underlying errors are free to iterate through the errors themselves.
+func (m MultiError) Cause() error { return nil }
+
func (m MultiError) Error() string {
s, n := "", 0
for _, e := range m {
@@ -476,3 +581,16 @@ func (m MultiError) Error() string {
}
return fmt.Sprintf("%s (and %d other errors)", s, n-1)
}
+
+// any returns true if any element of m is an Error for which pred returns true.
+func (m MultiError) any(pred func(Error) bool) bool {
+ for _, e := range m {
+ if ourErr, ok := e.(Error); ok && pred(ourErr) {
+ return true
+ }
+ }
+ return false
+}
+
+// Asserts that MultiError is an Error implementation.
+var _ Error = MultiError{}
diff --git a/securecookie_test.go b/securecookie_test.go
index e482397..5778000 100644
--- a/securecookie_test.go
+++ b/securecookie_test.go
@@ -52,6 +52,21 @@ func TestSecureCookie(t *testing.T) {
if err3 == nil {
t.Fatalf("Expected failure decoding.")
}
+ err4, ok := err3.(Error)
+ if !ok {
+ t.Fatalf("Expected error to implement Error, got: %#v", err3)
+ }
+ if !err4.IsDecode() {
+ t.Fatalf("Expected DecodeError, got: %#v", err4)
+ }
+
+ // Test other error type flags.
+ if err4.IsUsage() {
+ t.Fatalf("Expected IsUsage() == false, got: %#v", err4)
+ }
+ if err4.IsInternal() {
+ t.Fatalf("Expected IsInternal() == false, got: %#v", err4)
+ }
}
}
@@ -69,9 +84,18 @@ func TestDecodeInvalid(t *testing.T) {
s := New([]byte("12345"), nil)
var dst string
for i, v := range invalidCookies {
- err := s.Decode("name", base64.StdEncoding.EncodeToString([]byte(v)), &dst)
- if err == nil {
- t.Fatalf("%d: expected failure decoding", i)
+ for _, enc := range []*base64.Encoding{
+ base64.StdEncoding,
+ base64.URLEncoding,
+ } {
+ err := s.Decode("name", enc.EncodeToString([]byte(v)), &dst)
+ if err == nil {
+ t.Fatalf("%d: expected failure decoding", i)
+ }
+ err2, ok := err.(Error)
+ if !ok || !err2.IsDecode() {
+ t.Fatalf("%d: Expected IsDecode(), got: %#v", i, err)
+ }
}
}
}
@@ -174,6 +198,16 @@ func TestMultiError(t *testing.T) {
if strings.Index(err.Error(), "hash key is not set") == -1 {
t.Errorf("Expected missing hash key error, got %s.", err.Error())
}
+ ourErr, ok := err.(Error)
+ if !ok || !ourErr.IsUsage() {
+ t.Fatalf("Expected error to be a usage error; got %#v", err)
+ }
+ if ourErr.IsDecode() {
+ t.Errorf("Expected error NOT to be a decode error; got %#v", ourErr)
+ }
+ if ourErr.IsInternal() {
+ t.Errorf("Expected error NOT to be an internal error; got %#v", ourErr)
+ }
}
}
@@ -198,6 +232,9 @@ func TestMissingKey(t *testing.T) {
if err != errHashKeyNotSet {
t.Fatalf("Expected %#v, got %#v", errHashKeyNotSet, err)
}
+ if err2, ok := err.(Error); !ok || !err2.IsUsage() {
+ t.Errorf("Expected missing hash key to be IsUsage(); was %#v", err)
+ }
}
// ----------------------------------------------------------------------------