Skip to content

Commit

Permalink
[WIP] predict next words
Browse files Browse the repository at this point in the history
  • Loading branch information
dongyuwei committed May 1, 2024
1 parent db5db28 commit 75a9edd
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/InputController.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#import "ConversionEngine.h"

@interface InputController : IMKInputController {
NSMutableString *_sentenceBuffer;
NSMutableString *_composedBuffer;
NSMutableString *_originalBuffer;
NSInteger _insertionIndex;
Expand All @@ -17,6 +18,9 @@
AnnotationWinController *_annotationWin;
}

- (NSMutableString *)sentenceBuffer;
- (void)setSentenceBuffer:(NSString *)string;

- (NSMutableString *)composedBuffer;
- (void)setComposedBuffer:(NSString *)string;
- (NSMutableString *)originalBuffer;
Expand Down
98 changes: 98 additions & 0 deletions src/InputController.mm
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ - (BOOL)onKeyEvent:(NSEvent *)event client:(id)sender {
if (hasBufferedText) {
[self appendToComposedBuffer:characters];
[self commitCompositionWithoutSpace:sender];
[self setSentenceBuffer: @""];
return YES;
}
}
Expand Down Expand Up @@ -228,6 +229,27 @@ - (void)commitComposition:(id)sender {
[sender insertText:text replacementRange:NSMakeRange(NSNotFound, NSNotFound)];

[self reset];

NSLog(@"Current Sentence Buffer: %@", self.sentenceBuffer);
if ([self doesSentenceBufferIncludeSpace]) {
[self fetchPredictionsForText:self.sentenceBuffer completion:^(NSDictionary *responseDict, NSArray *bertArray, NSError *error) {
if (error) {
NSLog(@"Error: %@", error.localizedDescription);
} else {
NSLog(@"BERT: %@", bertArray);
dispatch_async(dispatch_get_main_queue(), ^{
[sharedCandidates setCandidateData:bertArray];
[sharedCandidates show:kIMKLocateCandidatesBelowHint];
});
}
}];
}

}

- (BOOL)doesSentenceBufferIncludeSpace {
NSRange range = [self.sentenceBuffer rangeOfString:@" "];
return range.location != NSNotFound;
}

- (void)commitCompositionWithoutSpace:(id)sender {
Expand All @@ -242,6 +264,66 @@ - (void)commitCompositionWithoutSpace:(id)sender {
[self reset];
}

- (NSString *) fetchAPIURL {
NSUserDefaults *defaults = [NSUserDefaults standardUserDefaults];
NSString *apiURL = [defaults stringForKey:@"NEXT_WORD_PREDICTION_SERVICE_URL"];
if (apiURL) {
return apiURL;
} else {
return @"http://127.0.0.1:8080/get_end_predictions";
}
}

- (void)fetchPredictionsForText:(NSString *)text completion:(void(^)(NSDictionary *responseDict, NSArray *bertArray, NSError *error))completionHandler {
NSString *urlString = [self fetchAPIURL];
NSURL *url = [NSURL URLWithString:urlString];
NSMutableURLRequest *request = [NSMutableURLRequest requestWithURL:url];
request.HTTPMethod = @"POST";
[request setValue:@"application/json" forHTTPHeaderField:@"Content-Type"];

NSDictionary *jsonBody = @{@"input_text": text, @"top_k": @"9"};
NSError *jsonError;
NSData *jsonData = [NSJSONSerialization dataWithJSONObject:jsonBody options:0 error:&jsonError];

if (jsonError) {
completionHandler(nil, nil, jsonError);
return;
}

request.HTTPBody = jsonData;

NSURLSession *session = [NSURLSession sharedSession];
NSURLSessionDataTask *task = [session dataTaskWithRequest:request completionHandler:^(NSData *data, NSURLResponse *response, NSError *error) {
if (error) {
completionHandler(nil, nil, error);
return;
}

NSError *jsonParsingError;
NSDictionary *responseDict = [NSJSONSerialization JSONObjectWithData:data options:0 error:&jsonParsingError];

if (jsonParsingError) {
completionHandler(nil, nil, jsonParsingError);
} else {
NSArray *bertArray = nil;
NSArray *bertCNArray = nil;

// Parsing the bert string
NSString *bertString = [responseDict objectForKey:@"bert"];
if (bertString) {
bertArray = [bertString componentsSeparatedByString:@"\n"];
bertArray = [bertArray filteredArrayUsingPredicate:[NSPredicate predicateWithBlock:^BOOL(id evaluatedObject, NSDictionary *bindings) {
return [evaluatedObject length] > 0 && ![evaluatedObject isEqualToString:@"[UNK]"];
}]];
}

completionHandler(responseDict, bertArray, nil);
}
}];

[task resume];
}

- (void)reset {
[self setComposedBuffer:@""];
[self setOriginalBuffer:@""];
Expand All @@ -264,6 +346,22 @@ - (NSMutableString *)composedBuffer {

- (void)setComposedBuffer:(NSString *)string {
NSMutableString *buffer = [self composedBuffer];
if (string && string.length > 0) {
NSString * sentence = self.sentenceBuffer;
[self setSentenceBuffer: [NSString stringWithFormat:@"%@ %@", sentence, string]];
}
[buffer setString:string];
}

- (NSMutableString *)sentenceBuffer {
if (_sentenceBuffer == nil) {
_sentenceBuffer = [[NSMutableString alloc] init];
}
return _sentenceBuffer;
}

- (void)setSentenceBuffer:(NSString *)string {
NSMutableString *buffer = [self sentenceBuffer];
[buffer setString:string];
}

Expand Down

0 comments on commit 75a9edd

Please sign in to comment.