diff --git a/directmedia.go b/directmedia.go index aa7a9d9..28af0e4 100644 --- a/directmedia.go +++ b/directmedia.go @@ -18,6 +18,7 @@ package main import ( "context" + "crypto/sha256" "errors" "fmt" "io" @@ -48,6 +49,8 @@ type DirectMediaAPI struct { log zerolog.Logger proxy http.Client + signatureKey [32]byte + attachmentCache map[AttachmentCacheKey]AttachmentCacheValue attachmentCacheLock sync.Mutex } @@ -88,6 +91,7 @@ func newDirectMediaAPI(br *DiscordBridge) *DirectMediaAPI { os.Exit(11) return nil } + dma.signatureKey = sha256.Sum256(parsed.Priv.Seed()) dma.ks = &federation.KeyServer{ KeyProvider: &federation.StaticServerKey{ ServerName: dma.cfg.ServerName, @@ -139,7 +143,7 @@ func newDirectMediaAPI(br *DiscordBridge) *DirectMediaAPI { func (dma *DirectMediaAPI) makeMXC(data MediaIDData) id.ContentURI { return id.ContentURI{ Homeserver: dma.cfg.ServerName, - FileID: data.Wrap().String(), + FileID: data.Wrap().SignedString(dma.signatureKey), } } @@ -319,7 +323,7 @@ func (dma *DirectMediaAPI) FetchNewAttachmentURL(ctx context.Context, meta *Atta func (dma *DirectMediaAPI) GetMediaURL(ctx context.Context, encodedMediaID string) (url string, expiry time.Time, err error) { var mediaID *MediaID - mediaID, err = ParseMediaID(encodedMediaID) + mediaID, err = ParseMediaID(encodedMediaID, dma.signatureKey) if err != nil { err = &RespError{ Code: mautrix.MNotFound.ErrCode, diff --git a/directmedia_id.go b/directmedia_id.go index 7c09118..260315b 100644 --- a/directmedia_id.go +++ b/directmedia_id.go @@ -18,6 +18,8 @@ package main import ( "bytes" + "crypto/hmac" + "crypto/sha256" "encoding/base64" "encoding/binary" "errors" @@ -51,11 +53,18 @@ type MediaID struct { Data MediaIDData } -func ParseMediaID(id string) (*MediaID, error) { +func ParseMediaID(id string, key [32]byte) (*MediaID, error) { data, err := base64.RawURLEncoding.DecodeString(id) if err != nil { return nil, fmt.Errorf("failed to decode base64: %w", err) } + hasher := hmac.New(sha256.New, key[:]) + checksum := data[len(data)-TruncatedHashLength:] + data = data[:len(data)-TruncatedHashLength] + hasher.Write(data) + if !hmac.Equal(checksum, hasher.Sum(nil)[:TruncatedHashLength]) { + return nil, ErrMediaIDChecksumMismatch + } mid := &MediaID{} err = mid.Read(bytes.NewReader(data)) if err != nil { @@ -64,9 +73,14 @@ func ParseMediaID(id string) (*MediaID, error) { return mid, nil } -func (mid *MediaID) String() string { - buf := bytes.NewBuffer(make([]byte, 0, mid.Data.Size())) +const TruncatedHashLength = 16 + +func (mid *MediaID) SignedString(key [32]byte) string { + buf := bytes.NewBuffer(make([]byte, 0, mid.Size())) mid.Write(buf) + hasher := hmac.New(sha256.New, key[:]) + hasher.Write(buf.Bytes()) + buf.Write(hasher.Sum(nil)[:TruncatedHashLength]) return base64.RawURLEncoding.EncodeToString(buf.Bytes()) } @@ -77,9 +91,14 @@ func (mid *MediaID) Write(to io.Writer) { mid.Data.Write(to) } +func (mid *MediaID) Size() int { + return len(MediaIDPrefix) + 2 + mid.Data.Size() + TruncatedHashLength +} + var ( - ErrInvalidMediaID = errors.New("invalid media ID") - ErrUnsupportedMediaID = errors.New("unsupported media ID") + ErrInvalidMediaID = errors.New("invalid media ID") + ErrMediaIDChecksumMismatch = errors.New("invalid checksum in media ID") + ErrUnsupportedMediaID = errors.New("unsupported media ID") ) func (mid *MediaID) Read(from io.Reader) error { diff --git a/example-config.yaml b/example-config.yaml index e390c7a..9e2dc4c 100644 --- a/example-config.yaml +++ b/example-config.yaml @@ -183,6 +183,7 @@ bridge: # Optionally, you can force redirects and not allow proxying at all by setting this to false. allow_proxy: true # Matrix server signing key to make the federation tester pass, same format as synapse's .signing.key file. + # This key is also used to sign the mxc:// URIs to ensure only the bridge can generate them. server_key: generate # Settings for converting animated stickers. animated_sticker: