Gemini: refactoring

This commit is contained in:
M66B 2024-05-09 10:20:16 +02:00
parent 861716aad0
commit 8cab3786e9
3 changed files with 77 additions and 33 deletions

View File

@ -2769,7 +2769,7 @@ public class FragmentCompose extends FragmentBase {
args.putString("body", body);
args.putBoolean("selection", selection);
new SimpleTask<String[]>() {
new SimpleTask<Gemini.Message[]>() {
@Override
protected void onPreExecute(Bundle args) {
chatting = true;
@ -2783,7 +2783,7 @@ public class FragmentCompose extends FragmentBase {
}
@Override
protected String[] onExecute(Context context, Bundle args) throws Throwable {
protected Gemini.Message[] onExecute(Context context, Bundle args) throws Throwable {
long id = args.getLong("id");
String body = args.getString("body");
boolean selection = args.getBoolean("selection");
@ -2792,15 +2792,17 @@ public class FragmentCompose extends FragmentBase {
String model = prefs.getString("gemini_model", "gemini-pro");
float temperature = prefs.getFloat("gemini_temperature", 0.5f);
return Gemini.generate(context, model, new String[]{Gemini.truncateParagraphs(body)}, temperature);
Gemini.Message message = new Gemini.Message(Gemini.USER, new String[]{Gemini.truncateParagraphs(body)});
return Gemini.generate(context, model, new Gemini.Message[]{message}, temperature, 1);
}
@Override
protected void onExecuted(Bundle args, String[] result) {
if (result == null || result.length == 0)
protected void onExecuted(Bundle args, Gemini.Message[] messages) {
if (messages == null || messages.length == 0)
return;
String text = result[0]
String text = TextUtils.join("\n", messages[0].getContent())
.replaceAll("^\\n+", "").replaceAll("\\n+$", "");
Editable edit = etBody.getText();

View File

@ -109,8 +109,12 @@ public class FragmentDialogSummarize extends FragmentDialogBase {
float temperature = prefs.getFloat("gemini_temperature", 0.5f);
String prompt = prefs.getString("gemini_summarize", Gemini.SUMMARY_PROMPT);
String[] result = Gemini.generate(context, model, new String[]{prompt, text}, temperature);
return TextUtils.join("\n", result);
Gemini.Message message = new Gemini.Message(Gemini.USER, new String[]{prompt, text});
Gemini.Message[] result = Gemini.generate(context, model, new Gemini.Message[]{message}, temperature, 1);
if (result.length == 0)
return null;
return TextUtils.join("\n", result[0].getContent());
}
return null;

View File

@ -35,7 +35,9 @@ import java.io.IOException;
import java.io.InputStream;
import java.net.HttpURLConnection;
import java.net.URL;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.Objects;
public class Gemini {
@ -59,23 +61,28 @@ public class Gemini {
(!TextUtils.isEmpty(apikey) || !Objects.equals(getUri(context), BuildConfig.GEMINI_ENDPOINT)));
}
static String[] generate(Context context, String model, String[] texts, float temperature) throws JSONException, IOException {
JSONArray jpart = new JSONArray();
for (String text : texts) {
JSONObject jtext = new JSONObject();
jtext.put("text", text);
jpart.put(jtext);
}
JSONObject jcontent0 = new JSONObject();
jcontent0.put("parts", jpart);
jcontent0.put("role", USER);
static Message[] generate(Context context, String model, Message[] messages, Float temperature, int n) throws JSONException, IOException {
JSONArray jcontents = new JSONArray();
jcontents.put(jcontent0);
for (Message message : messages) {
JSONArray jparts = new JSONArray();
for (String text : message.getContent()) {
JSONObject jtext = new JSONObject();
jtext.put("text", text);
jparts.put(jtext);
}
JSONObject jcontent = new JSONObject();
jcontent.put("parts", jparts);
jcontent.put("role", message.role);
jcontents.put(jcontent);
}
// https://ai.google.dev/api/python/google/generativeai/GenerationConfig
JSONObject jconfig = new JSONObject();
jconfig.put("temperature", temperature);
if (temperature != null)
jconfig.put("temperature", temperature);
jconfig.put("candidate_count", n);
JSONArray jsafety = new JSONArray();
@ -108,19 +115,26 @@ public class Gemini {
JSONObject jresponse = call(context, "POST", path, jrequest);
List<Message> result = new ArrayList<>();
JSONArray jcandidates = jresponse.optJSONArray("candidates");
if (jcandidates == null || jcandidates.length() < 1)
throw new IOException(jresponse.toString(2));
JSONObject jcontent = jcandidates.getJSONObject(0).optJSONObject("content");
if (jcontent == null)
throw new IOException(jresponse.toString(2));
JSONArray jparts = jcontent.optJSONArray("parts");
if (jparts == null || jparts.length() < 1)
throw new IOException(jresponse.toString(2));
JSONObject jtext = jparts.getJSONObject(0);
if (!jtext.has("text"))
throw new IOException(jresponse.toString(2));
return new String[]{jtext.getString("text")};
for (int i = 0; i < jcandidates.length(); i++) {
JSONObject jcandidate = jcandidates.getJSONObject(i);
JSONObject jcontent = jcandidate.getJSONObject("content");
String role = jcontent.getString("role");
List<String> texts = new ArrayList<>();
JSONArray jparts = jcontent.getJSONArray("parts");
for (int j = 0; j < jparts.length(); j++) {
JSONObject jpart = jparts.getJSONObject(j);
texts.add(jpart.getString("text"));
}
result.add(new Message(role, texts.toArray(new String[0])));
}
return result.toArray(new Message[0]);
}
private static String getUri(Context context) {
@ -202,4 +216,28 @@ public class Gemini {
return sb.toString();
}
static class Message {
private final String role; // model, user
private final String[] content;
public Message(String role, String[] content) {
this.role = role;
this.content = content;
}
public String getRole() {
return this.role;
}
public String[] getContent() {
return this.content;
}
@NonNull
@Override
public String toString() {
return this.role + ": " + (this.content == null ? null : TextUtils.join(", ", this.content));
}
}
}