diff --git a/log/root.go b/log/root.go index 91209c46ad..c09a36b114 100644 --- a/log/root.go +++ b/log/root.go @@ -3,18 +3,24 @@ package log import ( "log/slog" "os" - "sync/atomic" + "sync" ) -var root atomic.Value +var ( + rootLock sync.RWMutex + root Logger +) func init() { - root.Store(&logger{slog.New(DiscardHandler())}) + root = &logger{slog.New(DiscardHandler())} } // SetDefault sets the default global logger func SetDefault(l Logger) { - root.Store(l) + rootLock.Lock() + defer rootLock.Unlock() + + root = l if lg, ok := l.(*logger); ok { slog.SetDefault(lg.inner) } @@ -22,7 +28,10 @@ func SetDefault(l Logger) { // Root returns the root logger func Root() Logger { - return root.Load().(Logger) + rootLock.RLock() + defer rootLock.RUnlock() + + return root } // The following functions bypass the exported logger methods (logger.Debug, diff --git a/log/root_test.go b/log/root_test.go new file mode 100644 index 0000000000..b9b22af669 --- /dev/null +++ b/log/root_test.go @@ -0,0 +1,19 @@ +package log + +import ( + "testing" +) + +// SetDefault should properly set the default logger when custom loggers are +// provided. +func TestSetDefaultCustomLogger(t *testing.T) { + type customLogger struct { + Logger // Implement the Logger interface + } + + customLog := &customLogger{} + SetDefault(customLog) + if Root() != customLog { + t.Error("expected custom logger to be set as default") + } +}