4b79f3eec59e7e9a402251d2bacb18e043ce1994
[distributedratelimiting.git] / drl / samplehold.c
1 /* See the DRL-LICENSE file for this file's software license. */
2
3 #include <inttypes.h>
4 #include <stdio.h>
5 #include <stdlib.h>
6 #include <string.h>
7 #include <sys/time.h>
8 #include <time.h>
9
10 #include "common_accounting.h"
11 #include "samplehold.h"
12
13 static int match(const key_flow *key, const sampled_flow *flow) {
14     if (flow->state != FLOW_USED)
15         return 0;
16
17     if (key->source_ip != flow->source_ip)
18         return 0;
19
20     if (key->dest_ip != flow->dest_ip)
21         return 0;
22
23     if (key->source_port != flow->source_port)
24         return 0;
25
26     if (key->dest_port != flow->dest_port)
27         return 0;
28
29     if (key->protocol != flow->protocol)
30         return 0;
31
32     return 1;
33 }
34
35 static void get_key(key_flow *key, sampled_flow *flow) {
36     key->source_ip = flow->source_ip;
37     key->dest_ip = flow->dest_ip;
38     key->source_port = flow->source_port;
39     key->dest_port = flow->dest_port;
40     key->protocol = flow->protocol;
41
42     key->packet_size = 0;
43 }
44
45 static void move_flow(sampled_flow *dest, sampled_flow *src) {
46     memmove(dest, src, sizeof(sampled_flow));
47     memset(src, 0, sizeof(sampled_flow));
48 }
49
50 uint32_t sampled_table_size(const sampled_flow_table table) {
51     return table->size;
52 }
53
54 /*
55  * Notes to myself...
56  *
57  * max_bytes is the maximum number of bytes that can pass though DURING THE
58  * MEASUREMENT INTERVAL.  So, if you can have 100 Mbit/s and your measurement
59  * interval is 1/10 of a second, your max_bytes is 10Mbit because that's all
60  * you can transfer in 1/10 of a second.
61  *
62  * flow_percentage is the percentage of max_bytes that is considered an
63  * interesting flow.
64  *
65  * oversampling factor is a knob that tunes how accurate our results are at
66  * the cost of additional state/memory.
67  */
68 sampled_flow_table sampled_table_create(uint32_t (*hash_function)(const key_flow *key), const uint32_t max_bytes, const uint32_t flow_percentage, const uint32_t oversampling_factor, common_accounting_t *common) {
69     sampled_flow_table table = malloc(sizeof(struct sampled_flow_table));
70     double base_size = (double) 100 / (double) flow_percentage;
71
72     if (table == NULL) {
73         return NULL;
74     }
75
76     table->capacity = (uint32_t) ((base_size * oversampling_factor) * 1.03);
77     table->size = 0;
78     table->hash_function = hash_function;
79     table->sample_prob = (double) (((double) table->capacity / (double) max_bytes) * (double) RANDOM_GRANULARITY);
80     table->threshold = (double) ((double) flow_percentage / 100) * max_bytes;
81
82     table->largest = NULL;
83     table->backing = malloc(sizeof(sampled_flow) * table->capacity);
84
85     if (table->backing == NULL) {
86         free(table);
87         return NULL;
88     }
89
90     memset(table->backing, 0, sizeof(sampled_flow) * table->capacity);
91
92     srand(time(NULL));
93
94     table->common = common;
95     gettimeofday(&table->common->last_update, NULL);
96
97     return table;
98 }
99
100 void sampled_table_destroy(sampled_flow_table table) {
101     free(table->backing);
102     free(table);
103 }
104
105 sampled_flow *sampled_table_lookup(sampled_flow_table table, const key_flow *key) {
106     uint32_t hash = table->hash_function(key) % table->capacity;
107     uint32_t location = hash;
108
109     do {
110         if (table->backing[location].state == FLOW_FREE) {
111             /* It ain't here... */
112             return NULL;
113         }
114
115         if (match(key, &table->backing[location])) {
116             /* Got it! */
117             return &table->backing[location];
118         }
119
120         location++;
121         if (location == table->capacity) {
122             location = 0;
123         }
124     } while (location != hash);
125
126     return NULL;
127 }
128
129 int sampled_table_sample(sampled_flow_table table, const key_flow *key) {
130     sampled_flow *lookup = sampled_table_lookup(table, key);
131     int random_number;
132     double packet_prob;
133
134     /* First we update the common accouting information so that we have accurate
135      * aggregate information. */
136     table->common->bytes_since += key->packet_size;
137
138     /* Below here we're dealing with individual flows. */
139
140     /* It's already in the table, update it. */
141     if (lookup != NULL) {
142         lookup->bytes += key->packet_size;
143         return 1;
144     }
145
146     /* It's not in the table, probabilistically sample it. */
147     packet_prob = table->sample_prob * (double) key->packet_size;
148     random_number = rand() % RANDOM_GRANULARITY;
149
150     if (random_number < packet_prob) {
151         /* It's being sampled - add it to the table. */
152         uint32_t hash = table->hash_function(key) % table->capacity;
153         uint32_t location = hash;
154
155         do {
156             if (table->backing[location].state == FLOW_FREE ||
157                 table->backing[location].state == FLOW_DELETED) {
158                 lookup = &table->backing[location];
159                 break;
160             }
161
162             location++;
163             if (location == table->capacity) {
164                 location = 0;
165             }
166         } while (location != hash);
167
168         if (lookup == NULL) {
169             /* Table is full!?! */
170             printf("Full table!\n");
171             return 0;
172         }
173
174         table->size += 1;
175
176         lookup->bytes = key->packet_size;
177         lookup->source_ip = key->source_ip;
178         lookup->dest_ip = key->dest_ip;
179         lookup->source_port = key->source_port;
180         lookup->dest_port = key->dest_port;
181         lookup->protocol = key->protocol;
182         lookup->state = FLOW_USED;
183         lookup->last_bytes = 0;
184         lookup->rate = 0;
185
186         gettimeofday(&lookup->last_update, NULL);
187
188         return 1;
189     }
190
191     /* Not sampled. */
192     return 0;
193 }
194
195 int sampled_table_cleanup(sampled_flow_table table) {
196     /* This should...
197      * 1) Remove "small" flows from the table.
198      * 2) Compact the table so that the remaining flows are closer to their
199      * hash locations.
200      * 3) Reset the state of deleted flows to free.
201      */
202
203     /* How it might work...
204      * 1) Scan through the backing array.
205      * 2) If the flow is small, memset it to 0.
206      *    It it's large, add it to a linked list.
207      * 3) For all items in the linked list, hash them and put them in the
208      * correct location.
209      */
210
211     /* For now though, we're going to do it the inefficient way and loop
212      * through the backing twice.
213      */
214
215     int i;
216
217     /* Clear small items. */
218     for (i = 0; i < table->capacity; ++i) {
219         if (table->backing[i].state == FLOW_USED && table->backing[i].bytes > table->threshold) {
220             /* It gets to stick around. */
221         } else {
222             /* It dies... */
223             memset(&table->backing[i], 0, sizeof(sampled_flow));
224         }
225     }
226
227     /* Compact the table and put things closer to their hash locations. */
228     for (i = 0; i < table->capacity; ++i) {
229         if (table->backing[i].state == FLOW_USED) {
230             uint32_t hash;
231             key_flow key;
232             
233             get_key(&key, &table->backing[i]);
234             hash = table->hash_function(&key) % table->capacity;
235
236             if (i == hash) {
237                 /* Already in the best place */
238                 table->backing[i].bytes = 0;
239                 table->backing[i].last_bytes = 0;
240                 table->backing[i].rate = 0;
241             } else {
242                 uint32_t location = hash;
243
244                 do {
245                     if (table->backing[location].state == FLOW_FREE) {
246                         move_flow(&table->backing[location], &table->backing[i]);
247                         table->backing[location].bytes = 0;
248                         table->backing[location].last_bytes = 0;
249                         table->backing[location].rate = 0;
250                         break;
251                     }
252
253                     location++;
254                     if (location == table->capacity) {
255                         location = 0;
256                     }
257                 } while (location != hash);
258             }
259         }
260     }
261
262     table->largest = NULL;
263
264     return 0;
265 }
266
267 void sampled_table_update_flows(sampled_flow_table table, struct timeval now, double ewma_weight) {
268     int i = 0;
269     uint32_t largest_rate = 0;
270     uint32_t rate_delta = 0;
271     double time_delta = 0;
272     double unweighted_rate = 0;
273
274     /* Update common aggregate information. */
275     time_delta = timeval_subtract(now, table->common->last_update);
276
277     if (time_delta <= 0) {
278         unweighted_rate = 0;
279     } else {
280         unweighted_rate = table->common->bytes_since / time_delta;
281     }
282
283     table->common->last_inst_rate = table->common->inst_rate;
284     table->common->inst_rate = unweighted_rate;
285
286     table->common->last_rate = table->common->rate;
287
288     /* If the rate is zero, then we don't know anything yet.  Don't apply EWMA
289      * in that case. */
290     if (table->common->rate == 0) {
291         table->common->rate = unweighted_rate;
292     } else {
293         table->common->rate = table->common->rate * ewma_weight +
294                               unweighted_rate * (1 - ewma_weight);
295     }
296
297     table->common->bytes_since = 0;
298     table->common->last_update = now;
299
300     /* Update per-flow information. */
301     table->largest = &table->backing[i];
302     largest_rate = table->backing[i].rate;
303
304     for (i = 0; i < table->capacity; ++i) {
305         if (table->backing[i].state == FLOW_USED) {
306             rate_delta = table->backing[i].bytes - table->backing[i].last_bytes;
307             time_delta = timeval_subtract(now, table->backing[i].last_update);
308
309             /* Calculate the unweighted rate.  Be careful not to divide by
310              * something silly. */
311             if (time_delta <= 0) {
312                 unweighted_rate = 0;
313             } else {
314                 unweighted_rate = rate_delta / time_delta; 
315             }
316
317             if (table->backing[i].rate == 0) {
318                 table->backing[i].rate = unweighted_rate;
319             } else {
320                 table->backing[i].rate = (table->backing[i].rate * ewma_weight +
321                                           unweighted_rate * (1 - ewma_weight));
322             }
323
324             table->backing[i].last_bytes = table->backing[i].bytes;
325             table->backing[i].last_update = now;
326
327             if (table->backing[i].rate > largest_rate) {
328                 largest_rate = table->backing[i].rate;
329                 table->largest = &table->backing[i];
330             }
331         }
332     }
333
334     table->common->max_flow_rate = largest_rate;
335 }
336
337 sampled_flow *sampled_table_largest(sampled_flow_table table) {
338     return table->largest;
339 }