diff --git a/app/src/main/jni/netguard/netguard.c b/app/src/main/jni/netguard/netguard.c index 87e02d39..43ab01da 100644 --- a/app/src/main/jni/netguard/netguard.c +++ b/app/src/main/jni/netguard/netguard.c @@ -45,17 +45,17 @@ // Global variables -static JavaVM *jvm; -pthread_t thread_id = 0; -pthread_mutex_t lock; -jboolean stopping = 0; -jboolean signaled = 0; +static JavaVM *jvm = NULL; +static pthread_t thread_id = 0; +static pthread_mutex_t lock; +static jboolean stopping = 0; +static jboolean signaled = 0; -struct udp_session *udp_session = NULL; -struct tcp_session *tcp_session = NULL; +static struct udp_session *udp_session = NULL; +static struct tcp_session *tcp_session = NULL; -int loglevel = 0; -FILE *pcap_file = NULL; +static int loglevel = 0; +static FILE *pcap_file = NULL; // JNI @@ -2377,11 +2377,15 @@ jint get_uid(const int protocol, const int version, return uid; } +static jmethodID midProtect = NULL; + int protect_socket(const struct arguments *args, int socket) { jclass cls = (*args->env)->GetObjectClass(args->env, args->instance); - jmethodID mid = jniGetMethodID(args->env, cls, "protect", "(I)Z"); + if (midProtect == NULL) + midProtect = jniGetMethodID(args->env, cls, "protect", "(I)Z"); - jboolean isProtected = (*args->env)->CallBooleanMethod(args->env, args->instance, mid, socket); + jboolean isProtected = (*args->env)->CallBooleanMethod( + args->env, args->instance, midProtect, socket); jniCheckException(args->env); if (!isProtected) { @@ -2432,51 +2436,16 @@ jclass jniFindClass(JNIEnv *env, const char *name) { return cls; } -jmethodID method_protect = NULL; -jmethodID method_logPacket = NULL; -jmethodID method_isDomainBlocked = NULL; -jmethodID method_isAddressAllowed = NULL; - jmethodID jniGetMethodID(JNIEnv *env, jclass cls, const char *name, const char *signature) { - if (strcmp(name, "protect") == 0 && method_protect != NULL) - return method_protect; - - if (strcmp(name, "logPacket") == 0 && method_logPacket != NULL) - return method_logPacket; - - if (strcmp(name, "isDomainBlocked") == 0 && method_isDomainBlocked != NULL) - return method_isDomainBlocked; - - if (strcmp(name, "isAddressAllowed") == 0 && method_isAddressAllowed != NULL) - return method_isAddressAllowed; - jmethodID method = (*env)->GetMethodID(env, cls, name, signature); if (method == NULL) { log_android(ANDROID_LOG_ERROR, "Method %s %s not found", name, signature); jniCheckException(env); - } else { - if (strcmp(name, "protect") == 0) { - method_protect = method; - log_android(ANDROID_LOG_INFO, "Cached method ID protect"); - } - else if (strcmp(name, "logPacket") == 0) { - method_logPacket = method; - log_android(ANDROID_LOG_INFO, "Cached method ID logPacket"); - } - else if (strcmp(name, "isDomainBlocked") == 0) { - method_isDomainBlocked = method; - log_android(ANDROID_LOG_INFO, "Cached method ID isDomainBlocked"); - } - else if (strcmp(name, "isAddressAllowed") == 0) { - method_isAddressAllowed = method; - log_android(ANDROID_LOG_INFO, "Cached method ID isAddressAllowed"); - } } return method; } jfieldID jniGetFieldID(JNIEnv *env, jclass cls, const char *name, const char *type) { - // TODO cache field IDs jfieldID field = (*env)->GetFieldID(env, cls, name, type); if (field == NULL) log_android(ANDROID_LOG_ERROR, "Field %s type %s not found", name, type); @@ -2531,6 +2500,8 @@ void log_android(int prio, const char *fmt, ...) { } } +static jmethodID midLogPacket = NULL; + void log_packet(const struct arguments *args, jobject jpacket) { #ifdef PROFILE float mselapsed; @@ -2541,9 +2512,10 @@ void log_packet(const struct arguments *args, jobject jpacket) { jclass clsService = (*args->env)->GetObjectClass(args->env, args->instance); const char *signature = "(Leu/faircode/netguard/Packet;)V"; - jmethodID method = jniGetMethodID(args->env, clsService, "logPacket", signature); + if (midLogPacket == NULL) + midLogPacket = jniGetMethodID(args->env, clsService, "logPacket", signature); - (*args->env)->CallVoidMethod(args->env, args->instance, method, jpacket); + (*args->env)->CallVoidMethod(args->env, args->instance, midLogPacket, jpacket); jniCheckException(args->env); (*args->env)->DeleteLocalRef(args->env, jpacket); @@ -2558,6 +2530,8 @@ void log_packet(const struct arguments *args, jobject jpacket) { #endif } +static jmethodID midIsDomainBlocked = NULL; + jboolean is_domain_blocked(const struct arguments *args, const char *name) { #ifdef PROFILE float mselapsed; @@ -2568,11 +2542,13 @@ jboolean is_domain_blocked(const struct arguments *args, const char *name) { jclass clsService = (*args->env)->GetObjectClass(args->env, args->instance); const char *signature = "(Ljava/lang/String;)Z"; - jmethodID method = jniGetMethodID(args->env, clsService, "isDomainBlocked", signature); + if (midIsDomainBlocked == NULL) + midIsDomainBlocked = jniGetMethodID(args->env, clsService, "isDomainBlocked", signature); jstring jname = (*args->env)->NewStringUTF(args->env, name); - jboolean jallowed = (*args->env)->CallBooleanMethod(args->env, args->instance, method, jname); + jboolean jallowed = (*args->env)->CallBooleanMethod( + args->env, args->instance, midIsDomainBlocked, jname); jniCheckException(args->env); (*args->env)->DeleteLocalRef(args->env, jname); @@ -2589,6 +2565,8 @@ jboolean is_domain_blocked(const struct arguments *args, const char *name) { return jallowed; } +static jmethodID midIsAddressAllowed = NULL; + jboolean is_address_allowed(const struct arguments *args, jobject jpacket) { #ifdef PROFILE float mselapsed; @@ -2599,9 +2577,11 @@ jboolean is_address_allowed(const struct arguments *args, jobject jpacket) { jclass clsService = (*args->env)->GetObjectClass(args->env, args->instance); const char *signature = "(Leu/faircode/netguard/Packet;)Z"; - jmethodID method = jniGetMethodID(args->env, clsService, "isAddressAllowed", signature); + if (midIsAddressAllowed == NULL) + midIsAddressAllowed = jniGetMethodID(args->env, clsService, "isAddressAllowed", signature); - jboolean jallowed = (*args->env)->CallBooleanMethod(args->env, args->instance, method, jpacket); + jboolean jallowed = (*args->env)->CallBooleanMethod( + args->env, args->instance, midIsAddressAllowed, jpacket); jniCheckException(args->env); (*args->env)->DeleteLocalRef(args->env, jpacket); @@ -2618,6 +2598,20 @@ jboolean is_address_allowed(const struct arguments *args, jobject jpacket) { return jallowed; } +jmethodID midInitPacket = NULL; + +jfieldID fidTime = NULL; +jfieldID fidVersion = NULL; +jfieldID fidProtocol = NULL; +jfieldID fidFlags = NULL; +jfieldID fidSaddr = NULL; +jfieldID fidSport = NULL; +jfieldID fidDaddr = NULL; +jfieldID fidDport = NULL; +jfieldID fidData = NULL; +jfieldID fidUid = NULL; +jfieldID fidAllowed = NULL; + jobject create_packet(const struct arguments *args, jint version, jint protocol, @@ -2632,8 +2626,9 @@ jobject create_packet(const struct arguments *args, JNIEnv *env = args->env; const char *packet = "eu/faircode/netguard/Packet"; - jmethodID initPacket = jniGetMethodID(env, clsPacket, "", "()V"); - jobject jpacket = jniNewObject(env, clsPacket, initPacket, packet); + if (midInitPacket == NULL) + midInitPacket = jniGetMethodID(env, clsPacket, "", "()V"); + jobject jpacket = jniNewObject(env, clsPacket, midInitPacket, packet); struct timeval tv; gettimeofday(&tv, NULL); @@ -2643,18 +2638,32 @@ jobject create_packet(const struct arguments *args, jstring jdest = (*env)->NewStringUTF(env, dest); jstring jdata = (*env)->NewStringUTF(env, data); - const char *string = "Ljava/lang/String;"; - (*env)->SetLongField(env, jpacket, jniGetFieldID(env, clsPacket, "time", "J"), t); - (*env)->SetIntField(env, jpacket, jniGetFieldID(env, clsPacket, "version", "I"), version); - (*env)->SetIntField(env, jpacket, jniGetFieldID(env, clsPacket, "protocol", "I"), protocol); - (*env)->SetObjectField(env, jpacket, jniGetFieldID(env, clsPacket, "flags", string), jflags); - (*env)->SetObjectField(env, jpacket, jniGetFieldID(env, clsPacket, "saddr", string), jsource); - (*env)->SetIntField(env, jpacket, jniGetFieldID(env, clsPacket, "sport", "I"), sport); - (*env)->SetObjectField(env, jpacket, jniGetFieldID(env, clsPacket, "daddr", string), jdest); - (*env)->SetIntField(env, jpacket, jniGetFieldID(env, clsPacket, "dport", "I"), dport); - (*env)->SetObjectField(env, jpacket, jniGetFieldID(env, clsPacket, "data", string), jdata); - (*env)->SetIntField(env, jpacket, jniGetFieldID(env, clsPacket, "uid", "I"), uid); - (*env)->SetBooleanField(env, jpacket, jniGetFieldID(env, clsPacket, "allowed", "Z"), allowed); + if (fidTime == NULL) { + const char *string = "Ljava/lang/String;"; + fidTime = jniGetFieldID(env, clsPacket, "time", "J"); + fidVersion = jniGetFieldID(env, clsPacket, "version", "I"); + fidProtocol = jniGetFieldID(env, clsPacket, "protocol", "I"); + fidFlags = jniGetFieldID(env, clsPacket, "flags", string); + fidSaddr = jniGetFieldID(env, clsPacket, "saddr", string); + fidSport = jniGetFieldID(env, clsPacket, "sport", "I"); + fidDaddr = jniGetFieldID(env, clsPacket, "daddr", string); + fidDport = jniGetFieldID(env, clsPacket, "dport", "I"); + fidData = jniGetFieldID(env, clsPacket, "data", string); + fidUid = jniGetFieldID(env, clsPacket, "uid", "I"); + fidAllowed = jniGetFieldID(env, clsPacket, "allowed", "Z"); + } + + (*env)->SetLongField(env, jpacket, fidTime, t); + (*env)->SetIntField(env, jpacket, fidVersion, version); + (*env)->SetIntField(env, jpacket, fidProtocol, protocol); + (*env)->SetObjectField(env, jpacket, fidFlags, jflags); + (*env)->SetObjectField(env, jpacket, fidSaddr, jsource); + (*env)->SetIntField(env, jpacket, fidSport, sport); + (*env)->SetObjectField(env, jpacket, fidDaddr, jdest); + (*env)->SetIntField(env, jpacket, fidDport, dport); + (*env)->SetObjectField(env, jpacket, fidData, jdata); + (*env)->SetIntField(env, jpacket, fidUid, uid); + (*env)->SetBooleanField(env, jpacket, fidAllowed, allowed); (*env)->DeleteLocalRef(env, jdata); (*env)->DeleteLocalRef(env, jdest);