1
0
mirror of https://github.com/gogf/gf.git synced 2025-04-05 03:05:05 +08:00

feat(os/gsession): add RegenerateId/MustRegenerateId support (#4012)

This commit is contained in:
John Guo 2024-12-06 14:16:03 +08:00 committed by GitHub
parent ba968949f7
commit b8142bf1fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 179 additions and 17 deletions

View File

@ -45,7 +45,7 @@ func (s *Session) init() error {
// Retrieve stored session data from storage.
if s.manager.storage != nil {
s.data, err = s.manager.storage.GetSession(s.ctx, s.id, s.manager.GetTTL())
if err != nil && err != ErrorDisabled {
if err != nil && !gerror.Is(err, ErrorDisabled) {
intlog.Errorf(s.ctx, `session restoring failed for id "%s": %+v`, s.id, err)
return err
}
@ -59,7 +59,7 @@ func (s *Session) init() error {
} else {
// Use default session id creating function of storage.
s.id, err = s.manager.storage.New(s.ctx, s.manager.ttl)
if err != nil && err != ErrorDisabled {
if err != nil && !gerror.Is(err, ErrorDisabled) {
intlog.Errorf(s.ctx, "create session id failed: %+v", err)
return err
}
@ -89,12 +89,12 @@ func (s *Session) Close() error {
size := s.data.Size()
if s.dirty {
err := s.manager.storage.SetSession(s.ctx, s.id, s.data, s.manager.ttl)
if err != nil && err != ErrorDisabled {
if err != nil && !gerror.Is(err, ErrorDisabled) {
return err
}
} else if size > 0 {
err := s.manager.storage.UpdateTTL(s.ctx, s.id, s.manager.ttl)
if err != nil && err != ErrorDisabled {
if err != nil && !gerror.Is(err, ErrorDisabled) {
return err
}
}
@ -108,11 +108,10 @@ func (s *Session) Set(key string, value interface{}) (err error) {
return err
}
if err = s.manager.storage.Set(s.ctx, s.id, key, value, s.manager.ttl); err != nil {
if err == ErrorDisabled {
s.data.Set(key, value)
} else {
if !gerror.Is(err, ErrorDisabled) {
return err
}
s.data.Set(key, value)
}
s.dirty = true
return nil
@ -124,11 +123,10 @@ func (s *Session) SetMap(data map[string]interface{}) (err error) {
return err
}
if err = s.manager.storage.SetMap(s.ctx, s.id, data, s.manager.ttl); err != nil {
if err == ErrorDisabled {
s.data.Sets(data)
} else {
if !gerror.Is(err, ErrorDisabled) {
return err
}
s.data.Sets(data)
}
s.dirty = true
return nil
@ -144,11 +142,10 @@ func (s *Session) Remove(keys ...string) (err error) {
}
for _, key := range keys {
if err = s.manager.storage.Remove(s.ctx, s.id, key); err != nil {
if err == ErrorDisabled {
s.data.Remove(key)
} else {
if !gerror.Is(err, ErrorDisabled) {
return err
}
s.data.Remove(key)
}
}
s.dirty = true
@ -164,7 +161,7 @@ func (s *Session) RemoveAll() (err error) {
return err
}
if err = s.manager.storage.RemoveAll(s.ctx, s.id); err != nil {
if err != ErrorDisabled {
if !gerror.Is(err, ErrorDisabled) {
return err
}
}
@ -215,7 +212,7 @@ func (s *Session) Data() (sessionData map[string]interface{}, err error) {
return nil, err
}
sessionData, err = s.manager.storage.Data(s.ctx, s.id)
if err != nil && err != ErrorDisabled {
if err != nil && !gerror.Is(err, ErrorDisabled) {
intlog.Errorf(s.ctx, `%+v`, err)
}
if sessionData != nil {
@ -233,7 +230,7 @@ func (s *Session) Size() (size int, err error) {
return 0, err
}
size, err = s.manager.storage.GetSize(s.ctx, s.id)
if err != nil && err != ErrorDisabled {
if err != nil && !gerror.Is(err, ErrorDisabled) {
intlog.Errorf(s.ctx, `%+v`, err)
}
if size > 0 {
@ -273,7 +270,7 @@ func (s *Session) Get(key string, def ...interface{}) (value *gvar.Var, err erro
return nil, err
}
v, err := s.manager.storage.Get(s.ctx, s.id, key)
if err != nil && err != ErrorDisabled {
if err != nil && !gerror.Is(err, ErrorDisabled) {
intlog.Errorf(s.ctx, `%+v`, err)
return nil, err
}
@ -357,3 +354,60 @@ func (s *Session) MustRemove(keys ...string) {
panic(err)
}
}
// RegenerateId regenerates a new session id for current session.
// It keeps the session data and updates the session id with a new one.
// This is commonly used to prevent session fixation attacks and increase security.
//
// The parameter `deleteOld` specifies whether to delete the old session data:
// - If true: the old session data will be deleted immediately
// - If false: the old session data will be kept and expire according to its TTL
func (s *Session) RegenerateId(deleteOld bool) (newId string, err error) {
if err = s.init(); err != nil {
return "", err
}
// Generate new session id
if s.idFunc != nil {
newId = s.idFunc(s.manager.ttl)
} else {
newId, err = s.manager.storage.New(s.ctx, s.manager.ttl)
if err != nil && !gerror.Is(err, ErrorDisabled) {
return "", err
}
if newId == "" {
newId = NewSessionId()
}
}
// If using storage, need to copy data to new id
if s.manager.storage != nil {
if err = s.manager.storage.SetSession(s.ctx, newId, s.data, s.manager.ttl); err != nil {
if !gerror.Is(err, ErrorDisabled) {
return "", err
}
}
// Delete old session data if requested
if deleteOld {
if err = s.manager.storage.RemoveAll(s.ctx, s.id); err != nil {
if !gerror.Is(err, ErrorDisabled) {
return "", err
}
}
}
}
// Update session id
s.id = newId
s.dirty = true
return newId, nil
}
// MustRegenerateId performs as function RegenerateId, but it panics if any error occurs.
func (s *Session) MustRegenerateId(deleteOld bool) string {
newId, err := s.RegenerateId(deleteOld)
if err != nil {
panic(err)
}
return newId
}

View File

@ -7,11 +7,15 @@
package gsession
import (
"context"
"testing"
"time"
"github.com/gogf/gf/v2/test/gtest"
)
var ctx = context.TODO()
func Test_NewSessionId(t *testing.T) {
gtest.C(t, func(t *gtest.T) {
id1 := NewSessionId()
@ -20,3 +24,107 @@ func Test_NewSessionId(t *testing.T) {
t.Assert(len(id1), 32)
})
}
func Test_Session_RegenerateId(t *testing.T) {
gtest.C(t, func(t *gtest.T) {
// 1. Test with memory storage
storage := NewStorageMemory()
manager := New(time.Hour, storage)
session := manager.New(ctx)
// Store some data
err := session.Set("key1", "value1")
t.AssertNil(err)
err = session.Set("key2", "value2")
t.AssertNil(err)
// Get original session id
oldId := session.MustId()
// Test regenerate with deleteOld = true
newId1, err := session.RegenerateId(true)
t.AssertNil(err)
t.AssertNE(oldId, newId1)
// Verify data is preserved
v1 := session.MustGet("key1")
t.Assert(v1.String(), "value1")
v2 := session.MustGet("key2")
t.Assert(v2.String(), "value2")
// Verify old session is deleted
oldSession := manager.New(ctx)
err = oldSession.SetId(oldId)
t.AssertNil(err)
v3 := oldSession.MustGet("key1")
t.Assert(v3.IsNil(), true)
// Test regenerate with deleteOld = false
currentId := newId1
newId2, err := session.RegenerateId(false)
t.AssertNil(err)
t.AssertNE(currentId, newId2)
// Verify data is preserved in new session
v4 := session.MustGet("key1")
t.Assert(v4.String(), "value1")
// Create another session instance with the previous id
prevSession := manager.New(ctx)
err = prevSession.SetId(currentId)
t.AssertNil(err)
// Data should still be accessible in previous session
v5 := prevSession.MustGet("key1")
t.Assert(v5.String(), "value1")
})
gtest.C(t, func(t *gtest.T) {
// 2. Test with custom id function
storage := NewStorageMemory()
manager := New(time.Hour, storage)
session := manager.New(ctx)
customId := "custom_session_id"
err := session.SetIdFunc(func(ttl time.Duration) string {
return customId
})
t.AssertNil(err)
newId, err := session.RegenerateId(true)
t.AssertNil(err)
t.Assert(newId, customId)
})
gtest.C(t, func(t *gtest.T) {
// 3. Test with disabled storage
storage := &StorageBase{} // implements Storage interface but all methods return ErrorDisabled
manager := New(time.Hour, storage)
session := manager.New(ctx)
// Should still work even with disabled storage
newId, err := session.RegenerateId(true)
t.AssertNil(err)
t.Assert(len(newId), 32)
})
}
// Test MustRegenerateId
func Test_Session_MustRegenerateId(t *testing.T) {
gtest.C(t, func(t *gtest.T) {
storage := NewStorageMemory()
manager := New(time.Hour, storage)
session := manager.New(ctx)
// Normal case should not panic
t.AssertNil(session.Set("key", "value"))
newId := session.MustRegenerateId(true)
t.Assert(len(newId), 32)
// Test with disabled storage (should not panic)
storage2 := &StorageBase{}
manager2 := New(time.Hour, storage2)
session2 := manager2.New(ctx)
newId2 := session2.MustRegenerateId(true)
t.Assert(len(newId2), 32)
})
}