mirror of
https://github.com/gogf/gf.git
synced 2025-04-05 03:05:05 +08:00
feat(os/gsession): add RegenerateId/MustRegenerateId
support (#4012)
This commit is contained in:
parent
ba968949f7
commit
b8142bf1fc
@ -45,7 +45,7 @@ func (s *Session) init() error {
|
||||
// Retrieve stored session data from storage.
|
||||
if s.manager.storage != nil {
|
||||
s.data, err = s.manager.storage.GetSession(s.ctx, s.id, s.manager.GetTTL())
|
||||
if err != nil && err != ErrorDisabled {
|
||||
if err != nil && !gerror.Is(err, ErrorDisabled) {
|
||||
intlog.Errorf(s.ctx, `session restoring failed for id "%s": %+v`, s.id, err)
|
||||
return err
|
||||
}
|
||||
@ -59,7 +59,7 @@ func (s *Session) init() error {
|
||||
} else {
|
||||
// Use default session id creating function of storage.
|
||||
s.id, err = s.manager.storage.New(s.ctx, s.manager.ttl)
|
||||
if err != nil && err != ErrorDisabled {
|
||||
if err != nil && !gerror.Is(err, ErrorDisabled) {
|
||||
intlog.Errorf(s.ctx, "create session id failed: %+v", err)
|
||||
return err
|
||||
}
|
||||
@ -89,12 +89,12 @@ func (s *Session) Close() error {
|
||||
size := s.data.Size()
|
||||
if s.dirty {
|
||||
err := s.manager.storage.SetSession(s.ctx, s.id, s.data, s.manager.ttl)
|
||||
if err != nil && err != ErrorDisabled {
|
||||
if err != nil && !gerror.Is(err, ErrorDisabled) {
|
||||
return err
|
||||
}
|
||||
} else if size > 0 {
|
||||
err := s.manager.storage.UpdateTTL(s.ctx, s.id, s.manager.ttl)
|
||||
if err != nil && err != ErrorDisabled {
|
||||
if err != nil && !gerror.Is(err, ErrorDisabled) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -108,11 +108,10 @@ func (s *Session) Set(key string, value interface{}) (err error) {
|
||||
return err
|
||||
}
|
||||
if err = s.manager.storage.Set(s.ctx, s.id, key, value, s.manager.ttl); err != nil {
|
||||
if err == ErrorDisabled {
|
||||
s.data.Set(key, value)
|
||||
} else {
|
||||
if !gerror.Is(err, ErrorDisabled) {
|
||||
return err
|
||||
}
|
||||
s.data.Set(key, value)
|
||||
}
|
||||
s.dirty = true
|
||||
return nil
|
||||
@ -124,11 +123,10 @@ func (s *Session) SetMap(data map[string]interface{}) (err error) {
|
||||
return err
|
||||
}
|
||||
if err = s.manager.storage.SetMap(s.ctx, s.id, data, s.manager.ttl); err != nil {
|
||||
if err == ErrorDisabled {
|
||||
s.data.Sets(data)
|
||||
} else {
|
||||
if !gerror.Is(err, ErrorDisabled) {
|
||||
return err
|
||||
}
|
||||
s.data.Sets(data)
|
||||
}
|
||||
s.dirty = true
|
||||
return nil
|
||||
@ -144,11 +142,10 @@ func (s *Session) Remove(keys ...string) (err error) {
|
||||
}
|
||||
for _, key := range keys {
|
||||
if err = s.manager.storage.Remove(s.ctx, s.id, key); err != nil {
|
||||
if err == ErrorDisabled {
|
||||
s.data.Remove(key)
|
||||
} else {
|
||||
if !gerror.Is(err, ErrorDisabled) {
|
||||
return err
|
||||
}
|
||||
s.data.Remove(key)
|
||||
}
|
||||
}
|
||||
s.dirty = true
|
||||
@ -164,7 +161,7 @@ func (s *Session) RemoveAll() (err error) {
|
||||
return err
|
||||
}
|
||||
if err = s.manager.storage.RemoveAll(s.ctx, s.id); err != nil {
|
||||
if err != ErrorDisabled {
|
||||
if !gerror.Is(err, ErrorDisabled) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -215,7 +212,7 @@ func (s *Session) Data() (sessionData map[string]interface{}, err error) {
|
||||
return nil, err
|
||||
}
|
||||
sessionData, err = s.manager.storage.Data(s.ctx, s.id)
|
||||
if err != nil && err != ErrorDisabled {
|
||||
if err != nil && !gerror.Is(err, ErrorDisabled) {
|
||||
intlog.Errorf(s.ctx, `%+v`, err)
|
||||
}
|
||||
if sessionData != nil {
|
||||
@ -233,7 +230,7 @@ func (s *Session) Size() (size int, err error) {
|
||||
return 0, err
|
||||
}
|
||||
size, err = s.manager.storage.GetSize(s.ctx, s.id)
|
||||
if err != nil && err != ErrorDisabled {
|
||||
if err != nil && !gerror.Is(err, ErrorDisabled) {
|
||||
intlog.Errorf(s.ctx, `%+v`, err)
|
||||
}
|
||||
if size > 0 {
|
||||
@ -273,7 +270,7 @@ func (s *Session) Get(key string, def ...interface{}) (value *gvar.Var, err erro
|
||||
return nil, err
|
||||
}
|
||||
v, err := s.manager.storage.Get(s.ctx, s.id, key)
|
||||
if err != nil && err != ErrorDisabled {
|
||||
if err != nil && !gerror.Is(err, ErrorDisabled) {
|
||||
intlog.Errorf(s.ctx, `%+v`, err)
|
||||
return nil, err
|
||||
}
|
||||
@ -357,3 +354,60 @@ func (s *Session) MustRemove(keys ...string) {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// RegenerateId regenerates a new session id for current session.
|
||||
// It keeps the session data and updates the session id with a new one.
|
||||
// This is commonly used to prevent session fixation attacks and increase security.
|
||||
//
|
||||
// The parameter `deleteOld` specifies whether to delete the old session data:
|
||||
// - If true: the old session data will be deleted immediately
|
||||
// - If false: the old session data will be kept and expire according to its TTL
|
||||
func (s *Session) RegenerateId(deleteOld bool) (newId string, err error) {
|
||||
if err = s.init(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Generate new session id
|
||||
if s.idFunc != nil {
|
||||
newId = s.idFunc(s.manager.ttl)
|
||||
} else {
|
||||
newId, err = s.manager.storage.New(s.ctx, s.manager.ttl)
|
||||
if err != nil && !gerror.Is(err, ErrorDisabled) {
|
||||
return "", err
|
||||
}
|
||||
if newId == "" {
|
||||
newId = NewSessionId()
|
||||
}
|
||||
}
|
||||
|
||||
// If using storage, need to copy data to new id
|
||||
if s.manager.storage != nil {
|
||||
if err = s.manager.storage.SetSession(s.ctx, newId, s.data, s.manager.ttl); err != nil {
|
||||
if !gerror.Is(err, ErrorDisabled) {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
// Delete old session data if requested
|
||||
if deleteOld {
|
||||
if err = s.manager.storage.RemoveAll(s.ctx, s.id); err != nil {
|
||||
if !gerror.Is(err, ErrorDisabled) {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update session id
|
||||
s.id = newId
|
||||
s.dirty = true
|
||||
return newId, nil
|
||||
}
|
||||
|
||||
// MustRegenerateId performs as function RegenerateId, but it panics if any error occurs.
|
||||
func (s *Session) MustRegenerateId(deleteOld bool) string {
|
||||
newId, err := s.RegenerateId(deleteOld)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return newId
|
||||
}
|
||||
|
@ -7,11 +7,15 @@
|
||||
package gsession
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gogf/gf/v2/test/gtest"
|
||||
)
|
||||
|
||||
var ctx = context.TODO()
|
||||
|
||||
func Test_NewSessionId(t *testing.T) {
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
id1 := NewSessionId()
|
||||
@ -20,3 +24,107 @@ func Test_NewSessionId(t *testing.T) {
|
||||
t.Assert(len(id1), 32)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Session_RegenerateId(t *testing.T) {
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
// 1. Test with memory storage
|
||||
storage := NewStorageMemory()
|
||||
manager := New(time.Hour, storage)
|
||||
session := manager.New(ctx)
|
||||
|
||||
// Store some data
|
||||
err := session.Set("key1", "value1")
|
||||
t.AssertNil(err)
|
||||
err = session.Set("key2", "value2")
|
||||
t.AssertNil(err)
|
||||
|
||||
// Get original session id
|
||||
oldId := session.MustId()
|
||||
|
||||
// Test regenerate with deleteOld = true
|
||||
newId1, err := session.RegenerateId(true)
|
||||
t.AssertNil(err)
|
||||
t.AssertNE(oldId, newId1)
|
||||
|
||||
// Verify data is preserved
|
||||
v1 := session.MustGet("key1")
|
||||
t.Assert(v1.String(), "value1")
|
||||
v2 := session.MustGet("key2")
|
||||
t.Assert(v2.String(), "value2")
|
||||
|
||||
// Verify old session is deleted
|
||||
oldSession := manager.New(ctx)
|
||||
err = oldSession.SetId(oldId)
|
||||
t.AssertNil(err)
|
||||
v3 := oldSession.MustGet("key1")
|
||||
t.Assert(v3.IsNil(), true)
|
||||
|
||||
// Test regenerate with deleteOld = false
|
||||
currentId := newId1
|
||||
newId2, err := session.RegenerateId(false)
|
||||
t.AssertNil(err)
|
||||
t.AssertNE(currentId, newId2)
|
||||
|
||||
// Verify data is preserved in new session
|
||||
v4 := session.MustGet("key1")
|
||||
t.Assert(v4.String(), "value1")
|
||||
|
||||
// Create another session instance with the previous id
|
||||
prevSession := manager.New(ctx)
|
||||
err = prevSession.SetId(currentId)
|
||||
t.AssertNil(err)
|
||||
// Data should still be accessible in previous session
|
||||
v5 := prevSession.MustGet("key1")
|
||||
t.Assert(v5.String(), "value1")
|
||||
})
|
||||
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
// 2. Test with custom id function
|
||||
storage := NewStorageMemory()
|
||||
manager := New(time.Hour, storage)
|
||||
session := manager.New(ctx)
|
||||
|
||||
customId := "custom_session_id"
|
||||
err := session.SetIdFunc(func(ttl time.Duration) string {
|
||||
return customId
|
||||
})
|
||||
t.AssertNil(err)
|
||||
|
||||
newId, err := session.RegenerateId(true)
|
||||
t.AssertNil(err)
|
||||
t.Assert(newId, customId)
|
||||
})
|
||||
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
// 3. Test with disabled storage
|
||||
storage := &StorageBase{} // implements Storage interface but all methods return ErrorDisabled
|
||||
manager := New(time.Hour, storage)
|
||||
session := manager.New(ctx)
|
||||
|
||||
// Should still work even with disabled storage
|
||||
newId, err := session.RegenerateId(true)
|
||||
t.AssertNil(err)
|
||||
t.Assert(len(newId), 32)
|
||||
})
|
||||
}
|
||||
|
||||
// Test MustRegenerateId
|
||||
func Test_Session_MustRegenerateId(t *testing.T) {
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
storage := NewStorageMemory()
|
||||
manager := New(time.Hour, storage)
|
||||
session := manager.New(ctx)
|
||||
|
||||
// Normal case should not panic
|
||||
t.AssertNil(session.Set("key", "value"))
|
||||
newId := session.MustRegenerateId(true)
|
||||
t.Assert(len(newId), 32)
|
||||
|
||||
// Test with disabled storage (should not panic)
|
||||
storage2 := &StorageBase{}
|
||||
manager2 := New(time.Hour, storage2)
|
||||
session2 := manager2.New(ctx)
|
||||
newId2 := session2.MustRegenerateId(true)
|
||||
t.Assert(len(newId2), 32)
|
||||
})
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user