Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/image upload ai #199

Open
wants to merge 22 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ plugins {
// }
//}

version "3.0.0-SNAPSHOT"
version "3.0.0-AI-SNAPSHOT"

group "au.org.ala"

Expand Down Expand Up @@ -118,6 +118,10 @@ dependencies {
implementation group: 'org.locationtech.spatial4j', name: 'spatial4j', version: '0.7'
implementation group: 'org.locationtech.jts', name: 'jts-core', version: '1.15.0'
implementation "com.amazonaws:aws-java-sdk-s3:$amazonAwsSdkVersion"
implementation "com.amazonaws:aws-java-sdk-rekognition:$amazonAwsSdkVersion"
implementation "com.amazonaws:aws-java-sdk-core:$amazonAwsSdkVersion"
implementation "com.amazonaws:aws-java-sdk-sagemakerruntime:$amazonAwsSdkVersion"

implementation 'org.javaswift:joss:0.10.4'
runtimeOnly 'org.postgresql:postgresql:42.5.4'

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class WebServiceController {
def elasticSearchService
def collectoryService
def authService
def imageRecognitionService

@Operation(
method = "DELETE",
Expand Down Expand Up @@ -2016,6 +2017,41 @@ class WebServiceController {
log.error("Problem storing image " + e.getMessage(), e)
renderResults([success: false, message: "Failed to store image!"], 500)
}
finally {
imageRecognitionService.cleanup(null)
}
}

@RequireApiKey(scopes=["image-service/write"])
def uploadImageAI() {

MultipartFile file
def url = params.imageUrl ?: params.url
if (!url) {
if (request.metaClass.respondsTo(request, 'getFile', String)) {
file = request.getFile('image')
}
}
if(!url && !file) {
renderResults([success: false, message: "No url parameter, therefore expected multipart request!"], HttpStatus.SC_BAD_REQUEST)
return
}
try {
Map response = imageRecognitionService.checkImageContent(file, url)
if (!response.success) {
renderResults([success: false, message: response.message], HttpStatus.SC_BAD_REQUEST)
return
} else {
uploadImage()
}
}
catch (Exception e) {
log.error("Problem storing image " + e.getMessage(), e)
renderResults([success: false, message: "Failed to store image!"], HttpStatus.SC_INTERNAL_SERVER_ERROR)
}
finally {
imageRecognitionService.cleanup(url)
}
}

@Operation(
Expand Down
65 changes: 65 additions & 0 deletions grails-app/init/au/org/ala/images/Application.groovy
Original file line number Diff line number Diff line change
@@ -1,10 +1,75 @@
package au.org.ala.images

import com.amazonaws.auth.AWSCredentialsProvider
import com.amazonaws.auth.AWSStaticCredentialsProvider
import com.amazonaws.auth.BasicAWSCredentials
import com.amazonaws.auth.BasicSessionCredentials
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain
import com.amazonaws.regions.Region
import com.amazonaws.regions.Regions
import com.amazonaws.services.rekognition.AmazonRekognitionClient
import com.amazonaws.services.s3.AmazonS3ClientBuilder
import com.amazonaws.services.sagemakerruntime.AmazonSageMakerRuntime
import grails.boot.GrailsApp
import grails.boot.config.GrailsAutoConfiguration
import org.springframework.context.annotation.Bean
import com.amazonaws.services.s3.AmazonS3Client
import com.amazonaws.services.rekognition.AmazonRekognitionClientBuilder
import com.amazonaws.services.sagemakerruntime.AmazonSageMakerRuntimeClientBuilder

class Application extends GrailsAutoConfiguration {
static void main(String[] args) {
GrailsApp.run(Application, args)
}

@Bean
AWSCredentialsProvider awsCredentialsProvider() {
def accessKey = grailsApplication.config.getProperty('aws.access-key') ?: System.getenv('AWS_ACCESS_KEY_ID')
def secretKey = grailsApplication.config.getProperty('aws.secret-key') ?: System.getenv('AWS_SECRET_ACCESS_KEY')
def sessionToken = grailsApplication.config.getProperty('aws.session-token')


if (accessKey && secretKey) {
def credentials
if (sessionToken) {
credentials = new BasicSessionCredentials(accessKey, secretKey, sessionToken)
} else {
credentials = new BasicAWSCredentials(accessKey, secretKey)
}
return new AWSStaticCredentialsProvider(credentials)
} else {
return DefaultAWSCredentialsProviderChain.instance
}
}

@Bean('awsRegion')
Region awsRegion() {
def region = grailsApplication.config.getProperty('aws.region', String, "ap-southeast-2")
return region ? Region.getRegion(Regions.fromName(region)) : Regions.currentRegion
}

@Bean
AmazonRekognitionClient rekognitionClient(AWSCredentialsProvider awsCredentialsProvider, Region awsRegion) {
return AmazonRekognitionClientBuilder.standard()
.withCredentials(awsCredentialsProvider)
.withRegion(awsRegion.toString())
.build();
}

@Bean
AmazonS3Client s3Client(AWSCredentialsProvider awsCredentialsProvider, Region awsRegion) {
return AmazonS3ClientBuilder.standard()
.withCredentials(awsCredentialsProvider)
.withRegion(awsRegion.toString())
.build();
}

@Bean
AmazonSageMakerRuntime sageMakerRuntime(AWSCredentialsProvider awsCredentialsProvider, Region awsRegion) {
return AmazonSageMakerRuntimeClientBuilder.standard()
.withRegion(awsRegion.toString())
.withCredentials(awsCredentialsProvider)
.build()
}

}
212 changes: 212 additions & 0 deletions grails-app/services/au/org/ala/images/ImageRecognitionService.groovy
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
package au.org.ala.images

import com.amazonaws.services.rekognition.model.DetectFacesRequest
import com.amazonaws.services.rekognition.model.DetectFacesResult
import com.amazonaws.services.rekognition.model.DetectModerationLabelsRequest
import com.amazonaws.services.rekognition.model.DetectModerationLabelsResult
import com.amazonaws.services.rekognition.model.FaceDetail
import com.amazonaws.services.rekognition.model.ModerationLabel
import com.amazonaws.services.s3.AmazonS3Client
import com.amazonaws.services.rekognition.AmazonRekognitionClient
import com.amazonaws.services.s3.model.CannedAccessControlList
import com.amazonaws.services.s3.model.GetObjectRequest
import com.amazonaws.services.s3.model.ObjectMetadata
import com.amazonaws.services.rekognition.model.Image
import com.amazonaws.services.sagemakerruntime.AmazonSageMakerRuntime
import com.amazonaws.services.sagemakerruntime.model.InvokeEndpointRequest
import com.amazonaws.services.sagemakerruntime.model.InvokeEndpointResult
import org.springframework.web.multipart.MultipartFile

import javax.imageio.ImageIO
import com.amazonaws.services.rekognition.model.S3Object
import org.apache.commons.io.IOUtils

import javax.imageio.stream.ImageInputStream
import java.awt.image.BufferedImage
import java.awt.image.BufferedImageOp
import java.awt.image.ColorModel
import java.awt.image.ConvolveOp
import java.awt.image.Kernel
import java.nio.ByteBuffer

class ImageRecognitionService {

AmazonRekognitionClient rekognitionClient
AmazonS3Client s3Client
def grailsApplication
AmazonSageMakerRuntime sageMakerRuntime

private addImageToS3FromUrl(String filePath, String bucket, String tempFileName) {

URL url = new URL(filePath)
BufferedImage img = ImageIO.read(url)
File file = new File("/tmp/${tempFileName}.jpg")
ImageIO.write(img, "jpg", file)
s3Client.putObject(bucket, tempFileName, file)
}

private addImageToS3FromFile(MultipartFile file, String bucket, String tempFileName) {
s3Client.putObject(bucket, tempFileName, new ByteArrayInputStream(file?.bytes),
generateMetadata(file.contentType, null, file.size))
}

def addImageToS3FromBytes(byte[] bytes, String bucket, String tempFileName, String contentType) {
s3Client.putObject(bucket, tempFileName, new ByteArrayInputStream(bytes),
generateMetadata(contentType))
}

private deleteImageS3(String bucket, String filename) {
s3Client.deleteObject(bucket, filename)
}

private deleteImageFile(String filename) {
File file = new File("/tmp/$filename")
file.delete()
}

def cleanup(filePath) {
String tempImageBucket = grailsApplication.config.getProperty('aws.tempImageBucket', String, "ala-image-service-test-uploads-production")
String tempImageName = grailsApplication.config.getProperty('aws.tempImageName', String, "temp-image")
if (filePath) {
deleteImageFile("${tempImageName}.jpg")
}
deleteImageS3(tempImageBucket, tempImageName)
}

def checkImageContent(MultipartFile file, String filePath) {

String tempImageBucket = grailsApplication.config.getProperty('aws.tempImageBucket', String, "ala-image-service-test-uploads-production")
String tempImageName = grailsApplication.config.getProperty('aws.tempImageName', String, "temp-image")

try {
if (filePath) {
addImageToS3FromUrl(filePath, tempImageBucket, tempImageName)
} else {
addImageToS3FromFile(file, tempImageBucket, tempImageName)
}

List labels = detectModLabels(tempImageBucket, tempImageName)
if (labels) {
return [success: false, message: "Detected inappropriate content: $labels"]
}

def roadkillEndpointEnabled = grailsApplication.config.getProperty('aws.sagemaker.enabled', boolean, false)

if(roadkillEndpointEnabled) {
boolean ifRoadKill = detectRoadkill(tempImageBucket, tempImageName)
if (ifRoadKill) {
return [success: false, message: "Detected road kill"]
}
}
return [success: true]
}
catch (Exception e) {
throw e
}
}

private detectModLabels(String bucket, String tempFileName) {
try {
List labels = []
def acceptingLabel = "Blood & Gore"
DetectModerationLabelsRequest request = new DetectModerationLabelsRequest()
.withImage(new Image().withS3Object(new S3Object().withBucket(bucket).withName(tempFileName)))

DetectModerationLabelsResult result = rekognitionClient.detectModerationLabels(request)

for (ModerationLabel label : result.moderationLabels) {
labels.add(label.getName())
}
if(labels.contains(acceptingLabel)) {
labels = []
}
return labels

} catch (Exception e) {
e.printStackTrace()
}
}

def detectFaces(String bucket, String tempFileName) {

DetectFacesRequest faceDetectRequest = new DetectFacesRequest()
.withImage(new Image().withS3Object(new S3Object().withBucket(bucket).withName(tempFileName)))

DetectFacesResult faceDetectResult = rekognitionClient.detectFaces(faceDetectRequest)

return faceDetectResult.faceDetails
}

private detectRoadkill(String bucket, String tempFileName) {

def object = s3Client.getObject(new GetObjectRequest(bucket, tempFileName))
InputStream objectData = object.getObjectContent()
byte[] byteArray = IOUtils.toByteArray(objectData)

InvokeEndpointRequest invokeEndpointRequest = new InvokeEndpointRequest()
invokeEndpointRequest.setContentType("application/octet-stream")
ByteBuffer buf = ByteBuffer.wrap(byteArray)
invokeEndpointRequest.setBody(buf)
invokeEndpointRequest.setEndpointName(grailsApplication.config.getProperty('aws.sagemaker.endpointName', String, ""))
invokeEndpointRequest.setAccept("application/json")

InvokeEndpointResult invokeEndpointResult = sageMakerRuntime.invokeEndpoint(invokeEndpointRequest)
objectData.close()
String response = new String(invokeEndpointResult.getBody().array())
return response == "Roadkill"
}

private ObjectMetadata generateMetadata(String contentType, String contentDisposition = null, Long length = null) {
ObjectMetadata metadata = new ObjectMetadata()
metadata.setContentType(contentType)
if (contentDisposition) {
metadata.setContentDisposition(contentDisposition)
}
if (length != null) {
metadata.setContentLength(length)
}
def acl = CannedAccessControlList.Private
metadata.setHeader('x-amz-acl', acl.toString())
metadata.cacheControl = ('private') + ',max-age=31536000'
return metadata
}

def blurFaces(String bucket, String tempFileName, List<FaceDetail> faceDetails) {

def o = s3Client.getObject(bucket, tempFileName)
ImageInputStream iin = ImageIO.createImageInputStream(o.getObjectContent())
BufferedImage image = ImageIO.read(iin)

int radius = 20
int size = radius + (image.width/50) as int
float weight = 1.0f / (size * size);
float[] matrix = new float[size * size];

for (int i = 0; i < matrix.length; i++) {
matrix[i] = weight;
}

BufferedImageOp op = new ConvolveOp(new Kernel(size, size, matrix), ConvolveOp.EDGE_NO_OP, null)

faceDetails.each { face ->
def fullWidth = image.width
def fullHeight = image.height
def margin = 10
def dest = image.getSubimage(
((face.boundingBox.left * fullWidth) as int) - margin,
((face.boundingBox.top * fullHeight) as int) - margin,
((face.boundingBox.width * fullWidth) as int) + margin * 2,
((face.boundingBox.height * fullHeight) as int) + margin * 2)
ColorModel cm = dest.getColorModel()
def src = new BufferedImage(cm, dest.copyData(dest.getRaster().createCompatibleWritableRaster()), cm.isAlphaPremultiplied(),
null).getSubimage(0,0,dest.getWidth(), dest.getHeight())
op.filter(src, dest)
}

ByteArrayOutputStream baos = new ByteArrayOutputStream()
ImageIO.write(image, "jpg", baos)
byte[] bytes = baos.toByteArray()
baos.close()
return bytes
}
}
9 changes: 9 additions & 0 deletions grails-app/services/au/org/ala/images/ImageService.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class ImageService {
def elasticSearchService
def settingService
def collectoryService
def imageRecognitionService

final static List<String> SUPPORTED_UPDATE_FIELDS = [
"audience",
Expand Down Expand Up @@ -508,6 +509,14 @@ SELECT
try {
lock.lock()

String tempImageBucket = grailsApplication.config.getProperty('aws.tempImageBucket', String, "ala-image-service-test-uploads-production")
String tempImageName = grailsApplication.config.getProperty('aws.tempImageName', String, "temp-image")
imageRecognitionService.addImageToS3FromBytes(bytes, tempImageBucket, tempImageName, contentType)
List faces = imageRecognitionService.detectFaces(tempImageBucket, tempImageName)
if (faces) {
bytes = imageRecognitionService.blurFaces(tempImageBucket, tempImageName, faces)
}

def md5Hash = bytes.encodeAsMD5()

//check for existing image using MD5 hash
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ class ImageUploadSpec extends ImagesIntegrationSpec {
HttpClient.create(baseUrl, configuration).toBlocking()
}

@Ignore
//Fail in the jenkins
void "test home page"() {
when:
Expand Down
Loading