1
0
mirror of https://github.com/gogf/gf.git synced 2025-04-05 03:05:05 +08:00

feat(database/gdb): enable transaction propagation when using tx.GetCtx() after Begin (#4121)

This commit is contained in:
CyJaySong 2025-02-11 16:09:27 +08:00 committed by GitHub
parent 0eb229a887
commit a3b3c656d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 48 additions and 2 deletions

View File

@ -1708,3 +1708,44 @@ func Test_Transaction_Isolation(t *testing.T) {
t.AssertNil(err)
})
}
func Test_Transaction_Spread(t *testing.T) {
table := createTable()
defer dropTable(table)
db.SetDebug(true)
defer db.SetDebug(false)
gtest.C(t, func(t *gtest.T) {
var (
err error
ctx = context.TODO()
)
tx, err := db.Begin(ctx)
t.AssertNil(err)
err = db.Transaction(tx.GetCtx(), func(ctx context.Context, tx gdb.TX) error {
_, err = db.Model(table).Ctx(ctx).Data(g.Map{
"id": 1,
"passport": "USER_1",
"password": "PASS_1",
"nickname": "NAME_1",
"create_time": gtime.Now().String(),
}).Insert()
return err
})
t.AssertNil(err)
all, err := tx.Model(table).All()
t.AssertNil(err)
t.Assert(len(all), 1)
t.Assert(all[0]["id"], 1)
err = tx.Rollback()
t.AssertNil(err)
all, err = db.Ctx(ctx).Model(table).All()
t.AssertNil(err)
t.Assert(len(all), 0)
})
}

View File

@ -257,12 +257,14 @@ func WithTX(ctx context.Context, tx TX) context.Context {
}
// Inject transaction object and id into context.
ctx = context.WithValue(ctx, transactionKeyForContext(group), tx)
ctx = context.WithValue(ctx, transactionIdForLoggerCtx, tx.GetCtx().Value(transactionIdForLoggerCtx))
return ctx
}
// WithoutTX removed transaction object from context and returns a new context.
func WithoutTX(ctx context.Context, group string) context.Context {
ctx = context.WithValue(ctx, transactionKeyForContext(group), nil)
ctx = context.WithValue(ctx, transactionIdForLoggerCtx, nil)
return ctx
}

View File

@ -180,14 +180,17 @@ func (c *Core) DoCommit(ctx context.Context, in DoCommitInput) (out DoCommitOutp
formattedSql, in.TxOptions.Isolation.String(), in.TxOptions.ReadOnly,
)
if sqlTx, err = in.Db.BeginTx(ctx, &in.TxOptions); err == nil {
out.Tx = &TXCore{
tx := &TXCore{
db: c.db,
tx: sqlTx,
ctx: context.WithValue(ctx, transactionIdForLoggerCtx, transactionIdGenerator.Add(1)),
ctx: ctx,
master: in.Db,
transactionId: guid.S(),
cancelFunc: cancelFuncForTimeout,
}
tx.ctx = context.WithValue(ctx, transactionKeyForContext(tx.db.GetGroup()), tx)
tx.ctx = context.WithValue(tx.ctx, transactionIdForLoggerCtx, transactionIdGenerator.Add(1))
out.Tx = tx
ctx = out.Tx.GetCtx()
}
out.RawResult = sqlTx