diff --git a/tun_windows.go b/tun_windows.go index 7e1a0c3..db33954 100644 --- a/tun_windows.go +++ b/tun_windows.go @@ -9,7 +9,6 @@ import ( "net/netip" "os" "sync" - "sync/atomic" "time" "unsafe" @@ -17,6 +16,7 @@ import ( "github.com/sagernet/sing-tun/internal/winsys" "github.com/sagernet/sing-tun/internal/wintun" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/atomic" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/windnsapi" @@ -34,7 +34,7 @@ type NativeTun struct { rate rateJuggler running sync.WaitGroup closeOnce sync.Once - close int32 + close atomic.Int32 fwpmSession uintptr } @@ -334,13 +334,13 @@ func (t *NativeTun) ReadPacket() ([]byte, func(), error) { t.running.Add(1) defer t.running.Done() retry: - if atomic.LoadInt32(&t.close) == 1 { + if t.close.Load() == 1 { return nil, nil, os.ErrClosed } start := nanotime() - shouldSpin := atomic.LoadUint64(&t.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&t.rate.nextStartTime)) <= rateMeasurementGranularity*2 + shouldSpin := t.rate.current.Load() >= spinloopRateThreshold && uint64(start-t.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2 for { - if atomic.LoadInt32(&t.close) == 1 { + if t.close.Load() == 1 { return nil, nil, os.ErrClosed } packet, err := t.session.ReceivePacket() @@ -369,13 +369,13 @@ func (t *NativeTun) ReadFunc(block func(b []byte)) error { t.running.Add(1) defer t.running.Done() retry: - if atomic.LoadInt32(&t.close) == 1 { + if t.close.Load() == 1 { return os.ErrClosed } start := nanotime() - shouldSpin := atomic.LoadUint64(&t.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&t.rate.nextStartTime)) <= rateMeasurementGranularity*2 + shouldSpin := t.rate.current.Load() >= spinloopRateThreshold && uint64(start-t.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2 for { - if atomic.LoadInt32(&t.close) == 1 { + if t.close.Load() == 1 { return os.ErrClosed } packet, err := t.session.ReceivePacket() @@ -405,7 +405,7 @@ retry: func (t *NativeTun) Write(p []byte) (n int, err error) { t.running.Add(1) defer t.running.Done() - if atomic.LoadInt32(&t.close) == 1 { + if t.close.Load() == 1 { return 0, os.ErrClosed } t.rate.update(uint64(len(p))) @@ -427,7 +427,7 @@ func (t *NativeTun) Write(p []byte) (n int, err error) { func (t *NativeTun) write(packetElementList [][]byte) (n int, err error) { t.running.Add(1) defer t.running.Done() - if atomic.LoadInt32(&t.close) == 1 { + if t.close.Load() == 1 { return 0, os.ErrClosed } var packetSize int @@ -461,7 +461,7 @@ func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error { func (t *NativeTun) Close() error { var err error t.closeOnce.Do(func() { - atomic.StoreInt32(&t.close, 1) + t.close.Store(1) windows.SetEvent(t.readWait) t.running.Wait() t.session.End() @@ -491,24 +491,24 @@ func procyield(cycles uint32) func nanotime() int64 type rateJuggler struct { - current uint64 - nextByteCount uint64 - nextStartTime int64 - changing int32 + current atomic.Uint64 + nextByteCount atomic.Uint64 + nextStartTime atomic.Int64 + changing atomic.Int32 } func (rate *rateJuggler) update(packetLen uint64) { now := nanotime() - total := atomic.AddUint64(&rate.nextByteCount, packetLen) - period := uint64(now - atomic.LoadInt64(&rate.nextStartTime)) + total := rate.nextByteCount.Add(packetLen) + period := uint64(now - rate.nextStartTime.Load()) if period >= rateMeasurementGranularity { - if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) { + if !rate.changing.CompareAndSwap(0, 1) { return } - atomic.StoreInt64(&rate.nextStartTime, now) - atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period) - atomic.StoreUint64(&rate.nextByteCount, 0) - atomic.StoreInt32(&rate.changing, 0) + rate.nextStartTime.Store(now) + rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period) + rate.nextByteCount.Store(0) + rate.changing.Store(0) } }