// Copyright 2012 The Gorilla Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package sessions import ( "bytes" "encoding/gob" "net/http" "testing" ) // ---------------------------------------------------------------------------- // ResponseRecorder // ---------------------------------------------------------------------------- // Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // ResponseRecorder is an implementation of http.ResponseWriter that // records its mutations for later inspection in tests. type ResponseRecorder struct { Code int // the HTTP response code from WriteHeader HeaderMap http.Header // the HTTP response headers Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to Flushed bool } // NewRecorder returns an initialized ResponseRecorder. func NewRecorder() *ResponseRecorder { return &ResponseRecorder{ HeaderMap: make(http.Header), Body: new(bytes.Buffer), } } // DefaultRemoteAddr is the default remote address to return in RemoteAddr if // an explicit DefaultRemoteAddr isn't set on ResponseRecorder. const DefaultRemoteAddr = "1.2.3.4" // Header returns the response headers. func (rw *ResponseRecorder) Header() http.Header { return rw.HeaderMap } // Write always succeeds and writes to rw.Body, if not nil. func (rw *ResponseRecorder) Write(buf []byte) (int, error) { if rw.Body != nil { rw.Body.Write(buf) } if rw.Code == 0 { rw.Code = http.StatusOK } return len(buf), nil } // WriteHeader sets rw.Code. func (rw *ResponseRecorder) WriteHeader(code int) { rw.Code = code } // Flush sets rw.Flushed to true. func (rw *ResponseRecorder) Flush() { rw.Flushed = true } // ---------------------------------------------------------------------------- type FlashMessage struct { Type int Message string } func TestFlashes(t *testing.T) { var req *http.Request var rsp *ResponseRecorder var hdr http.Header var err error var ok bool var cookies []string var session *Session var flashes []interface{} store := NewCookieStore([]byte("secret-key")) // Round 1 ---------------------------------------------------------------- req, _ = http.NewRequest("GET", "http://localhost:8080/", nil) rsp = NewRecorder() // Get a session. if session, err = store.Get(req, "session-key"); err != nil { t.Fatalf("Error getting session: %v", err) } // Get a flash. flashes = session.Flashes() if len(flashes) != 0 { t.Errorf("Expected empty flashes; Got %v", flashes) } // Add some flashes. session.AddFlash("foo") session.AddFlash("bar") // Custom key. session.AddFlash("baz", "custom_key") // Save. if err = Save(req, rsp); err != nil { t.Fatalf("Error saving session: %v", err) } hdr = rsp.Header() cookies, ok = hdr["Set-Cookie"] if !ok || len(cookies) != 1 { t.Fatalf("No cookies. Header:", hdr) } // Round 2 ---------------------------------------------------------------- req, _ = http.NewRequest("GET", "http://localhost:8080/", nil) req.Header.Add("Cookie", cookies[0]) rsp = NewRecorder() // Get a session. if session, err = store.Get(req, "session-key"); err != nil { t.Fatalf("Error getting session: %v", err) } // Check all saved values. flashes = session.Flashes() if len(flashes) != 2 { t.Fatalf("Expected flashes; Got %v", flashes) } if flashes[0] != "foo" || flashes[1] != "bar" { t.Errorf("Expected foo,bar; Got %v", flashes) } flashes = session.Flashes() if len(flashes) != 0 { t.Errorf("Expected dumped flashes; Got %v", flashes) } // Custom key. flashes = session.Flashes("custom_key") if len(flashes) != 1 { t.Errorf("Expected flashes; Got %v", flashes) } else if flashes[0] != "baz" { t.Errorf("Expected baz; Got %v", flashes) } flashes = session.Flashes("custom_key") if len(flashes) != 0 { t.Errorf("Expected dumped flashes; Got %v", flashes) } // Round 3 ---------------------------------------------------------------- // Custom type req, _ = http.NewRequest("GET", "http://localhost:8080/", nil) rsp = NewRecorder() // Get a session. if session, err = store.Get(req, "session-key"); err != nil { t.Fatalf("Error getting session: %v", err) } // Get a flash. flashes = session.Flashes() if len(flashes) != 0 { t.Errorf("Expected empty flashes; Got %v", flashes) } // Add some flashes. session.AddFlash(&FlashMessage{42, "foo"}) // Save. if err = Save(req, rsp); err != nil { t.Fatalf("Error saving session: %v", err) } hdr = rsp.Header() cookies, ok = hdr["Set-Cookie"] if !ok || len(cookies) != 1 { t.Fatalf("No cookies. Header:", hdr) } // Round 4 ---------------------------------------------------------------- // Custom type req, _ = http.NewRequest("GET", "http://localhost:8080/", nil) req.Header.Add("Cookie", cookies[0]) rsp = NewRecorder() // Get a session. if session, err = store.Get(req, "session-key"); err != nil { t.Fatalf("Error getting session: %v", err) } // Check all saved values. flashes = session.Flashes() if len(flashes) != 1 { t.Fatalf("Expected flashes; Got %v", flashes) } custom := flashes[0].(FlashMessage) if custom.Type != 42 || custom.Message != "foo" { t.Errorf("Expected %#v, got %#v", FlashMessage{42, "foo"}, custom) } } func init() { gob.Register(FlashMessage{}) }