Context sensitive classifier

This commit is contained in:
M66B 2021-01-06 21:35:11 +01:00
parent 63beb8f87d
commit 502357b4fd
1 changed files with 117 additions and 54 deletions

View File

@ -39,6 +39,7 @@ import java.util.Collections;
import java.util.Comparator;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
@ -49,7 +50,7 @@ public class MessageClassifier {
private static boolean loaded = false;
private static boolean dirty = false;
private static final Map<Long, Map<String, Integer>> classMessages = new HashMap<>();
private static final Map<Long, Map<String, Map<String, Integer>>> wordClassFrequency = new HashMap<>();
private static final Map<Long, Map<String, Map<String, Frequency>>> wordClassFrequency = new HashMap<>();
private static final int MIN_MATCHED_WORDS = 10;
private static final double CHANCE_THRESHOLD = 2.0;
@ -171,9 +172,8 @@ public class MessageClassifier {
}
private static String classify(long account, String currentClass, String text, boolean added, Context context) {
int maxMatchedWords = 0;
List<String> words = new ArrayList<>();
Map<String, Stat> classStats = new HashMap<>();
State state = new State();
state.words.add(null);
BreakIterator boundary = BreakIterator.getWordInstance(); // TODO ICU
boundary.setText(text);
@ -181,59 +181,26 @@ public class MessageClassifier {
for (int end = boundary.next(); end != BreakIterator.DONE; end = boundary.next()) {
String word = text.substring(start, end).toLowerCase();
if (word.length() > 1 &&
!words.contains(word) &&
!state.words.contains(word) &&
!word.matches(".*\\d.*")) {
words.add(word);
Map<String, Integer> classFrequency = wordClassFrequency.get(account).get(word);
if (added) {
if (classFrequency == null) {
classFrequency = new HashMap<>();
wordClassFrequency.get(account).put(word, classFrequency);
}
for (String clazz : classFrequency.keySet()) {
int frequency = classFrequency.get(clazz);
Stat stat = classStats.get(clazz);
if (stat == null) {
stat = new Stat();
classStats.put(clazz, stat);
}
stat.matchedWords++;
stat.totalFrequency += frequency;
if (BuildConfig.DEBUG)
stat.words.add(word);
if (stat.matchedWords > maxMatchedWords)
maxMatchedWords = stat.matchedWords;
}
Integer c = classFrequency.get(currentClass);
c = (c == null ? 1 : c + 1);
classFrequency.put(currentClass, c);
} else {
Integer c = (classFrequency == null ? null : classFrequency.get(currentClass));
if (c != null)
if (c > 0)
classFrequency.put(currentClass, c - 1);
else
classFrequency.remove(currentClass);
}
state.words.add(word);
process(account, currentClass, added, state);
}
start = end;
}
state.words.add(null);
process(account, currentClass, added, state);
if (!added)
return null;
if (maxMatchedWords == 0)
if (state.maxMatchedWords == 0)
return null;
DB db = DB.getInstance(context);
List<Chance> chances = new ArrayList<>();
for (String clazz : classStats.keySet()) {
for (String clazz : state.classStats.keySet()) {
Integer messages = classMessages.get(account).get(clazz);
if (messages == null || messages == 0) {
Log.w("Classifier no messages class=" + account + ":" + clazz);
@ -246,13 +213,13 @@ public class MessageClassifier {
continue;
}
Stat stat = classStats.get(clazz);
Stat stat = state.classStats.get(clazz);
boolean consider = (stat.matchedWords >= MIN_MATCHED_WORDS);
double chance = (double) stat.totalFrequency / messages / maxMatchedWords;
double chance = stat.totalFrequency / messages / state.maxMatchedWords;
Chance c = new Chance(clazz, chance);
EntityLog.log(context, "Classifier " + c +
" frequency=" + stat.totalFrequency + "/" + messages +
" matched=" + stat.matchedWords + "/" + maxMatchedWords +
" matched=" + stat.matchedWords + "/" + state.maxMatchedWords +
" consider=" + consider +
" words=" + TextUtils.join(", ", stat.words));
if (consider)
@ -260,7 +227,7 @@ public class MessageClassifier {
}
if (BuildConfig.DEBUG)
Log.i("Classifier words=" + TextUtils.join(", ", words));
Log.i("Classifier words=" + TextUtils.join(", ", state.words));
if (chances.size() <= 1)
return null;
@ -281,6 +248,67 @@ public class MessageClassifier {
return classification;
}
private static void process(long account, String currentClass, boolean added, State state) {
if (state.words.size() < 3)
return;
String before = state.words.get(state.words.size() - 3);
String current = state.words.get(state.words.size() - 2);
String after = state.words.get(state.words.size() - 1);
Map<String, Frequency> classFrequency = wordClassFrequency.get(account).get(current);
if (added) {
if (classFrequency == null) {
classFrequency = new HashMap<>();
wordClassFrequency.get(account).put(current, classFrequency);
}
for (String clazz : classFrequency.keySet()) {
Frequency frequency = classFrequency.get(clazz);
Stat stat = state.classStats.get(clazz);
if (stat == null) {
stat = new Stat();
state.classStats.put(clazz, stat);
}
stat.matchedWords++;
boolean b = (before != null && frequency.before.contains(before));
boolean a = (after != null && frequency.after.contains(after));
if (b && a)
stat.totalFrequency += frequency.count;
else if (b || a)
stat.totalFrequency += frequency.count * 0.5;
else
stat.totalFrequency += frequency.count * 0.25;
if (BuildConfig.DEBUG)
stat.words.add(current);
if (stat.matchedWords > state.maxMatchedWords)
state.maxMatchedWords = stat.matchedWords;
}
Frequency c = classFrequency.get(currentClass);
if (c == null)
c = new Frequency();
c.count++;
if (before != null && !c.before.contains(before))
c.before.add(before);
if (after != null && !c.after.contains(after))
c.after.add(after);
classFrequency.put(currentClass, c);
} else {
Frequency c = (classFrequency == null ? null : classFrequency.get(currentClass));
if (c != null)
if (c.count > 0)
c.count--;
else
classFrequency.remove(currentClass);
}
}
static synchronized void save(Context context) throws JSONException, IOException {
if (!dirty)
return;
@ -345,13 +373,16 @@ public class MessageClassifier {
JSONArray jwords = new JSONArray();
for (Long account : classMessages.keySet())
for (String word : wordClassFrequency.get(account).keySet()) {
Map<String, Integer> classFrequency = wordClassFrequency.get(account).get(word);
Map<String, Frequency> classFrequency = wordClassFrequency.get(account).get(word);
for (String clazz : classFrequency.keySet()) {
Frequency f = classFrequency.get(clazz);
JSONObject jword = new JSONObject();
jword.put("account", account);
jword.put("word", word);
jword.put("class", clazz);
jword.put("frequency", classFrequency.get(clazz));
jword.put("frequency", f.count);
jword.put("before", from(f.before));
jword.put("after", from(f.after));
jwords.put(jword);
}
}
@ -363,6 +394,13 @@ public class MessageClassifier {
return jroot;
}
private static JSONArray from(HashSet<String> list) {
JSONArray jarray = new JSONArray();
for (String item : list)
jarray.put(item);
return jarray;
}
static void fromJson(JSONObject jroot) throws JSONException {
JSONArray jmessages = jroot.getJSONArray("messages");
for (int m = 0; m < jmessages.length(); m++) {
@ -380,18 +418,43 @@ public class MessageClassifier {
if (!wordClassFrequency.containsKey(account))
wordClassFrequency.put(account, new HashMap<>());
String word = jword.getString("word");
Map<String, Integer> classFrequency = wordClassFrequency.get(account).get(word);
Map<String, Frequency> classFrequency = wordClassFrequency.get(account).get(word);
if (classFrequency == null) {
classFrequency = new HashMap<>();
wordClassFrequency.get(account).put(word, classFrequency);
}
classFrequency.put(jword.getString("class"), jword.getInt("frequency"));
Frequency f = new Frequency();
f.count = jword.getInt("frequency");
if (jword.has("before"))
f.before = from(jword.getJSONArray("before"));
if (jword.has("after"))
f.after = from(jword.getJSONArray("after"));
classFrequency.put(jword.getString("class"), f);
}
}
private static HashSet<String> from(JSONArray jarray) throws JSONException {
HashSet<String> result = new HashSet<>(jarray.length());
for (int i = 0; i < jarray.length(); i++)
result.add((String) jarray.get(i));
return result;
}
private static class State {
int maxMatchedWords = 0;
List<String> words = new ArrayList<>();
Map<String, Stat> classStats = new HashMap<>();
}
private static class Frequency {
int count;
HashSet<String> before = new HashSet<>();
HashSet<String> after = new HashSet<>();
}
private static class Stat {
int matchedWords = 0;
int totalFrequency = 0;
double totalFrequency = 0;
List<String> words = new ArrayList<>();
}