Skip to content

Commit

Permalink
Fix unaligned panic on windows
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Feb 10, 2024
1 parent 38c945f commit 9b7c2a0
Showing 1 changed file with 22 additions and 22 deletions.
44 changes: 22 additions & 22 deletions tun_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ import (
"net/netip"
"os"
"sync"
"sync/atomic"
"time"
"unsafe"

"github.com/sagernet/sing-tun/internal/winipcfg"
"github.com/sagernet/sing-tun/internal/winsys"
"github.com/sagernet/sing-tun/internal/wintun"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/windnsapi"
Expand All @@ -34,7 +34,7 @@ type NativeTun struct {
rate rateJuggler
running sync.WaitGroup
closeOnce sync.Once
close int32
close atomic.Int32
fwpmSession uintptr
}

Expand Down Expand Up @@ -334,13 +334,13 @@ func (t *NativeTun) ReadPacket() ([]byte, func(), error) {
t.running.Add(1)
defer t.running.Done()
retry:
if atomic.LoadInt32(&t.close) == 1 {
if t.close.Load() == 1 {
return nil, nil, os.ErrClosed
}
start := nanotime()
shouldSpin := atomic.LoadUint64(&t.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&t.rate.nextStartTime)) <= rateMeasurementGranularity*2
shouldSpin := t.rate.current.Load() >= spinloopRateThreshold && uint64(start-t.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2
for {
if atomic.LoadInt32(&t.close) == 1 {
if t.close.Load() == 1 {
return nil, nil, os.ErrClosed
}
packet, err := t.session.ReceivePacket()
Expand Down Expand Up @@ -369,13 +369,13 @@ func (t *NativeTun) ReadFunc(block func(b []byte)) error {
t.running.Add(1)
defer t.running.Done()
retry:
if atomic.LoadInt32(&t.close) == 1 {
if t.close.Load() == 1 {
return os.ErrClosed
}
start := nanotime()
shouldSpin := atomic.LoadUint64(&t.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&t.rate.nextStartTime)) <= rateMeasurementGranularity*2
shouldSpin := t.rate.current.Load() >= spinloopRateThreshold && uint64(start-t.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2
for {
if atomic.LoadInt32(&t.close) == 1 {
if t.close.Load() == 1 {
return os.ErrClosed
}
packet, err := t.session.ReceivePacket()
Expand Down Expand Up @@ -405,7 +405,7 @@ retry:
func (t *NativeTun) Write(p []byte) (n int, err error) {
t.running.Add(1)
defer t.running.Done()
if atomic.LoadInt32(&t.close) == 1 {
if t.close.Load() == 1 {
return 0, os.ErrClosed
}
t.rate.update(uint64(len(p)))
Expand All @@ -427,7 +427,7 @@ func (t *NativeTun) Write(p []byte) (n int, err error) {
func (t *NativeTun) write(packetElementList [][]byte) (n int, err error) {
t.running.Add(1)
defer t.running.Done()
if atomic.LoadInt32(&t.close) == 1 {
if t.close.Load() == 1 {
return 0, os.ErrClosed
}
var packetSize int
Expand Down Expand Up @@ -461,7 +461,7 @@ func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
func (t *NativeTun) Close() error {
var err error
t.closeOnce.Do(func() {
atomic.StoreInt32(&t.close, 1)
t.close.Store(1)
windows.SetEvent(t.readWait)
t.running.Wait()
t.session.End()
Expand Down Expand Up @@ -491,24 +491,24 @@ func procyield(cycles uint32)
func nanotime() int64

type rateJuggler struct {
current uint64
nextByteCount uint64
nextStartTime int64
changing int32
current atomic.Uint64
nextByteCount atomic.Uint64
nextStartTime atomic.Int64
changing atomic.Int32
}

func (rate *rateJuggler) update(packetLen uint64) {
now := nanotime()
total := atomic.AddUint64(&rate.nextByteCount, packetLen)
period := uint64(now - atomic.LoadInt64(&rate.nextStartTime))
total := rate.nextByteCount.Add(packetLen)
period := uint64(now - rate.nextStartTime.Load())
if period >= rateMeasurementGranularity {
if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) {
if !rate.changing.CompareAndSwap(0, 1) {
return
}
atomic.StoreInt64(&rate.nextStartTime, now)
atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period)
atomic.StoreUint64(&rate.nextByteCount, 0)
atomic.StoreInt32(&rate.changing, 0)
rate.nextStartTime.Store(now)
rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period)
rate.nextByteCount.Store(0)
rate.changing.Store(0)
}
}

Expand Down

0 comments on commit 9b7c2a0

Please sign in to comment.