diff --git a/app/src/main/java/eu/faircode/email/AI.java b/app/src/main/java/eu/faircode/email/AI.java index d383245e91..b4dfc30043 100644 --- a/app/src/main/java/eu/faircode/email/AI.java +++ b/app/src/main/java/eu/faircode/email/AI.java @@ -29,6 +29,7 @@ import androidx.preference.PreferenceManager; import org.json.JSONException; import org.jsoup.nodes.Document; +import org.jsoup.select.Elements; import java.io.File; import java.io.IOException; @@ -43,8 +44,22 @@ public class AI { } static String completeChat(Context context, long id, CharSequence body) throws JSONException, IOException { + File file = EntityMessage.getFile(context, id); + Document d = JsoupEx.parse(file); + Elements ref = d.select("div[fairemail=reference]"); + d = Document.createShell(""); + d.appendChildren(ref); + + HtmlHelper.removeSignatures(d); + HtmlHelper.truncate(d, MAX_SUMMARIZE_TEXT_SIZE); + if (body == null || TextUtils.isEmpty(body.toString().trim())) - body = "?"; + if (OpenAI.isAvailable(context)) + body = OpenAI.DEFAULT_ANSWER_PROMPT; + else if (Gemini.isAvailable(context)) + body = Gemini.DEFAULT_ANSWER_PROMPT; + else + body = "?"; if (OpenAI.isAvailable(context)) { SharedPreferences prefs = PreferenceManager.getDefaultSharedPreferences(context); @@ -52,16 +67,21 @@ public class AI { float temperature = prefs.getFloat("openai_temperature", OpenAI.DEFAULT_TEMPERATURE); boolean multimodal = prefs.getBoolean("openai_multimodal", false); - OpenAI.Message message; - if (body instanceof Spannable && multimodal) - message = new OpenAI.Message(OpenAI.USER, - OpenAI.Content.get((Spannable) body, id, context)); - else - message = new OpenAI.Message(OpenAI.USER, new OpenAI.Content[]{ - new OpenAI.Content(OpenAI.CONTENT_TEXT, body.toString())}); + List messages = new ArrayList<>(); - OpenAI.Message[] completions = - OpenAI.completeChat(context, model, new OpenAI.Message[]{message}, temperature, 1); + if (body instanceof Spannable && multimodal) + messages.add(new OpenAI.Message(OpenAI.USER, + OpenAI.Content.get((Spannable) body, id, context))); + else + messages.add(new OpenAI.Message(OpenAI.USER, new OpenAI.Content[]{ + new OpenAI.Content(OpenAI.CONTENT_TEXT, body.toString())})); + + if (!ref.isEmpty()) + messages.add(new OpenAI.Message(OpenAI.USER, new OpenAI.Content[]{ + new OpenAI.Content(OpenAI.CONTENT_TEXT, ref.text())})); + + OpenAI.Message[] completions = OpenAI.completeChat(context, + model, messages.toArray(new OpenAI.Message[0]), temperature, 1); StringBuilder sb = new StringBuilder(); for (OpenAI.Message completion : completions) @@ -79,9 +99,16 @@ public class AI { String model = prefs.getString("gemini_model", Gemini.DEFAULT_MODEL); float temperature = prefs.getFloat("gemini_temperature", Gemini.DEFAULT_TEMPERATURE); - Gemini.Message message = new Gemini.Message(Gemini.USER, - new String[]{Gemini.truncateParagraphs(body.toString())}); - Gemini.Message[] completions = Gemini.generate(context, model, new Gemini.Message[]{message}, temperature, 1); + List messages = new ArrayList<>(); + + messages.add(new Gemini.Message(Gemini.USER, + new String[]{Gemini.truncateParagraphs(body.toString())})); + + if (!ref.isEmpty()) + messages.add(new Gemini.Message(Gemini.USER, new String[]{ref.text()})); + + Gemini.Message[] completions = Gemini.generate(context, + model, messages.toArray(new Gemini.Message[0]), temperature, 1); StringBuilder sb = new StringBuilder(); for (Gemini.Message completion : completions) diff --git a/app/src/main/java/eu/faircode/email/Gemini.java b/app/src/main/java/eu/faircode/email/Gemini.java index 458c0716bd..b3d3d99d23 100644 --- a/app/src/main/java/eu/faircode/email/Gemini.java +++ b/app/src/main/java/eu/faircode/email/Gemini.java @@ -45,6 +45,7 @@ public class Gemini { static final String DEFAULT_MODEL = "gemini-pro"; static final float DEFAULT_TEMPERATURE = 0.9f; static final String DEFAULT_SUMMARY_PROMPT = "Summarize the following text:"; + static final String DEFAULT_ANSWER_PROMPT = "Answer this message:"; static final String MODEL = "model"; static final String USER = "user"; diff --git a/app/src/main/java/eu/faircode/email/OpenAI.java b/app/src/main/java/eu/faircode/email/OpenAI.java index f460e0f1fd..eda704bf38 100644 --- a/app/src/main/java/eu/faircode/email/OpenAI.java +++ b/app/src/main/java/eu/faircode/email/OpenAI.java @@ -50,6 +50,7 @@ public class OpenAI { static final String DEFAULT_MODEL = "gpt-4o"; static final float DEFAULT_TEMPERATURE = 0.5f; static final String DEFAULT_SUMMARY_PROMPT = "Summarize the following text:"; + static final String DEFAULT_ANSWER_PROMPT = "Answer this message:"; static final String ASSISTANT = "assistant"; static final String USER = "user";