diff options
-rw-r--r-- | securecookie.go | 60 | ||||
-rw-r--r-- | securecookie_test.go | 10 |
2 files changed, 48 insertions, 22 deletions
diff --git a/securecookie.go b/securecookie.go index 620c508..2509608 100644 --- a/securecookie.go +++ b/securecookie.go @@ -177,60 +177,76 @@ func (s *SecureCookie) Encode(name string, value interface{}) (string, error) { // it was stored. The value argument is the encoded cookie value. The dst // argument is where the cookie will be decoded. It must be a pointer. func (s *SecureCookie) Decode(name, value string, dst interface{}) error { + // retErr is the error which will be returned. + // It will be the first error that will occur (if any). + var retErr error + + // setErr saves the error only if there was no previous error. + // Otherwise retErr would be overwritten by subsequent errors. + setErr := func(err error) { + if retErr == nil { + retErr = err + } + } + if s.err != nil { - return s.err + setErr(s.err) } if s.hashKey == nil { s.err = errHashKeyNotSet - return s.err + setErr(s.err) } // 1. Check length. if s.maxLength != 0 && len(value) > s.maxLength { - return errors.New("securecookie: the value is too long") + setErr(errors.New("securecookie: the value is too long")) } // 2. Decode from base64. b, err := decode([]byte(value)) if err != nil { - return err + setErr(err) + // Dummy b to avoid errors + b = []byte("||") } // 3. Verify MAC. Value is "date|value|mac". parts := bytes.SplitN(b, []byte("|"), 3) if len(parts) != 3 { - return ErrMacInvalid + setErr(ErrMacInvalid) } h := hmac.New(s.hashFunc, s.hashKey) b = append([]byte(name+"|"), b[:len(b)-len(parts[2])-1]...) if err = verifyMac(h, b, parts[2]); err != nil { - return err + setErr(err) } // 4. Verify date ranges. var t1 int64 if t1, err = strconv.ParseInt(string(parts[0]), 10, 64); err != nil { - return errors.New("securecookie: invalid timestamp") + setErr(errors.New("securecookie: invalid timestamp")) } t2 := s.timestamp() if s.minAge != 0 && t1 > t2-s.minAge { - return errors.New("securecookie: timestamp is too new") + setErr(errors.New("securecookie: timestamp is too new")) } if s.maxAge != 0 && t1 < t2-s.maxAge { - return errors.New("securecookie: expired timestamp") + setErr(errors.New("securecookie: expired timestamp")) } // 5. Decrypt (optional). b, err = decode(parts[1]) if err != nil { - return err + setErr(err) } if s.block != nil { if b, err = decrypt(s.block, b); err != nil { - return err + setErr(err) } } - // 6. Deserialize. - if err = deserialize(b, dst); err != nil { - return err + + // Check for errors before deserialization to avoid unwanted side effects + if retErr != nil { + return retErr } - // Done. - return nil + + // 6. Deserialize. + return deserialize(b, dst) } // timestamp returns the current timestamp, in seconds. @@ -375,11 +391,11 @@ func EncodeMulti(name string, value interface{}, codecs ...Codec) (string, error var errors MultiError for _, codec := range codecs { - if encoded, err := codec.Encode(name, value); err == nil { + encoded, err := codec.Encode(name, value) + if err == nil { return encoded, nil - } else { - errors = append(errors, err) } + errors = append(errors, err) } return "", errors } @@ -395,11 +411,11 @@ func DecodeMulti(name string, value string, dst interface{}, codecs ...Codec) er var errors MultiError for _, codec := range codecs { - if err := codec.Decode(name, value, dst); err == nil { + err := codec.Decode(name, value, dst) + if err == nil { return nil - } else { - errors = append(errors, err) } + errors = append(errors, err) } return errors } diff --git a/securecookie_test.go b/securecookie_test.go index fe0cdb1..381320d 100644 --- a/securecookie_test.go +++ b/securecookie_test.go @@ -157,6 +157,16 @@ func TestMultiNoCodecs(t *testing.T) { } } +func TestMissingKey(t *testing.T) { + s1 := New(nil, nil) + + var dst []byte + err := s1.Decode("sid", "value", &dst) + if err != errHashKeyNotSet { + t.Fatalf("Expected %#v, got %#v", errHashKeyNotSet, err) + } +} + // ---------------------------------------------------------------------------- type FooBar struct { |