Lots of changes. In no particular order:
[distributedratelimiting.git] / drl / samplehold.c
1 /* See the DRL-LICENSE file for this file's software license. */
2
3 #include <arpa/inet.h>
4 #include <inttypes.h>
5 #include <netinet/in.h>
6 #include <stdio.h>
7 #include <stdlib.h>
8 #include <string.h>
9 #include <sys/time.h>
10 #include <sys/types.h>
11 #include <time.h>
12
13 #include "common_accounting.h"
14 #include "samplehold.h"
15 #include "logging.h"
16
17 static int match(const key_flow *key, const sampled_flow *flow) {
18     if (flow->state != FLOW_USED)
19         return 0;
20
21     if (key->source_ip != flow->source_ip)
22         return 0;
23
24     if (key->dest_ip != flow->dest_ip)
25         return 0;
26
27     if (key->source_port != flow->source_port)
28         return 0;
29
30     if (key->dest_port != flow->dest_port)
31         return 0;
32
33     if (key->protocol != flow->protocol)
34         return 0;
35
36     return 1;
37 }
38
39 static void get_key(key_flow *key, sampled_flow *flow) {
40     key->source_ip = flow->source_ip;
41     key->dest_ip = flow->dest_ip;
42     key->source_port = flow->source_port;
43     key->dest_port = flow->dest_port;
44     key->protocol = flow->protocol;
45
46     key->packet_size = 0;
47 }
48
49 static void move_flow(sampled_flow *dest, sampled_flow *src) {
50     memmove(dest, src, sizeof(sampled_flow));
51     memset(src, 0, sizeof(sampled_flow));
52 }
53
54 uint32_t sampled_table_size(const sampled_flow_table table) {
55     return table->size;
56 }
57
58 /*
59  * Notes to myself...
60  *
61  * max_bytes is the maximum number of bytes that can pass though DURING THE
62  * MEASUREMENT INTERVAL.  So, if you can have 100 Mbit/s and your measurement
63  * interval is 1/10 of a second, your max_bytes is 10Mbit because that's all
64  * you can transfer in 1/10 of a second.
65  *
66  * flow_percentage is the percentage of max_bytes that is considered an
67  * interesting flow.
68  *
69  * oversampling factor is a knob that tunes how accurate our results are at
70  * the cost of additional state/memory.
71  */
72 sampled_flow_table sampled_table_create(uint32_t (*hash_function)(const key_flow *key), const uint32_t max_bytes, const uint32_t flow_percentage, const uint32_t oversampling_factor, common_accounting_t *common) {
73     sampled_flow_table table = malloc(sizeof(struct sampled_flow_table));
74     double base_size = (double) 100 / (double) flow_percentage;
75
76     if (table == NULL) {
77         return NULL;
78     }
79
80     table->capacity = (uint32_t) (base_size * oversampling_factor);
81     table->size = 0;
82     table->hash_function = hash_function;
83     table->sample_prob = (double) (((double) table->capacity / (double) max_bytes) * (double) RANDOM_GRANULARITY);
84     table->threshold = (double) ((double) flow_percentage / 100) * max_bytes;
85
86
87     /* Allocate the backing and give it a little bit extra to deal with variance. */
88     table->largest = NULL;
89     table->backing = malloc(sizeof(sampled_flow) * table->capacity * 1.05);
90
91     if (table->backing == NULL) {
92         free(table);
93         return NULL;
94     }
95
96     memset(table->backing, 0, sizeof(sampled_flow) * table->capacity);
97
98     srand(time(NULL));
99
100     table->common = common;
101     gettimeofday(&table->common->last_update, NULL);
102
103     return table;
104 }
105
106 void sampled_table_destroy(sampled_flow_table table) {
107     free(table->backing);
108     free(table);
109 }
110
111 sampled_flow *sampled_table_lookup(sampled_flow_table table, const key_flow *key) {
112     uint32_t hash = table->hash_function(key) % table->capacity;
113     uint32_t location = hash;
114
115     do {
116         if (table->backing[location].state == FLOW_FREE) {
117             /* It ain't here... */
118             return NULL;
119         }
120
121         if (match(key, &table->backing[location])) {
122             /* Got it! */
123             return &table->backing[location];
124         }
125
126         location++;
127         if (location == table->capacity) {
128             location = 0;
129         }
130     } while (location != hash);
131
132     return NULL;
133 }
134
135 int sampled_table_sample(sampled_flow_table table, const key_flow *key) {
136     sampled_flow *lookup = sampled_table_lookup(table, key);
137     int random_number;
138     double packet_prob;
139
140     /* First we update the common accouting information so that we have accurate
141      * aggregate information. */
142     table->common->bytes_since += key->packet_size;
143
144     /* Below here we're dealing with individual flows. */
145
146     /* It's already in the table, update it. */
147     if (lookup != NULL) {
148         lookup->bytes += key->packet_size;
149         return 1;
150     }
151
152     /* It's not in the table, probabilistically sample it. */
153     packet_prob = table->sample_prob * (double) key->packet_size;
154     random_number = rand() % RANDOM_GRANULARITY;
155
156     if (random_number < packet_prob) {
157         /* It's being sampled - add it to the table. */
158         uint32_t hash = table->hash_function(key) % table->capacity;
159         uint32_t location = hash;
160
161         do {
162             if (table->backing[location].state == FLOW_FREE ||
163                 table->backing[location].state == FLOW_DELETED) {
164                 lookup = &table->backing[location];
165                 break;
166             }
167
168             location++;
169             if (location == table->capacity) {
170                 location = 0;
171             }
172         } while (location != hash);
173
174         if (lookup == NULL) {
175             /* Table is full!?! */
176             printlog(LOG_WARN, "samplehold.c: Table full!\n");
177             return 0;
178         }
179
180         table->size += 1;
181
182         lookup->bytes = key->packet_size;
183         lookup->source_ip = key->source_ip;
184         lookup->dest_ip = key->dest_ip;
185         lookup->source_port = key->source_port;
186         lookup->dest_port = key->dest_port;
187         lookup->protocol = key->protocol;
188         lookup->state = FLOW_USED;
189         lookup->last_bytes = 0;
190         lookup->rate = 0;
191
192         gettimeofday(&lookup->last_update, NULL);
193
194         return 1;
195     }
196
197     /* Not sampled. */
198     return 0;
199 }
200
201 int sampled_table_cleanup(sampled_flow_table table) {
202     /* This should...
203      * 1) Remove "small" flows from the table.
204      * 2) Compact the table so that the remaining flows are closer to their
205      * hash locations.
206      * 3) Reset the state of deleted flows to free.
207      */
208
209     /* How it might work...
210      * 1) Scan through the backing array.
211      * 2) If the flow is small, memset it to 0.
212      *    It it's large, add it to a linked list.
213      * 3) For all items in the linked list, hash them and put them in the
214      * correct location.
215      */
216
217     /* For now though, we're going to do it the inefficient way and loop
218      * through the backing twice.
219      */
220
221     int i;
222
223     /* Clear small items. */
224     for (i = 0; i < table->capacity; ++i) {
225         if (table->backing[i].state == FLOW_USED && table->backing[i].bytes > table->threshold) {
226             /* It gets to stick around. */
227         } else {
228             /* It dies... */
229             memset(&table->backing[i], 0, sizeof(sampled_flow));
230         }
231     }
232
233     /* Compact the table and put things closer to their hash locations. */
234     for (i = 0; i < table->capacity; ++i) {
235         if (table->backing[i].state == FLOW_USED) {
236             uint32_t hash;
237             key_flow key;
238             
239             get_key(&key, &table->backing[i]);
240             hash = table->hash_function(&key) % table->capacity;
241
242             if (i == hash) {
243                 /* Already in the best place */
244                 table->backing[i].bytes = 0;
245                 table->backing[i].last_bytes = 0;
246                 table->backing[i].rate = 0;
247             } else {
248                 uint32_t location = hash;
249
250                 do {
251                     if (table->backing[location].state == FLOW_FREE) {
252                         move_flow(&table->backing[location], &table->backing[i]);
253                         table->backing[location].bytes = 0;
254                         table->backing[location].last_bytes = 0;
255                         table->backing[location].rate = 0;
256                         break;
257                     }
258
259                     location++;
260                     if (location == table->capacity) {
261                         location = 0;
262                     }
263                 } while (location != hash);
264             }
265         }
266     }
267
268     table->largest = NULL;
269
270     return 0;
271 }
272
273 void sampled_table_update_flows(sampled_flow_table table, struct timeval now, double ewma_weight) {
274     int i = 0;
275     uint32_t largest_rate = 0;
276     uint32_t rate_delta = 0;
277     double time_delta = 0;
278     double unweighted_rate = 0;
279     struct in_addr src, dst;
280     char sip[22], dip[22];
281
282     /* Update common aggregate information. */
283     time_delta = timeval_subtract(now, table->common->last_update);
284
285     if (time_delta <= 0) {
286         unweighted_rate = 0;
287     } else {
288         unweighted_rate = table->common->bytes_since / time_delta;
289     }
290
291     table->common->last_inst_rate = table->common->inst_rate;
292     table->common->inst_rate = unweighted_rate;
293
294     table->common->last_rate = table->common->rate;
295
296     /* If the rate is zero, then we don't know anything yet.  Don't apply EWMA
297      * in that case. */
298     if (table->common->rate == 0) {
299         table->common->rate = unweighted_rate;
300     } else {
301         table->common->rate = table->common->rate * ewma_weight +
302                               unweighted_rate * (1 - ewma_weight);
303     }
304
305     printlog(LOG_DEBUG, "table->common->rate is now %u\n", table->common->rate);
306
307     table->common->bytes_since = 0;
308     table->common->last_update = now;
309     table->common->num_flows = 0;
310
311     /* Update per-flow information. */
312     table->largest = &table->backing[i];
313     largest_rate = table->backing[i].rate;
314
315     for (i = 0; i < table->capacity; ++i) {
316         if (table->backing[i].state == FLOW_USED) {
317             rate_delta = table->backing[i].bytes - table->backing[i].last_bytes;
318             time_delta = timeval_subtract(now, table->backing[i].last_update);
319
320             /* Calculate the unweighted rate.  Be careful not to divide by
321              * something silly. */
322             if (time_delta <= 0) {
323                 unweighted_rate = 0;
324             } else {
325                 unweighted_rate = rate_delta / time_delta; 
326             }
327
328             if (table->backing[i].rate == 0) {
329                 table->backing[i].rate = unweighted_rate;
330             } else {
331                 table->backing[i].rate = (table->backing[i].rate * ewma_weight +
332                                           unweighted_rate * (1 - ewma_weight));
333             }
334
335             table->backing[i].last_bytes = table->backing[i].bytes;
336             table->backing[i].last_update = now;
337
338             if (table->backing[i].rate > largest_rate) {
339                 largest_rate = table->backing[i].rate;
340                 table->largest = &table->backing[i];
341             }
342
343             table->common->num_flows += 1;
344
345             /* Print debugging info. */
346             src.s_addr = ntohl(table->backing[i].source_ip);
347             dst.s_addr = ntohl(table->backing[i].dest_ip);
348             strcpy(sip, inet_ntoa(src));
349             strcpy(dip, inet_ntoa(dst));
350             printlog(LOG_DEBUG, "FLOW: (%p)  %s:%d -> %s:%d at %d\n", &table->backing[i],
351                     sip, table->backing[i].source_port,
352                     dip, table->backing[i].dest_port,
353                     table->backing[i].rate);
354         }
355     }
356
357     table->common->max_flow_rate = largest_rate;
358 }
359
360 sampled_flow *sampled_table_largest(sampled_flow_table table) {
361     return table->largest;
362 }