diff --git a/mp4/crypto.go b/mp4/crypto.go index d5b3fe26..7a069c8f 100644 --- a/mp4/crypto.go +++ b/mp4/crypto.go @@ -658,19 +658,19 @@ func ExtractInitProtectData(inSeg *InitSegment) (*InitProtectData, error) { for _, c := range stsd.Children { switch box := c.(type) { case *VisualSampleEntryBox: - switch box.Type() { - case "avc1": + sinf = box.Sinf + frma := sinf.Frma + if frma.DataFormat == "avc1" { ipd.ProtFunc, err = getAVCProtFunc(box.AvcC) if err != nil { return nil, fmt.Errorf("get AVC protect func: %w", err) } - default: - return nil, fmt.Errorf("unsupported video codec descriptor %s", box.Type()) + } else { + return nil, fmt.Errorf("unsupported video codec descriptor %s", frma.DataFormat) } - sinf = box.Sinf case *AudioSampleEntryBox: - ipd.ProtFunc = getAudioProtectRanges sinf = box.Sinf + ipd.ProtFunc = getAudioProtectRanges default: continue } diff --git a/mp4/crypto_test.go b/mp4/crypto_test.go index 11fe2e83..4dfdf927 100644 --- a/mp4/crypto_test.go +++ b/mp4/crypto_test.go @@ -95,98 +95,125 @@ func TestFindAVCSubsampleRanges(t *testing.T) { } } -func TestEncryptDecryptAVC(t *testing.T) { - testInit := "testdata/init.mp4" - testFile := "testdata/1.m4s" +func TestEncryptDecrypt(t *testing.T) { + videoInit := "testdata/init.mp4" + videoSeg := "testdata/1.m4s" + audioInit := "testdata/aac_init.mp4" + audioSeg := "testdata/aac_1.m4s" keyHex := "00112233445566778899aabbccddeeff" - ivHex := "7766554433221100" + ivHex8 := "7766554433221100" + ivHex16 := "ffeeddccbbaa99887766554433221100" kidHex := "11112222333344445555666677778888" key, _ := hex.DecodeString(keyHex) - iv, _ := hex.DecodeString(ivHex) kidUUID, _ := NewUUIDFromHex(kidHex) - if len(iv) == 8 { - // Convert to 16 bytes - iv8 := iv - iv = make([]byte, 16) - copy(iv, iv8) - } - testCases := []struct { + desc string + init string + seg string scheme string + iv string }{ - {scheme: "cenc"}, - {scheme: "cbcs"}, + {desc: "video, cenc, iv8", init: videoInit, seg: videoSeg, scheme: "cenc", iv: ivHex8}, + {desc: "video, cbcs, iv8", init: videoInit, seg: videoSeg, scheme: "cbcs", iv: ivHex8}, + {desc: "video, cbcs, iv16", init: videoInit, seg: videoSeg, scheme: "cbcs", iv: ivHex16}, + {desc: "audio, cbcs, iv16", init: audioInit, seg: audioSeg, scheme: "cbcs", iv: ivHex16}, } + for _, c := range testCases { + t.Run(c.desc, func(t *testing.T) { + ifh, err := os.Open(c.init) + if err != nil { + t.Fatal(err) + } + init, err := DecodeFile(ifh) + ifh.Close() + if err != nil { + t.Fatal(err) + } + iv, err := hex.DecodeString(c.iv) + if err != nil { + t.Fatal(err) + } + ipf, err := InitProtect(init.Init, key, iv, c.scheme, kidUUID, nil) + if err != nil { + t.Fatal(err) + } + // Write init segment with encyption info + encInitBuf := bytes.Buffer{} + err = init.Encode(&encInitBuf) + if err != nil { + t.Fatal(err) + } - for _, tc := range testCases { - ifh, err := os.Open(testInit) - if err != nil { - t.Fatal(err) - } - init, err := DecodeFile(ifh) - if err != nil { - t.Fatal(err) - } - ifh.Close() - ipf, err := InitProtect(init.Init, key, iv, tc.scheme, kidUUID, nil) - if err != nil { - t.Fatal(err) - } - ifh, err = os.Open(testFile) - if err != nil { - t.Fatal(err) - } - segFile, err := DecodeFile(ifh) - if err != nil { - t.Fatal(err) - } - dInfo, err := DecryptInit(init.Init) - if err != nil { - t.Fatal(err) - } - ifh.Close() - for _, s := range segFile.Segments { - for _, f := range s.Fragments { - rawInput := make([]byte, len(f.Mdat.Data)) - copy(rawInput, f.Mdat.Data) - err := EncryptFragment(f, key, iv, ipf) - if err != nil { - t.Error(err) - } - outBuf := bytes.Buffer{} - err = f.Encode(&outBuf) - if err != nil { - t.Error(err) - } - sr := bits.NewFixedSliceReader(outBuf.Bytes()) - dff, err := DecodeFileSR(sr) - if err != nil { - t.Error(err) - } - if len(dff.Segments) != 1 { - t.Errorf("Expected 1 segment, got %d", len(dff.Segments)) - } - if len(dff.Segments[0].Fragments) != 1 { - t.Errorf("Expected 1 fragment, got %d", len(dff.Segments[0].Fragments)) - } - df := dff.Segments[0].Fragments[0] - encData := make([]byte, len(df.Mdat.Data)) - copy(encData, df.Mdat.Data) - if bytes.Equal(rawInput, encData) { - t.Errorf("bytes equal after encryption") + // Check that one can extract the protection the InitProtectData from the init segment + ipd, err := ExtractInitProtectData(init.Init) + if err != nil { + t.Fatal(err) + } + diff := deep.Equal(ipd, ipf) + if len(diff) > 0 { + t.Errorf("InitProtectData not equal after extraction") + } + + // Encrypt and write media segment + rawSeg, err := os.ReadFile(c.seg) + if err != nil { + t.Fatal(err) + } + rs := bytes.NewBuffer(rawSeg) + seg, err := DecodeFile(rs) + if err != nil { + t.Fatal(err) + } + for _, s := range seg.Segments { + for _, f := range s.Fragments { + err := EncryptFragment(f, key, iv, ipf) + if err != nil { + t.Error(err) + } } - err = DecryptFragment(df, dInfo, key) + } + outBuf := bytes.Buffer{} + err = seg.Encode(&outBuf) + if err != nil { + t.Error(err) + } + // Get decrypt info from init segment + encInitRaw := encInitBuf.Bytes() + sr := bits.NewFixedSliceReader(encInitRaw) + encInit, err := DecodeFileSR(sr) + if err != nil { + t.Error(err) + } + decInfo, err := DecryptInit(encInit.Init) + if err != nil { + t.Error(err) + } + + // Decode and decrypt the written segment + sr = bits.NewFixedSliceReader(outBuf.Bytes()) + decode, err := DecodeFileSR(sr) + if err != nil { + t.Error(err) + } + // Decrypt the segment + for _, s := range decode.Segments { + err := DecryptSegment(s, decInfo, key) if err != nil { t.Error(err) } - decData := make([]byte, len(df.Mdat.Data)) - copy(decData, df.Mdat.Data) - if !bytes.Equal(rawInput, decData) { - t.Errorf("bytes not equal after encryption+decryption") - } } - } + + decSegBuf := bytes.Buffer{} + err = decode.Encode(&decSegBuf) + if err != nil { + t.Error(err) + } + + if !bytes.Equal(rawSeg, decSegBuf.Bytes()) { + t.Errorf("segment not equal after encryption+decryption") + } + }) } } diff --git a/mp4/testdata/aac_1.m4s b/mp4/testdata/aac_1.m4s new file mode 100644 index 00000000..dea547d2 Binary files /dev/null and b/mp4/testdata/aac_1.m4s differ diff --git a/mp4/testdata/aac_init.mp4 b/mp4/testdata/aac_init.mp4 new file mode 100644 index 00000000..e332e81d Binary files /dev/null and b/mp4/testdata/aac_init.mp4 differ