mirror of
https://github.com/gogf/gf.git
synced 2025-04-05 11:18:50 +08:00
feat(os/gsession): add RegenerateId/MustRegenerateId
support (#4012)
This commit is contained in:
parent
ba968949f7
commit
b8142bf1fc
@ -45,7 +45,7 @@ func (s *Session) init() error {
|
|||||||
// Retrieve stored session data from storage.
|
// Retrieve stored session data from storage.
|
||||||
if s.manager.storage != nil {
|
if s.manager.storage != nil {
|
||||||
s.data, err = s.manager.storage.GetSession(s.ctx, s.id, s.manager.GetTTL())
|
s.data, err = s.manager.storage.GetSession(s.ctx, s.id, s.manager.GetTTL())
|
||||||
if err != nil && err != ErrorDisabled {
|
if err != nil && !gerror.Is(err, ErrorDisabled) {
|
||||||
intlog.Errorf(s.ctx, `session restoring failed for id "%s": %+v`, s.id, err)
|
intlog.Errorf(s.ctx, `session restoring failed for id "%s": %+v`, s.id, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -59,7 +59,7 @@ func (s *Session) init() error {
|
|||||||
} else {
|
} else {
|
||||||
// Use default session id creating function of storage.
|
// Use default session id creating function of storage.
|
||||||
s.id, err = s.manager.storage.New(s.ctx, s.manager.ttl)
|
s.id, err = s.manager.storage.New(s.ctx, s.manager.ttl)
|
||||||
if err != nil && err != ErrorDisabled {
|
if err != nil && !gerror.Is(err, ErrorDisabled) {
|
||||||
intlog.Errorf(s.ctx, "create session id failed: %+v", err)
|
intlog.Errorf(s.ctx, "create session id failed: %+v", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -89,12 +89,12 @@ func (s *Session) Close() error {
|
|||||||
size := s.data.Size()
|
size := s.data.Size()
|
||||||
if s.dirty {
|
if s.dirty {
|
||||||
err := s.manager.storage.SetSession(s.ctx, s.id, s.data, s.manager.ttl)
|
err := s.manager.storage.SetSession(s.ctx, s.id, s.data, s.manager.ttl)
|
||||||
if err != nil && err != ErrorDisabled {
|
if err != nil && !gerror.Is(err, ErrorDisabled) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else if size > 0 {
|
} else if size > 0 {
|
||||||
err := s.manager.storage.UpdateTTL(s.ctx, s.id, s.manager.ttl)
|
err := s.manager.storage.UpdateTTL(s.ctx, s.id, s.manager.ttl)
|
||||||
if err != nil && err != ErrorDisabled {
|
if err != nil && !gerror.Is(err, ErrorDisabled) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -108,11 +108,10 @@ func (s *Session) Set(key string, value interface{}) (err error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err = s.manager.storage.Set(s.ctx, s.id, key, value, s.manager.ttl); err != nil {
|
if err = s.manager.storage.Set(s.ctx, s.id, key, value, s.manager.ttl); err != nil {
|
||||||
if err == ErrorDisabled {
|
if !gerror.Is(err, ErrorDisabled) {
|
||||||
s.data.Set(key, value)
|
|
||||||
} else {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
s.data.Set(key, value)
|
||||||
}
|
}
|
||||||
s.dirty = true
|
s.dirty = true
|
||||||
return nil
|
return nil
|
||||||
@ -124,11 +123,10 @@ func (s *Session) SetMap(data map[string]interface{}) (err error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err = s.manager.storage.SetMap(s.ctx, s.id, data, s.manager.ttl); err != nil {
|
if err = s.manager.storage.SetMap(s.ctx, s.id, data, s.manager.ttl); err != nil {
|
||||||
if err == ErrorDisabled {
|
if !gerror.Is(err, ErrorDisabled) {
|
||||||
s.data.Sets(data)
|
|
||||||
} else {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
s.data.Sets(data)
|
||||||
}
|
}
|
||||||
s.dirty = true
|
s.dirty = true
|
||||||
return nil
|
return nil
|
||||||
@ -144,11 +142,10 @@ func (s *Session) Remove(keys ...string) (err error) {
|
|||||||
}
|
}
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
if err = s.manager.storage.Remove(s.ctx, s.id, key); err != nil {
|
if err = s.manager.storage.Remove(s.ctx, s.id, key); err != nil {
|
||||||
if err == ErrorDisabled {
|
if !gerror.Is(err, ErrorDisabled) {
|
||||||
s.data.Remove(key)
|
|
||||||
} else {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
s.data.Remove(key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
s.dirty = true
|
s.dirty = true
|
||||||
@ -164,7 +161,7 @@ func (s *Session) RemoveAll() (err error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err = s.manager.storage.RemoveAll(s.ctx, s.id); err != nil {
|
if err = s.manager.storage.RemoveAll(s.ctx, s.id); err != nil {
|
||||||
if err != ErrorDisabled {
|
if !gerror.Is(err, ErrorDisabled) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -215,7 +212,7 @@ func (s *Session) Data() (sessionData map[string]interface{}, err error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
sessionData, err = s.manager.storage.Data(s.ctx, s.id)
|
sessionData, err = s.manager.storage.Data(s.ctx, s.id)
|
||||||
if err != nil && err != ErrorDisabled {
|
if err != nil && !gerror.Is(err, ErrorDisabled) {
|
||||||
intlog.Errorf(s.ctx, `%+v`, err)
|
intlog.Errorf(s.ctx, `%+v`, err)
|
||||||
}
|
}
|
||||||
if sessionData != nil {
|
if sessionData != nil {
|
||||||
@ -233,7 +230,7 @@ func (s *Session) Size() (size int, err error) {
|
|||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
size, err = s.manager.storage.GetSize(s.ctx, s.id)
|
size, err = s.manager.storage.GetSize(s.ctx, s.id)
|
||||||
if err != nil && err != ErrorDisabled {
|
if err != nil && !gerror.Is(err, ErrorDisabled) {
|
||||||
intlog.Errorf(s.ctx, `%+v`, err)
|
intlog.Errorf(s.ctx, `%+v`, err)
|
||||||
}
|
}
|
||||||
if size > 0 {
|
if size > 0 {
|
||||||
@ -273,7 +270,7 @@ func (s *Session) Get(key string, def ...interface{}) (value *gvar.Var, err erro
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
v, err := s.manager.storage.Get(s.ctx, s.id, key)
|
v, err := s.manager.storage.Get(s.ctx, s.id, key)
|
||||||
if err != nil && err != ErrorDisabled {
|
if err != nil && !gerror.Is(err, ErrorDisabled) {
|
||||||
intlog.Errorf(s.ctx, `%+v`, err)
|
intlog.Errorf(s.ctx, `%+v`, err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -357,3 +354,60 @@ func (s *Session) MustRemove(keys ...string) {
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegenerateId regenerates a new session id for current session.
|
||||||
|
// It keeps the session data and updates the session id with a new one.
|
||||||
|
// This is commonly used to prevent session fixation attacks and increase security.
|
||||||
|
//
|
||||||
|
// The parameter `deleteOld` specifies whether to delete the old session data:
|
||||||
|
// - If true: the old session data will be deleted immediately
|
||||||
|
// - If false: the old session data will be kept and expire according to its TTL
|
||||||
|
func (s *Session) RegenerateId(deleteOld bool) (newId string, err error) {
|
||||||
|
if err = s.init(); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate new session id
|
||||||
|
if s.idFunc != nil {
|
||||||
|
newId = s.idFunc(s.manager.ttl)
|
||||||
|
} else {
|
||||||
|
newId, err = s.manager.storage.New(s.ctx, s.manager.ttl)
|
||||||
|
if err != nil && !gerror.Is(err, ErrorDisabled) {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if newId == "" {
|
||||||
|
newId = NewSessionId()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If using storage, need to copy data to new id
|
||||||
|
if s.manager.storage != nil {
|
||||||
|
if err = s.manager.storage.SetSession(s.ctx, newId, s.data, s.manager.ttl); err != nil {
|
||||||
|
if !gerror.Is(err, ErrorDisabled) {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Delete old session data if requested
|
||||||
|
if deleteOld {
|
||||||
|
if err = s.manager.storage.RemoveAll(s.ctx, s.id); err != nil {
|
||||||
|
if !gerror.Is(err, ErrorDisabled) {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update session id
|
||||||
|
s.id = newId
|
||||||
|
s.dirty = true
|
||||||
|
return newId, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustRegenerateId performs as function RegenerateId, but it panics if any error occurs.
|
||||||
|
func (s *Session) MustRegenerateId(deleteOld bool) string {
|
||||||
|
newId, err := s.RegenerateId(deleteOld)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return newId
|
||||||
|
}
|
||||||
|
@ -7,11 +7,15 @@
|
|||||||
package gsession
|
package gsession
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gogf/gf/v2/test/gtest"
|
"github.com/gogf/gf/v2/test/gtest"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ctx = context.TODO()
|
||||||
|
|
||||||
func Test_NewSessionId(t *testing.T) {
|
func Test_NewSessionId(t *testing.T) {
|
||||||
gtest.C(t, func(t *gtest.T) {
|
gtest.C(t, func(t *gtest.T) {
|
||||||
id1 := NewSessionId()
|
id1 := NewSessionId()
|
||||||
@ -20,3 +24,107 @@ func Test_NewSessionId(t *testing.T) {
|
|||||||
t.Assert(len(id1), 32)
|
t.Assert(len(id1), 32)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_Session_RegenerateId(t *testing.T) {
|
||||||
|
gtest.C(t, func(t *gtest.T) {
|
||||||
|
// 1. Test with memory storage
|
||||||
|
storage := NewStorageMemory()
|
||||||
|
manager := New(time.Hour, storage)
|
||||||
|
session := manager.New(ctx)
|
||||||
|
|
||||||
|
// Store some data
|
||||||
|
err := session.Set("key1", "value1")
|
||||||
|
t.AssertNil(err)
|
||||||
|
err = session.Set("key2", "value2")
|
||||||
|
t.AssertNil(err)
|
||||||
|
|
||||||
|
// Get original session id
|
||||||
|
oldId := session.MustId()
|
||||||
|
|
||||||
|
// Test regenerate with deleteOld = true
|
||||||
|
newId1, err := session.RegenerateId(true)
|
||||||
|
t.AssertNil(err)
|
||||||
|
t.AssertNE(oldId, newId1)
|
||||||
|
|
||||||
|
// Verify data is preserved
|
||||||
|
v1 := session.MustGet("key1")
|
||||||
|
t.Assert(v1.String(), "value1")
|
||||||
|
v2 := session.MustGet("key2")
|
||||||
|
t.Assert(v2.String(), "value2")
|
||||||
|
|
||||||
|
// Verify old session is deleted
|
||||||
|
oldSession := manager.New(ctx)
|
||||||
|
err = oldSession.SetId(oldId)
|
||||||
|
t.AssertNil(err)
|
||||||
|
v3 := oldSession.MustGet("key1")
|
||||||
|
t.Assert(v3.IsNil(), true)
|
||||||
|
|
||||||
|
// Test regenerate with deleteOld = false
|
||||||
|
currentId := newId1
|
||||||
|
newId2, err := session.RegenerateId(false)
|
||||||
|
t.AssertNil(err)
|
||||||
|
t.AssertNE(currentId, newId2)
|
||||||
|
|
||||||
|
// Verify data is preserved in new session
|
||||||
|
v4 := session.MustGet("key1")
|
||||||
|
t.Assert(v4.String(), "value1")
|
||||||
|
|
||||||
|
// Create another session instance with the previous id
|
||||||
|
prevSession := manager.New(ctx)
|
||||||
|
err = prevSession.SetId(currentId)
|
||||||
|
t.AssertNil(err)
|
||||||
|
// Data should still be accessible in previous session
|
||||||
|
v5 := prevSession.MustGet("key1")
|
||||||
|
t.Assert(v5.String(), "value1")
|
||||||
|
})
|
||||||
|
|
||||||
|
gtest.C(t, func(t *gtest.T) {
|
||||||
|
// 2. Test with custom id function
|
||||||
|
storage := NewStorageMemory()
|
||||||
|
manager := New(time.Hour, storage)
|
||||||
|
session := manager.New(ctx)
|
||||||
|
|
||||||
|
customId := "custom_session_id"
|
||||||
|
err := session.SetIdFunc(func(ttl time.Duration) string {
|
||||||
|
return customId
|
||||||
|
})
|
||||||
|
t.AssertNil(err)
|
||||||
|
|
||||||
|
newId, err := session.RegenerateId(true)
|
||||||
|
t.AssertNil(err)
|
||||||
|
t.Assert(newId, customId)
|
||||||
|
})
|
||||||
|
|
||||||
|
gtest.C(t, func(t *gtest.T) {
|
||||||
|
// 3. Test with disabled storage
|
||||||
|
storage := &StorageBase{} // implements Storage interface but all methods return ErrorDisabled
|
||||||
|
manager := New(time.Hour, storage)
|
||||||
|
session := manager.New(ctx)
|
||||||
|
|
||||||
|
// Should still work even with disabled storage
|
||||||
|
newId, err := session.RegenerateId(true)
|
||||||
|
t.AssertNil(err)
|
||||||
|
t.Assert(len(newId), 32)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test MustRegenerateId
|
||||||
|
func Test_Session_MustRegenerateId(t *testing.T) {
|
||||||
|
gtest.C(t, func(t *gtest.T) {
|
||||||
|
storage := NewStorageMemory()
|
||||||
|
manager := New(time.Hour, storage)
|
||||||
|
session := manager.New(ctx)
|
||||||
|
|
||||||
|
// Normal case should not panic
|
||||||
|
t.AssertNil(session.Set("key", "value"))
|
||||||
|
newId := session.MustRegenerateId(true)
|
||||||
|
t.Assert(len(newId), 32)
|
||||||
|
|
||||||
|
// Test with disabled storage (should not panic)
|
||||||
|
storage2 := &StorageBase{}
|
||||||
|
manager2 := New(time.Hour, storage2)
|
||||||
|
session2 := manager2.New(ctx)
|
||||||
|
newId2 := session2.MustRegenerateId(true)
|
||||||
|
t.Assert(len(newId2), 32)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user