diff --git a/exir/_serialize/_program.py b/exir/_serialize/_program.py index aa5aba1fd7..2256d5fcc9 100644 --- a/exir/_serialize/_program.py +++ b/exir/_serialize/_program.py @@ -387,7 +387,12 @@ def serialize_pte_binary( constant_segment_data, constant_segment_offsets = _extract_constant_segment( program.constant_buffer, tensor_alignment=constant_tensor_alignment ) - if len(constant_segment_data) > 0: + + # If there are no constants, len(constant_segment_data) = 0. However, there may + # be non-constants, in which case len(constant_segment_offsets) = 1, containing + # the placeholder value 0. Ensure the placeholder value is put into + # program.constant_segment.offsets. + if len(constant_segment_offsets) > 0: # Update program.constant_segment with constant subsegment offset information. program.constant_segment = SubsegmentOffsets( segment_index=len(segments), offsets=constant_segment_offsets diff --git a/exir/_serialize/test/test_program.py b/exir/_serialize/test/test_program.py index c4f4df0d0b..afd8e3d282 100644 --- a/exir/_serialize/test/test_program.py +++ b/exir/_serialize/test/test_program.py @@ -583,6 +583,33 @@ def test_round_trip_with_segments(self) -> None: program2 = deserialize_pte_binary(pte_data) self.assert_programs_equal(program, program2) + def test_no_constants(self) -> None: + program = get_test_program() + # Insert placeholder for non-const tensors. + add_constant_data(program, [b""]) + + pte_data = bytes( + serialize_pte_binary( + program, + extract_delegate_segments=True, + segment_alignment=SEGMENT_ALIGNMENT, + constant_tensor_alignment=CONSTANT_TENSOR_ALIGNMENT, + ) + ) + # The input Program should not be modified. + self.assertEqual(program.segments, []) + + # Peek inside the actual flatbuffer data to see the segments. + flatbuffer_program = _json_to_program(_program_flatbuffer_to_json(pte_data)) + + # Constant buffer should be empty. + self.assertEqual(len(flatbuffer_program.constant_buffer), 0) + + # Constant segment should contain the placeholder. + self.assertEqual(flatbuffer_program.constant_segment.segment_index, 0) + self.assertEqual(len(flatbuffer_program.constant_segment.offsets), 1) + self.assertEqual(flatbuffer_program.constant_segment.offsets[0], 0) + def test_unused_inline_delegate_blobs_with_segments(self) -> None: # Create a program with some delegate data blobs. program = get_test_program() diff --git a/runtime/executor/test/program_test.cpp b/runtime/executor/test/program_test.cpp index 2cc9b4369d..80f91f1af6 100644 --- a/runtime/executor/test/program_test.cpp +++ b/runtime/executor/test/program_test.cpp @@ -379,9 +379,31 @@ TEST_F(ProgramTest, DEPRECATEDLoad) { EXPECT_EQ(program_res.error(), Error::Ok); } +TEST_F(ProgramTest, LoadConstantSegmentWithNoConstantSegment) { + Result program = + Program::load(add_loader_.get(), kDefaultVerification); + ASSERT_EQ(program.error(), Error::Ok); + + // Load constant segment data should fail. + const auto segment_info = DataLoader::SegmentInfo( + DataLoader::SegmentInfo::Type::Constant, + /*segment_index=*/0); + Result segment = + ProgramTestFriend::LoadSegment(&program.get(), segment_info); + EXPECT_NE(segment.error(), Error::Ok); + + const executorch_flatbuffer::Program* flatbuffer_program = + ProgramTestFriend::GetInternalProgram(&program.get()); + + // The constant buffer should be empty. + EXPECT_EQ(flatbuffer_program->constant_buffer()->size(), 0); + + // Expect 1 constant segment, placeholder for non-const tensors. + EXPECT_EQ(flatbuffer_program->segments()->size(), 1); +} + TEST_F(ProgramTest, LoadConstantSegment) { - // Load the serialized ModuleLinear data, with constants in the segment and no - // constants in the flatbuffer. + // Load the serialized ModuleLinear data, with constants in the segment. const char* linear_path = std::getenv("ET_MODULE_LINEAR_PATH"); Result linear_loader = FileDataLoader::from(linear_path); ASSERT_EQ(linear_loader.error(), Error::Ok); @@ -504,8 +526,8 @@ TEST_F(ProgramTest, LoadFromMutableSegment) { const executorch_flatbuffer::Program* flatbuffer_program = ProgramTestFriend::GetInternalProgram(&program.get()); - // Expect 1 segment. 1 mutable segment and no constant segment. - EXPECT_EQ(flatbuffer_program->segments()->size(), 1); + // Expect 2 segments. 1 mutable segment and 1 constant segment. + EXPECT_EQ(flatbuffer_program->segments()->size(), 2); // Expect a mutable data segment. EXPECT_EQ(flatbuffer_program->mutable_data_segments()->size(), 1);