From 8cab3786e9167c2b558060d65ac5b4952425e65d Mon Sep 17 00:00:00 2001 From: M66B Date: Thu, 9 May 2024 10:20:16 +0200 Subject: [PATCH] Gemini: refactoring --- .../eu/faircode/email/FragmentCompose.java | 14 +-- .../email/FragmentDialogSummarize.java | 8 +- .../main/java/eu/faircode/email/Gemini.java | 88 +++++++++++++------ 3 files changed, 77 insertions(+), 33 deletions(-) diff --git a/app/src/main/java/eu/faircode/email/FragmentCompose.java b/app/src/main/java/eu/faircode/email/FragmentCompose.java index 7f788f3a2c..81e5bf95c1 100644 --- a/app/src/main/java/eu/faircode/email/FragmentCompose.java +++ b/app/src/main/java/eu/faircode/email/FragmentCompose.java @@ -2769,7 +2769,7 @@ public class FragmentCompose extends FragmentBase { args.putString("body", body); args.putBoolean("selection", selection); - new SimpleTask() { + new SimpleTask() { @Override protected void onPreExecute(Bundle args) { chatting = true; @@ -2783,7 +2783,7 @@ public class FragmentCompose extends FragmentBase { } @Override - protected String[] onExecute(Context context, Bundle args) throws Throwable { + protected Gemini.Message[] onExecute(Context context, Bundle args) throws Throwable { long id = args.getLong("id"); String body = args.getString("body"); boolean selection = args.getBoolean("selection"); @@ -2792,15 +2792,17 @@ public class FragmentCompose extends FragmentBase { String model = prefs.getString("gemini_model", "gemini-pro"); float temperature = prefs.getFloat("gemini_temperature", 0.5f); - return Gemini.generate(context, model, new String[]{Gemini.truncateParagraphs(body)}, temperature); + Gemini.Message message = new Gemini.Message(Gemini.USER, new String[]{Gemini.truncateParagraphs(body)}); + + return Gemini.generate(context, model, new Gemini.Message[]{message}, temperature, 1); } @Override - protected void onExecuted(Bundle args, String[] result) { - if (result == null || result.length == 0) + protected void onExecuted(Bundle args, Gemini.Message[] messages) { + if (messages == null || messages.length == 0) return; - String text = result[0] + String text = TextUtils.join("\n", messages[0].getContent()) .replaceAll("^\\n+", "").replaceAll("\\n+$", ""); Editable edit = etBody.getText(); diff --git a/app/src/main/java/eu/faircode/email/FragmentDialogSummarize.java b/app/src/main/java/eu/faircode/email/FragmentDialogSummarize.java index b4e57f40c8..28948430e3 100644 --- a/app/src/main/java/eu/faircode/email/FragmentDialogSummarize.java +++ b/app/src/main/java/eu/faircode/email/FragmentDialogSummarize.java @@ -109,8 +109,12 @@ public class FragmentDialogSummarize extends FragmentDialogBase { float temperature = prefs.getFloat("gemini_temperature", 0.5f); String prompt = prefs.getString("gemini_summarize", Gemini.SUMMARY_PROMPT); - String[] result = Gemini.generate(context, model, new String[]{prompt, text}, temperature); - return TextUtils.join("\n", result); + Gemini.Message message = new Gemini.Message(Gemini.USER, new String[]{prompt, text}); + + Gemini.Message[] result = Gemini.generate(context, model, new Gemini.Message[]{message}, temperature, 1); + if (result.length == 0) + return null; + return TextUtils.join("\n", result[0].getContent()); } return null; diff --git a/app/src/main/java/eu/faircode/email/Gemini.java b/app/src/main/java/eu/faircode/email/Gemini.java index afbe725888..3c9f0f40e0 100644 --- a/app/src/main/java/eu/faircode/email/Gemini.java +++ b/app/src/main/java/eu/faircode/email/Gemini.java @@ -35,7 +35,9 @@ import java.io.IOException; import java.io.InputStream; import java.net.HttpURLConnection; import java.net.URL; +import java.util.ArrayList; import java.util.Date; +import java.util.List; import java.util.Objects; public class Gemini { @@ -59,23 +61,28 @@ public class Gemini { (!TextUtils.isEmpty(apikey) || !Objects.equals(getUri(context), BuildConfig.GEMINI_ENDPOINT))); } - static String[] generate(Context context, String model, String[] texts, float temperature) throws JSONException, IOException { - JSONArray jpart = new JSONArray(); - for (String text : texts) { - JSONObject jtext = new JSONObject(); - jtext.put("text", text); - jpart.put(jtext); - } - - JSONObject jcontent0 = new JSONObject(); - jcontent0.put("parts", jpart); - jcontent0.put("role", USER); + static Message[] generate(Context context, String model, Message[] messages, Float temperature, int n) throws JSONException, IOException { JSONArray jcontents = new JSONArray(); - jcontents.put(jcontent0); + for (Message message : messages) { + JSONArray jparts = new JSONArray(); + for (String text : message.getContent()) { + JSONObject jtext = new JSONObject(); + jtext.put("text", text); + jparts.put(jtext); + } + + JSONObject jcontent = new JSONObject(); + jcontent.put("parts", jparts); + jcontent.put("role", message.role); + + jcontents.put(jcontent); + } // https://ai.google.dev/api/python/google/generativeai/GenerationConfig JSONObject jconfig = new JSONObject(); - jconfig.put("temperature", temperature); + if (temperature != null) + jconfig.put("temperature", temperature); + jconfig.put("candidate_count", n); JSONArray jsafety = new JSONArray(); @@ -108,19 +115,26 @@ public class Gemini { JSONObject jresponse = call(context, "POST", path, jrequest); + List result = new ArrayList<>(); + JSONArray jcandidates = jresponse.optJSONArray("candidates"); - if (jcandidates == null || jcandidates.length() < 1) - throw new IOException(jresponse.toString(2)); - JSONObject jcontent = jcandidates.getJSONObject(0).optJSONObject("content"); - if (jcontent == null) - throw new IOException(jresponse.toString(2)); - JSONArray jparts = jcontent.optJSONArray("parts"); - if (jparts == null || jparts.length() < 1) - throw new IOException(jresponse.toString(2)); - JSONObject jtext = jparts.getJSONObject(0); - if (!jtext.has("text")) - throw new IOException(jresponse.toString(2)); - return new String[]{jtext.getString("text")}; + for (int i = 0; i < jcandidates.length(); i++) { + JSONObject jcandidate = jcandidates.getJSONObject(i); + JSONObject jcontent = jcandidate.getJSONObject("content"); + + String role = jcontent.getString("role"); + + List texts = new ArrayList<>(); + JSONArray jparts = jcontent.getJSONArray("parts"); + for (int j = 0; j < jparts.length(); j++) { + JSONObject jpart = jparts.getJSONObject(j); + texts.add(jpart.getString("text")); + } + + result.add(new Message(role, texts.toArray(new String[0]))); + } + + return result.toArray(new Message[0]); } private static String getUri(Context context) { @@ -202,4 +216,28 @@ public class Gemini { return sb.toString(); } + + static class Message { + private final String role; // model, user + private final String[] content; + + public Message(String role, String[] content) { + this.role = role; + this.content = content; + } + + public String getRole() { + return this.role; + } + + public String[] getContent() { + return this.content; + } + + @NonNull + @Override + public String toString() { + return this.role + ": " + (this.content == null ? null : TextUtils.join(", ", this.content)); + } + } }