OpenAI: embeddings

This commit is contained in:
M66B 2023-03-10 09:07:45 +01:00
parent 93b1187032
commit b075824fd7
1 changed files with 36 additions and 0 deletions

View File

@ -109,6 +109,20 @@ public class OpenAI {
}
}
static double[] getEmbedding(Context context, String text, String model) throws JSONException, IOException {
// https://platform.openai.com/docs/api-reference/embeddings
JSONObject jrequest = new JSONObject();
jrequest.put("input", text);
jrequest.put("model", model == null ? "text-embedding-ada-002" : model);
JSONObject jresponse = call(context, "POST", "v1/embeddings", jrequest);
JSONObject jdata = jresponse.getJSONArray("data").getJSONObject(0);
JSONArray jembedding = jdata.getJSONArray("embedding");
double[] result = new double[jembedding.length()];
for (int i = 0; i < jembedding.length(); i++)
result[i] = jembedding.getDouble(i);
return result;
}
static Message[] completeChat(Context context, String model, Message[] messages, Float temperature, int n) throws JSONException, IOException {
// https://platform.openai.com/docs/guides/chat/introduction
// https://platform.openai.com/docs/api-reference/chat/create
@ -238,4 +252,26 @@ public class OpenAI {
return this.role + ": " + this.content;
}
}
static class Embedding {
public static double getSimilarity(double[] v1, double[] v2) {
if (v1.length != v2.length)
throw new IllegalArgumentException("Invalid vector length=" + v1.length + "/" + v2.length);
double dotProduct = dotProduct(v1, v2);
double magV1 = magnitude(v1);
double magV2 = magnitude(v2);
return dotProduct / (magV1 * magV2);
}
private static double dotProduct(double[] v1, double[] v2) {
float val = 0;
for (int i = 0; i <= v1.length - 1; i++)
val += v1[i] * v2[i];
return val;
}
private static double magnitude(double[] v) {
return Math.sqrt(dotProduct(v, v));
}
}
}