Skip to content

Commit

Permalink
Rename Writer
Browse files Browse the repository at this point in the history
  • Loading branch information
fortuna committed Nov 2, 2023
1 parent 8257d95 commit d3657a8
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions transport/shadowsocks/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ const payloadSizeMask = 0x3FFF // 16*1024 - 1
// The largest buffer we could need is for decrypting a max-length payload.
var readBufPool = slicepool.MakePool(payloadSizeMask + maxTagSize)

// Writer is an [io.Writer] that also implements [io.ReaderFrom] to
// ReaderFrom is an [io.ReaderFrom] that also implements [io.ReaderFrom] to
// allow for piping the data without extra allocations and copies.
// The LazyWrite and Flush methods allow a header to be
// added but delayed until the first write, for concatenation.
// All methods except Flush must be called from a single thread.
type Writer struct {
type ReaderFrom struct {
// This type is single-threaded except when needFlush is true.
// mu protects needFlush, and also protects everything
// else while needFlush could be true.
Expand All @@ -56,23 +56,23 @@ type Writer struct {
}

var (
_ io.ReaderFrom = (*Writer)(nil)
_ io.ReaderFrom = (*ReaderFrom)(nil)
)

// NewReaderFrom creates a [Writer] that encrypts the given [io.Writer] using
// the shadowsocks protocol with the given encryption key.
func NewReaderFrom(rf io.ReaderFrom, key *EncryptionKey) *Writer {
return &Writer{rf: rf, key: key, saltGenerator: RandomSaltGenerator}
func NewReaderFrom(rf io.ReaderFrom, key *EncryptionKey) *ReaderFrom {
return &ReaderFrom{rf: rf, key: key, saltGenerator: RandomSaltGenerator}
}

// SetSaltGenerator sets the salt generator to be used. Must be called before the first write.
func (sw *Writer) SetSaltGenerator(saltGenerator SaltGenerator) {
func (sw *ReaderFrom) SetSaltGenerator(saltGenerator SaltGenerator) {
sw.saltGenerator = saltGenerator
}

// init generates a random salt, sets up the AEAD object and writes
// the salt to the inner Writer.
func (sw *Writer) init() (err error) {
func (sw *ReaderFrom) init() (err error) {
if sw.aead == nil {
salt := make([]byte, sw.key.SaltSize())
if err := sw.saltGenerator.GetSalt(salt); err != nil {
Expand All @@ -97,15 +97,15 @@ func (sw *Writer) init() (err error) {

// encryptBlock encrypts `plaintext` in-place. The slice must have enough capacity
// for the tag. Returns the total ciphertext length.
func (sw *Writer) encryptBlock(plaintext []byte) int {
func (sw *ReaderFrom) encryptBlock(plaintext []byte) int {
out := sw.aead.Seal(plaintext[:0], sw.counter, plaintext, nil)
increment(sw.counter)
return len(out)
}

// LazyWrite queues p to be written, but doesn't send it until Flush() is
// called, a non-lazy write is made, or the buffer is filled.
func (sw *Writer) LazyWrite(p []byte) (int, error) {
func (sw *ReaderFrom) LazyWrite(p []byte) (int, error) {
if err := sw.init(); err != nil {
return 0, err
}
Expand Down Expand Up @@ -133,7 +133,7 @@ func (sw *Writer) LazyWrite(p []byte) (int, error) {
}

// Flush sends the pending data, if any. This method is thread-safe.
func (sw *Writer) Flush() error {
func (sw *ReaderFrom) Flush() error {
sw.mu.Lock()
defer sw.mu.Unlock()
if !sw.needFlush {
Expand All @@ -152,7 +152,7 @@ func isZero(b []byte) bool {
}

// Returns the slices of sw.buf in which to place plaintext for encryption.
func (sw *Writer) buffers() (sizeBuf, payloadBuf []byte) {
func (sw *ReaderFrom) buffers() (sizeBuf, payloadBuf []byte) {
// sw.buf starts with the salt.
saltSize := sw.key.SaltSize()

Expand All @@ -165,7 +165,7 @@ func (sw *Writer) buffers() (sizeBuf, payloadBuf []byte) {
}

// ReadFrom implements the [io.ReaderFrom] interface.
func (sw *Writer) ReadFrom(r io.Reader) (int64, error) {
func (sw *ReaderFrom) ReadFrom(r io.Reader) (int64, error) {
if err := sw.init(); err != nil {
return 0, err
}
Expand Down Expand Up @@ -215,15 +215,15 @@ func (sw *Writer) ReadFrom(r io.Reader) (int64, error) {

// Adds as much of `plaintext` into the buffer as will fit, and increases
// sw.pending accordingly. Returns the number of bytes consumed.
func (sw *Writer) enqueue(plaintext []byte) int {
func (sw *ReaderFrom) enqueue(plaintext []byte) int {
_, payloadBuf := sw.buffers()
n := copy(payloadBuf[sw.pending:], plaintext)
sw.pending += n
return n
}

// Encrypts all pending data and writes it to the output.
func (sw *Writer) flush() error {
func (sw *ReaderFrom) flush() error {
if sw.pending == 0 {
return nil
}
Expand Down

0 comments on commit d3657a8

Please sign in to comment.