diff --git a/app/src/main/java/eu/faircode/email/MessageClassifier.java b/app/src/main/java/eu/faircode/email/MessageClassifier.java index 16dea84bbe..6310b40538 100644 --- a/app/src/main/java/eu/faircode/email/MessageClassifier.java +++ b/app/src/main/java/eu/faircode/email/MessageClassifier.java @@ -41,6 +41,7 @@ import java.util.Collections; import java.util.Comparator; import java.util.Date; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -48,13 +49,14 @@ import javax.mail.Address; import javax.mail.internet.InternetAddress; public class MessageClassifier { - private static int version = 3; private static boolean loaded = false; private static boolean dirty = false; private static final Map> accountMsgIds = new HashMap<>(); private static final Map> classMessages = new HashMap<>(); - private static final Map>> wordClassFrequency = new HashMap<>(); + private static final Map>> wordClassFrequency = new HashMap<>(); + private static final Map wordIndex = new LinkedHashMap<>(); + private static final int VERSION = 4; private static final int MAX_WORDS = 1000; static synchronized void classify(EntityMessage message, EntityFolder folder, boolean added, Context context) { @@ -192,7 +194,7 @@ public class MessageClassifier { " class=" + message.account + ":" + clazz + " exists=" + (folder != null)); classMessages.get(message.account).remove(clazz); - for (String word : wordClassFrequency.get(message.account).keySet()) + for (int word : wordClassFrequency.get(message.account).keySet()) wordClassFrequency.get(message.account).get(word).remove(clazz); } } @@ -265,8 +267,15 @@ public class MessageClassifier { " text=" + TextUtils.join(", ", stat.words)); } - if (BuildConfig.DEBUG) - Log.i("Classifier words=" + state.words.size() + " " + TextUtils.join(", ", state.words)); + if (BuildConfig.DEBUG) { + StringBuilder sb = new StringBuilder(); + for (Integer word : state.words) { + if (sb.length() > 0) + sb.append(", "); + sb.append(getWord(word)); + } + Log.i("Classifier words=" + state.words.size() + " " + sb); + } // Sort classes by chance Collections.sort(chances, new Comparator() { @@ -320,6 +329,11 @@ public class MessageClassifier { return; } + _processWord(account, added, word == null ? null : getWordIndex(word), state); + } + + private static void _processWord(long account, boolean added, Integer word, State state) { + if (word != null || state.words.size() == 0 || state.words.get(state.words.size() - 1) != null) @@ -331,9 +345,9 @@ public class MessageClassifier { if (state.words.size() < 3) return; - String before = state.words.get(state.words.size() - 3); - String current = state.words.get(state.words.size() - 2); - String after = state.words.get(state.words.size() - 1); + Integer before = state.words.get(state.words.size() - 3); + Integer current = state.words.get(state.words.size() - 2); + Integer after = state.words.get(state.words.size() - 1); if (current == null) return; @@ -373,6 +387,25 @@ public class MessageClassifier { } } + private static int getWordIndex(String word) { + synchronized (wordIndex) { + Integer index = wordIndex.get(word); + if (index == null) { + index = wordIndex.size(); + wordIndex.put(word, index); + } + return index; + } + } + + private static String getWord(Integer index) { + if (index == null) + return ""; + if (index < 0 || index >= wordIndex.size()) + return "<" + index + ">"; + return new ArrayList<>(wordIndex.keySet()).get(index); + } + private static void updateFrequencies(long account, @NonNull String currentClass, boolean added, @NonNull State state) { Integer m = classMessages.get(account).get(currentClass); m = (m == null ? 0 : m) + (added ? 1 : -1); @@ -383,9 +416,9 @@ public class MessageClassifier { Log.i("Classifier " + currentClass + "=" + m + " msgs"); for (int i = 1; i < state.words.size() - 1; i++) { - String before = state.words.get(i - 1); - String current = state.words.get(i); - String after = state.words.get(i + 1); + Integer before = state.words.get(i - 1); + Integer current = state.words.get(i); + Integer after = state.words.get(i + 1); if (current == null) continue; @@ -428,8 +461,8 @@ public class MessageClassifier { try (JsonWriter writer = new JsonWriter(new BufferedWriter(new FileWriter(file)))) { writer.beginObject(); - Log.i("Classifier write version=" + version); - writer.name("version").value(version); + Log.i("Classifier write version=" + VERSION); + writer.name("version").value(VERSION); writer.name("messages"); writer.beginArray(); @@ -446,7 +479,7 @@ public class MessageClassifier { writer.name("words"); writer.beginArray(); for (Long account : wordClassFrequency.keySet()) - for (String word : wordClassFrequency.get(account).keySet()) { + for (int word : wordClassFrequency.get(account).keySet()) { Map classFrequency = wordClassFrequency.get(account).get(word); for (String clazz : classFrequency.keySet()) { Frequency f = classFrequency.get(clazz); @@ -460,14 +493,14 @@ public class MessageClassifier { writer.name("before"); writer.beginObject(); - for (String key : f.before.keySet()) - writer.name(key).value(f.before.get(key)); + for (int key : f.before.keySet()) + writer.name(Integer.toString(key)).value(f.before.get(key)); writer.endObject(); writer.name("after"); writer.beginObject(); - for (String key : f.after.keySet()) - writer.name(key).value(f.after.get(key)); + for (int key : f.after.keySet()) + writer.name(Integer.toString(key)).value(f.after.get(key)); writer.endObject(); writer.endObject(); @@ -475,6 +508,12 @@ public class MessageClassifier { } writer.endArray(); + writer.name("list"); + writer.beginArray(); + for (String word : wordIndex.keySet()) + writer.value(word); + writer.endArray(); + writer.name("classified"); writer.beginArray(); for (Long account : accountMsgIds.keySet()) { @@ -521,7 +560,7 @@ public class MessageClassifier { private static synchronized void _load(File file) throws IOException { Log.i("Classifier read " + file); long start = new Date().getTime(); - version = 0; + int version = 0; if (file.exists()) try (JsonReader reader = new JsonReader(new BufferedReader(new FileReader(file)))) { reader.beginObject(); @@ -568,7 +607,7 @@ public class MessageClassifier { reader.beginArray(); while (reader.hasNext()) { Long account = null; - String word = null; + Integer word = null; String clazz = null; Frequency f = new Frequency(); @@ -579,7 +618,10 @@ public class MessageClassifier { account = reader.nextLong(); break; case "word": - word = reader.nextString(); + if (version > 3) + word = Integer.parseInt(reader.nextString()); + else + word = getWordIndex(reader.nextString()); break; case "class": clazz = reader.nextString(); @@ -592,14 +634,22 @@ public class MessageClassifier { break; case "before": reader.beginObject(); - while (reader.hasNext()) - f.before.put(reader.nextName(), reader.nextInt()); + while (reader.hasNext()) { + int b = (version > 3 + ? Integer.parseInt(reader.nextName()) + : getWordIndex(reader.nextName())); + f.before.put(b, reader.nextInt()); + } reader.endObject(); break; case "after": reader.beginObject(); - while (reader.hasNext()) - f.after.put(reader.nextName(), reader.nextInt()); + while (reader.hasNext()) { + int a = (version > 3 + ? Integer.parseInt(reader.nextName()) + : getWordIndex(reader.nextName())); + f.after.put(a, reader.nextInt()); + } reader.endObject(); break; } @@ -622,6 +672,13 @@ public class MessageClassifier { reader.endArray(); break; + case "list": + reader.beginArray(); + while (reader.hasNext()) + wordIndex.put(reader.nextString(), wordIndex.size()); + reader.endArray(); + break; + case "classified": reader.beginArray(); while (reader.hasNext()) { @@ -660,7 +717,7 @@ public class MessageClassifier { dirty = false; long elapsed = new Date().getTime() - start; - Log.i("Classifier data loaded elapsed=" + elapsed); + Log.i("Classifier data loaded elapsed=" + elapsed + " words=" + wordIndex.size()); } private static void reduce() { @@ -670,7 +727,7 @@ public class MessageClassifier { Map total = new HashMap<>(); Map count = new HashMap<>(); - for (String word : wordClassFrequency.get(account).keySet()) + for (int word : wordClassFrequency.get(account).keySet()) for (String clazz : wordClassFrequency.get(account).get(word).keySet()) { int f = wordClassFrequency.get(account).get(word).get(clazz).count; @@ -691,7 +748,7 @@ public class MessageClassifier { Log.i("Classifier max " + account + ":" + clazz + "=" + max.get(clazz)); int dropped = 0; - for (String word : wordClassFrequency.get(account).keySet()) + for (int word : wordClassFrequency.get(account).keySet()) for (String clazz : new ArrayList<>(wordClassFrequency.get(account).get(word).keySet())) { long m = max.get(clazz); long avg = total.get(clazz) / count.get(clazz); @@ -703,13 +760,6 @@ public class MessageClassifier { dropped++; Log.i("Classifier dropping account=" + account + " word=" + word + " class=" + clazz + " freq=" + freq.count + " avg=" + avg); - } else if (version >= 3 && false) { - for (String b : new ArrayList<>(freq.before.keySet())) - if (freq.before.get(b) < freq.count / 20) - freq.before.remove(b); - for (String a : new ArrayList<>(freq.after.keySet())) - if (freq.after.get(a) < freq.count / 20) - freq.after.remove(a); } } Log.i("Classifier dropped words=" + dropped); @@ -768,6 +818,7 @@ public class MessageClassifier { accountMsgIds.clear(); classMessages.clear(); wordClassFrequency.clear(); + wordIndex.clear(); dirty = true; Log.i("Classifier data cleared"); } @@ -791,17 +842,17 @@ public class MessageClassifier { } private static class State { - private final List words = new ArrayList<>(); + private final List words = new ArrayList<>(); private final Map classStats = new HashMap<>(); } private static class Frequency { private int count = 0; private int duplicates = 0; - private Map before = new HashMap<>(); - private Map after = new HashMap<>(); + private Map before = new HashMap<>(); + private Map after = new HashMap<>(); - private void add(String b, String a, int c, boolean duplicate) { + private void add(Integer b, Integer a, int c, boolean duplicate) { if (count + c < 0) return;