diff --git a/app/src/main/jni/netguard/netguard.c b/app/src/main/jni/netguard/netguard.c index 17f7cca6..6c87ea5b 100644 --- a/app/src/main/jni/netguard/netguard.c +++ b/app/src/main/jni/netguard/netguard.c @@ -62,6 +62,10 @@ void handle_tcp(JNIEnv *, jobject, const uint8_t *, const uint16_t); int openSocket(JNIEnv *, jobject, const struct sockaddr_in *); +int getLocalPort(const int); + +int canWrite(const int); + int writeSYN(const struct connection *, const int); void decode(JNIEnv *, jobject, const uint8_t *, const uint16_t); @@ -298,21 +302,13 @@ void *handle_events(void *a) { // Check connects if (cur->state == TCP_SYN_RECV) { - if (FD_ISSET(cur->socket, &wfds)) { + if (FD_ISSET(cur->socket, &wfds) && canWrite(args->tun)) { // Log char dest[20]; inet_ntop(AF_INET, &(cur->daddr), dest, sizeof(dest)); __android_log_print(ANDROID_LOG_DEBUG, TAG, "Established %s/%u lport %u", dest, ntohs(cur->dest), cur->lport); - // Set blocking - uint8_t flags = fcntl(cur->socket, F_GETFL, 0); - if (flags < 0 || fcntl(cur->socket, F_SETFL, flags & ~O_NONBLOCK) < 0) { - __android_log_print(ANDROID_LOG_ERROR, TAG, "fcntl error %d: %s", - errno, strerror(errno)); - return -1; - } - if (writeSYN(cur, args->tun) < 0) cur->state = TCP_CLOSE; else @@ -494,7 +490,7 @@ int openSocket(JNIEnv *env, jobject instance, const struct sockaddr_in *daddr) { // Set non blocking uint8_t flags = fcntl(sock, F_GETFL, 0); if (flags < 0 || fcntl(sock, F_SETFL, flags | O_NONBLOCK) < 0) { - __android_log_print(ANDROID_LOG_ERROR, TAG, "fcntl error %d: %s", + __android_log_print(ANDROID_LOG_ERROR, TAG, "fcntl O_NONBLOCK error %d: %s", errno, strerror(errno)); return -1; } @@ -507,6 +503,13 @@ int openSocket(JNIEnv *env, jobject instance, const struct sockaddr_in *daddr) { return -1; } + // Set blocking + if (fcntl(sock, F_SETFL, flags) < 0) { + __android_log_print(ANDROID_LOG_ERROR, TAG, "fcntl error %d: %s", + errno, strerror(errno)); + return -1; + } + return sock; } @@ -521,6 +524,16 @@ int getLocalPort(const int sock) { return ntohs(sin.sin_port); } +int canWrite(const int fd) { + struct timeval tv; + tv.tv_sec = 0; + tv.tv_usec = 0; + fd_set wfds; + FD_ZERO(&wfds); + FD_SET(fd, &wfds); + return (select(fd + 1, NULL, &wfds, NULL, &tv) > 0); +} + int writeSYN(const struct connection *cur, int tun) { // Build packet uint16_t len = sizeof(struct iphdr) + sizeof(struct tcphdr); // no data