diff options
-rw-r--r-- | securecookie.go | 43 |
1 files changed, 27 insertions, 16 deletions
diff --git a/securecookie.go b/securecookie.go index a76083b..a96deb1 100644 --- a/securecookie.go +++ b/securecookie.go @@ -177,60 +177,71 @@ func (s *SecureCookie) Encode(name string, value interface{}) (string, error) { // it was stored. The value argument is the encoded cookie value. The dst // argument is where the cookie will be decoded. It must be a pointer. func (s *SecureCookie) Decode(name, value string, dst interface{}) error { + var retErr error + setErr := func(err error) { + if retErr == nil { + retErr = err + } + } + if s.err != nil { - return s.err + setErr(s.err) } if s.hashKey == nil { s.err = errHashKeyNotSet - return s.err + setErr(s.err) } // 1. Check length. if s.maxLength != 0 && len(value) > s.maxLength { - return errors.New("securecookie: the value is too long") + setErr(errors.New("securecookie: the value is too long")) } // 2. Decode from base64. b, err := decode([]byte(value)) if err != nil { - return err + setErr(err) + // Dummy b to avoid errors + b = []byte("||") } // 3. Verify MAC. Value is "date|value|mac". parts := bytes.SplitN(b, []byte("|"), 3) if len(parts) != 3 { - return ErrMacInvalid + setErr(ErrMacInvalid) } h := hmac.New(s.hashFunc, s.hashKey) b = append([]byte(name+"|"), b[:len(b)-len(parts[2])-1]...) if err = verifyMac(h, b, parts[2]); err != nil { - return err + setErr(err) } // 4. Verify date ranges. var t1 int64 if t1, err = strconv.ParseInt(string(parts[0]), 10, 64); err != nil { - return errors.New("securecookie: invalid timestamp") + setErr(errors.New("securecookie: invalid timestamp")) } t2 := s.timestamp() if s.minAge != 0 && t1 > t2-s.minAge { - return errors.New("securecookie: timestamp is too new") + setErr(errors.New("securecookie: timestamp is too new")) } if s.maxAge != 0 && t1 < t2-s.maxAge { - return errors.New("securecookie: expired timestamp") + setErr(errors.New("securecookie: expired timestamp")) } // 5. Decrypt (optional). b, err = decode(parts[1]) if err != nil { - return err + setErr(err) } if s.block != nil { if b, err = decrypt(s.block, b); err != nil { - return err + setErr(err) } } - // 6. Deserialize. - if err = deserialize(b, dst); err != nil { - return err + + // Check for errors before deserialization to avoid unwanted side effects + if retErr != nil { + return retErr } - // Done. - return nil + + // 6. Deserialize. + return deserialize(b, dst) } // timestamp returns the current timestamp, in seconds. |