Classifier word list

This commit is contained in:
M66B 2023-03-22 15:18:59 +01:00
parent 27307edf56
commit 1221b69dcc
1 changed files with 90 additions and 39 deletions

View File

@ -41,6 +41,7 @@ import java.util.Collections;
import java.util.Comparator; import java.util.Comparator;
import java.util.Date; import java.util.Date;
import java.util.HashMap; import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -48,13 +49,14 @@ import javax.mail.Address;
import javax.mail.internet.InternetAddress; import javax.mail.internet.InternetAddress;
public class MessageClassifier { public class MessageClassifier {
private static int version = 3;
private static boolean loaded = false; private static boolean loaded = false;
private static boolean dirty = false; private static boolean dirty = false;
private static final Map<Long, List<String>> accountMsgIds = new HashMap<>(); private static final Map<Long, List<String>> accountMsgIds = new HashMap<>();
private static final Map<Long, Map<String, Integer>> classMessages = new HashMap<>(); private static final Map<Long, Map<String, Integer>> classMessages = new HashMap<>();
private static final Map<Long, Map<String, Map<String, Frequency>>> wordClassFrequency = new HashMap<>(); private static final Map<Long, Map<Integer, Map<String, Frequency>>> wordClassFrequency = new HashMap<>();
private static final Map<String, Integer> wordIndex = new LinkedHashMap<>();
private static final int VERSION = 4;
private static final int MAX_WORDS = 1000; private static final int MAX_WORDS = 1000;
static synchronized void classify(EntityMessage message, EntityFolder folder, boolean added, Context context) { static synchronized void classify(EntityMessage message, EntityFolder folder, boolean added, Context context) {
@ -192,7 +194,7 @@ public class MessageClassifier {
" class=" + message.account + ":" + clazz + " class=" + message.account + ":" + clazz +
" exists=" + (folder != null)); " exists=" + (folder != null));
classMessages.get(message.account).remove(clazz); classMessages.get(message.account).remove(clazz);
for (String word : wordClassFrequency.get(message.account).keySet()) for (int word : wordClassFrequency.get(message.account).keySet())
wordClassFrequency.get(message.account).get(word).remove(clazz); wordClassFrequency.get(message.account).get(word).remove(clazz);
} }
} }
@ -265,8 +267,15 @@ public class MessageClassifier {
" text=" + TextUtils.join(", ", stat.words)); " text=" + TextUtils.join(", ", stat.words));
} }
if (BuildConfig.DEBUG) if (BuildConfig.DEBUG) {
Log.i("Classifier words=" + state.words.size() + " " + TextUtils.join(", ", state.words)); StringBuilder sb = new StringBuilder();
for (Integer word : state.words) {
if (sb.length() > 0)
sb.append(", ");
sb.append(getWord(word));
}
Log.i("Classifier words=" + state.words.size() + " " + sb);
}
// Sort classes by chance // Sort classes by chance
Collections.sort(chances, new Comparator<Chance>() { Collections.sort(chances, new Comparator<Chance>() {
@ -320,6 +329,11 @@ public class MessageClassifier {
return; return;
} }
_processWord(account, added, word == null ? null : getWordIndex(word), state);
}
private static void _processWord(long account, boolean added, Integer word, State state) {
if (word != null || if (word != null ||
state.words.size() == 0 || state.words.size() == 0 ||
state.words.get(state.words.size() - 1) != null) state.words.get(state.words.size() - 1) != null)
@ -331,9 +345,9 @@ public class MessageClassifier {
if (state.words.size() < 3) if (state.words.size() < 3)
return; return;
String before = state.words.get(state.words.size() - 3); Integer before = state.words.get(state.words.size() - 3);
String current = state.words.get(state.words.size() - 2); Integer current = state.words.get(state.words.size() - 2);
String after = state.words.get(state.words.size() - 1); Integer after = state.words.get(state.words.size() - 1);
if (current == null) if (current == null)
return; return;
@ -373,6 +387,25 @@ public class MessageClassifier {
} }
} }
private static int getWordIndex(String word) {
synchronized (wordIndex) {
Integer index = wordIndex.get(word);
if (index == null) {
index = wordIndex.size();
wordIndex.put(word, index);
}
return index;
}
}
private static String getWord(Integer index) {
if (index == null)
return "<null>";
if (index < 0 || index >= wordIndex.size())
return "<" + index + ">";
return new ArrayList<>(wordIndex.keySet()).get(index);
}
private static void updateFrequencies(long account, @NonNull String currentClass, boolean added, @NonNull State state) { private static void updateFrequencies(long account, @NonNull String currentClass, boolean added, @NonNull State state) {
Integer m = classMessages.get(account).get(currentClass); Integer m = classMessages.get(account).get(currentClass);
m = (m == null ? 0 : m) + (added ? 1 : -1); m = (m == null ? 0 : m) + (added ? 1 : -1);
@ -383,9 +416,9 @@ public class MessageClassifier {
Log.i("Classifier " + currentClass + "=" + m + " msgs"); Log.i("Classifier " + currentClass + "=" + m + " msgs");
for (int i = 1; i < state.words.size() - 1; i++) { for (int i = 1; i < state.words.size() - 1; i++) {
String before = state.words.get(i - 1); Integer before = state.words.get(i - 1);
String current = state.words.get(i); Integer current = state.words.get(i);
String after = state.words.get(i + 1); Integer after = state.words.get(i + 1);
if (current == null) if (current == null)
continue; continue;
@ -428,8 +461,8 @@ public class MessageClassifier {
try (JsonWriter writer = new JsonWriter(new BufferedWriter(new FileWriter(file)))) { try (JsonWriter writer = new JsonWriter(new BufferedWriter(new FileWriter(file)))) {
writer.beginObject(); writer.beginObject();
Log.i("Classifier write version=" + version); Log.i("Classifier write version=" + VERSION);
writer.name("version").value(version); writer.name("version").value(VERSION);
writer.name("messages"); writer.name("messages");
writer.beginArray(); writer.beginArray();
@ -446,7 +479,7 @@ public class MessageClassifier {
writer.name("words"); writer.name("words");
writer.beginArray(); writer.beginArray();
for (Long account : wordClassFrequency.keySet()) for (Long account : wordClassFrequency.keySet())
for (String word : wordClassFrequency.get(account).keySet()) { for (int word : wordClassFrequency.get(account).keySet()) {
Map<String, Frequency> classFrequency = wordClassFrequency.get(account).get(word); Map<String, Frequency> classFrequency = wordClassFrequency.get(account).get(word);
for (String clazz : classFrequency.keySet()) { for (String clazz : classFrequency.keySet()) {
Frequency f = classFrequency.get(clazz); Frequency f = classFrequency.get(clazz);
@ -460,14 +493,14 @@ public class MessageClassifier {
writer.name("before"); writer.name("before");
writer.beginObject(); writer.beginObject();
for (String key : f.before.keySet()) for (int key : f.before.keySet())
writer.name(key).value(f.before.get(key)); writer.name(Integer.toString(key)).value(f.before.get(key));
writer.endObject(); writer.endObject();
writer.name("after"); writer.name("after");
writer.beginObject(); writer.beginObject();
for (String key : f.after.keySet()) for (int key : f.after.keySet())
writer.name(key).value(f.after.get(key)); writer.name(Integer.toString(key)).value(f.after.get(key));
writer.endObject(); writer.endObject();
writer.endObject(); writer.endObject();
@ -475,6 +508,12 @@ public class MessageClassifier {
} }
writer.endArray(); writer.endArray();
writer.name("list");
writer.beginArray();
for (String word : wordIndex.keySet())
writer.value(word);
writer.endArray();
writer.name("classified"); writer.name("classified");
writer.beginArray(); writer.beginArray();
for (Long account : accountMsgIds.keySet()) { for (Long account : accountMsgIds.keySet()) {
@ -521,7 +560,7 @@ public class MessageClassifier {
private static synchronized void _load(File file) throws IOException { private static synchronized void _load(File file) throws IOException {
Log.i("Classifier read " + file); Log.i("Classifier read " + file);
long start = new Date().getTime(); long start = new Date().getTime();
version = 0; int version = 0;
if (file.exists()) if (file.exists())
try (JsonReader reader = new JsonReader(new BufferedReader(new FileReader(file)))) { try (JsonReader reader = new JsonReader(new BufferedReader(new FileReader(file)))) {
reader.beginObject(); reader.beginObject();
@ -568,7 +607,7 @@ public class MessageClassifier {
reader.beginArray(); reader.beginArray();
while (reader.hasNext()) { while (reader.hasNext()) {
Long account = null; Long account = null;
String word = null; Integer word = null;
String clazz = null; String clazz = null;
Frequency f = new Frequency(); Frequency f = new Frequency();
@ -579,7 +618,10 @@ public class MessageClassifier {
account = reader.nextLong(); account = reader.nextLong();
break; break;
case "word": case "word":
word = reader.nextString(); if (version > 3)
word = Integer.parseInt(reader.nextString());
else
word = getWordIndex(reader.nextString());
break; break;
case "class": case "class":
clazz = reader.nextString(); clazz = reader.nextString();
@ -592,14 +634,22 @@ public class MessageClassifier {
break; break;
case "before": case "before":
reader.beginObject(); reader.beginObject();
while (reader.hasNext()) while (reader.hasNext()) {
f.before.put(reader.nextName(), reader.nextInt()); int b = (version > 3
? Integer.parseInt(reader.nextName())
: getWordIndex(reader.nextName()));
f.before.put(b, reader.nextInt());
}
reader.endObject(); reader.endObject();
break; break;
case "after": case "after":
reader.beginObject(); reader.beginObject();
while (reader.hasNext()) while (reader.hasNext()) {
f.after.put(reader.nextName(), reader.nextInt()); int a = (version > 3
? Integer.parseInt(reader.nextName())
: getWordIndex(reader.nextName()));
f.after.put(a, reader.nextInt());
}
reader.endObject(); reader.endObject();
break; break;
} }
@ -622,6 +672,13 @@ public class MessageClassifier {
reader.endArray(); reader.endArray();
break; break;
case "list":
reader.beginArray();
while (reader.hasNext())
wordIndex.put(reader.nextString(), wordIndex.size());
reader.endArray();
break;
case "classified": case "classified":
reader.beginArray(); reader.beginArray();
while (reader.hasNext()) { while (reader.hasNext()) {
@ -660,7 +717,7 @@ public class MessageClassifier {
dirty = false; dirty = false;
long elapsed = new Date().getTime() - start; long elapsed = new Date().getTime() - start;
Log.i("Classifier data loaded elapsed=" + elapsed); Log.i("Classifier data loaded elapsed=" + elapsed + " words=" + wordIndex.size());
} }
private static void reduce() { private static void reduce() {
@ -670,7 +727,7 @@ public class MessageClassifier {
Map<String, Long> total = new HashMap<>(); Map<String, Long> total = new HashMap<>();
Map<String, Integer> count = new HashMap<>(); Map<String, Integer> count = new HashMap<>();
for (String word : wordClassFrequency.get(account).keySet()) for (int word : wordClassFrequency.get(account).keySet())
for (String clazz : wordClassFrequency.get(account).get(word).keySet()) { for (String clazz : wordClassFrequency.get(account).get(word).keySet()) {
int f = wordClassFrequency.get(account).get(word).get(clazz).count; int f = wordClassFrequency.get(account).get(word).get(clazz).count;
@ -691,7 +748,7 @@ public class MessageClassifier {
Log.i("Classifier max " + account + ":" + clazz + "=" + max.get(clazz)); Log.i("Classifier max " + account + ":" + clazz + "=" + max.get(clazz));
int dropped = 0; int dropped = 0;
for (String word : wordClassFrequency.get(account).keySet()) for (int word : wordClassFrequency.get(account).keySet())
for (String clazz : new ArrayList<>(wordClassFrequency.get(account).get(word).keySet())) { for (String clazz : new ArrayList<>(wordClassFrequency.get(account).get(word).keySet())) {
long m = max.get(clazz); long m = max.get(clazz);
long avg = total.get(clazz) / count.get(clazz); long avg = total.get(clazz) / count.get(clazz);
@ -703,13 +760,6 @@ public class MessageClassifier {
dropped++; dropped++;
Log.i("Classifier dropping account=" + account + Log.i("Classifier dropping account=" + account +
" word=" + word + " class=" + clazz + " freq=" + freq.count + " avg=" + avg); " word=" + word + " class=" + clazz + " freq=" + freq.count + " avg=" + avg);
} else if (version >= 3 && false) {
for (String b : new ArrayList<>(freq.before.keySet()))
if (freq.before.get(b) < freq.count / 20)
freq.before.remove(b);
for (String a : new ArrayList<>(freq.after.keySet()))
if (freq.after.get(a) < freq.count / 20)
freq.after.remove(a);
} }
} }
Log.i("Classifier dropped words=" + dropped); Log.i("Classifier dropped words=" + dropped);
@ -768,6 +818,7 @@ public class MessageClassifier {
accountMsgIds.clear(); accountMsgIds.clear();
classMessages.clear(); classMessages.clear();
wordClassFrequency.clear(); wordClassFrequency.clear();
wordIndex.clear();
dirty = true; dirty = true;
Log.i("Classifier data cleared"); Log.i("Classifier data cleared");
} }
@ -791,17 +842,17 @@ public class MessageClassifier {
} }
private static class State { private static class State {
private final List<String> words = new ArrayList<>(); private final List<Integer> words = new ArrayList<>();
private final Map<String, Stat> classStats = new HashMap<>(); private final Map<String, Stat> classStats = new HashMap<>();
} }
private static class Frequency { private static class Frequency {
private int count = 0; private int count = 0;
private int duplicates = 0; private int duplicates = 0;
private Map<String, Integer> before = new HashMap<>(); private Map<Integer, Integer> before = new HashMap<>();
private Map<String, Integer> after = new HashMap<>(); private Map<Integer, Integer> after = new HashMap<>();
private void add(String b, String a, int c, boolean duplicate) { private void add(Integer b, Integer a, int c, boolean duplicate) {
if (count + c < 0) if (count + c < 0)
return; return;