5c6a6d8537fe05f1f7bee39599d67e8b6a14388f
[distributedratelimiting.git] / drl / standard.c
1 /* See the DRL-LICENSE file for this file's software license. */
2
3 #include <arpa/inet.h>
4 #include <assert.h>
5 #include <inttypes.h>
6 #include <netinet/in.h>
7 #include <pthread.h>
8 #include <stdio.h>
9 #include <stdlib.h>
10 #include <string.h>
11 #include <sys/time.h>
12 #include <sys/types.h>
13 #include <time.h>
14
15 #include "common_accounting.h"
16 #include "standard.h"
17 #include "logging.h"
18
19 standard_flow_table standard_table_create(uint32_t (*hash_function)(const key_flow *key), common_accounting_t *common) {
20     standard_flow_table table = malloc(sizeof(struct std_flow_table));
21
22     if (table == NULL) {
23         return NULL;
24     }
25
26     memset(table, 0, sizeof(struct std_flow_table));
27     table->common = common;
28     table->hash_function = hash_function;
29
30     gettimeofday(&table->common->last_update, NULL);
31
32     return table;
33 }
34
35 void standard_table_destroy(standard_flow_table table) {
36     standard_flow *current, *next;
37
38     if ((current = table->flows_head)) {
39         while (current->next) {
40             next = current->next;
41             free(current);
42             current = next;
43         }
44         free(current);
45     }
46
47     free(table);
48 }
49
50 /* Looks for the flow in the table.  If the flow isn't there, it allocates a
51  * place for it. */
52 standard_flow *standard_table_lookup(standard_flow_table table, const key_flow *key) {
53     uint32_t hash;
54     standard_flow *flow;
55     struct in_addr src, dst;
56     char sip[22], dip[22];
57
58     if (table == NULL) {
59         return NULL;
60     }
61
62     hash = table->hash_function(key);
63
64     /* Find the flow, if it's there. */
65     for (flow = table->flows[hash]; flow; flow = flow->nexth) {
66         if (flow->source_ip == key->source_ip &&
67                 flow->dest_ip == key->dest_ip &&
68                 flow->source_port == key->source_port &&
69                 flow->dest_port == key->dest_port &&
70                 flow->protocol == key->protocol) {
71             break;
72         }
73     }
74
75     if (flow == NULL) {
76         flow = malloc(sizeof(standard_flow));
77         if (flow == NULL) {
78             printlog(LOG_CRITICAL, "standard.c: Malloc returned NULL.\n");
79             return NULL;
80         }
81
82         memset(flow, 0, sizeof(standard_flow));
83         flow->protocol = key->protocol;
84         flow->source_ip = key->source_ip;
85         flow->dest_ip = key->dest_ip;
86         flow->source_port = key->source_port;
87         flow->dest_port = key->dest_port;
88         flow->last_packet = key->packet_time;
89         gettimeofday(&flow->last_update, NULL);
90
91         /* Add the flow to the hash list. */
92         flow->nexth = table->flows[hash];
93         table->flows[hash] = flow;
94
95         /* Add the flow to the linked list. */
96         if (table->flows_tail) {
97             flow->prev = table->flows_tail;
98             table->flows_tail->next = flow;
99             table->flows_tail = flow;
100         } else {
101             table->flows_head = table->flows_tail = flow;
102             /* next and prev are already null due to memset above. */
103         }
104
105         src.s_addr = ntohl(flow->source_ip);
106         dst.s_addr = ntohl(flow->dest_ip);
107         strcpy(sip, inet_ntoa(src));
108         strcpy(dip, inet_ntoa(dst));
109         printlog(LOG_DEBUG, "ALLOC:%s:%hu -> %s:%hu\n", sip,
110                 flow->source_port, dip, flow->dest_port);
111     }
112
113     return flow;
114 }
115
116 int standard_table_sample(standard_flow_table table, const key_flow *key) {
117     standard_flow *flow;
118
119     assert(table != NULL);
120     assert(table->common != NULL);
121
122     /* Update aggregate. */
123     table->common->bytes_since += key->packet_size;
124
125     /* Update flow. */
126     flow = standard_table_lookup(table, key);
127     if (flow == NULL) {
128         return 0;
129     }
130
131     flow->bytes_since += key->packet_size;
132     flow->last_packet = key->packet_time;
133
134     return 1;
135 }
136
137 void standard_table_remove(standard_flow_table table, standard_flow *flow) {
138     key_flow key;
139     uint32_t hash;
140
141     assert(flow);
142
143     /* Remove the flow from the hash list. */
144     key.source_ip = flow->source_ip;
145     key.dest_ip = flow->dest_ip;
146     key.source_port = flow->source_port;
147     key.dest_port = flow->dest_port;
148     key.protocol = flow->protocol;
149
150     hash = table->hash_function(&key);
151
152     assert(table->flows[hash]);
153
154     if (table->flows[hash] == flow) {
155         /* It's the head of the hash list. */
156         table->flows[hash] = flow->nexth;
157     } else {
158         standard_flow *current, *prev;
159         
160         prev = table->flows[hash];
161
162         for (current = table->flows[hash]->nexth; current; current = current->nexth) {
163             if (current == flow) {
164                 prev->nexth = flow->nexth;
165                 break;
166             } else {
167                 prev = current;
168             }
169         }
170
171         if (current == NULL) {
172             printlog(LOG_CRITICAL, "Flow %p disappeared?\n", flow);
173         }
174         assert(current != NULL);
175     }
176
177     /* Remove the flow from the linked list. */
178     if (flow->prev == NULL && flow->next == NULL) {
179         /* It's the head, tail, and only element of the list. */
180         assert(table->flows_head == flow);
181         assert(table->flows_tail == flow);
182
183         table->flows_head = NULL;
184         table->flows_tail = NULL;
185     } else if (flow->prev == NULL) {
186         /* It's the head of the list. */
187         assert(table->flows_head == flow);
188
189         table->flows_head = flow->next;
190
191         if (table->flows_head != NULL) {
192             table->flows_head->prev = NULL;
193         }
194     } else if (flow->next == NULL) {
195         /* It's the tail of the list. */
196         assert(table->flows_tail == flow);
197
198         table->flows_tail = flow->prev;
199
200         table->flows_tail->next = NULL;
201     } else {
202         /* Not the head or tail of the list. */
203         assert(table->flows_head != flow);
204
205         flow->prev->next = flow->next;
206
207         if (flow->next != NULL) {
208             flow->next->prev = flow->prev;
209         }
210     }
211
212     memset(flow, 0, sizeof(standard_flow));
213
214     /* Free the flow. */
215     free(flow);
216 }
217
218 int standard_table_cleanup(standard_flow_table table) {
219     standard_flow *current = table->flows_head;
220     standard_flow *remove;
221     time_t now = time(NULL);
222
223     while (current != NULL) {
224         if (current->last_packet + FLOW_IDLE_TIME <= now) {
225             /* Flow hasn't received a packet in the time limit - kill it. */
226             remove = current;
227             current = current->next;
228
229             standard_table_remove(table, remove);
230         } else {
231             current = current->next;
232         }
233     }
234
235     return 0;
236 }
237
238 void standard_table_update_flows(standard_flow_table table, struct timeval now, double ewma_weight) {
239     uint32_t maxflowrate = 0;
240     double time_delta;
241     double unweighted_rate;
242     standard_flow *current;
243     struct in_addr src, dst;
244     char sip[22], dip[22];
245
246     /* Reset statistics. */
247     table->common->num_flows = 0;
248     table->common->num_flows_5k = 0;
249     table->common->num_flows_10k = 0;
250     table->common->num_flows_20k = 0;
251     table->common->num_flows_50k = 0;
252     table->common->avg_rate = 0;
253     /* End statistics. */
254
255     time_delta = timeval_subtract(now, table->common->last_update);
256
257     if (time_delta <= 0) {
258         unweighted_rate = 0;
259     } else {
260         unweighted_rate = table->common->bytes_since / time_delta;
261     }
262
263     table->common->last_inst_rate = table->common->inst_rate;
264     table->common->inst_rate = unweighted_rate;
265
266     table->common->last_rate = table->common->rate;
267
268     /* If the rate is zero, then we don't know anything yet.  Don't apply EWMA
269      * in that case. */
270     if (table->common->rate == 0) {
271         table->common->rate = unweighted_rate;
272     } else {
273         table->common->rate = table->common->rate * ewma_weight +
274                               unweighted_rate * (1 - ewma_weight);
275     }
276
277     table->common->bytes_since = 0;
278     table->common->last_update = now;
279
280     /* Update per-flow information. */
281     for (current = table->flows_head; current; current = current->next) {
282         time_delta = timeval_subtract(now, current->last_update);
283
284         if (time_delta <= 0) {
285             unweighted_rate = 0;
286         } else {
287             unweighted_rate = current->bytes_since / time_delta;
288         }
289
290         current->last_rate = current->rate;
291
292         if (current->rate == 0) {
293             current->rate = unweighted_rate;
294         } else {
295             current->rate = current->rate * ewma_weight +
296                             unweighted_rate * (1 - ewma_weight);
297         }
298
299         current->bytes_since = 0;
300         current->last_update = now;
301
302         if (current->rate > maxflowrate) {
303             maxflowrate = current->rate;
304         }
305
306         if (current->rate > 51200) {
307             table->common->num_flows_50k += 1;
308             table->common->num_flows_20k += 1;
309             table->common->num_flows_10k += 1;
310             table->common->num_flows_5k += 1;
311             table->common->num_flows += 1;
312         } else if (current->rate > 20480) {
313             table->common->num_flows_20k += 1;
314             table->common->num_flows_10k += 1;
315             table->common->num_flows_5k += 1;
316             table->common->num_flows += 1;
317         } else if (current->rate > 10240) {
318             table->common->num_flows_10k += 1;
319             table->common->num_flows_5k += 1;
320             table->common->num_flows += 1;
321         } else if (current->rate > 5120) {
322             table->common->num_flows_5k += 1;
323             table->common->num_flows += 1;
324         } else {
325             table->common->num_flows += 1;
326         }
327
328         src.s_addr = ntohl(current->source_ip);
329         dst.s_addr = ntohl(current->dest_ip);
330         strcpy(sip, inet_ntoa(src));
331         strcpy(dip, inet_ntoa(dst));
332         printlog(LOG_DEBUG, "FLOW: (%p)  %s:%d -> %s:%d at %d\n", current,
333                 sip, current->source_port,
334                 dip, current->dest_port,
335                 current->rate);
336     }
337
338     if (table->common->num_flows > 0) {
339         table->common->avg_rate = table->common->rate / table->common->num_flows;
340     }
341
342     printlog(LOG_DEBUG, "FLOW:--\n--\n");
343
344     table->common->max_flow_rate = maxflowrate;
345 }