diff options
-rw-r--r-- | securecookie.go | 174 | ||||
-rw-r--r-- | securecookie_test.go | 43 |
2 files changed, 186 insertions, 31 deletions
diff --git a/securecookie.go b/securecookie.go index cfee1d8..ed2db69 100644 --- a/securecookie.go +++ b/securecookie.go @@ -15,19 +15,102 @@ import ( "encoding/base64" "encoding/gob" "encoding/json" - "errors" "fmt" "hash" "io" "strconv" + "strings" "time" ) -var ( - errNoCodecs = errors.New("securecookie: no codecs provided") - errHashKeyNotSet = errors.New("securecookie: hash key is not set") +// Error is the interface of all errors returned by functions in this library. +type Error interface { + error + + // IsUsage returns true for errors indicating the client code probably + // uses this library incorrectly. For example, the client may have + // failed to provide a valid hash key, or may have failed to configure + // the Serializer adequately for encoding value. + IsUsage() bool + + // IsDecode returns true for errors indicating that a cookie could not + // be decoded and validated. Since cookies are usually untrusted + // user-provided input, errors of this type should be expected. + // Usually, the proper action is simply to reject the request. + IsDecode() bool + + // IsInternal returns true for unexpected errors occurring in the + // securecookie implementation. + IsInternal() bool + + // Cause, if it returns a non-nil value, indicates that this error was + // propagated from some underlying library. If this method returns nil, + // this error was raised directly by this library. + // + // Cause is provided principally for debugging/logging purposes; it is + // rare that application logic should perform meaningfully different + // logic based on Cause. See, for example, the caveats described on + // (MultiError).Cause(). + Cause() error +} + +// errorType is a bitmask giving the error type(s) of an errorImpl value. +type errorType int + +const ( + usageError = errorType(1 << iota) + decodeError + internalError +) + +type errorImpl struct { + typ errorType + msg string + cause error +} + +func (e errorImpl) IsUsage() bool { return (e.typ & usageError) != 0 } +func (e errorImpl) IsDecode() bool { return (e.typ & decodeError) != 0 } +func (e errorImpl) IsInternal() bool { return (e.typ & internalError) != 0 } - ErrMacInvalid = errors.New("securecookie: the value is not valid") +func (e errorImpl) Cause() error { return e.cause } + +func (e errorImpl) Error() string { + parts := []string{"securecookie: "} + if e.msg == "" { + parts = append(parts, "error") + } else { + parts = append(parts, e.msg) + } + if c := e.Cause(); c != nil { + parts = append(parts, " - caused by: ", c.Error()) + } + return strings.Join(parts, "") +} + +// Asserts that errorImpl is an Error implementation. +var _ Error = errorImpl{} + +var ( + errGeneratingIV = errorImpl{typ: internalError, msg: "failed to generate random iv"} + + errNoCodecs = errorImpl{typ: usageError, msg: "no codecs provided"} + errHashKeyNotSet = errorImpl{typ: usageError, msg: "hash key is not set"} + errBlockKeyNotSet = errorImpl{typ: usageError, msg: "block key is not set"} + errEncodedValueTooLong = errorImpl{typ: usageError, msg: "the value is too long"} + + errValueToDecodeTooLong = errorImpl{typ: decodeError, msg: "the value is too long"} + errTimestampInvalid = errorImpl{typ: decodeError, msg: "invalid timestamp"} + errTimestampTooNew = errorImpl{typ: decodeError, msg: "timestamp is too new"} + errTimestampExpired = errorImpl{typ: decodeError, msg: "expired timestamp"} + errDecryptionFailed = errorImpl{typ: decodeError, msg: "the value could not be decrypted"} + + // ErrMacInvalid indicates that cookie decoding failed because the HMAC + // could not be extracted and verified. Direct use of this error + // variable is deprecated; it is public only for legacy compatibility, + // and may be privatized in the future, as it is rarely useful to + // distinguish between this error and other Error implementations. + ErrMacInvalid = errorImpl{typ: decodeError, msg: "the value is not valid"} ) // Codec defines an interface to encode and decode cookie values. @@ -134,18 +217,19 @@ func (s *SecureCookie) HashFunc(f func() hash.Hash) *SecureCookie { // Default is crypto/aes.New. func (s *SecureCookie) BlockFunc(f func([]byte) (cipher.Block, error)) *SecureCookie { if s.blockKey == nil { - s.err = errors.New("securecookie: block key is not set") + s.err = errBlockKeyNotSet } else if block, err := f(s.blockKey); err == nil { s.block = block } else { - s.err = err + s.err = errorImpl{cause: err, typ: usageError} } return s } // Encoding sets the encoding/serialization method for cookies. // -// Default is encoding/gob. +// Default is encoding/gob. To encode special structures using encoding/gob, +// they must be registered first using gob.Register(). func (s *SecureCookie) SetSerializer(sz Serializer) *SecureCookie { s.sz = sz @@ -154,13 +238,16 @@ func (s *SecureCookie) SetSerializer(sz Serializer) *SecureCookie { // Encode encodes a cookie value. // -// It serializes, optionally encrypts, signs with a message authentication code, and -// finally encodes the value. +// It serializes, optionally encrypts, signs with a message authentication code, +// and finally encodes the value. // // The name argument is the cookie name. It is stored with the encoded value. // The value argument is the value to be encoded. It can be any value that can -// be encoded using encoding/gob. To store special structures, they must be -// registered first using gob.Register(). +// be encoded using the currently selected serializer; see SetSerializer(). +// +// It is the client's responsibility to ensure that value, when encoded using +// the current serialization/encryption settings on s and then base64-encoded, +// is shorter than the maximum permissible length. func (s *SecureCookie) Encode(name string, value interface{}) (string, error) { if s.err != nil { return "", s.err @@ -173,12 +260,12 @@ func (s *SecureCookie) Encode(name string, value interface{}) (string, error) { var b []byte // 1. Serialize. if b, err = s.sz.Serialize(value); err != nil { - return "", err + return "", errorImpl{cause: err, typ: usageError} } // 2. Encrypt (optional). if s.block != nil { if b, err = encrypt(s.block, b); err != nil { - return "", err + return "", errorImpl{cause: err, typ: usageError} } } b = encode(b) @@ -191,7 +278,7 @@ func (s *SecureCookie) Encode(name string, value interface{}) (string, error) { b = encode(b) // 5. Check length. if s.maxLength != 0 && len(b) > s.maxLength { - return "", errors.New("securecookie: the value is too long") + return "", errEncodedValueTooLong } // Done. return string(b), nil @@ -215,7 +302,7 @@ func (s *SecureCookie) Decode(name, value string, dst interface{}) error { } // 1. Check length. if s.maxLength != 0 && len(value) > s.maxLength { - return errors.New("securecookie: the value is too long") + return errValueToDecodeTooLong } // 2. Decode from base64. b, err := decode([]byte(value)) @@ -235,14 +322,14 @@ func (s *SecureCookie) Decode(name, value string, dst interface{}) error { // 4. Verify date ranges. var t1 int64 if t1, err = strconv.ParseInt(string(parts[0]), 10, 64); err != nil { - return errors.New("securecookie: invalid timestamp") + return errTimestampInvalid } t2 := s.timestamp() if s.minAge != 0 && t1 > t2-s.minAge { - return errors.New("securecookie: timestamp is too new") + return errTimestampTooNew } if s.maxAge != 0 && t1 < t2-s.maxAge { - return errors.New("securecookie: expired timestamp") + return errTimestampExpired } // 5. Decrypt (optional). b, err = decode(parts[1]) @@ -256,7 +343,7 @@ func (s *SecureCookie) Decode(name, value string, dst interface{}) error { } // 6. Deserialize. if err = s.sz.Deserialize(b, dst); err != nil { - return err + return errorImpl{cause: err, typ: decodeError} } // Done. return nil @@ -299,7 +386,7 @@ func verifyMac(h hash.Hash, value []byte, mac []byte) error { func encrypt(block cipher.Block, value []byte) ([]byte, error) { iv := GenerateRandomKey(block.BlockSize()) if iv == nil { - return nil, errors.New("securecookie: failed to generate random iv") + return nil, errGeneratingIV } // Encrypt it. stream := cipher.NewCTR(block, iv) @@ -324,7 +411,7 @@ func decrypt(block cipher.Block, value []byte) ([]byte, error) { stream.XORKeyStream(value, value) return value, nil } - return nil, errors.New("securecookie: the value could not be decrypted") + return nil, errDecryptionFailed } // Serialization -------------------------------------------------------------- @@ -334,7 +421,7 @@ func (e GobEncoder) Serialize(src interface{}) ([]byte, error) { buf := new(bytes.Buffer) enc := gob.NewEncoder(buf) if err := enc.Encode(src); err != nil { - return nil, err + return nil, errorImpl{cause: err, typ: usageError} } return buf.Bytes(), nil } @@ -343,7 +430,7 @@ func (e GobEncoder) Serialize(src interface{}) ([]byte, error) { func (e GobEncoder) Deserialize(src []byte, dst interface{}) error { dec := gob.NewDecoder(bytes.NewBuffer(src)) if err := dec.Decode(dst); err != nil { - return err + return errorImpl{cause: err, typ: decodeError} } return nil } @@ -353,9 +440,8 @@ func (e JSONEncoder) Serialize(src interface{}) ([]byte, error) { buf := new(bytes.Buffer) enc := json.NewEncoder(buf) if err := enc.Encode(src); err != nil { - return nil, err + return nil, errorImpl{cause: err, typ: usageError} } - return buf.Bytes(), nil } @@ -363,7 +449,7 @@ func (e JSONEncoder) Serialize(src interface{}) ([]byte, error) { func (e JSONEncoder) Deserialize(src []byte, dst interface{}) error { dec := json.NewDecoder(bytes.NewReader(src)) if err := dec.Decode(dst); err != nil { - return err + return errorImpl{cause: err, typ: decodeError} } return nil } @@ -382,7 +468,7 @@ func decode(value []byte) ([]byte, error) { decoded := make([]byte, base64.URLEncoding.DecodedLen(len(value))) b, err := base64.URLEncoding.Decode(decoded, value) if err != nil { - return nil, err + return nil, errorImpl{cause: err, typ: decodeError, msg: "base64 decode failed"} } return decoded[:b], nil } @@ -390,6 +476,7 @@ func decode(value []byte) ([]byte, error) { // Helpers -------------------------------------------------------------------- // GenerateRandomKey creates a random key with the given length in bytes. +// On failure, returns nil. func GenerateRandomKey(length int) []byte { k := make([]byte, length) if _, err := io.ReadFull(rand.Reader, k); err != nil { @@ -417,6 +504,8 @@ func CodecsFromPairs(keyPairs ...[]byte) []Codec { // // The codecs are tried in order. Multiple codecs are accepted to allow // key rotation. +// +// On error, may return a MultiError. func EncodeMulti(name string, value interface{}, codecs ...Codec) (string, error) { if len(codecs) == 0 { return "", errNoCodecs @@ -437,6 +526,8 @@ func EncodeMulti(name string, value interface{}, codecs ...Codec) (string, error // // The codecs are tried in order. Multiple codecs are accepted to allow // key rotation. +// +// On error, may return a MultiError. func DecodeMulti(name string, value string, dst interface{}, codecs ...Codec) error { if len(codecs) == 0 { return errNoCodecs @@ -456,6 +547,20 @@ func DecodeMulti(name string, value string, dst interface{}, codecs ...Codec) er // MultiError groups multiple errors. type MultiError []error +func (m MultiError) IsUsage() bool { return m.any(func(e Error) bool { return e.IsUsage() }) } +func (m MultiError) IsDecode() bool { return m.any(func(e Error) bool { return e.IsDecode() }) } +func (m MultiError) IsInternal() bool { return m.any(func(e Error) bool { return e.IsInternal() }) } + +// Cause returns nil for MultiError; there is no unique underlying cause in the +// general case. +// +// Note: we could conceivably return a non-nil Cause only when there is exactly +// one child error with a Cause. However, it would be brittle for client code +// to rely on the arity of causes inside a MultiError, so we have opted not to +// provide this functionality. Clients which really wish to access the Causes +// of the underlying errors are free to iterate through the errors themselves. +func (m MultiError) Cause() error { return nil } + func (m MultiError) Error() string { s, n := "", 0 for _, e := range m { @@ -476,3 +581,16 @@ func (m MultiError) Error() string { } return fmt.Sprintf("%s (and %d other errors)", s, n-1) } + +// any returns true if any element of m is an Error for which pred returns true. +func (m MultiError) any(pred func(Error) bool) bool { + for _, e := range m { + if ourErr, ok := e.(Error); ok && pred(ourErr) { + return true + } + } + return false +} + +// Asserts that MultiError is an Error implementation. +var _ Error = MultiError{} diff --git a/securecookie_test.go b/securecookie_test.go index e482397..5778000 100644 --- a/securecookie_test.go +++ b/securecookie_test.go @@ -52,6 +52,21 @@ func TestSecureCookie(t *testing.T) { if err3 == nil { t.Fatalf("Expected failure decoding.") } + err4, ok := err3.(Error) + if !ok { + t.Fatalf("Expected error to implement Error, got: %#v", err3) + } + if !err4.IsDecode() { + t.Fatalf("Expected DecodeError, got: %#v", err4) + } + + // Test other error type flags. + if err4.IsUsage() { + t.Fatalf("Expected IsUsage() == false, got: %#v", err4) + } + if err4.IsInternal() { + t.Fatalf("Expected IsInternal() == false, got: %#v", err4) + } } } @@ -69,9 +84,18 @@ func TestDecodeInvalid(t *testing.T) { s := New([]byte("12345"), nil) var dst string for i, v := range invalidCookies { - err := s.Decode("name", base64.StdEncoding.EncodeToString([]byte(v)), &dst) - if err == nil { - t.Fatalf("%d: expected failure decoding", i) + for _, enc := range []*base64.Encoding{ + base64.StdEncoding, + base64.URLEncoding, + } { + err := s.Decode("name", enc.EncodeToString([]byte(v)), &dst) + if err == nil { + t.Fatalf("%d: expected failure decoding", i) + } + err2, ok := err.(Error) + if !ok || !err2.IsDecode() { + t.Fatalf("%d: Expected IsDecode(), got: %#v", i, err) + } } } } @@ -174,6 +198,16 @@ func TestMultiError(t *testing.T) { if strings.Index(err.Error(), "hash key is not set") == -1 { t.Errorf("Expected missing hash key error, got %s.", err.Error()) } + ourErr, ok := err.(Error) + if !ok || !ourErr.IsUsage() { + t.Fatalf("Expected error to be a usage error; got %#v", err) + } + if ourErr.IsDecode() { + t.Errorf("Expected error NOT to be a decode error; got %#v", ourErr) + } + if ourErr.IsInternal() { + t.Errorf("Expected error NOT to be an internal error; got %#v", ourErr) + } } } @@ -198,6 +232,9 @@ func TestMissingKey(t *testing.T) { if err != errHashKeyNotSet { t.Fatalf("Expected %#v, got %#v", errHashKeyNotSet, err) } + if err2, ok := err.(Error); !ok || !err2.IsUsage() { + t.Errorf("Expected missing hash key to be IsUsage(); was %#v", err) + } } // ---------------------------------------------------------------------------- |