From 502357b4fd5a0c3ca75c142efe41ae0890f185a1 Mon Sep 17 00:00:00 2001 From: M66B Date: Wed, 6 Jan 2021 21:35:11 +0100 Subject: [PATCH] Context sensitive classifier --- .../eu/faircode/email/MessageClassifier.java | 171 ++++++++++++------ 1 file changed, 117 insertions(+), 54 deletions(-) diff --git a/app/src/main/java/eu/faircode/email/MessageClassifier.java b/app/src/main/java/eu/faircode/email/MessageClassifier.java index 40f35b97fa..1a3dc0f0a1 100644 --- a/app/src/main/java/eu/faircode/email/MessageClassifier.java +++ b/app/src/main/java/eu/faircode/email/MessageClassifier.java @@ -39,6 +39,7 @@ import java.util.Collections; import java.util.Comparator; import java.util.Date; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; @@ -49,7 +50,7 @@ public class MessageClassifier { private static boolean loaded = false; private static boolean dirty = false; private static final Map> classMessages = new HashMap<>(); - private static final Map>> wordClassFrequency = new HashMap<>(); + private static final Map>> wordClassFrequency = new HashMap<>(); private static final int MIN_MATCHED_WORDS = 10; private static final double CHANCE_THRESHOLD = 2.0; @@ -171,9 +172,8 @@ public class MessageClassifier { } private static String classify(long account, String currentClass, String text, boolean added, Context context) { - int maxMatchedWords = 0; - List words = new ArrayList<>(); - Map classStats = new HashMap<>(); + State state = new State(); + state.words.add(null); BreakIterator boundary = BreakIterator.getWordInstance(); // TODO ICU boundary.setText(text); @@ -181,59 +181,26 @@ public class MessageClassifier { for (int end = boundary.next(); end != BreakIterator.DONE; end = boundary.next()) { String word = text.substring(start, end).toLowerCase(); if (word.length() > 1 && - !words.contains(word) && + !state.words.contains(word) && !word.matches(".*\\d.*")) { - words.add(word); - - Map classFrequency = wordClassFrequency.get(account).get(word); - if (added) { - if (classFrequency == null) { - classFrequency = new HashMap<>(); - wordClassFrequency.get(account).put(word, classFrequency); - } - - for (String clazz : classFrequency.keySet()) { - int frequency = classFrequency.get(clazz); - - Stat stat = classStats.get(clazz); - if (stat == null) { - stat = new Stat(); - classStats.put(clazz, stat); - } - - stat.matchedWords++; - stat.totalFrequency += frequency; - if (BuildConfig.DEBUG) - stat.words.add(word); - - if (stat.matchedWords > maxMatchedWords) - maxMatchedWords = stat.matchedWords; - } - - Integer c = classFrequency.get(currentClass); - c = (c == null ? 1 : c + 1); - classFrequency.put(currentClass, c); - } else { - Integer c = (classFrequency == null ? null : classFrequency.get(currentClass)); - if (c != null) - if (c > 0) - classFrequency.put(currentClass, c - 1); - else - classFrequency.remove(currentClass); - } + state.words.add(word); + process(account, currentClass, added, state); } start = end; } + state.words.add(null); + process(account, currentClass, added, state); + if (!added) return null; - if (maxMatchedWords == 0) + if (state.maxMatchedWords == 0) return null; DB db = DB.getInstance(context); List chances = new ArrayList<>(); - for (String clazz : classStats.keySet()) { + for (String clazz : state.classStats.keySet()) { Integer messages = classMessages.get(account).get(clazz); if (messages == null || messages == 0) { Log.w("Classifier no messages class=" + account + ":" + clazz); @@ -246,13 +213,13 @@ public class MessageClassifier { continue; } - Stat stat = classStats.get(clazz); + Stat stat = state.classStats.get(clazz); boolean consider = (stat.matchedWords >= MIN_MATCHED_WORDS); - double chance = (double) stat.totalFrequency / messages / maxMatchedWords; + double chance = stat.totalFrequency / messages / state.maxMatchedWords; Chance c = new Chance(clazz, chance); EntityLog.log(context, "Classifier " + c + " frequency=" + stat.totalFrequency + "/" + messages + - " matched=" + stat.matchedWords + "/" + maxMatchedWords + + " matched=" + stat.matchedWords + "/" + state.maxMatchedWords + " consider=" + consider + " words=" + TextUtils.join(", ", stat.words)); if (consider) @@ -260,7 +227,7 @@ public class MessageClassifier { } if (BuildConfig.DEBUG) - Log.i("Classifier words=" + TextUtils.join(", ", words)); + Log.i("Classifier words=" + TextUtils.join(", ", state.words)); if (chances.size() <= 1) return null; @@ -281,6 +248,67 @@ public class MessageClassifier { return classification; } + private static void process(long account, String currentClass, boolean added, State state) { + if (state.words.size() < 3) + return; + + String before = state.words.get(state.words.size() - 3); + String current = state.words.get(state.words.size() - 2); + String after = state.words.get(state.words.size() - 1); + + Map classFrequency = wordClassFrequency.get(account).get(current); + if (added) { + if (classFrequency == null) { + classFrequency = new HashMap<>(); + wordClassFrequency.get(account).put(current, classFrequency); + } + + for (String clazz : classFrequency.keySet()) { + Frequency frequency = classFrequency.get(clazz); + + Stat stat = state.classStats.get(clazz); + if (stat == null) { + stat = new Stat(); + state.classStats.put(clazz, stat); + } + + stat.matchedWords++; + + boolean b = (before != null && frequency.before.contains(before)); + boolean a = (after != null && frequency.after.contains(after)); + if (b && a) + stat.totalFrequency += frequency.count; + else if (b || a) + stat.totalFrequency += frequency.count * 0.5; + else + stat.totalFrequency += frequency.count * 0.25; + + if (BuildConfig.DEBUG) + stat.words.add(current); + + if (stat.matchedWords > state.maxMatchedWords) + state.maxMatchedWords = stat.matchedWords; + } + + Frequency c = classFrequency.get(currentClass); + if (c == null) + c = new Frequency(); + c.count++; + if (before != null && !c.before.contains(before)) + c.before.add(before); + if (after != null && !c.after.contains(after)) + c.after.add(after); + classFrequency.put(currentClass, c); + } else { + Frequency c = (classFrequency == null ? null : classFrequency.get(currentClass)); + if (c != null) + if (c.count > 0) + c.count--; + else + classFrequency.remove(currentClass); + } + } + static synchronized void save(Context context) throws JSONException, IOException { if (!dirty) return; @@ -345,13 +373,16 @@ public class MessageClassifier { JSONArray jwords = new JSONArray(); for (Long account : classMessages.keySet()) for (String word : wordClassFrequency.get(account).keySet()) { - Map classFrequency = wordClassFrequency.get(account).get(word); + Map classFrequency = wordClassFrequency.get(account).get(word); for (String clazz : classFrequency.keySet()) { + Frequency f = classFrequency.get(clazz); JSONObject jword = new JSONObject(); jword.put("account", account); jword.put("word", word); jword.put("class", clazz); - jword.put("frequency", classFrequency.get(clazz)); + jword.put("frequency", f.count); + jword.put("before", from(f.before)); + jword.put("after", from(f.after)); jwords.put(jword); } } @@ -363,6 +394,13 @@ public class MessageClassifier { return jroot; } + private static JSONArray from(HashSet list) { + JSONArray jarray = new JSONArray(); + for (String item : list) + jarray.put(item); + return jarray; + } + static void fromJson(JSONObject jroot) throws JSONException { JSONArray jmessages = jroot.getJSONArray("messages"); for (int m = 0; m < jmessages.length(); m++) { @@ -380,18 +418,43 @@ public class MessageClassifier { if (!wordClassFrequency.containsKey(account)) wordClassFrequency.put(account, new HashMap<>()); String word = jword.getString("word"); - Map classFrequency = wordClassFrequency.get(account).get(word); + Map classFrequency = wordClassFrequency.get(account).get(word); if (classFrequency == null) { classFrequency = new HashMap<>(); wordClassFrequency.get(account).put(word, classFrequency); } - classFrequency.put(jword.getString("class"), jword.getInt("frequency")); + Frequency f = new Frequency(); + f.count = jword.getInt("frequency"); + if (jword.has("before")) + f.before = from(jword.getJSONArray("before")); + if (jword.has("after")) + f.after = from(jword.getJSONArray("after")); + classFrequency.put(jword.getString("class"), f); } } + private static HashSet from(JSONArray jarray) throws JSONException { + HashSet result = new HashSet<>(jarray.length()); + for (int i = 0; i < jarray.length(); i++) + result.add((String) jarray.get(i)); + return result; + } + + private static class State { + int maxMatchedWords = 0; + List words = new ArrayList<>(); + Map classStats = new HashMap<>(); + } + + private static class Frequency { + int count; + HashSet before = new HashSet<>(); + HashSet after = new HashSet<>(); + } + private static class Stat { int matchedWords = 0; - int totalFrequency = 0; + double totalFrequency = 0; List words = new ArrayList<>(); }