forked from TheoKanning/openai-java
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathOpenAiService.java
197 lines (161 loc) · 7.14 KB
/
OpenAiService.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
package com.theokanning.openai;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.PropertyNamingStrategy;
import com.theokanning.openai.answer.AnswerRequest;
import com.theokanning.openai.answer.AnswerResult;
import com.theokanning.openai.classification.ClassificationRequest;
import com.theokanning.openai.classification.ClassificationResult;
import com.theokanning.openai.completion.CompletionRequest;
import com.theokanning.openai.completion.CompletionResult;
import com.theokanning.openai.edit.EditRequest;
import com.theokanning.openai.edit.EditResult;
import com.theokanning.openai.embedding.EmbeddingRequest;
import com.theokanning.openai.embedding.EmbeddingResult;
import com.theokanning.openai.engine.Engine;
import com.theokanning.openai.file.File;
import com.theokanning.openai.finetune.FineTuneEvent;
import com.theokanning.openai.finetune.FineTuneRequest;
import com.theokanning.openai.finetune.FineTuneResult;
import com.theokanning.openai.model.Model;
import com.theokanning.openai.moderation.ModerationRequest;
import com.theokanning.openai.moderation.ModerationResult;
import com.theokanning.openai.search.SearchRequest;
import com.theokanning.openai.search.SearchResult;
import okhttp3.*;
import retrofit2.Retrofit;
import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory;
import retrofit2.converter.jackson.JacksonConverterFactory;
import java.util.List;
import java.util.concurrent.TimeUnit;
public class OpenAiService {
OpenAiApi api;
/**
* Creates a new OpenAiService that wraps OpenAiApi
* @param token OpenAi token string "sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
*/
public OpenAiService(String token) {
this(token, 10);
}
/**
* Creates a new OpenAiService that wraps OpenAiApi
* @param token OpenAi token string "sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
* @param timeout http read timeout in seconds, 0 means no timeout
*/
public OpenAiService(String token, int timeout) {
ObjectMapper mapper = new ObjectMapper();
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
mapper.setPropertyNamingStrategy(PropertyNamingStrategy.SNAKE_CASE);
OkHttpClient client = new OkHttpClient.Builder()
.addInterceptor(new AuthenticationInterceptor(token))
.connectionPool(new ConnectionPool(5, 1, TimeUnit.SECONDS))
.readTimeout(timeout, TimeUnit.SECONDS)
.build();
Retrofit retrofit = new Retrofit.Builder()
.baseUrl("https://api.openai.com/")
.client(client)
.addConverterFactory(JacksonConverterFactory.create(mapper))
.addCallAdapterFactory(RxJava2CallAdapterFactory.create())
.build();
this.api = retrofit.create(OpenAiApi.class);
}
/**
* Creates a new OpenAiService that wraps OpenAiApi
* @param api OpenAiApi instance to use for all methods
*/
public OpenAiService(OpenAiApi api) {
this.api = api;
}
public List<Model> listModels() {
return api.listModels().blockingGet().data;
}
public Model getModel(String modelId) {
return api.getModel(modelId).blockingGet();
}
public CompletionResult createCompletion(CompletionRequest request) {
return api.createCompletion(request).blockingGet();
}
/** Use {@link OpenAiService#createCompletion(CompletionRequest)} and {@link CompletionRequest#model}instead */
@Deprecated
public CompletionResult createCompletion(String engineId, CompletionRequest request) {
return api.createCompletion(engineId, request).blockingGet();
}
public EditResult createEdit(EditRequest request) {
return api.createEdit(request).blockingGet();
}
/** Use {@link OpenAiService#createEdit(EditRequest)} and {@link EditRequest#model}instead */
@Deprecated
public EditResult createEdit(String engineId, EditRequest request) {
return api.createEdit(engineId, request).blockingGet();
}
public EmbeddingResult createEmbeddings(EmbeddingRequest request) {
return api.createEmbeddings(request).blockingGet();
}
/** Use {@link OpenAiService#createEmbeddings(EmbeddingRequest)} and {@link EmbeddingRequest#model}instead */
@Deprecated
public EmbeddingResult createEmbeddings(String engineId, EmbeddingRequest request) {
return api.createEmbeddings(engineId, request).blockingGet();
}
public List<File> listFiles() {
return api.listFiles().blockingGet().data;
}
public File uploadFile(String purpose, String filepath) {
java.io.File file = new java.io.File(filepath);
RequestBody purposeBody = RequestBody.create(okhttp3.MultipartBody.FORM, purpose);
RequestBody fileBody = RequestBody.create(MediaType.parse("text"), file);
MultipartBody.Part body = MultipartBody.Part.createFormData("file", filepath, fileBody);
return api.uploadFile(purposeBody, body).blockingGet();
}
public DeleteResult deleteFile(String fileId) {
return api.deleteFile(fileId).blockingGet();
}
public File retrieveFile(String fileId) {
return api.retrieveFile(fileId).blockingGet();
}
public FineTuneResult createFineTune(FineTuneRequest request) {
return api.createFineTune(request).blockingGet();
}
public CompletionResult createFineTuneCompletion(CompletionRequest request) {
return api.createFineTuneCompletion(request).blockingGet();
}
public List<FineTuneResult> listFineTunes() {
return api.listFineTunes().blockingGet().data;
}
public FineTuneResult retrieveFineTune(String fineTuneId) {
return api.retrieveFineTune(fineTuneId).blockingGet();
}
public FineTuneResult cancelFineTune(String fineTuneId) {
return api.cancelFineTune(fineTuneId).blockingGet();
}
public List<FineTuneEvent> listFineTuneEvents(String fineTuneId) {
return api.listFineTuneEvents(fineTuneId).blockingGet().data;
}
public DeleteResult deleteFineTune(String fineTuneId) {
return api.deleteFineTune(fineTuneId).blockingGet();
}
public ModerationResult createModeration(ModerationRequest request) {
return api.createModeration(request).blockingGet();
}
@Deprecated
public List<Engine> getEngines() {
return api.getEngines().blockingGet().data;
}
@Deprecated
public Engine getEngine(String engineId) {
return api.getEngine(engineId).blockingGet();
}
@Deprecated
public AnswerResult createAnswer(AnswerRequest request) {
return api.createAnswer(request).blockingGet();
}
@Deprecated
public ClassificationResult createClassification(ClassificationRequest request) {
return api.createClassification(request).blockingGet();
}
@Deprecated
public List<SearchResult> search(String engineId, SearchRequest request) {
return api.search(engineId, request).blockingGet().data;
}
}