Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wxiaoguang committed Dec 12, 2024
1 parent 22bf2ca commit c11b896
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 8 deletions.
51 changes: 43 additions & 8 deletions modules/ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"os"
"os/exec"
"path/filepath"
"reflect"
"strconv"
"strings"
"sync"
Expand All @@ -33,9 +34,21 @@ import (
gossh "golang.org/x/crypto/ssh"
)

type contextKey string

const giteaKeyID = contextKey("gitea-key-id")
// The ssh auth overall works like this:
// NewServerConn:
// serverHandshake+serverAuthenticate:
// PublicKeyCallback:
// PublicKeyHandler (our code):
// clear(ctx.Permissions) and set ctx.Permissions.giteaKeyID = keyID
// pubKey.Verify
// return ctx.Permissions // only reaches here, the pub key is really authenticated
// set conn.Permissions from serverAuthenticate
// sessionHandler(conn)
//
// Then sessionHandler should only use the "verified keyID" from the conn.
// Otherwise, if a users provides 2 keys A and B, if A succeeds to authenticate, sessionHandler will see B's keyID

const giteaPermissionExtensionKeyID = "gitea-perm-ext-key-id"

func getExitStatusFromError(err error) int {
if err == nil {
Expand All @@ -61,8 +74,26 @@ func getExitStatusFromError(err error) int {
return waitStatus.ExitStatus()
}

type sessionPartial struct {
sync.Mutex
gossh.Channel
conn *gossh.ServerConn
}

func ptr[T any](intf any) *T {
// https://pkg.go.dev/unsafe#Pointer
// (1) Conversion of a *T1 to Pointer to *T2.
// Provided that T2 is no larger than T1 and that the two share an equivalent memory layout,
// this conversion allows reinterpreting data of one type as data of another type.
v := reflect.ValueOf(intf)
p := v.UnsafePointer()
return (*T)(p)
}

func sessionHandler(session ssh.Session) {
keyID := fmt.Sprintf("%d", session.Context().Value(giteaKeyID).(int64))
// it can't use session.Permissions() because it only use the ctx one, so we must use the original ssh conn
sshConn := ptr[sessionPartial](session)
keyID := sshConn.conn.Permissions.Extensions[giteaPermissionExtensionKeyID]

command := session.RawCommand()

Expand Down Expand Up @@ -164,6 +195,12 @@ func sessionHandler(session ssh.Session) {
}

func publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
setPermExt := func(keyID int64) {
ctx.Permissions().Permissions.Extensions = map[string]string{
giteaPermissionExtensionKeyID: fmt.Sprint(keyID),
}
}

if log.IsDebug() { // <- FingerprintSHA256 is kinda expensive so only calculate it if necessary
log.Debug("Handle Public Key: Fingerprint: %s from %s", gossh.FingerprintSHA256(key), ctx.RemoteAddr())
}
Expand Down Expand Up @@ -238,8 +275,7 @@ func publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
if log.IsDebug() { // <- FingerprintSHA256 is kinda expensive so only calculate it if necessary
log.Debug("Successfully authenticated: %s Certificate Fingerprint: %s Principal: %s", ctx.RemoteAddr(), gossh.FingerprintSHA256(key), principal)
}
ctx.SetValue(giteaKeyID, pkey.ID)

setPermExt(pkey.ID)
return true
}

Expand All @@ -266,8 +302,7 @@ func publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
if log.IsDebug() { // <- FingerprintSHA256 is kinda expensive so only calculate it if necessary
log.Debug("Successfully authenticated: %s Public Key Fingerprint: %s", ctx.RemoteAddr(), gossh.FingerprintSHA256(key))
}
ctx.SetValue(giteaKeyID, pkey.ID)

setPermExt(pkey.ID)
return true
}

Expand Down
32 changes: 32 additions & 0 deletions modules/ssh/ssh_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright 2024 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT

package ssh

import (
"testing"

"github.com/stretchr/testify/assert"
)

type S1 struct {
a, b, c int
}

func (s S1) S1Func() {}

type S1Intf interface {
S1Func()
}

type S2 struct {
a, b int
}

func TestPtr(t *testing.T) {
s1 := &S1{1, 2, 3}
var intf S1Intf = s1
s2 := ptr[S2](intf)
assert.Equal(t, 1, s2.a)
assert.Equal(t, 2, s2.b)
}

0 comments on commit c11b896

Please sign in to comment.