Skip to content

Commit

Permalink
Merge pull request #1 from yashkant/update-py3
Browse files Browse the repository at this point in the history
Update to Python3 and Rip-Off Lua code
  • Loading branch information
yashkant authored Jul 7, 2019
2 parents f5db5a0 + e0f4f7d commit ad95136
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 23 deletions.
9 changes: 9 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,12 @@ ques_feat.json
models/*.caffemodel
models/*.lua
models/*.prototxt
*.zip

# Pycharm
.idea/

# Installed packages
pytorch/
migrations/
!migrations/__init__.py
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ If you find this code useful, consider citing our work:

BSD

## Helpful Issues
Problems installing uwsgi: https://github.com/unbit/uwsgi/issues/1770

Problems with asgiref: https://stackoverflow.com/questions/41335478/importerror-no-module-named-asgiref-base-layer
## Credits

- Visual Chatbot Image: "[Robot-clip-art-book-covers-feJCV3-clipart](https://commons.wikimedia.org/wiki/File:Robot-clip-art-book-covers-feJCV3-clipart.png)" by [Wikimedia Commons](https://commons.wikimedia.org) is licensed under [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/deed.en)
Expand Down
4 changes: 2 additions & 2 deletions chat/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from chat.utils import log_to_terminal

def ws_connect(message):
print "User connnected via Socket"
print("User connnected via Socket")


def ws_message(message):
print "Message recieved from client side and the content is ", message.content['text']
print("Message recieved from client side and the content is ", message.content['text'])
# prefix, label = message['path'].strip('/').split('/')
socketid = message.content['text']

Expand Down
6 changes: 3 additions & 3 deletions chat/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import os
import traceback
import random
import urllib2
import urllib


def home(request, template_name="chat/index.html"):
Expand All @@ -28,7 +28,7 @@ def home(request, template_name="chat/index.html"):
job_id = request.POST.get("job_id")
history = request.POST.get("history", "")

img_path = urllib2.unquote(img_path)
img_path = urllib.parse.unquote(img_path)
abs_image_path = str(img_path)

q_tokens = word_tokenize(str(question))
Expand All @@ -38,7 +38,7 @@ def home(request, template_name="chat/index.html"):
response = svqa(str(question), str(history),
str(abs_image_path), socketid, job_id)
return JsonResponse({"success": True})
except Exception, err:
except Exception:
return JsonResponse({"success": False})

elif request.method == "GET":
Expand Down
57 changes: 57 additions & 0 deletions models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
class VisDialDummyModel:

def __init__(self, inputJson, loadPath, beamSize, beamLen, sampleWords,
temperature, gpuid, backend, proto_file, model_file,
maxThreads, encoder, decoder):
print("init visdial model")
self.decoder = decoder
self.encoder = encoder
self.maxThreads = maxThreads
self.proto_file = proto_file
self.model_file = model_file
self.backend = backend
self.gpuid = gpuid
self.sampleWords = sampleWords
self.beamLen = beamLen
self.beamSize = beamSize
self.loadPath = loadPath
self.inputJson = inputJson
self.temperature = temperature

def predict(self, img, history, question):
print("predict-visdial called!")
print("img: ", img)
print("hist: ", history)
print("ques: ", question)
dummy_ans_str = "dummy ans here to your dummy question there"
dummy_hist_str = question + " " + dummy_ans_str

result = {'answer': dummy_ans_str,
'question': question,
'history': ''.join(history) + dummy_hist_str, 'input_image': img}
print(result)
return result


class CaptioningTorchDummyModel:

def __init__(self, model_path, backend, input_sz, layer, seed, gpuid):
print("init caption model")
self.seed = seed
self.layer = layer
self.input_sz = input_sz
self.model_path = model_path
self.backend = backend
self.gpuid = gpuid
self.loadModel(model_path)

def loadModel(self, model_path):
print("load-model called!")
print("model_path: ", model_path)

def predict(self, input_image_path, input_sz1, input_sz2):
print("predict-caption called!")
dummy_caption_str = "dummy caption here"
result = {'input_image': input_image_path,
'pred_caption': dummy_caption_str}
return result
15 changes: 6 additions & 9 deletions worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,14 @@
from django.conf import settings
from chat.utils import log_to_terminal
from chat.models import Job, Dialog

from models import VisDialDummyModel
import chat.constants as constants
import PyTorch
import PyTorchHelpers
import pika
import time
import yaml
import json
import traceback

VisDialModel = PyTorchHelpers.load_lua_class(
constants.VISDIAL_LUA_PATH, 'VisDialTorchModel')
VisDialModel = VisDialDummyModel

VisDialATorchModel = VisDialModel(
constants.VISDIAL_CONFIG['input_json'],
Expand All @@ -47,6 +43,7 @@

django.db.close_old_connections()


def callback(ch, method, properties, body):
try:
body = yaml.safe_load(body)
Expand All @@ -69,12 +66,12 @@ def callback(ch, method, properties, body):
job = Job.objects.get(id=int(body['job_id']))
Dialog.objects.create(job=job, question=result['question'], answer=result['answer'].replace("<START>", "").replace("<END>", ""))
except:
print str(traceback.print_exc())
print(str(traceback.print_exc()))

django.db.close_old_connections()

except Exception, err:
print str(traceback.print_exc())
except Exception:
print(str(traceback.print_exc()))

channel.basic_consume(callback,
queue='visdial_task_queue')
Expand Down
18 changes: 9 additions & 9 deletions worker_captioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,21 @@

from django.conf import settings
from chat.utils import log_to_terminal
from chat.models import Job, Dialog
from chat.models import Job

import chat.constants as constants

import PyTorch
import PyTorchHelpers
import pika
import time
import yaml
import json
import traceback

from models import CaptioningTorchDummyModel

django.db.close_old_connections()

CaptioningModel = PyTorchHelpers.load_lua_class(
constants.CAPTIONING_LUA_PATH, 'CaptioningTorchModel')
CaptioningModel = CaptioningTorchDummyModel

CaptioningTorchModel = CaptioningModel(
constants.CAPTIONING_CONFIG['model_path'],
constants.CAPTIONING_CONFIG['backend'],
Expand Down Expand Up @@ -52,16 +51,17 @@ def callback(ch, method, properties, body):
result['input_image'] = str(result['input_image']).replace(settings.BASE_DIR, '')
log_to_terminal(body['socketid'], {"result": json.dumps(result)})
ch.basic_ack(delivery_tag=method.delivery_tag)
print('succesfull callback')

try:
Job.objects.filter(id=int(body['job_id'])).update(caption=result['pred_caption'])
except Exception as e:
print str(traceback.print_exc())
print(str(traceback.print_exc()))

django.db.close_old_connections()

except Exception, err:
print str(traceback.print_exc())
except Exception:
print(str(traceback.print_exc()))

channel.basic_consume(callback,
queue='visdial_captioning_task_queue')
Expand Down

0 comments on commit ad95136

Please sign in to comment.