diff --git a/core/src/execution/datafusion/shuffle_writer.rs b/core/src/execution/datafusion/shuffle_writer.rs
index 967340979..99ac885b5 100644
--- a/core/src/execution/datafusion/shuffle_writer.rs
+++ b/core/src/execution/datafusion/shuffle_writer.rs
@@ -575,6 +575,8 @@ struct ShuffleRepartitioner {
     hashes_buf: Vec<u32>,
     /// Partition ids for each row in the current batch
     partition_ids: Vec<u64>,
+    /// The configured batch size
+    batch_size: usize,
 }
 
 struct ShuffleRepartitionerMetrics {
@@ -642,17 +644,41 @@ impl ShuffleRepartitioner {
             reservation,
             hashes_buf,
             partition_ids,
+            batch_size,
         }
     }
 
+    /// Shuffles rows in input batch into corresponding partition buffer.
+    /// This function will slice input batch according to configured batch size and then
+    /// shuffle rows into corresponding partition buffer.
+    async fn insert_batch(&mut self, batch: RecordBatch) -> Result<()> {
+        let mut start = 0;
+        while start < batch.num_rows() {
+            let end = (start + self.batch_size).min(batch.num_rows());
+            let batch = batch.slice(start, end - start);
+            self.partitioning_batch(batch).await?;
+            start = end;
+        }
+        Ok(())
+    }
+
     /// Shuffles rows in input batch into corresponding partition buffer.
     /// This function first calculates hashes for rows and then takes rows in same
     /// partition as a record batch which is appended into partition buffer.
-    async fn insert_batch(&mut self, input: RecordBatch) -> Result<()> {
+    /// This should not be called directly. Use `insert_batch` instead.
+    async fn partitioning_batch(&mut self, input: RecordBatch) -> Result<()> {
         if input.num_rows() == 0 {
             // skip empty batch
             return Ok(());
         }
+
+        if input.num_rows() > self.batch_size {
+            return Err(DataFusionError::Internal(
+                "Input batch size exceeds configured batch size. Call `insert_batch` instead."
+                    .to_string(),
+            ));
+        }
+
         let _timer = self.metrics.baseline.elapsed_compute().timer();
 
         // NOTE: in shuffle writer exec, the output_rows metrics represents the
@@ -951,8 +977,7 @@ async fn external_shuffle(
     );
 
     while let Some(batch) = input.next().await {
-        let batch = batch?;
-        repartitioner.insert_batch(batch).await?;
+        repartitioner.insert_batch(batch?).await?;
     }
     repartitioner.shuffle_write().await
 }
@@ -1387,6 +1412,11 @@ impl RecordBatchStream for EmptyStream {
 #[cfg(test)]
 mod test {
     use super::*;
+    use datafusion::physical_plan::common::collect;
+    use datafusion::physical_plan::memory::MemoryExec;
+    use datafusion::prelude::SessionContext;
+    use datafusion_physical_expr::expressions::Column;
+    use tokio::runtime::Runtime;
 
     #[test]
     fn test_slot_size() {
@@ -1415,4 +1445,32 @@ mod test {
                 assert_eq!(slot_size, *expected);
             })
     }
+
+    #[test]
+    fn test_insert_larger_batch() {
+        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)]));
+        let mut b = StringBuilder::new();
+        for i in 0..10000 {
+            b.append_value(format!("{i}"));
+        }
+        let array = b.finish();
+        let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap();
+
+        let mut batches = Vec::new();
+        batches.push(batch.clone());
+
+        let partitions = &[batches];
+        let exec = ShuffleWriterExec::try_new(
+            Arc::new(MemoryExec::try_new(partitions, batch.schema(), None).unwrap()),
+            Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 16),
+            "/tmp/data.out".to_string(),
+            "/tmp/index.out".to_string(),
+        )
+        .unwrap();
+        let ctx = SessionContext::new();
+        let task_ctx = ctx.task_ctx();
+        let stream = exec.execute(0, task_ctx).unwrap();
+        let rt = Runtime::new().unwrap();
+        rt.block_on(collect(stream)).unwrap();
+    }
 }