A deep learning project that classifies tomatoes as either ripe or unripe using computer vision techniques. This project was developed to automate the sorting process in agricultural settings, inspired by a hypothetical business scenario for FreshHarvest Inc.. By leveraging Convolutional Neural Networks (CNNs), transfer learning (using pre-trained ResNet models), data augmentation, and interpretability tools like Integrated Gradients, the model aims to improve the quality control of tomato produce.
- Project Overview
- Business Problem & Motivation
- Dataset and Data Preparation
- Model Architecture and Training
- Model Interpretability
- Quantized Models & Efficiency
- Semantic Segmentation Example
- Results and Outcome
- Future Work
- Additional Resources
- Requirements
- Dataset Access
- Grayscale Usage
- Conclusion
This project presents a complete pipeline for classifying tomatoes by ripeness using deep learning. The model is developed in Google Colab and integrates several key steps:
- Data Ingestion and Preprocessing: Reading images and labels from a Kaggle dataset.
- Model Development: Building a CNN from scratch and fine-tuning a pre-trained ResNet50.
- Model Interpretability: Using Integrated Gradients from the Captum library to understand feature attributions.
- Quantization: Discussing both pre-training and post-training quantization to improve efficiency on resource-constrained devices.
- Visualization: Monitoring training progress and embedding visualizations with TensorBoard.
An example image of the project workflow is shown below:
![Workflow Diagram](https://private-user-images.githubusercontent.com/177891933/360221126-9b8a25a4-6c43-428a-86c5-be83a97f235f.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk2OTM4NjUsIm5iZiI6MTczOTY5MzU2NSwicGF0aCI6Ii8xNzc4OTE5MzMvMzYwMjIxMTI2LTliOGEyNWE0LTZjNDMtNDI4YS04NmM1LWJlODNhOTdmMjM1Zi5wbmc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjUwMjE2JTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI1MDIxNlQwODEyNDVaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT0xNzZhZmMxYmU3NTE1ZjBlYWNlMmFlNmFlNmY3ZmQ2YjIzMjdkZjE3NDE3OTcyMTU3NDNkMTlhNWM3NzIwNDNhJlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCJ9.cU1jKG0vyOmDXu9VvfXp-_CAXhdD5x-Fl8zyER0Gl6s)
Modern agricultural practices increasingly rely on automation to enhance crop quality and yield. FreshHarvest Inc.—a hypothetical agricultural technology company—is looking to replace manual tomato sorting with an automated system. Manual sorting is:
- Time-consuming
- Labor-intensive
- Prone to human error
By deploying a robust tomato classification model, the company aims to:
- Reduce labor costs.
- Increase sorting speed.
- Improve the overall quality of tomatoes sent to market.
The dataset comprises 177 images of tomatoes along with corresponding label files indicating whether a tomato is ripe or unripe. The data was sourced from Kaggle.
-
Data Wrangling:
- Removed corrupted images.
- Ensured accurate matching between images and labels.
- Verified image quality.
-
Exploratory Data Analysis (EDA):
- Visualized the distribution of classes (ripe vs. unripe) to check for imbalance.
- Displayed sample images along with their labels.
Example code to visualize an image sample:
import matplotlib.pyplot as plt
from PIL import Image
test_img = Image.open(first_sample)
plt.imshow(np.asarray(test_img))
plt.show()
Preprocessing steps are critical for model performance:
- Resizing: Images are resized to 224x224 pixels.
- Normalization: Images are normalized using ImageNet’s mean and standard deviation values.
- Data Augmentation: Techniques such as rotation, flipping, zooming, and shifting are applied to artificially increase dataset diversity.
A typical transformation pipeline is defined as:
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
The project begins with a vanilla CNN developed from scratch. This network includes:
- Two convolutional layers with ReLU activations and max pooling.
- Fully connected layers that output a binary prediction (ripe or unripe).
This simple model serves as a foundation before exploring more advanced methods like transfer learning.
The training loop involves:
- Loss Function: Binary Cross-Entropy with logits.
- Optimizer: Adam, with learning rate adjustments based on training dynamics.
- TensorBoard: Used to visualize training and validation loss metrics.
An excerpt from the training loop:
for epoch in range(num_epochs):
for i, (inputs, labels) in enumerate(train_dataloader):
optimizer.zero_grad()
output = model(inputs)
loss = loss_fn(output, labels)
loss.backward()
optimizer.step()
TensorBoard logging example:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('runs/tomato_classification')
writer.add_scalars('Training vs. Validation Loss',
{'Training': avg_loss, 'Validation': avg_vloss},
epoch * len(train_dataloader) + i)
To improve performance and speed up training, the project leverages transfer learning with ResNet50.
import torch.nn as nn
import torchvision.models as models
model = models.resnet50(pretrained=True)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 2)
for param in model.parameters():
param.requires_grad = False
for param in model.fc.parameters():
param.requires_grad = True
To interpret model decisions, we use Integrated Gradients from Captum.
from captum.attr import IntegratedGradients
integrated_gradients = IntegratedGradients(model)
attributions_ig = integrated_gradients.attribute(image, target=label, n_steps=200)
Quantization reduces model size and speeds up inference by converting 32-bit floating point weights to lower-precision formats (e.g., 8-bit integers). Two strategies are covered: • Post-training Quantization (PTQ): Quantizes a trained model. • Quantization-Aware Training (QAT): Trains the model while simulating quantization effects.
Important concepts: • Symmetric vs. Asymmetric Quantization: • Symmetric: Zero point is typically zero. • Asymmetric: Zero point is calculated based on data distribution. • Fake Quantization: Used in QAT to simulate quantization effects during training.
Code snippet for tensor quantization:
quint8_tensor = torch.quantize_per_tensor(preprocessed_img, scale=1.0, zero_point=0, dtype=torch.quint8)
dequantized_tensor = quint8_tensor.dequantize()
Beyond classification, the project demonstrates semantic segmentation using FCN ResNet50. This model labels each pixel in an image, with the output visualized as a segmentation mask.
Example code:
model = fcn_resnet50(weights=FCN_ResNet50_Weights.DEFAULT)
with torch.no_grad():
prediction = model(batch)["out"]
mask = prediction.softmax(dim=1)[0, class_to_idx["dog"]]
to_pil_image(mask).show()
This example shows how similar techniques can be applied to more complex computer vision tasks.
• Training Outcome:
• Training and validation losses steadily decreased, demonstrating effective learning.
• Accuracy: The models achieved up to 83% accuracy on the validation set.
• Although the goal was to reach above 90%, the achieved accuracy is promising given the dataset size and model complexity.
• Visualization:
• TensorBoard was used to monitor the training process, providing valuable insights into model performance.
Example loss curve visualization:
Potential improvements include: • Expanding the Dataset: A larger, more diverse dataset may improve model generalization. • Robotic Integration: Combining the model with AI-driven robots using OpenCV for real-time image processing and control. • Advanced Architectures: Experimenting with state-of-the-art models and further fine-tuning with quantization-aware training (QAT) for deployment on edge devices. • Deployment: Developing a prototype or web service to demonstrate real-time tomato sorting.
To deepen your understanding, consider exploring these resources: Books:
- Python Data Science Handbook: Essential Tools For Working With Data by Jake VanderPlas
- Hands-On Machine Learning with Scikit-Learn, Keras & TensorFlow 2 by Aurelien Geron
(The second part of this book covers advanced topics like RNNs, CNNs, and deep neural networks.)
Research Papers • Attention Is All You Need
Libraries and Frameworks: • PyTorch
TorchVision • Captum
To run this project, ensure that your environment meets the following requirements: • Python: 3.6 or higher • PyTorch: 1.7.0 or higher • TorchVision: 0.8.1 or higher • Captum: 0.4.0 or higher • Pandas: 1.1.5 or higher • Matplotlib: 3.3.2 or higher • Pillow: 7.2.0 or higher • TensorBoard: 2.3.0 or higher • Google Colab: Recommended for running the notebook
Install the required packages with:
pip install torch torchvision captum matplotlib tensorboard pandas pillow
Due to file size constraints, the dataset is not included in this repository. You can download the dataset from Kaggle: • Riped and Unriped Tomato Dataset
Converting images to grayscale reduces complexity by focusing on intensity rather than color. The grayscale conversion formula is:
[ \text{Gray} = 0.2989 \times R + 0.5870 \times G + 0.1140 \times B ]
Grayscale image examples:
This project demonstrates a comprehensive approach to automating tomato sorting in agricultural settings. By combining classical CNN architectures with transfer learning, interpretability methods, and quantization techniques, we have built a system capable of distinguishing between ripe and unripe tomatoes. Although the current model achieves 83% accuracy, further data collection, experimentation, and integration with hardware could push performance even higher.
At a larger scale, it is entirely feasible to integrate this tomato classification model with AI-driven robotics and real-time image processing using OpenCV.