mirror of https://github.com/M66B/FairEmail.git
OpenAI: embeddings
This commit is contained in:
parent
93b1187032
commit
b075824fd7
|
@ -109,6 +109,20 @@ public class OpenAI {
|
|||
}
|
||||
}
|
||||
|
||||
static double[] getEmbedding(Context context, String text, String model) throws JSONException, IOException {
|
||||
// https://platform.openai.com/docs/api-reference/embeddings
|
||||
JSONObject jrequest = new JSONObject();
|
||||
jrequest.put("input", text);
|
||||
jrequest.put("model", model == null ? "text-embedding-ada-002" : model);
|
||||
JSONObject jresponse = call(context, "POST", "v1/embeddings", jrequest);
|
||||
JSONObject jdata = jresponse.getJSONArray("data").getJSONObject(0);
|
||||
JSONArray jembedding = jdata.getJSONArray("embedding");
|
||||
double[] result = new double[jembedding.length()];
|
||||
for (int i = 0; i < jembedding.length(); i++)
|
||||
result[i] = jembedding.getDouble(i);
|
||||
return result;
|
||||
}
|
||||
|
||||
static Message[] completeChat(Context context, String model, Message[] messages, Float temperature, int n) throws JSONException, IOException {
|
||||
// https://platform.openai.com/docs/guides/chat/introduction
|
||||
// https://platform.openai.com/docs/api-reference/chat/create
|
||||
|
@ -238,4 +252,26 @@ public class OpenAI {
|
|||
return this.role + ": " + this.content;
|
||||
}
|
||||
}
|
||||
|
||||
static class Embedding {
|
||||
public static double getSimilarity(double[] v1, double[] v2) {
|
||||
if (v1.length != v2.length)
|
||||
throw new IllegalArgumentException("Invalid vector length=" + v1.length + "/" + v2.length);
|
||||
double dotProduct = dotProduct(v1, v2);
|
||||
double magV1 = magnitude(v1);
|
||||
double magV2 = magnitude(v2);
|
||||
return dotProduct / (magV1 * magV2);
|
||||
}
|
||||
|
||||
private static double dotProduct(double[] v1, double[] v2) {
|
||||
float val = 0;
|
||||
for (int i = 0; i <= v1.length - 1; i++)
|
||||
val += v1[i] * v2[i];
|
||||
return val;
|
||||
}
|
||||
|
||||
private static double magnitude(double[] v) {
|
||||
return Math.sqrt(dotProduct(v, v));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue