Classifier word list

This commit is contained in:
M66B 2023-03-22 15:18:59 +01:00
parent 27307edf56
commit 1221b69dcc
1 changed files with 90 additions and 39 deletions

View File

@ -41,6 +41,7 @@ import java.util.Collections;
import java.util.Comparator;
import java.util.Date;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
@ -48,13 +49,14 @@ import javax.mail.Address;
import javax.mail.internet.InternetAddress;
public class MessageClassifier {
private static int version = 3;
private static boolean loaded = false;
private static boolean dirty = false;
private static final Map<Long, List<String>> accountMsgIds = new HashMap<>();
private static final Map<Long, Map<String, Integer>> classMessages = new HashMap<>();
private static final Map<Long, Map<String, Map<String, Frequency>>> wordClassFrequency = new HashMap<>();
private static final Map<Long, Map<Integer, Map<String, Frequency>>> wordClassFrequency = new HashMap<>();
private static final Map<String, Integer> wordIndex = new LinkedHashMap<>();
private static final int VERSION = 4;
private static final int MAX_WORDS = 1000;
static synchronized void classify(EntityMessage message, EntityFolder folder, boolean added, Context context) {
@ -192,7 +194,7 @@ public class MessageClassifier {
" class=" + message.account + ":" + clazz +
" exists=" + (folder != null));
classMessages.get(message.account).remove(clazz);
for (String word : wordClassFrequency.get(message.account).keySet())
for (int word : wordClassFrequency.get(message.account).keySet())
wordClassFrequency.get(message.account).get(word).remove(clazz);
}
}
@ -265,8 +267,15 @@ public class MessageClassifier {
" text=" + TextUtils.join(", ", stat.words));
}
if (BuildConfig.DEBUG)
Log.i("Classifier words=" + state.words.size() + " " + TextUtils.join(", ", state.words));
if (BuildConfig.DEBUG) {
StringBuilder sb = new StringBuilder();
for (Integer word : state.words) {
if (sb.length() > 0)
sb.append(", ");
sb.append(getWord(word));
}
Log.i("Classifier words=" + state.words.size() + " " + sb);
}
// Sort classes by chance
Collections.sort(chances, new Comparator<Chance>() {
@ -320,6 +329,11 @@ public class MessageClassifier {
return;
}
_processWord(account, added, word == null ? null : getWordIndex(word), state);
}
private static void _processWord(long account, boolean added, Integer word, State state) {
if (word != null ||
state.words.size() == 0 ||
state.words.get(state.words.size() - 1) != null)
@ -331,9 +345,9 @@ public class MessageClassifier {
if (state.words.size() < 3)
return;
String before = state.words.get(state.words.size() - 3);
String current = state.words.get(state.words.size() - 2);
String after = state.words.get(state.words.size() - 1);
Integer before = state.words.get(state.words.size() - 3);
Integer current = state.words.get(state.words.size() - 2);
Integer after = state.words.get(state.words.size() - 1);
if (current == null)
return;
@ -373,6 +387,25 @@ public class MessageClassifier {
}
}
private static int getWordIndex(String word) {
synchronized (wordIndex) {
Integer index = wordIndex.get(word);
if (index == null) {
index = wordIndex.size();
wordIndex.put(word, index);
}
return index;
}
}
private static String getWord(Integer index) {
if (index == null)
return "<null>";
if (index < 0 || index >= wordIndex.size())
return "<" + index + ">";
return new ArrayList<>(wordIndex.keySet()).get(index);
}
private static void updateFrequencies(long account, @NonNull String currentClass, boolean added, @NonNull State state) {
Integer m = classMessages.get(account).get(currentClass);
m = (m == null ? 0 : m) + (added ? 1 : -1);
@ -383,9 +416,9 @@ public class MessageClassifier {
Log.i("Classifier " + currentClass + "=" + m + " msgs");
for (int i = 1; i < state.words.size() - 1; i++) {
String before = state.words.get(i - 1);
String current = state.words.get(i);
String after = state.words.get(i + 1);
Integer before = state.words.get(i - 1);
Integer current = state.words.get(i);
Integer after = state.words.get(i + 1);
if (current == null)
continue;
@ -428,8 +461,8 @@ public class MessageClassifier {
try (JsonWriter writer = new JsonWriter(new BufferedWriter(new FileWriter(file)))) {
writer.beginObject();
Log.i("Classifier write version=" + version);
writer.name("version").value(version);
Log.i("Classifier write version=" + VERSION);
writer.name("version").value(VERSION);
writer.name("messages");
writer.beginArray();
@ -446,7 +479,7 @@ public class MessageClassifier {
writer.name("words");
writer.beginArray();
for (Long account : wordClassFrequency.keySet())
for (String word : wordClassFrequency.get(account).keySet()) {
for (int word : wordClassFrequency.get(account).keySet()) {
Map<String, Frequency> classFrequency = wordClassFrequency.get(account).get(word);
for (String clazz : classFrequency.keySet()) {
Frequency f = classFrequency.get(clazz);
@ -460,14 +493,14 @@ public class MessageClassifier {
writer.name("before");
writer.beginObject();
for (String key : f.before.keySet())
writer.name(key).value(f.before.get(key));
for (int key : f.before.keySet())
writer.name(Integer.toString(key)).value(f.before.get(key));
writer.endObject();
writer.name("after");
writer.beginObject();
for (String key : f.after.keySet())
writer.name(key).value(f.after.get(key));
for (int key : f.after.keySet())
writer.name(Integer.toString(key)).value(f.after.get(key));
writer.endObject();
writer.endObject();
@ -475,6 +508,12 @@ public class MessageClassifier {
}
writer.endArray();
writer.name("list");
writer.beginArray();
for (String word : wordIndex.keySet())
writer.value(word);
writer.endArray();
writer.name("classified");
writer.beginArray();
for (Long account : accountMsgIds.keySet()) {
@ -521,7 +560,7 @@ public class MessageClassifier {
private static synchronized void _load(File file) throws IOException {
Log.i("Classifier read " + file);
long start = new Date().getTime();
version = 0;
int version = 0;
if (file.exists())
try (JsonReader reader = new JsonReader(new BufferedReader(new FileReader(file)))) {
reader.beginObject();
@ -568,7 +607,7 @@ public class MessageClassifier {
reader.beginArray();
while (reader.hasNext()) {
Long account = null;
String word = null;
Integer word = null;
String clazz = null;
Frequency f = new Frequency();
@ -579,7 +618,10 @@ public class MessageClassifier {
account = reader.nextLong();
break;
case "word":
word = reader.nextString();
if (version > 3)
word = Integer.parseInt(reader.nextString());
else
word = getWordIndex(reader.nextString());
break;
case "class":
clazz = reader.nextString();
@ -592,14 +634,22 @@ public class MessageClassifier {
break;
case "before":
reader.beginObject();
while (reader.hasNext())
f.before.put(reader.nextName(), reader.nextInt());
while (reader.hasNext()) {
int b = (version > 3
? Integer.parseInt(reader.nextName())
: getWordIndex(reader.nextName()));
f.before.put(b, reader.nextInt());
}
reader.endObject();
break;
case "after":
reader.beginObject();
while (reader.hasNext())
f.after.put(reader.nextName(), reader.nextInt());
while (reader.hasNext()) {
int a = (version > 3
? Integer.parseInt(reader.nextName())
: getWordIndex(reader.nextName()));
f.after.put(a, reader.nextInt());
}
reader.endObject();
break;
}
@ -622,6 +672,13 @@ public class MessageClassifier {
reader.endArray();
break;
case "list":
reader.beginArray();
while (reader.hasNext())
wordIndex.put(reader.nextString(), wordIndex.size());
reader.endArray();
break;
case "classified":
reader.beginArray();
while (reader.hasNext()) {
@ -660,7 +717,7 @@ public class MessageClassifier {
dirty = false;
long elapsed = new Date().getTime() - start;
Log.i("Classifier data loaded elapsed=" + elapsed);
Log.i("Classifier data loaded elapsed=" + elapsed + " words=" + wordIndex.size());
}
private static void reduce() {
@ -670,7 +727,7 @@ public class MessageClassifier {
Map<String, Long> total = new HashMap<>();
Map<String, Integer> count = new HashMap<>();
for (String word : wordClassFrequency.get(account).keySet())
for (int word : wordClassFrequency.get(account).keySet())
for (String clazz : wordClassFrequency.get(account).get(word).keySet()) {
int f = wordClassFrequency.get(account).get(word).get(clazz).count;
@ -691,7 +748,7 @@ public class MessageClassifier {
Log.i("Classifier max " + account + ":" + clazz + "=" + max.get(clazz));
int dropped = 0;
for (String word : wordClassFrequency.get(account).keySet())
for (int word : wordClassFrequency.get(account).keySet())
for (String clazz : new ArrayList<>(wordClassFrequency.get(account).get(word).keySet())) {
long m = max.get(clazz);
long avg = total.get(clazz) / count.get(clazz);
@ -703,13 +760,6 @@ public class MessageClassifier {
dropped++;
Log.i("Classifier dropping account=" + account +
" word=" + word + " class=" + clazz + " freq=" + freq.count + " avg=" + avg);
} else if (version >= 3 && false) {
for (String b : new ArrayList<>(freq.before.keySet()))
if (freq.before.get(b) < freq.count / 20)
freq.before.remove(b);
for (String a : new ArrayList<>(freq.after.keySet()))
if (freq.after.get(a) < freq.count / 20)
freq.after.remove(a);
}
}
Log.i("Classifier dropped words=" + dropped);
@ -768,6 +818,7 @@ public class MessageClassifier {
accountMsgIds.clear();
classMessages.clear();
wordClassFrequency.clear();
wordIndex.clear();
dirty = true;
Log.i("Classifier data cleared");
}
@ -791,17 +842,17 @@ public class MessageClassifier {
}
private static class State {
private final List<String> words = new ArrayList<>();
private final List<Integer> words = new ArrayList<>();
private final Map<String, Stat> classStats = new HashMap<>();
}
private static class Frequency {
private int count = 0;
private int duplicates = 0;
private Map<String, Integer> before = new HashMap<>();
private Map<String, Integer> after = new HashMap<>();
private Map<Integer, Integer> before = new HashMap<>();
private Map<Integer, Integer> after = new HashMap<>();
private void add(String b, String a, int c, boolean duplicate) {
private void add(Integer b, Integer a, int c, boolean duplicate) {
if (count + c < 0)
return;