Skip to content

Commit

Permalink
Merge pull request #2 from EducaTour/apta
Browse files Browse the repository at this point in the history
Prediction Feature
  • Loading branch information
swusjask authored Jun 13, 2024
2 parents a89e3be + 2bce1cd commit ac89f06
Show file tree
Hide file tree
Showing 12 changed files with 265 additions and 3 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,6 @@ GitHub.sublime-settings

# certbot configuration
certbot/*

# gcloud credential
credential*.json
2 changes: 2 additions & 0 deletions app/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ RUN pip install --upgrade pip
COPY ./requirements.txt .
RUN pip install -r requirements.txt

RUN set TF_ENABLE_ONEDNN_OPTS=0

# copy entrypoint.sh
COPY ./entrypoint.sh .
RUN sed -i 's/\r$//g' /usr/src/app/entrypoint.sh
Expand Down
19 changes: 19 additions & 0 deletions app/ce_tour/gcloud.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from urllib.parse import urljoin

from django.conf import settings
from storages.backends.gcloud import GoogleCloudStorage
from storages.utils import setting


class GoogleCloudMediaFileStorage(GoogleCloudStorage):
"""
Google file storage class which gives a media file path from MEDIA_URL not google generated one.
"""

bucket_name = setting("GS_BUCKET_NAME")

def url(self, name):
"""
Gives correct MEDIA_URL and not google generated url.
"""
return urljoin(settings.MEDIA_URL, name)
11 changes: 11 additions & 0 deletions app/ce_tour/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,17 @@
},
]

# configuration for gcloud storage
from google.oauth2 import service_account

GS_CREDENTIALS = service_account.Credentials.from_service_account_file(
os.getenv("GS_CREDENTIALS")
)

DEFAULT_FILE_STORAGE = "ce_tour.gcloud.GoogleCloudMediaFileStorage"
GS_PROJECT_ID = os.getenv("GS_PROJECT_ID")
GS_BUCKET_NAME = os.getenv("GS_BUCKET_NAME")
MEDIA_URL = "https://storage.googleapis.com/{}/".format(GS_BUCKET_NAME)

# Internationalization
# https://docs.djangoproject.com/en/5.0/topics/i18n/
Expand Down
63 changes: 63 additions & 0 deletions app/prediction/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import logging
import tempfile

from django.conf import settings
from google.cloud import storage

from .prediction import Prediction

LABELS = sorted(
[
"benteng_vredeburg",
"candi_borobudur",
"candi_prambanan",
"garuda_wisnu_kencana",
"gedung_sate",
"istana_maimun",
"jam_gadang",
"keong_mas",
"keraton_jogja",
"kota_tua",
"lawang_sewu",
"masjid_istiqlal",
"masjid_menara_kudus",
"masjid_raya_baiturrahman",
"menara_siger_lampung",
"monas",
"monumen_bandung_lautan_api",
"monumen_gong_perdamaian",
"monumen_nol_km",
"monumen_simpang_lima_gumul",
"patung_ikan_surabaya",
"patung_yesus_memberkati",
"tugu_jogja",
"tugu_khatulistiwa",
"tugu_pahlawan_surabaya",
]
)

# # Initialize the model
# Model = Prediction(
# model_path="./prediction/model_predic/best_model_2.h5",
# target_size=(224, 224),
# classes=LABELS)

try:
client = storage.Client(
credentials=settings.GS_CREDENTIALS, project=settings.GS_PROJECT_ID
)
bucket = client.get_bucket(settings.GS_BUCKET_NAME)
# Construct the source path in GCS
gcs_source_path = "model/best_model_2.h5"

# Download the model file from GCS into a temporary file
with tempfile.NamedTemporaryFile(suffix=".h5") as temp_model_file:
blob = bucket.blob(gcs_source_path)
blob.download_to_filename(temp_model_file.name)

# Use the downloaded temporary file as the model path
Model = Prediction(
model_path=temp_model_file.name, target_size=(224, 224), classes=LABELS
)
except Exception as e:
logger.error(f"Error reading from GCS: {e}")
25 changes: 25 additions & 0 deletions app/prediction/migrations/0001_initial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Generated by Django 5.0.6 on 2024-06-02 18:27

import prediction.models
from django.db import migrations, models


class Migration(migrations.Migration):

initial = True

dependencies = [
]

operations = [
migrations.CreateModel(
name='Prediction',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('result', models.CharField(max_length=200)),
('rate', models.DecimalField(decimal_places=2, max_digits=4)),
('image', models.URLField()),
('createdAt', models.DateTimeField(auto_now_add=True)),
],
),
]
7 changes: 6 additions & 1 deletion app/prediction/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from django.db import models

# Create your models here.

class Prediction(models.Model):
result = models.CharField(max_length=200)
rate = models.DecimalField(max_digits=4, decimal_places=2)
image = models.URLField()
createdAt = models.DateTimeField(auto_now_add=True)
33 changes: 33 additions & 0 deletions app/prediction/prediction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import tensorflow as tf


class Prediction:
def __init__(self, model_path: str, classes: list, target_size: tuple = (224, 224)):
self.model_path = model_path
self.classes = classes
self.target_size = target_size

try:
self.__loaded_model = tf.keras.models.load_model(self.model_path)
except Exception as e:
print(e)

self.__prepared_img = None

def model_summary(self):
return self.__loaded_model.summary()

def __preprocess_img(self, img):
img = tf.io.read_file(img)
img = tf.io.decode_image(img)
img = tf.image.resize(img, self.target_size)
return img

def predict_class(self, img):
img = self.__preprocess_img(img)
pred_prob = self.__loaded_model.predict(tf.expand_dims(img, axis=0))
pred_cat = pred_prob.argmax(axis=-1)
pred_class = self.classes[pred_cat[0]]
confidence_score = pred_prob.max() * 100
confidence_score = f"{confidence_score:.2f}"
return pred_class, confidence_score
8 changes: 8 additions & 0 deletions app/prediction/serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from prediction.models import Prediction
from rest_framework import serializers


class PredictionSerializer(serializers.ModelSerializer):
class Meta:
model = Prediction
fields = "__all__"
7 changes: 7 additions & 0 deletions app/prediction/urls.py
Original file line number Diff line number Diff line change
@@ -1 +1,8 @@
urlpatterns = []
from django.urls import path

from .views import AddCapture

urlpatterns = [
path("", AddCapture.as_view(), name="landmark"),
]
90 changes: 88 additions & 2 deletions app/prediction/views.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,89 @@
from django.shortcuts import render
import logging
import os
import uuid

# Create your views here.
from django.conf import settings
from google.cloud import storage
from rest_framework import status
from rest_framework.response import Response
from rest_framework.views import APIView

from . import Model
from .serializers import PredictionSerializer

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class AddCapture(APIView):
def post(self, request):
if "image" not in request.FILES:
return Response(
{"message": "No file provided"}, status=status.HTTP_400_BAD_REQUEST
)

save_directory = "/photo/"

if not os.path.exists(save_directory):
os.makedirs(save_directory)

uploaded_file = request.FILES["image"]

new_filename = f"{uuid.uuid4()}.{uploaded_file.name.split('.')[-1]}"

file_path = os.path.join(save_directory, new_filename)

with open(file_path, "wb") as f:
for chunk in uploaded_file.chunks():
f.write(chunk)

try:
predicted_class_name, confidence_score = Model.predict_class(file_path)

if float(confidence_score) < 65.0:
return Response(
{
"message": "Confidence score below 65%. Please provide a clearer image."
},
status=status.HTTP_400_BAD_REQUEST,
)
except Exception as e:
logger.error(f"Error in prediction: {e}")
return Response(
{"message": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR
)

try:
client = storage.Client(
credentials=settings.GS_CREDENTIALS, project=settings.GS_PROJECT_ID
)
bucket = client.get_bucket(settings.GS_BUCKET_NAME)
gcs_destination_path = f"predictions/{predicted_class_name}/{new_filename}"

blob = bucket.blob(gcs_destination_path)
blob.upload_from_filename(file_path)
image_url = blob.public_url
except Exception as e:
logger.error(f"Error uploading to GCS: {e}")
return Response(
{"message": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR
)

os.remove(file_path)

if os.path.exists(file_path) == 0:
print("File does not exist")

data = {
"result": predicted_class_name,
"rate": confidence_score,
"image": image_url,
}
photo_serializer = PredictionSerializer(data=data)

if photo_serializer.is_valid():
photo_serializer.save()
return Response(photo_serializer.data, status=status.HTTP_201_CREATED)
else:
logger.error(f"Serializer errors: {photo_serializer.errors}")
return Response(photo_serializer.errors, status=status.HTTP_400_BAD_REQUEST)
Binary file modified app/requirements.txt
Binary file not shown.

0 comments on commit ac89f06

Please sign in to comment.