Skip to content

Commit

Permalink
refactor: reduce cloning usage and fix vision queue issue
Browse files Browse the repository at this point in the history
  • Loading branch information
louis030195 committed Feb 1, 2025
1 parent 01bbb6a commit 14d77ea
Show file tree
Hide file tree
Showing 18 changed files with 150 additions and 119 deletions.
2 changes: 1 addition & 1 deletion screenpipe-audio/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ pub async fn record_and_transcribe(

pub async fn start_realtime_recording(
audio_stream: Arc<AudioStream>,
languages: Vec<Language>,
languages: Arc<Vec<Language>>,
is_running: Arc<AtomicBool>,
deepgram_api_key: Option<String>,
) -> Result<()> {
Expand Down
4 changes: 2 additions & 2 deletions screenpipe-audio/src/deepgram/realtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use tokio::sync::oneshot;

pub async fn stream_transcription_deepgram(
stream: Arc<AudioStream>,
languages: Vec<Language>,
languages: Arc<Vec<Language>>,
is_running: Arc<AtomicBool>,
deepgram_api_key: Option<String>,
) -> Result<()> {
Expand All @@ -44,7 +44,7 @@ pub async fn start_deepgram_stream(
device: Arc<AudioDevice>,
sample_rate: u32,
is_running: Arc<AtomicBool>,
_languages: Vec<Language>,
_languages: Arc<Vec<Language>>,
deepgram_api_key: Option<String>,
) -> Result<()> {
let api_key = deepgram_api_key.unwrap_or(CUSTOM_DEEPGRAM_API_TOKEN.to_string());
Expand Down
2 changes: 1 addition & 1 deletion screenpipe-audio/src/realtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::sync::{atomic::AtomicBool, Arc};

pub async fn realtime_stt(
stream: Arc<AudioStream>,
languages: Vec<Language>,
languages: Arc<Vec<Language>>,
is_running: Arc<AtomicBool>,
deepgram_api_key: Option<String>,
) -> Result<()> {
Expand Down
2 changes: 1 addition & 1 deletion screenpipe-audio/tests/realtime_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ async fn test_realtime_transcription() {
Arc::new(device),
sample_rate,
is_running_clone,
vec![],
Arc::new([].to_vec()),
Some(deepgram_api_key),
)
.await;
Expand Down
3 changes: 2 additions & 1 deletion screenpipe-integrations/src/unstructured_ocr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ use std::io::Cursor;
use std::io::Read;
use std::io::Write;
use std::path::PathBuf;
use std::sync::Arc;
use tempfile::NamedTempFile;
use tokio::time::{timeout, Duration};

pub async fn perform_ocr_cloud(
image: &DynamicImage,
languages: Vec<Language>,
languages: Arc<Vec<Language>>,
) -> Result<(String, String, Option<f64>)> {
let api_key = match env::var("UNSTRUCTURED_API_KEY") {
Ok(key) => key,
Expand Down
4 changes: 2 additions & 2 deletions screenpipe-server/src/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,12 @@ pub async fn handle_index_command(
// Do OCR processing directly
let (text, _, confidence): (String, String, Option<f64>) = match engine.clone() {
#[cfg(target_os = "macos")]
OcrEngine::AppleNative => perform_ocr_apple(frame, &[]),
OcrEngine::AppleNative => perform_ocr_apple(frame, Arc::new([].to_vec())),
#[cfg(target_os = "windows")]
OcrEngine::WindowsNative => perform_ocr_windows(&frame).await.unwrap(),
_ => {
#[cfg(not(any(target_os = "macos", target_os = "windows")))]
return perform_ocr_tesseract(&frame, vec![]);
return perform_ocr_tesseract(&frame, Arc::new([].to_vec()));

Check failure on line 191 in screenpipe-server/src/add.rs

View workflow job for this annotation

GitHub Actions / test-linux

mismatched types

Check failure on line 191 in screenpipe-server/src/add.rs

View workflow job for this annotation

GitHub Actions / test-ubuntu

mismatched types

panic!("unsupported ocr engine");

Check warning on line 193 in screenpipe-server/src/add.rs

View workflow job for this annotation

GitHub Actions / test-linux

unreachable statement

Check warning on line 193 in screenpipe-server/src/add.rs

View workflow job for this annotation

GitHub Actions / test-ubuntu

unreachable statement
}
Expand Down
6 changes: 3 additions & 3 deletions screenpipe-server/src/bin/screenpipe-server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ async fn main() -> anyhow::Result<()> {
video_chunk_duration: Duration::from_secs(cli.video_chunk_duration),
use_pii_removal: cli.use_pii_removal,
capture_unfocused_windows: cli.capture_unfocused_windows,
languages: languages.clone(),
languages: Arc::new(languages.clone()),
};

let audio_config = AudioConfig {
Expand All @@ -697,8 +697,8 @@ async fn main() -> anyhow::Result<()> {
let vision_config = VisionConfig {
disabled: cli.disable_vision,
ocr_engine: Arc::new(cli.ocr_engine.clone().into()),
ignored_windows: cli.ignored_windows.clone(),
include_windows: cli.included_windows.clone(),
ignored_windows: Arc::new(cli.ignored_windows.clone()),
include_windows: Arc::new(cli.included_windows.clone()),
};

let recording_future = start_continuous_recording(
Expand Down
160 changes: 88 additions & 72 deletions screenpipe-server/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use screenpipe_core::pii_removal::remove_pii;
use screenpipe_core::{AudioDevice, DeviceManager, DeviceType, Language};
use screenpipe_events::{poll_meetings_events, send_event};
use screenpipe_vision::core::WindowOcr;
use screenpipe_vision::OcrEngine;
use screenpipe_vision::{CaptureResult, OcrEngine};
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
Expand All @@ -29,7 +29,7 @@ pub struct RecordingConfig {
pub video_chunk_duration: Duration,
pub use_pii_removal: bool,
pub capture_unfocused_windows: bool,
pub languages: Vec<Language>,
pub languages: Arc<Vec<Language>>,
}

#[derive(Clone)]
Expand All @@ -49,8 +49,8 @@ pub struct AudioConfig {
pub struct VisionConfig {
pub disabled: bool,
pub ocr_engine: Arc<OcrEngine>,
pub ignored_windows: Vec<String>,
pub include_windows: Vec<String>,
pub ignored_windows: Arc<Vec<String>>,
pub include_windows: Arc<Vec<String>>,
}

#[derive(Clone)]
Expand All @@ -61,10 +61,10 @@ pub struct VideoRecordingConfig {
pub ocr_engine: Arc<OcrEngine>,
pub monitor_id: u32,
pub use_pii_removal: bool,
pub ignored_windows: Vec<String>,
pub include_windows: Vec<String>,
pub ignored_windows: Arc<Vec<String>>,
pub include_windows: Arc<Vec<String>>,
pub video_chunk_duration: Duration,
pub languages: Vec<Language>,
pub languages: Arc<Vec<Language>>,
pub capture_unfocused_windows: bool,
}

Expand Down Expand Up @@ -165,10 +165,10 @@ async fn record_vision(
db: Arc<DatabaseManager>,
output_path: Arc<String>,
fps: f64,
languages: Vec<Language>,
languages: Arc<Vec<Language>>,
capture_unfocused_windows: bool,
ignored_windows: Vec<String>,
include_windows: Vec<String>,
ignored_windows: Arc<Vec<String>>,
include_windows: Arc<Vec<String>>,
video_chunk_duration: Duration,
use_pii_removal: bool,
) -> Result<()> {
Expand Down Expand Up @@ -199,11 +199,11 @@ async fn record_vision(
let db_manager_video = Arc::clone(&db);
let output_path_video = Arc::clone(&output_path);
let ocr_engine = Arc::clone(&ocr_engine);
let ignored_windows_video = ignored_windows.to_vec();
let include_windows_video = include_windows.to_vec();

let languages = languages.clone();
let device_manager_vision_clone = device_manager.clone();
let ignored_windows = ignored_windows.clone();
let include_windows = include_windows.clone();
let handle = tokio::spawn(async move {
let config = VideoRecordingConfig {
db: db_manager_video,
Expand All @@ -212,8 +212,8 @@ async fn record_vision(
ocr_engine,
monitor_id,
use_pii_removal,
ignored_windows: ignored_windows_video,
include_windows: include_windows_video,
ignored_windows,
include_windows,
video_chunk_duration,
languages,
capture_unfocused_windows,
Expand Down Expand Up @@ -306,65 +306,81 @@ async fn record_video(
_ => continue, // Ignore other devices or monitors
}
}
_ = tokio::time::sleep(Duration::from_secs_f64(1.0 / config.fps)) => {
if let Some(frame) = video_capture.ocr_frame_queue.pop() {
for window_result in &frame.window_ocr_results {
match config.db.insert_frame(&device_name, None).await {
Ok(frame_id) => {
let text_json =
serde_json::to_string(&window_result.text_json).unwrap_or_default();

let text = if config.use_pii_removal {
&remove_pii(&window_result.text)
} else {
&window_result.text
};

let _ = send_event(
"ocr_result",
WindowOcr {
image: Some(frame.image.clone()),
text: text.clone(),
text_json: window_result.text_json.clone(),
app_name: window_result.app_name.clone(),
window_name: window_result.window_name.clone(),
focused: window_result.focused,
confidence: window_result.confidence,
timestamp: frame.timestamp,
},
);

if let Err(e) = config.db
.insert_ocr_text(
frame_id,
text,
&text_json,
&window_result.app_name,
&window_result.window_name,
Arc::clone(&config.ocr_engine),
window_result.focused, // Add this line
)
.await
{
error!(
"Failed to insert OCR text: {}, skipping window {} of frame {}",
e, window_result.window_name, frame_id
);
continue;
}
}
// we should process faster than the fps we use to do OCR
_ = tokio::time::sleep(Duration::from_secs_f64(1.0 / (config.fps * 2.0))) => {
let frame = match video_capture.ocr_frame_queue.pop() {
Some(f) => f,
None => continue,
};

process_ocr_frame(
frame,
&config.db,
&device_name,
config.use_pii_removal,
config.ocr_engine.clone(),
).await;
}
}
}
}

Err(e) => {
warn!("Failed to insert frame: {}", e);
tokio::time::sleep(Duration::from_millis(100)).await;
continue;
}
}
}
}
}
async fn process_ocr_frame(
frame: Arc<CaptureResult>,
db: &DatabaseManager,
device_name: &str,
use_pii_removal: bool,
ocr_engine: Arc<OcrEngine>,
) {
for window_result in &frame.window_ocr_results {
let frame_id = match db.insert_frame(device_name, None).await {
Ok(id) => id,
Err(e) => {
warn!("Failed to insert frame: {}", e);
tokio::time::sleep(Duration::from_millis(100)).await;
continue;
}
};

let text_json = serde_json::to_string(&window_result.text_json).unwrap_or_default();

let text = if use_pii_removal {
remove_pii(&window_result.text)
} else {
window_result.text.clone()
};

let _ = send_event(
"ocr_result",
WindowOcr {
image: Some(frame.image.clone()),
text: text.clone(),
text_json: window_result.text_json.clone(),
app_name: window_result.app_name.clone(),
window_name: window_result.window_name.clone(),
focused: window_result.focused,
confidence: window_result.confidence,
timestamp: frame.timestamp,
},
);

if let Err(e) = db
.insert_ocr_text(
frame_id,
&text,
&text_json,
&window_result.app_name,
&window_result.window_name,
ocr_engine.clone(),
window_result.focused,
)
.await
{
error!(
"Failed to insert OCR text: {}, skipping window {} of frame {}",
e, window_result.window_name, frame_id
);
}
tokio::time::sleep(Duration::from_secs_f64(1.0 / config.fps)).await;
}
}

Expand All @@ -377,7 +393,7 @@ async fn record_audio(
audio_transcription_engine: Arc<AudioTranscriptionEngine>,
realtime_audio_enabled: bool,
realtime_audio_devices: Vec<Arc<AudioDevice>>,
languages: Vec<Language>,
languages: Arc<Vec<Language>>,
deepgram_api_key: Option<String>,
) -> Result<()> {
let mut handles: HashMap<String, JoinHandle<()>> = HashMap::new();
Expand Down
19 changes: 12 additions & 7 deletions screenpipe-server/src/video.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,18 @@ impl VideoCapture {
new_chunk_callback: impl Fn(&str) + Send + Sync + 'static,
ocr_engine: Arc<OcrEngine>,
monitor_id: u32,
ignore_list: Vec<String>,
include_list: Vec<String>,
languages: Vec<Language>,
ignore_list: Arc<Vec<String>>,
include_list: Arc<Vec<String>>,
languages: Arc<Vec<Language>>,
capture_unfocused_windows: bool,
) -> Self {
let fps = if fps.is_finite() && fps > 0.0 {
fps
} else {
warn!("[monitor_id: {}] Invalid FPS value: {}. Using default of 1.0", monitor_id, fps);
warn!(
"[monitor_id: {}] Invalid FPS value: {}. Using default of 1.0",
monitor_id, fps
);
1.0
};
let interval = Duration::from_secs_f64(1.0 / fps);
Expand All @@ -66,8 +69,9 @@ impl VideoCapture {
let shutdown_rx_capture = shutdown_rx.clone();
let shutdown_rx_queue = shutdown_rx.clone();
let shutdown_rx_video = shutdown_rx.clone();

let languages_clone = languages.clone();
let result_sender_inner = result_sender.clone();

let capture_handle = tokio::spawn(async move {
let mut rx = shutdown_rx_capture;
loop {
Expand All @@ -80,15 +84,16 @@ impl VideoCapture {
}
let result_sender = result_sender_inner.clone();
let window_filters_clone = Arc::clone(&window_filters_clone);
let languages_clone = languages_clone.clone();

tokio::select! {
_ = continuous_capture(
result_sender,
interval,
(*ocr_engine).clone(),
ocr_engine.clone(),
monitor_id,
window_filters_clone,
languages.clone(),
languages_clone.clone(),
capture_unfocused_windows,
rx.clone(),
) => {
Expand Down
7 changes: 6 additions & 1 deletion screenpipe-server/tests/video_utils_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ async fn test_extract_frames() -> Result<()> {
#[cfg(target_os = "macos")]
#[tokio::test]
async fn test_extract_frames_and_ocr() -> Result<()> {
use std::sync::Arc;

use screenpipe_vision::perform_ocr_apple;
setup_test_env().await?;
let video_path = create_test_video().await?;
Expand Down Expand Up @@ -133,7 +135,10 @@ async fn test_extract_frames_and_ocr() -> Result<()> {
};

// perform ocr using apple native (macos only)
let (text, _, confidence) = perform_ocr_apple(&captured_window.image, &[Language::English]);
let (text, _, confidence) = perform_ocr_apple(
&captured_window.image,
Arc::new([Language::English].to_vec()),
);

println!("ocr confidence: {}", confidence.unwrap_or(0.0));
println!("extracted text: {}", text);
Expand Down
Loading

0 comments on commit 14d77ea

Please sign in to comment.