Lots of changes. In no particular order:
[distributedratelimiting.git] / drl / multipleinterval.c
diff --git a/drl/multipleinterval.c b/drl/multipleinterval.c
new file mode 100644 (file)
index 0000000..131d7f9
--- /dev/null
@@ -0,0 +1,458 @@
+/* See the DRL-LICENSE file for this file's software license. */
+
+#include <arpa/inet.h>
+#include <assert.h>
+#include <inttypes.h>
+#include <netinet/in.h>
+#include <pthread.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/time.h>
+#include <sys/types.h>
+#include <time.h>
+
+#include "common_accounting.h"
+#include "multipleinterval.h"
+#include "logging.h"
+
+multiple_flow_table multiple_table_create(uint32_t (*hash_function)(const key_flow *key), uint32_t interval_count, common_accounting_t *common) {
+    int i;
+    multiple_flow_table table = malloc(sizeof(struct mul_flow_table));
+
+    if (table == NULL) {
+        return NULL;
+    }
+
+    memset(table, 0, sizeof(struct mul_flow_table));
+    table->common = common;
+    table->hash_function = hash_function;
+    table->interval_count = interval_count;
+
+    gettimeofday(&table->common->last_update, NULL);
+
+    table->intervals = malloc(interval_count * sizeof(interval));
+
+    if (table->intervals == NULL) {
+        free(table);
+        return NULL;
+    }
+
+    memset(table->intervals, 0, interval_count * sizeof(interval));
+    table->intervals[0].valid = 1;
+    table->intervals[0].last_update = table->common->last_update;
+
+    for (i = 0; i < interval_count; ++i) {
+        table->intervals[i].next = &table->intervals[(i + 1) % interval_count];
+    }
+
+    table->current_interval = &table->intervals[0];
+
+    return table;
+}
+
+void multiple_table_destroy(multiple_flow_table table) {
+    multiple_flow *current, *next;
+
+    if ((current = table->flows_head)) {
+        while (current->next) {
+            next = current->next;
+            free(current->intervals);
+            free(current);
+            current = next;
+        }
+        free(current->intervals);
+        free(current);
+    }
+
+    free(table->intervals);
+    free(table);
+}
+
+/* Looks for the flow in the table.  If the flow isn't there, it allocates a
+ * place for it. */
+multiple_flow *multiple_table_lookup(multiple_flow_table table, const key_flow *key) {
+    uint32_t hash;
+    multiple_flow *flow;
+    struct in_addr src, dst;
+    char sip[22], dip[22];
+    int i;
+
+    if (table == NULL) {
+        return NULL;
+    }
+
+    hash = table->hash_function(key);
+
+    /* Find the flow, if it's there. */
+    for (flow = table->flows[hash]; flow; flow = flow->nexth) {
+        if (flow->source_ip == key->source_ip &&
+                flow->dest_ip == key->dest_ip &&
+                flow->source_port == key->source_port &&
+                flow->dest_port == key->dest_port &&
+                flow->protocol == key->protocol) {
+            break;
+        }
+    }
+
+    if (flow == NULL) {
+        flow = malloc(sizeof(multiple_flow));
+        if (flow == NULL) {
+            printlog(LOG_CRITICAL, "multipleinterval.c: Malloc returned NULL.\n");
+            return NULL;
+        }
+        memset(flow, 0, sizeof(multiple_flow));
+
+        flow->intervals = malloc(table->interval_count * sizeof(interval));
+        if (flow->intervals == NULL) {
+            free(flow);
+            printlog(LOG_CRITICAL, "multipleinterval.c: Malloc returned NULL.\n");
+            return NULL;
+        }
+        memset(flow->intervals, 0, table->interval_count * sizeof(interval));
+
+        flow->protocol = key->protocol;
+        flow->source_ip = key->source_ip;
+        flow->dest_ip = key->dest_ip;
+        flow->source_port = key->source_port;
+        flow->dest_port = key->dest_port;
+
+        flow->intervals[0].last_packet = key->packet_time;
+        flow->intervals[0].last_update = table->common->last_update;
+        flow->intervals[0].valid = 1;
+
+        for (i = 0; i < table->interval_count; ++i) {
+            flow->intervals[i].next = &flow->intervals[(i + 1) % table->interval_count];
+        }
+
+        flow->current_interval = &flow->intervals[0];
+
+        /* Add the flow to the hash list. */
+        flow->nexth = table->flows[hash];
+        table->flows[hash] = flow;
+
+        /* Add the flow to the linked list. */
+        if (table->flows_tail) {
+            flow->prev = table->flows_tail;
+            table->flows_tail->next = flow;
+            table->flows_tail = flow;
+        } else {
+            table->flows_head = table->flows_tail = flow;
+            /* next and prev are already null due to memset above. */
+        }
+
+        src.s_addr = ntohl(flow->source_ip);
+        dst.s_addr = ntohl(flow->dest_ip);
+        strcpy(sip, inet_ntoa(src));
+        strcpy(dip, inet_ntoa(dst));
+        printlog(LOG_DEBUG, "ALLOC:%s:%hu -> %s:%hu\n", sip,
+                flow->source_port, dip, flow->dest_port);
+    }
+
+    return flow;
+}
+
+int multiple_table_sample(multiple_flow_table table, const key_flow *key) {
+    multiple_flow *flow;
+
+    assert(table != NULL);
+    assert(table->common != NULL);
+
+    /* Update aggregate. */
+    //table->common->bytes_since += key->packet_size;
+    table->current_interval->bytes_since += key->packet_size;
+    table->current_interval->valid = 1;
+
+    /* Update flow. */
+    flow = multiple_table_lookup(table, key);
+    if (flow == NULL) {
+        return 0;
+    }
+
+    /* Update flow's last packet info so that we know when to delete. */
+    flow->last_packet = key->packet_time;
+
+    /* Update interval information. */
+    flow->current_interval->bytes_since += key->packet_size;
+    flow->current_interval->last_packet = key->packet_time;
+    flow->current_interval->valid = 1;
+
+    return 1;
+}
+
+void multiple_table_remove(multiple_flow_table table, multiple_flow *flow) {
+    key_flow key;
+    uint32_t hash;
+
+    assert(flow);
+
+    /* Remove the flow from the hash list. */
+    key.source_ip = flow->source_ip;
+    key.dest_ip = flow->dest_ip;
+    key.source_port = flow->source_port;
+    key.dest_port = flow->dest_port;
+    key.protocol = flow->protocol;
+
+    hash = table->hash_function(&key);
+
+    assert(table->flows[hash]);
+
+    if (table->flows[hash] == flow) {
+        /* It's the head of the hash list. */
+        table->flows[hash] = flow->nexth;
+    } else {
+        multiple_flow *current, *prev;
+        
+        prev = table->flows[hash];
+
+        for (current = table->flows[hash]->nexth; current; current = current->nexth) {
+            if (current == flow) {
+                prev->nexth = flow->nexth;
+                break;
+            } else {
+                prev = current;
+            }
+        }
+
+        if (current == NULL) {
+            printlog(LOG_CRITICAL, "Flow %p disappeared?\n", flow);
+        }
+        assert(current != NULL);
+    }
+
+    /* Remove the flow from the linked list. */
+    if (flow->prev == NULL && flow->next == NULL) {
+        /* It's the head, tail, and only element of the list. */
+        assert(table->flows_head == flow);
+        assert(table->flows_tail == flow);
+
+        table->flows_head = NULL;
+        table->flows_tail = NULL;
+    } else if (flow->prev == NULL) {
+        /* It's the head of the list. */
+        assert(table->flows_head == flow);
+
+        table->flows_head = flow->next;
+
+        if (table->flows_head != NULL) {
+            table->flows_head->prev = NULL;
+        }
+    } else if (flow->next == NULL) {
+        /* It's the tail of the list. */
+        assert(table->flows_tail == flow);
+
+        table->flows_tail = flow->prev;
+
+        table->flows_tail->next = NULL;
+    } else {
+        /* Not the head or tail of the list. */
+        assert(table->flows_head != flow);
+
+        flow->prev->next = flow->next;
+
+        if (flow->next != NULL) {
+            flow->next->prev = flow->prev;
+        }
+    }
+
+    /* Free the interval info. */
+    memset(flow->intervals, 0, table->interval_count * sizeof(interval));
+    free(flow->intervals);
+
+    /* Free the flow. */
+    memset(flow, 0, sizeof(multiple_flow));
+    free(flow);
+}
+
+int multiple_table_cleanup(multiple_flow_table table) {
+    multiple_flow *current = table->flows_head;
+    multiple_flow *remove;
+    time_t now = time(NULL);
+
+    while (current != NULL) {
+        if (current->last_packet + MUL_FLOW_IDLE_TIME <= now) {
+            /* Flow hasn't received a packet in the time limit - kill it. */
+            remove = current;
+            current = current->next;
+
+            multiple_table_remove(table, remove);
+        } else {
+            current = current->next;
+        }
+    }
+
+    return 0;
+}
+
+static interval *get_oldest_interval(interval *newest) {
+    interval *candidate = newest;
+    interval *oldest = NULL;
+
+    while (oldest == NULL) {
+        candidate = candidate->next;
+
+        if (candidate == newest) {
+            oldest = newest;
+        } else if (candidate->valid) {
+            oldest = candidate;
+        }
+    }
+
+    return oldest;
+}
+
+static uint32_t get_bytes_over_interval(interval *newest, interval *oldest) {
+    uint32_t result = newest->bytes_since;
+    interval *current = oldest;
+
+    while (current != newest) {
+        result += current->bytes_since;
+        current = current->next;
+    }
+
+    return result;
+}
+
+void multiple_table_update_flows(multiple_flow_table table, struct timeval now, double ewma_weight) {
+    uint32_t maxflowrate = 0;
+    double time_delta;
+    double unweighted_rate;
+    multiple_flow *current;
+    struct in_addr src, dst;
+    char sip[22], dip[22];
+    key_flow largest_flow_info;
+
+    /* Table interval variables. */
+    interval *table_newest = NULL;
+    interval *table_oldest = NULL;
+    uint32_t table_bytes_over_intervals = 0;
+
+    /* Reset statistics. */
+    table->common->num_flows = 0;
+    table->common->num_flows_5k = 0;
+    table->common->num_flows_10k = 0;
+    table->common->num_flows_20k = 0;
+    table->common->num_flows_50k = 0;
+    table->common->avg_rate = 0;
+    /* End statistics. */
+
+    table_newest = table->current_interval;
+    table_oldest = get_oldest_interval(table_newest);
+
+    table_bytes_over_intervals = get_bytes_over_interval(table_newest, table_oldest);
+
+    time_delta = timeval_subtract(now, table_oldest->last_update);
+
+    if (time_delta <= 0) {
+        unweighted_rate = 0;
+    } else {
+        unweighted_rate = table_bytes_over_intervals / time_delta;
+    }
+
+    table->common->last_inst_rate = table->common->inst_rate;
+    table->common->inst_rate = unweighted_rate;
+    printf("Unweighted rate is: %.3f, computed from %d bytes in %f seconds\n", unweighted_rate, table_bytes_over_intervals, time_delta);
+
+    table->common->last_rate = table->common->rate;
+
+    /* If the rate is zero, then we don't know anything yet.  Don't apply EWMA
+     * in that case. */
+    if (table->common->rate == 0) {
+        table->common->rate = unweighted_rate;
+    } else {
+        //FIXME: Continue to use ewma here?
+        table->common->rate = table->common->rate * ewma_weight + unweighted_rate * (1 - ewma_weight);
+    }
+
+    table->common->last_update = now;
+    table->current_interval = table->current_interval->next;
+    table->current_interval->last_update = now;
+    table->current_interval->bytes_since = 0;
+    table->current_interval->valid = 1;
+
+    /* Update per-flow information. */
+    for (current = table->flows_head; current; current = current->next) {
+        interval *newest = current->current_interval;
+        interval *oldest = get_oldest_interval(newest);
+        uint32_t bytes_over_intervals = 0;
+
+        /* This flow is invalid - don't consider it further. */
+        if (newest->valid == 0) {
+            printlog(LOG_WARN, "Found invalid flow in table.\n");
+            continue;
+        }
+
+        time_delta = timeval_subtract(now, oldest->last_update);
+        bytes_over_intervals = get_bytes_over_interval(newest, oldest);
+
+        if (time_delta <= 0) {
+            unweighted_rate = 0;
+        } else {
+            unweighted_rate = bytes_over_intervals / time_delta;
+        }
+
+        current->last_rate = current->rate;
+
+        if (current->rate == 0) {
+            current->rate = unweighted_rate;
+        } else {
+            //FIXME: Continue to use ewma here?
+            current->rate = current->rate * ewma_weight + unweighted_rate * (1 - ewma_weight);
+        }
+
+        /* Update the accounting info for intervals. */
+        current->current_interval = current->current_interval->next;
+        current->current_interval->last_update = now;
+        current->current_interval->bytes_since = 0;
+        current->current_interval->valid = 1;
+
+        if (current->rate > maxflowrate) {
+            maxflowrate = current->rate;
+            largest_flow_info.source_ip = current->source_ip;
+            largest_flow_info.dest_ip = current->dest_ip;
+            largest_flow_info.source_port = current->source_port;
+            largest_flow_info.dest_port = current->dest_port;
+            largest_flow_info.protocol = current->protocol;
+        }
+
+        if (current->rate > 51200) {
+            table->common->num_flows_50k += 1;
+            table->common->num_flows_20k += 1;
+            table->common->num_flows_10k += 1;
+            table->common->num_flows_5k += 1;
+            table->common->num_flows += 1;
+        } else if (current->rate > 20480) {
+            table->common->num_flows_20k += 1;
+            table->common->num_flows_10k += 1;
+            table->common->num_flows_5k += 1;
+            table->common->num_flows += 1;
+        } else if (current->rate > 10240) {
+            table->common->num_flows_10k += 1;
+            table->common->num_flows_5k += 1;
+            table->common->num_flows += 1;
+        } else if (current->rate > 5120) {
+            table->common->num_flows_5k += 1;
+            table->common->num_flows += 1;
+        } else {
+            table->common->num_flows += 1;
+        }
+
+        src.s_addr = ntohl(current->source_ip);
+        dst.s_addr = ntohl(current->dest_ip);
+        strcpy(sip, inet_ntoa(src));
+        strcpy(dip, inet_ntoa(dst));
+        printlog(LOG_DEBUG, "FLOW: (%p)  %s:%d -> %s:%d at %d\n", current,
+                sip, current->source_port,
+                dip, current->dest_port,
+                current->rate);
+    }
+
+    if (table->common->num_flows > 0) {
+        table->common->avg_rate = table->common->rate / table->common->num_flows;
+    }
+
+    printlog(LOG_DEBUG, "FLOW:--\n--\n");
+
+    table->common->max_flow_rate = maxflowrate;
+    table->common->max_flow_rate_flow_hash = table->hash_function(&largest_flow_info);
+}