diff --git a/.gitignore b/.gitignore index 96ca72be..d36696fe 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ .DS_Store /build /captures +/*.sh diff --git a/app/build.gradle b/app/build.gradle index f7db20b3..3fb2e719 100644 --- a/app/build.gradle +++ b/app/build.gradle @@ -34,6 +34,16 @@ model { release { minifyEnabled = true proguardFiles.add(file('proguard-rules.pro')) + ndk.with { + debuggable = true + } + } + } + android.buildTypes { + debug { + ndk.with { + debuggable = true + } } } android.productFlavors { diff --git a/app/src/main/java/eu/faircode/netguard/SinkholeService.java b/app/src/main/java/eu/faircode/netguard/SinkholeService.java index d3d2c43b..df355e0e 100644 --- a/app/src/main/java/eu/faircode/netguard/SinkholeService.java +++ b/app/src/main/java/eu/faircode/netguard/SinkholeService.java @@ -113,6 +113,8 @@ public class SinkholeService extends VpnService implements SharedPreferences.OnS private static final String ACTION_SCREEN_OFF_DELAYED = "eu.faircode.netguard.SCREEN_OFF_DELAYED"; + private native void jni_init(); + private native void jni_start(int tun); private native void jni_stop(int tun); @@ -849,6 +851,8 @@ public class SinkholeService extends VpnService implements SharedPreferences.OnS public void onCreate() { Log.i(TAG, "Create"); + jni_init(); + SharedPreferences prefs = PreferenceManager.getDefaultSharedPreferences(this); prefs.registerOnSharedPreferenceChangeListener(this); diff --git a/app/src/main/jni/netguard/netguard.c b/app/src/main/jni/netguard/netguard.c index c2b3745c..68946099 100644 --- a/app/src/main/jni/netguard/netguard.c +++ b/app/src/main/jni/netguard/netguard.c @@ -35,9 +35,8 @@ struct arguments { }; struct data { - uint32_t seq; // host notation - jbyte *data; uint32_t len; + uint8_t *data; struct data *next; }; @@ -52,14 +51,13 @@ struct connection { uint8_t state; jint socket; uint32_t lport; // host notation - struct data *received; struct data *sent; struct connection *next; }; void *handle_events(void *); -void handle_tcp(JNIEnv *, jobject, const uint8_t *, const uint16_t); +void handle_tcp(JNIEnv *, jobject, const struct arguments *args, const uint8_t *, const uint16_t); int openSocket(JNIEnv *, jobject, const struct sockaddr_in *); @@ -69,7 +67,7 @@ int canWrite(const int); int writeSYN(const struct connection *, const int); -void decode(JNIEnv *, jobject, const uint8_t *, const uint16_t); +void decode(JNIEnv *, jobject, const struct arguments *args, const uint8_t *, const uint16_t); jint getUid(const int, const int, const void *, const uint16_t); @@ -88,6 +86,12 @@ struct connection *connection = NULL; // JNI +JNIEXPORT void JNICALL +Java_eu_faircode_netguard_SinkholeService_jni_1init(JNIEnv *env, jobject instance) { + __android_log_print(ANDROID_LOG_DEBUG, TAG, "Init"); + connection = NULL; +} + JNIEXPORT void JNICALL Java_eu_faircode_netguard_SinkholeService_jni_1start(JNIEnv *env, jobject instance, jint tun) { __android_log_print(ANDROID_LOG_DEBUG, TAG, "Starting tun=%d", tun); @@ -202,26 +206,12 @@ void *handle_events(void *a) { __android_log_print(ANDROID_LOG_DEBUG, TAG, "Idle %s/%u lport %u", dest, ntohs(cur->dest), cur->lport); + // TODO check if open shutdown(cur->socket, SHUT_RDWR); // TODO check for errors - if (last == NULL) - connection = cur->next; - else - last->next = cur->next; - struct data *prev; - - struct data *received = cur->received; - while (received != NULL) { - prev = received; - received = received->next; - if (prev->data != NULL) - free(prev->data); - free(prev); - } - struct data *sent = cur->sent; while (sent != NULL) { prev = sent; @@ -231,11 +221,18 @@ void *handle_events(void *a) { free(prev); } - free(cur); + if (last == NULL) + connection = cur->next; + else + last->next = cur->next; + + struct connection *c = cur; + cur = cur->next; + free(c); + continue; } else { if (cur->state == TCP_SYN_RECV) { - // TODO check if tun writable? FD_SET(cur->socket, &wfds); if (cur->socket > max) max = cur->socket; @@ -274,12 +271,13 @@ void *handle_events(void *a) { if (ready == 0) __android_log_print(ANDROID_LOG_DEBUG, TAG, "Yield"); else { - // Check tun + // Check tun exception if (FD_ISSET(args->tun, &efds)) { __android_log_print(ANDROID_LOG_ERROR, TAG, "tun exception"); break; } + // Check tun read if (FD_ISSET(args->tun, &rfds)) { uint8_t buffer[MAXPKT]; ssize_t length = read(args->tun, buffer, MAXPKT); @@ -289,7 +287,7 @@ void *handle_events(void *a) { break; } if (length > 0) - decode(env, args->instance, buffer, length); + decode(env, args->instance, args, buffer, length); else { __android_log_print(ANDROID_LOG_ERROR, TAG, "tun empty read"); break; @@ -299,7 +297,7 @@ void *handle_events(void *a) { // Check sockets struct connection *cur = connection; while (cur != NULL) { - // Check exceptions + // Check socket exception if (FD_ISSET(cur->socket, &efds)) { int serr; socklen_t optlen = sizeof(serr); @@ -317,8 +315,8 @@ void *handle_events(void *a) { } } - // Check connects if (cur->state == TCP_SYN_RECV) { + // Check socket connect if (FD_ISSET(cur->socket, &wfds) && canWrite(args->tun)) { // Log char dest[20]; @@ -333,22 +331,23 @@ void *handle_events(void *a) { } } - // Check incoming data - if (cur->state == TCP_ESTABLISHED) { + else if (cur->state == TCP_ESTABLISHED) { + // Check socket read if (FD_ISSET(cur->socket, &rfds)) { uint8_t buffer[MAXPKT]; - ssize_t bytes = recv(cur->socket, buffer, MAXPKT, MSG_DONTWAIT); + ssize_t bytes = recv(cur->socket, buffer, MAXPKT, 0); if (bytes < 0) { - // TODO handle EINTR - __android_log_print(ANDROID_LOG_ERROR, TAG, "recv error %d: %s", + __android_log_print(ANDROID_LOG_ERROR, TAG, "recv socket error %d: %s", errno, strerror(errno)); - cur->state = TCP_CLOSE; + if (errno != EINTR) + cur->state = TCP_CLOSE; } else if (bytes == 0) { - __android_log_print(ANDROID_LOG_ERROR, TAG, "recv empty"); + __android_log_print(ANDROID_LOG_ERROR, TAG, "recv socket empty"); cur->state = TCP_CLOSE; } else { - __android_log_print(ANDROID_LOG_DEBUG, TAG, "recv lport %u bytes %d", + __android_log_print(ANDROID_LOG_DEBUG, TAG, + "recv socket lport %u bytes %d", cur->lport, bytes); } } @@ -369,9 +368,10 @@ void *handle_events(void *a) { args->tun, thread_id); } -void handle_tcp(JNIEnv *env, jobject instance, const uint8_t *buffer, uint16_t length) { +void handle_tcp(JNIEnv *env, jobject instance, const struct arguments *args, + const uint8_t *buffer, uint16_t length) { // Check version - jbyte version = (*buffer) >> 4; + uint8_t version = (*buffer) >> 4; if (version != 4) return; @@ -379,6 +379,7 @@ void handle_tcp(JNIEnv *env, jobject instance, const uint8_t *buffer, uint16_t l struct iphdr *iphdr = buffer; uint8_t optlen = (iphdr->ihl - 5) * 4; struct tcphdr *tcphdr = buffer + sizeof(struct iphdr) + optlen; + __android_log_print(ANDROID_LOG_DEBUG, TAG, "optlen %d", optlen); if (ntohs(iphdr->tot_len) != length) __android_log_print(ANDROID_LOG_WARN, TAG, "Invalid length %u/%d", iphdr->tot_len, length); @@ -386,13 +387,14 @@ void handle_tcp(JNIEnv *env, jobject instance, const uint8_t *buffer, uint16_t l // Get data uint16_t dataoff = sizeof(struct iphdr) + optlen + sizeof(struct tcphdr); uint16_t datalen = length - dataoff; - struct data *data = malloc(sizeof(struct data)); - data->seq = ntohl(tcphdr->seq); - data->data = malloc(datalen); // TODO free - data->next = NULL; - if (datalen) + struct data *data = NULL; + if (datalen > 0) { + data = malloc(sizeof(struct data)); + data->len = datalen; + data->data = malloc(datalen); // TODO free memcpy(data->data, buffer + dataoff, datalen); - data->len = datalen; + data->next = NULL; + } // Search connection struct connection *last = NULL; @@ -416,21 +418,21 @@ void handle_tcp(JNIEnv *env, jobject instance, const uint8_t *buffer, uint16_t l // Register connection struct connection *syn = malloc(sizeof(struct connection)); // TODO check/free syn->time = time(NULL); - syn->remote_seq = ntohl(tcphdr->seq); - syn->local_seq = 123; // TODO randomize + syn->remote_seq = ntohl(tcphdr->seq); // ISN remote + syn->local_seq = 123; // ISN local TODO randomize syn->saddr = iphdr->saddr; syn->source = tcphdr->source; syn->daddr = iphdr->daddr; syn->dest = tcphdr->dest; syn->state = TCP_SYN_RECV; - syn->received = NULL; syn->sent = NULL; syn->next = NULL; // Ignore data - if (data->data != NULL) + if (data != NULL) { free(data->data); - free(data); + free(data); + } // Build target address struct sockaddr_in daddr; @@ -470,9 +472,10 @@ void handle_tcp(JNIEnv *env, jobject instance, const uint8_t *buffer, uint16_t l if (cur->state == TCP_SYN_SENT) { // TODO proper warp around if (ntohl(tcphdr->ack_seq) == cur->local_seq + 1 && - ntohl(tcphdr->seq) == cur->remote_seq + 1) { + ntohl(tcphdr->seq) >= cur->remote_seq + 1) { cur->local_seq += 1; cur->remote_seq += 1; + // TODO process data __android_log_print(ANDROID_LOG_DEBUG, TAG, "Established"); cur->state = TCP_ESTABLISHED; @@ -482,14 +485,31 @@ void handle_tcp(JNIEnv *env, jobject instance, const uint8_t *buffer, uint16_t l } else if (cur->state == TCP_ESTABLISHED) { + // TODO proper wrap around if (ntohl(tcphdr->seq) + 1 == cur->remote_seq) + // TODO respond to keepalive? __android_log_print(ANDROID_LOG_DEBUG, TAG, "Keep alive"); - else - __android_log_print(ANDROID_LOG_DEBUG, TAG, "Data"); + else if (ntohl(tcphdr->seq) < cur->remote_seq) + __android_log_print(ANDROID_LOG_WARN, TAG, "Processed ack"); + else { + __android_log_print(ANDROID_LOG_DEBUG, TAG, "New ack"); + if (data != NULL && data->len) { + // TODO non blocking + __android_log_print(ANDROID_LOG_DEBUG, TAG, "send socket data %u", + data->len); + if (send(cur->socket, data->data, data->len, 0) < 0) + __android_log_print(ANDROID_LOG_ERROR, TAG, "send error %d: %s", + errno, strerror(errno)); + else { + if (writeACK(cur, data->len, args->tun)) + cur->remote_seq += data->len; + } + } + } } else { - __android_log_print(ANDROID_LOG_WARN, TAG, "Ignored"); + __android_log_print(ANDROID_LOG_WARN, TAG, "Ignored state %d", cur->state); } } } @@ -543,7 +563,7 @@ int openSocket(JNIEnv *env, jobject instance, const struct sockaddr_in *daddr) { } // Set blocking - if (fcntl(sock, F_SETFL, flags) < 0) { + if (fcntl(sock, F_SETFL, flags & ~O_NONBLOCK) < 0) { __android_log_print(ANDROID_LOG_ERROR, TAG, "fcntl error %d: %s", errno, strerror(errno)); return -1; @@ -600,10 +620,11 @@ int writeSYN(const struct connection *cur, int tun) { tcp->doff = sizeof(struct tcphdr) >> 2; tcp->syn = 1; tcp->ack = 1; + tcp->window = htons(2048); // Calculate TCP checksum uint16_t clen = sizeof(struct ippseudo) + sizeof(struct tcphdr); - jbyte csum[clen]; + uint8_t csum[clen]; // Build pseudo header struct ippseudo *pseudo = csum; @@ -638,7 +659,73 @@ int writeSYN(const struct connection *cur, int tun) { return res; } -void decode(JNIEnv *env, jobject instance, const uint8_t *buffer, const uint16_t length) { +int writeACK(const struct connection *cur, uint32_t datalen, int tun) { + // Build packet + uint16_t len = sizeof(struct iphdr) + sizeof(struct tcphdr); // no data + u_int8_t *buffer = calloc(len, 1); + struct iphdr *ip = buffer; + struct tcphdr *tcp = buffer + sizeof(struct iphdr); + + // Build IP header + ip->version = 4; + ip->ihl = sizeof(struct iphdr) >> 2; + ip->tot_len = htons(len); + ip->ttl = TTL; + ip->protocol = IPPROTO_TCP; + ip->saddr = cur->daddr; + ip->daddr = cur->saddr; + + // Calculate IP checksum + ip->check = checksum(ip, sizeof(struct iphdr)); + + // Build TCP header + tcp->source = cur->dest; + tcp->dest = cur->source; + tcp->seq = htonl(cur->local_seq); + tcp->ack_seq = htonl(cur->remote_seq + datalen); // TODO proper wrap around + tcp->doff = sizeof(struct tcphdr) >> 2; + tcp->ack = 1; + tcp->window = htons(2048); + + // Calculate TCP checksum + uint16_t clen = sizeof(struct ippseudo) + sizeof(struct tcphdr); + uint8_t csum[clen]; + + // Build pseudo header + struct ippseudo *pseudo = csum; + pseudo->ippseudo_src.s_addr = ip->saddr; + pseudo->ippseudo_dst.s_addr = ip->daddr; + pseudo->ippseudo_pad = 0; + pseudo->ippseudo_p = ip->protocol; + pseudo->ippseudo_len = htons(sizeof(struct tcphdr)); // no data + + // Copy TCP header + memcpy(csum + sizeof(struct ippseudo), tcp, sizeof(struct tcphdr)); + + tcp->check = checksum(csum, clen); + + char to[20]; + inet_ntop(AF_INET, &(ip->daddr), to, sizeof(to)); + + // Send packet + __android_log_print(ANDROID_LOG_DEBUG, TAG, + "Sending ACK to tun %s/%u seq %u ack %u", + to, ntohs(tcp->dest), + ntohl(tcp->seq), ntohl(tcp->ack_seq)); + int res = write(tun, buffer, len); + if (res < 0) { + // TODO handle EINTR + __android_log_print(ANDROID_LOG_ERROR, TAG, "write error %d: %s", + errno, strerror(errno)); + } + + free(buffer); + + return res; +} + +void decode(JNIEnv *env, jobject instance, const struct arguments *args, + const uint8_t *buffer, const uint16_t length) { uint8_t protocol; void *saddr; void *daddr; @@ -649,7 +736,7 @@ void decode(JNIEnv *env, jobject instance, const uint8_t *buffer, const uint16_t uint8_t *payload; // Get protocol, addresses & payload - jbyte version = (*buffer) >> 4; + uint8_t version = (*buffer) >> 4; if (version == 4) { struct iphdr *ip4hdr = buffer; @@ -752,7 +839,7 @@ void decode(JNIEnv *env, jobject instance, const uint8_t *buffer, const uint16_t version, source, sport, dest, dport, protocol, flags, uid); if (protocol == IPPROTO_TCP) - handle_tcp(env, instance, buffer, length); + handle_tcp(env, instance, args, buffer, length); // Call back jclass cls = (*env)->GetObjectClass(env, instance);