Classifier simplication/fixes

This commit is contained in:
M66B 2021-01-07 10:58:56 +01:00
parent 5cce1c4ad6
commit 7683a7f047
1 changed files with 20 additions and 58 deletions

View File

@ -49,10 +49,8 @@ import javax.mail.internet.InternetAddress;
public class MessageClassifier { public class MessageClassifier {
private static boolean loaded = false; private static boolean loaded = false;
private static boolean dirty = false; private static boolean dirty = false;
private static final Map<Long, Map<String, Integer>> classMessages = new HashMap<>();
private static final Map<Long, Map<String, Map<String, Frequency>>> wordClassFrequency = new HashMap<>(); private static final Map<Long, Map<String, Map<String, Frequency>>> wordClassFrequency = new HashMap<>();
private static final int MIN_MATCHED_WORDS = 10;
private static final double CHANCE_THRESHOLD = 2.0; private static final double CHANCE_THRESHOLD = 2.0;
static void classify(EntityMessage message, EntityFolder folder, EntityFolder target, Context context) { static void classify(EntityMessage message, EntityFolder folder, EntityFolder target, Context context) {
@ -113,8 +111,6 @@ public class MessageClassifier {
load(context); load(context);
// Initialize data if needed // Initialize data if needed
if (!classMessages.containsKey(folder.account))
classMessages.put(folder.account, new HashMap<>());
if (!wordClassFrequency.containsKey(folder.account)) if (!wordClassFrequency.containsKey(folder.account))
wordClassFrequency.put(folder.account, new HashMap<>()); wordClassFrequency.put(folder.account, new HashMap<>());
@ -131,18 +127,6 @@ public class MessageClassifier {
" re=" + message.auto_classified + " re=" + message.auto_classified +
" elapsed=" + elapsed); " elapsed=" + elapsed);
// Update message count
Integer m = classMessages.get(folder.account).get(folder.name);
if (target == null) {
m = (m == null ? 1 : m + 1);
classMessages.get(folder.account).put(folder.name, m);
} else {
if (m != null && m > 0)
classMessages.get(folder.account).put(folder.name, m - 1);
}
EntityLog.log(context, "Classifier classify=" + folder.name +
" messages=" + classMessages.get(folder.account).get(folder.name));
dirty = true; dirty = true;
// Auto classify // Auto classify
@ -172,6 +156,15 @@ public class MessageClassifier {
} }
private static String classify(long account, String currentClass, String text, boolean added, Context context) { private static String classify(long account, String currentClass, String text, boolean added, Context context) {
int maxMessages = 0;
for (String word : wordClassFrequency.get(account).keySet()) {
for (String clazz : wordClassFrequency.get(account).get(word).keySet()) {
int count = wordClassFrequency.get(account).get(word).get(clazz).count;
if (count > maxMessages)
maxMessages = count;
}
}
State state = new State(); State state = new State();
state.words.add(null); state.words.add(null);
@ -180,7 +173,7 @@ public class MessageClassifier {
boundary.setText(text); boundary.setText(text);
int start = boundary.first(); int start = boundary.first();
for (int end = boundary.next(); end != java.text.BreakIterator.DONE; end = boundary.next()) { for (int end = boundary.next(); end != java.text.BreakIterator.DONE; end = boundary.next()) {
String word = text.substring(start, end).toLowerCase(); String word = text.substring(start, end).trim().toLowerCase();
if (word.length() > 1 && if (word.length() > 1 &&
!state.words.contains(word) && !state.words.contains(word) &&
!word.matches(".*\\d.*")) { !word.matches(".*\\d.*")) {
@ -195,7 +188,7 @@ public class MessageClassifier {
boundary.setText(text); boundary.setText(text);
int start = boundary.first(); int start = boundary.first();
for (int end = boundary.next(); end != android.icu.text.BreakIterator.DONE; end = boundary.next()) { for (int end = boundary.next(); end != android.icu.text.BreakIterator.DONE; end = boundary.next()) {
String word = text.substring(start, end).toLowerCase(); String word = text.substring(start, end).trim().toLowerCase();
if (word.length() > 1 && if (word.length() > 1 &&
!state.words.contains(word) && !state.words.contains(word) &&
!word.matches(".*\\d.*")) { !word.matches(".*\\d.*")) {
@ -212,16 +205,6 @@ public class MessageClassifier {
if (!added) if (!added)
return null; return null;
if (state.maxMatchedWords < MIN_MATCHED_WORDS)
return null;
int maxMessages = 0;
for (String clazz : state.classStats.keySet()) {
Integer messages = classMessages.get(account).get(clazz);
if (messages != null && messages > maxMessages)
maxMessages = messages;
}
if (maxMessages == 0) { if (maxMessages == 0) {
Log.e("Classifier no messages account=" + account); Log.e("Classifier no messages account=" + account);
} }
@ -236,13 +219,14 @@ public class MessageClassifier {
} }
Stat stat = state.classStats.get(clazz); Stat stat = state.classStats.get(clazz);
double chance = stat.totalFrequency / maxMessages / state.maxMatchedWords;
double chance = stat.totalFrequency / maxMessages / state.words.size();
Chance c = new Chance(clazz, chance); Chance c = new Chance(clazz, chance);
EntityLog.log(context, "Classifier " + c +
" frequency=" + stat.totalFrequency + "/" + maxMessages +
" matched=" + stat.matchedWords + "/" + state.maxMatchedWords +
" words=" + TextUtils.join(", ", stat.words));
chances.add(c); chances.add(c);
EntityLog.log(context, "Classifier " + c +
" frequency=" + stat.totalFrequency + "/" + maxMessages + " msgs" +
" matched=" + stat.matchedWords + "/" + state.words.size() + " words" +
" text=" + TextUtils.join(", ", stat.words));
} }
if (BuildConfig.DEBUG) if (BuildConfig.DEBUG)
@ -294,8 +278,8 @@ public class MessageClassifier {
int c = frequency.count; int c = frequency.count;
Integer b = (before == null ? null : frequency.before.get(before)); Integer b = (before == null ? null : frequency.before.get(before));
Integer a = (after == null ? null : frequency.after.get(after)); Integer a = (after == null ? null : frequency.after.get(after));
stat.totalFrequency += double f = ((b == null ? 0 : b) + c + (a == null ? 0 : a)) / 3.0;
((b == null ? 0.0 : (double) b / c) + c + (a == null ? 0.0 : (double) a / c)) / 3; stat.totalFrequency += f;
stat.matchedWords++; stat.matchedWords++;
if (stat.matchedWords > state.maxMatchedWords) if (stat.matchedWords > state.maxMatchedWords)
@ -333,7 +317,6 @@ public class MessageClassifier {
if (loaded || dirty) if (loaded || dirty)
return; return;
classMessages.clear();
wordClassFrequency.clear(); wordClassFrequency.clear();
File file = getFile(context); File file = getFile(context);
@ -347,7 +330,6 @@ public class MessageClassifier {
} }
static synchronized void clear(Context context) { static synchronized void clear(Context context) {
classMessages.clear();
wordClassFrequency.clear(); wordClassFrequency.clear();
dirty = true; dirty = true;
Log.i("Classifier data cleared"); Log.i("Classifier data cleared");
@ -369,18 +351,8 @@ public class MessageClassifier {
} }
static JSONObject toJson() throws JSONException { static JSONObject toJson() throws JSONException {
JSONArray jmessages = new JSONArray();
for (Long account : classMessages.keySet())
for (String clazz : classMessages.get(account).keySet()) {
JSONObject jmessage = new JSONObject();
jmessage.put("account", account);
jmessage.put("class", clazz);
jmessage.put("count", classMessages.get(account).get(clazz));
jmessages.put(jmessage);
}
JSONArray jwords = new JSONArray(); JSONArray jwords = new JSONArray();
for (Long account : classMessages.keySet()) for (Long account : wordClassFrequency.keySet())
for (String word : wordClassFrequency.get(account).keySet()) { for (String word : wordClassFrequency.get(account).keySet()) {
Map<String, Frequency> classFrequency = wordClassFrequency.get(account).get(word); Map<String, Frequency> classFrequency = wordClassFrequency.get(account).get(word);
for (String clazz : classFrequency.keySet()) { for (String clazz : classFrequency.keySet()) {
@ -397,7 +369,6 @@ public class MessageClassifier {
} }
JSONObject jroot = new JSONObject(); JSONObject jroot = new JSONObject();
jroot.put("messages", jmessages);
jroot.put("words", jwords); jroot.put("words", jwords);
return jroot; return jroot;
@ -411,15 +382,6 @@ public class MessageClassifier {
} }
static void fromJson(JSONObject jroot) throws JSONException { static void fromJson(JSONObject jroot) throws JSONException {
JSONArray jmessages = jroot.getJSONArray("messages");
for (int m = 0; m < jmessages.length(); m++) {
JSONObject jmessage = (JSONObject) jmessages.get(m);
long account = jmessage.getLong("account");
if (!classMessages.containsKey(account))
classMessages.put(account, new HashMap<>());
classMessages.get(account).put(jmessage.getString("class"), jmessage.getInt("count"));
}
JSONArray jwords = jroot.getJSONArray("words"); JSONArray jwords = jroot.getJSONArray("words");
for (int w = 0; w < jwords.length(); w++) { for (int w = 0; w < jwords.length(); w++) {
JSONObject jword = (JSONObject) jwords.get(w); JSONObject jword = (JSONObject) jwords.get(w);