-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from EducaTour/apta
Prediction Feature
- Loading branch information
Showing
12 changed files
with
265 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -140,3 +140,6 @@ GitHub.sublime-settings | |
|
||
# certbot configuration | ||
certbot/* | ||
|
||
# gcloud credential | ||
credential*.json |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from urllib.parse import urljoin | ||
|
||
from django.conf import settings | ||
from storages.backends.gcloud import GoogleCloudStorage | ||
from storages.utils import setting | ||
|
||
|
||
class GoogleCloudMediaFileStorage(GoogleCloudStorage): | ||
""" | ||
Google file storage class which gives a media file path from MEDIA_URL not google generated one. | ||
""" | ||
|
||
bucket_name = setting("GS_BUCKET_NAME") | ||
|
||
def url(self, name): | ||
""" | ||
Gives correct MEDIA_URL and not google generated url. | ||
""" | ||
return urljoin(settings.MEDIA_URL, name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import logging | ||
import tempfile | ||
|
||
from django.conf import settings | ||
from google.cloud import storage | ||
|
||
from .prediction import Prediction | ||
|
||
LABELS = sorted( | ||
[ | ||
"benteng_vredeburg", | ||
"candi_borobudur", | ||
"candi_prambanan", | ||
"garuda_wisnu_kencana", | ||
"gedung_sate", | ||
"istana_maimun", | ||
"jam_gadang", | ||
"keong_mas", | ||
"keraton_jogja", | ||
"kota_tua", | ||
"lawang_sewu", | ||
"masjid_istiqlal", | ||
"masjid_menara_kudus", | ||
"masjid_raya_baiturrahman", | ||
"menara_siger_lampung", | ||
"monas", | ||
"monumen_bandung_lautan_api", | ||
"monumen_gong_perdamaian", | ||
"monumen_nol_km", | ||
"monumen_simpang_lima_gumul", | ||
"patung_ikan_surabaya", | ||
"patung_yesus_memberkati", | ||
"tugu_jogja", | ||
"tugu_khatulistiwa", | ||
"tugu_pahlawan_surabaya", | ||
] | ||
) | ||
|
||
# # Initialize the model | ||
# Model = Prediction( | ||
# model_path="./prediction/model_predic/best_model_2.h5", | ||
# target_size=(224, 224), | ||
# classes=LABELS) | ||
|
||
try: | ||
client = storage.Client( | ||
credentials=settings.GS_CREDENTIALS, project=settings.GS_PROJECT_ID | ||
) | ||
bucket = client.get_bucket(settings.GS_BUCKET_NAME) | ||
# Construct the source path in GCS | ||
gcs_source_path = "model/best_model_2.h5" | ||
|
||
# Download the model file from GCS into a temporary file | ||
with tempfile.NamedTemporaryFile(suffix=".h5") as temp_model_file: | ||
blob = bucket.blob(gcs_source_path) | ||
blob.download_to_filename(temp_model_file.name) | ||
|
||
# Use the downloaded temporary file as the model path | ||
Model = Prediction( | ||
model_path=temp_model_file.name, target_size=(224, 224), classes=LABELS | ||
) | ||
except Exception as e: | ||
logger.error(f"Error reading from GCS: {e}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# Generated by Django 5.0.6 on 2024-06-02 18:27 | ||
|
||
import prediction.models | ||
from django.db import migrations, models | ||
|
||
|
||
class Migration(migrations.Migration): | ||
|
||
initial = True | ||
|
||
dependencies = [ | ||
] | ||
|
||
operations = [ | ||
migrations.CreateModel( | ||
name='Prediction', | ||
fields=[ | ||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), | ||
('result', models.CharField(max_length=200)), | ||
('rate', models.DecimalField(decimal_places=2, max_digits=4)), | ||
('image', models.URLField()), | ||
('createdAt', models.DateTimeField(auto_now_add=True)), | ||
], | ||
), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,8 @@ | ||
from django.db import models | ||
|
||
# Create your models here. | ||
|
||
class Prediction(models.Model): | ||
result = models.CharField(max_length=200) | ||
rate = models.DecimalField(max_digits=4, decimal_places=2) | ||
image = models.URLField() | ||
createdAt = models.DateTimeField(auto_now_add=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import tensorflow as tf | ||
|
||
|
||
class Prediction: | ||
def __init__(self, model_path: str, classes: list, target_size: tuple = (224, 224)): | ||
self.model_path = model_path | ||
self.classes = classes | ||
self.target_size = target_size | ||
|
||
try: | ||
self.__loaded_model = tf.keras.models.load_model(self.model_path) | ||
except Exception as e: | ||
print(e) | ||
|
||
self.__prepared_img = None | ||
|
||
def model_summary(self): | ||
return self.__loaded_model.summary() | ||
|
||
def __preprocess_img(self, img): | ||
img = tf.io.read_file(img) | ||
img = tf.io.decode_image(img) | ||
img = tf.image.resize(img, self.target_size) | ||
return img | ||
|
||
def predict_class(self, img): | ||
img = self.__preprocess_img(img) | ||
pred_prob = self.__loaded_model.predict(tf.expand_dims(img, axis=0)) | ||
pred_cat = pred_prob.argmax(axis=-1) | ||
pred_class = self.classes[pred_cat[0]] | ||
confidence_score = pred_prob.max() * 100 | ||
confidence_score = f"{confidence_score:.2f}" | ||
return pred_class, confidence_score |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from prediction.models import Prediction | ||
from rest_framework import serializers | ||
|
||
|
||
class PredictionSerializer(serializers.ModelSerializer): | ||
class Meta: | ||
model = Prediction | ||
fields = "__all__" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,8 @@ | ||
urlpatterns = [] | ||
from django.urls import path | ||
|
||
from .views import AddCapture | ||
|
||
urlpatterns = [ | ||
path("", AddCapture.as_view(), name="landmark"), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,89 @@ | ||
from django.shortcuts import render | ||
import logging | ||
import os | ||
import uuid | ||
|
||
# Create your views here. | ||
from django.conf import settings | ||
from google.cloud import storage | ||
from rest_framework import status | ||
from rest_framework.response import Response | ||
from rest_framework.views import APIView | ||
|
||
from . import Model | ||
from .serializers import PredictionSerializer | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class AddCapture(APIView): | ||
def post(self, request): | ||
if "image" not in request.FILES: | ||
return Response( | ||
{"message": "No file provided"}, status=status.HTTP_400_BAD_REQUEST | ||
) | ||
|
||
save_directory = "/photo/" | ||
|
||
if not os.path.exists(save_directory): | ||
os.makedirs(save_directory) | ||
|
||
uploaded_file = request.FILES["image"] | ||
|
||
new_filename = f"{uuid.uuid4()}.{uploaded_file.name.split('.')[-1]}" | ||
|
||
file_path = os.path.join(save_directory, new_filename) | ||
|
||
with open(file_path, "wb") as f: | ||
for chunk in uploaded_file.chunks(): | ||
f.write(chunk) | ||
|
||
try: | ||
predicted_class_name, confidence_score = Model.predict_class(file_path) | ||
|
||
if float(confidence_score) < 65.0: | ||
return Response( | ||
{ | ||
"message": "Confidence score below 65%. Please provide a clearer image." | ||
}, | ||
status=status.HTTP_400_BAD_REQUEST, | ||
) | ||
except Exception as e: | ||
logger.error(f"Error in prediction: {e}") | ||
return Response( | ||
{"message": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR | ||
) | ||
|
||
try: | ||
client = storage.Client( | ||
credentials=settings.GS_CREDENTIALS, project=settings.GS_PROJECT_ID | ||
) | ||
bucket = client.get_bucket(settings.GS_BUCKET_NAME) | ||
gcs_destination_path = f"predictions/{predicted_class_name}/{new_filename}" | ||
|
||
blob = bucket.blob(gcs_destination_path) | ||
blob.upload_from_filename(file_path) | ||
image_url = blob.public_url | ||
except Exception as e: | ||
logger.error(f"Error uploading to GCS: {e}") | ||
return Response( | ||
{"message": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR | ||
) | ||
|
||
os.remove(file_path) | ||
|
||
if os.path.exists(file_path) == 0: | ||
print("File does not exist") | ||
|
||
data = { | ||
"result": predicted_class_name, | ||
"rate": confidence_score, | ||
"image": image_url, | ||
} | ||
photo_serializer = PredictionSerializer(data=data) | ||
|
||
if photo_serializer.is_valid(): | ||
photo_serializer.save() | ||
return Response(photo_serializer.data, status=status.HTTP_201_CREATED) | ||
else: | ||
logger.error(f"Serializer errors: {photo_serializer.errors}") | ||
return Response(photo_serializer.errors, status=status.HTTP_400_BAD_REQUEST) |
Binary file not shown.