Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add files via upload #3

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions yolo_cnn_lstm/test_cnn_lstm.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Using CNN-LSTM to classify violent clips"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import cv2\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"\n",
"# The video should be converted to frames first. During training, three frames are extracted per second,\n",
"# each frame is resized to 64x64 pixels, and the LSTM layer takes 10 consecutive frames as input.\n",
"def preprocess_video(video_path, frame_count=10, frame_size=(64, 64)):\n",
" cap = cv2.VideoCapture(video_path)\n",
" frames = []\n",
" while len(frames) < frame_count:\n",
" ret, frame = cap.read()\n",
" if not ret:\n",
" break\n",
" frame = cv2.resize(frame, frame_size)\n",
" frames.append(frame)\n",
" cap.release()\n",
" \n",
" if len(frames) == 0:\n",
" return np.zeros((frame_count, frame_size[0], frame_size[1], 3))\n",
" elif len(frames) < frame_count:\n",
" frames.extend([np.zeros_like(frames[0])]*(frame_count - len(frames)))\n",
" return np.array(frames)\n",
"\n",
"def load_test_data(folder_path, frame_count=10, frame_size=(64, 64)):\n",
" test_data = []\n",
" video_files = [f for f in os.listdir(folder_path) if f.endswith('.mp4') or f.endswith('.avi')]\n",
" for video_file in video_files:\n",
" video_path = os.path.join(folder_path, video_file)\n",
" frames = preprocess_video(video_path, frame_count, frame_size)\n",
" test_data.append(frames)\n",
" return np.array(test_data), video_files\n",
"\n",
"# Folder containing your test videos\n",
"test_video_folder = \"D:/Yolo/test/result\"\n",
"test_data, test_video_files = load_test_data(test_video_folder)\n",
"test_data = test_data / 255.0\n",
"\n",
"# This model can achieve over 90% accuracy on the training set and 77% on the validation set.\n",
"model_path = \"D:/CNN-LSTM/models/violence_detection_model_conv_64_lstm_64.h5\"\n",
"model = tf.keras.models.load_model(model_path, compile=False)\n",
"\n",
"# Recompile the model with a compatible optimizer and loss function\n",
"model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])\n",
"\n",
"predictions = model.predict(test_data)\n",
"predicted_classes = np.argmax(predictions, axis=1)\n",
"\n",
"class_names = [\"high-level violence\", \"low-level violence\", \"non-violence\"]\n",
"for video_file, predicted_class in zip(test_video_files, predicted_classes):\n",
" print(f\"Video: {video_file}, Predicted Class: {class_names[predicted_class]}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "test",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.18"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
390 changes: 390 additions & 0 deletions yolo_cnn_lstm/train_cnn_lstm.ipynb

Large diffs are not rendered by default.

160 changes: 160 additions & 0 deletions yolo_cnn_lstm/yolo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 1. Training the Yolo model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import ultralytics\n",
"ultralytics.checks()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from ultralytics import YOLO\n",
"\n",
"from IPython.display import display, Image"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!yolo task=detect mode=train model=yolov8s.pt data=Violence.yaml epochs=50 imgsz=416 plots=True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 2. Using the YOLO model to filter violent clips"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import cv2\n",
"import os\n",
"from ultralytics import YOLO\n",
"\n",
"# Load the trained YOLOv8 model\n",
"model = YOLO('D:\\\\Yolo\\\\runs\\\\detect\\\\train\\\\weights\\\\best.pt')\n",
"\n",
"# Input video file path\n",
"input_video_path = 'D:\\\\Yolo\\\\test\\\\videos'\n",
"output_clip_path = 'D:\\\\Yolo\\\\test\\\\result'\n",
"\n",
"# Create the output folder (if it doesn't exist)\n",
"os.makedirs(output_clip_path, exist_ok=True)\n",
"\n",
"# Get all video files\n",
"video_files = [f for f in os.listdir(input_video_path) if f.endswith('.mp4')]\n",
"\n",
"for video_file in video_files:\n",
" cap = cv2.VideoCapture(os.path.join(input_video_path, video_file))\n",
" fps = cap.get(cv2.CAP_PROP_FPS)\n",
" frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))\n",
" frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))\n",
"\n",
" frame_count = 0\n",
" clip_number = 1\n",
" clip_frames = []\n",
" start_time = None\n",
" last_detection_time = None\n",
"\n",
" while cap.isOpened():\n",
" ret, frame = cap.read()\n",
" if not ret:\n",
" break\n",
"\n",
" # Perform object detection on the current frame\n",
" results = model(frame)\n",
"\n",
" # Extract detection results\n",
" detections = results[0].boxes.data.cpu().numpy()\n",
" current_time = frame_count / fps\n",
" if len(detections) > 0:\n",
" # If an object is detected, save the current frame\n",
" clip_frames.append(frame)\n",
" if start_time is None:\n",
" start_time = current_time\n",
" last_detection_time = current_time\n",
" else:\n",
" # If no object is detected but the time interval is less than 2 seconds, save the current frame\n",
" if last_detection_time is not None and current_time - last_detection_time < 2:\n",
" clip_frames.append(frame)\n",
" else:\n",
" # If the time interval is greater than 2 seconds and there are accumulated frames, save these frames as a video clip\n",
" if len(clip_frames) > 0:\n",
" end_time = current_time if last_detection_time is None else last_detection_time\n",
" start_time_formatted = f'{int(start_time // 60):02d}_{int(start_time % 60):02d}'\n",
" end_time_formatted = f'{int(end_time // 60):02d}_{int(end_time % 60):02d}'\n",
" clip_output_path = os.path.join(output_clip_path, f'{os.path.splitext(video_file)[0]}_clip_{clip_number}_{start_time_formatted}_to_{end_time_formatted}.mp4')\n",
" out = cv2.VideoWriter(clip_output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))\n",
" for clip_frame in clip_frames:\n",
" out.write(clip_frame)\n",
" out.release()\n",
" clip_frames = []\n",
" clip_number += 1\n",
" start_time = None\n",
" last_detection_time = None\n",
"\n",
" frame_count += 1\n",
"\n",
" # Process the remaining frames at the end of the video\n",
" if len(clip_frames) > 0:\n",
" end_time = current_time if last_detection_time is None else last_detection_time\n",
" start_time_formatted = f'{int(start_time // 60):02d}_{int(start_time % 60):02d}'\n",
" end_time_formatted = f'{int(end_time // 60):02d}_{int(end_time % 60):02d}'\n",
" clip_output_path = os.path.join(output_clip_path, f'{os.path.splitext(video_file)[0]}_clip_{clip_number}_{start_time_formatted}_to_{end_time_formatted}.mp4')\n",
" out = cv2.VideoWriter(clip_output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))\n",
" for clip_frame in clip_frames:\n",
" out.write(clip_frame)\n",
" out.release()\n",
"\n",
" cap.release()\n",
"\n",
"cv2.destroyAllWindows()\n",
"\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "VD",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.19"
}
},
"nbformat": 4,
"nbformat_minor": 2
}