Skip to content

Commit

Permalink
fix: ExtractInitProtectData was broken for video
Browse files Browse the repository at this point in the history
  • Loading branch information
tobbee committed Nov 8, 2024
1 parent c57ef5f commit 8463870
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 83 deletions.
12 changes: 6 additions & 6 deletions mp4/crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -658,19 +658,19 @@ func ExtractInitProtectData(inSeg *InitSegment) (*InitProtectData, error) {
for _, c := range stsd.Children {
switch box := c.(type) {
case *VisualSampleEntryBox:
switch box.Type() {
case "avc1":
sinf = box.Sinf
frma := sinf.Frma
if frma.DataFormat == "avc1" {
ipd.ProtFunc, err = getAVCProtFunc(box.AvcC)
if err != nil {
return nil, fmt.Errorf("get AVC protect func: %w", err)
}
default:
return nil, fmt.Errorf("unsupported video codec descriptor %s", box.Type())
} else {
return nil, fmt.Errorf("unsupported video codec descriptor %s", frma.DataFormat)
}
sinf = box.Sinf
case *AudioSampleEntryBox:
ipd.ProtFunc = getAudioProtectRanges
sinf = box.Sinf
ipd.ProtFunc = getAudioProtectRanges
default:
continue
}
Expand Down
181 changes: 104 additions & 77 deletions mp4/crypto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,98 +95,125 @@ func TestFindAVCSubsampleRanges(t *testing.T) {
}
}

func TestEncryptDecryptAVC(t *testing.T) {
testInit := "testdata/init.mp4"
testFile := "testdata/1.m4s"
func TestEncryptDecrypt(t *testing.T) {
videoInit := "testdata/init.mp4"
videoSeg := "testdata/1.m4s"
audioInit := "testdata/aac_init.mp4"
audioSeg := "testdata/aac_1.m4s"
keyHex := "00112233445566778899aabbccddeeff"
ivHex := "7766554433221100"
ivHex8 := "7766554433221100"
ivHex16 := "ffeeddccbbaa99887766554433221100"
kidHex := "11112222333344445555666677778888"
key, _ := hex.DecodeString(keyHex)
iv, _ := hex.DecodeString(ivHex)
kidUUID, _ := NewUUIDFromHex(kidHex)

if len(iv) == 8 {
// Convert to 16 bytes
iv8 := iv
iv = make([]byte, 16)
copy(iv, iv8)
}

testCases := []struct {
desc string
init string
seg string
scheme string
iv string
}{
{scheme: "cenc"},
{scheme: "cbcs"},
{desc: "video, cenc, iv8", init: videoInit, seg: videoSeg, scheme: "cenc", iv: ivHex8},
{desc: "video, cbcs, iv8", init: videoInit, seg: videoSeg, scheme: "cbcs", iv: ivHex8},
{desc: "video, cbcs, iv16", init: videoInit, seg: videoSeg, scheme: "cbcs", iv: ivHex16},
{desc: "audio, cbcs, iv16", init: audioInit, seg: audioSeg, scheme: "cbcs", iv: ivHex16},
}
for _, c := range testCases {
t.Run(c.desc, func(t *testing.T) {
ifh, err := os.Open(c.init)
if err != nil {
t.Fatal(err)
}
init, err := DecodeFile(ifh)
ifh.Close()
if err != nil {
t.Fatal(err)
}
iv, err := hex.DecodeString(c.iv)
if err != nil {
t.Fatal(err)
}
ipf, err := InitProtect(init.Init, key, iv, c.scheme, kidUUID, nil)
if err != nil {
t.Fatal(err)
}
// Write init segment with encyption info
encInitBuf := bytes.Buffer{}
err = init.Encode(&encInitBuf)
if err != nil {
t.Fatal(err)
}

for _, tc := range testCases {
ifh, err := os.Open(testInit)
if err != nil {
t.Fatal(err)
}
init, err := DecodeFile(ifh)
if err != nil {
t.Fatal(err)
}
ifh.Close()
ipf, err := InitProtect(init.Init, key, iv, tc.scheme, kidUUID, nil)
if err != nil {
t.Fatal(err)
}
ifh, err = os.Open(testFile)
if err != nil {
t.Fatal(err)
}
segFile, err := DecodeFile(ifh)
if err != nil {
t.Fatal(err)
}
dInfo, err := DecryptInit(init.Init)
if err != nil {
t.Fatal(err)
}
ifh.Close()
for _, s := range segFile.Segments {
for _, f := range s.Fragments {
rawInput := make([]byte, len(f.Mdat.Data))
copy(rawInput, f.Mdat.Data)
err := EncryptFragment(f, key, iv, ipf)
if err != nil {
t.Error(err)
}
outBuf := bytes.Buffer{}
err = f.Encode(&outBuf)
if err != nil {
t.Error(err)
}
sr := bits.NewFixedSliceReader(outBuf.Bytes())
dff, err := DecodeFileSR(sr)
if err != nil {
t.Error(err)
}
if len(dff.Segments) != 1 {
t.Errorf("Expected 1 segment, got %d", len(dff.Segments))
}
if len(dff.Segments[0].Fragments) != 1 {
t.Errorf("Expected 1 fragment, got %d", len(dff.Segments[0].Fragments))
}
df := dff.Segments[0].Fragments[0]
encData := make([]byte, len(df.Mdat.Data))
copy(encData, df.Mdat.Data)
if bytes.Equal(rawInput, encData) {
t.Errorf("bytes equal after encryption")
// Check that one can extract the protection the InitProtectData from the init segment
ipd, err := ExtractInitProtectData(init.Init)
if err != nil {
t.Fatal(err)
}
diff := deep.Equal(ipd, ipf)
if len(diff) > 0 {
t.Errorf("InitProtectData not equal after extraction")
}

// Encrypt and write media segment
rawSeg, err := os.ReadFile(c.seg)
if err != nil {
t.Fatal(err)
}
rs := bytes.NewBuffer(rawSeg)
seg, err := DecodeFile(rs)
if err != nil {
t.Fatal(err)
}
for _, s := range seg.Segments {
for _, f := range s.Fragments {
err := EncryptFragment(f, key, iv, ipf)
if err != nil {
t.Error(err)
}
}
err = DecryptFragment(df, dInfo, key)
}
outBuf := bytes.Buffer{}
err = seg.Encode(&outBuf)
if err != nil {
t.Error(err)
}
// Get decrypt info from init segment
encInitRaw := encInitBuf.Bytes()
sr := bits.NewFixedSliceReader(encInitRaw)
encInit, err := DecodeFileSR(sr)
if err != nil {
t.Error(err)
}
decInfo, err := DecryptInit(encInit.Init)
if err != nil {
t.Error(err)
}

// Decode and decrypt the written segment
sr = bits.NewFixedSliceReader(outBuf.Bytes())
decode, err := DecodeFileSR(sr)
if err != nil {
t.Error(err)
}
// Decrypt the segment
for _, s := range decode.Segments {
err := DecryptSegment(s, decInfo, key)
if err != nil {
t.Error(err)
}
decData := make([]byte, len(df.Mdat.Data))
copy(decData, df.Mdat.Data)
if !bytes.Equal(rawInput, decData) {
t.Errorf("bytes not equal after encryption+decryption")
}
}
}

decSegBuf := bytes.Buffer{}
err = decode.Encode(&decSegBuf)
if err != nil {
t.Error(err)
}

if !bytes.Equal(rawSeg, decSegBuf.Bytes()) {
t.Errorf("segment not equal after encryption+decryption")
}
})
}
}

Expand Down
Binary file added mp4/testdata/aac_1.m4s
Binary file not shown.
Binary file added mp4/testdata/aac_init.mp4
Binary file not shown.

0 comments on commit 8463870

Please sign in to comment.