cadafee6389a452f91ae69c1a023ab241fc5250a
[distributedratelimiting.git] / drl / standard.c
1 /* See the DRL-LICENSE file for this file's software license. */
2
3 #include <arpa/inet.h>
4 #include <assert.h>
5 #include <inttypes.h>
6 #include <netinet/in.h>
7 #include <pthread.h>
8 #include <stdio.h>
9 #include <stdlib.h>
10 #include <string.h>
11 #include <sys/time.h>
12 #include <sys/types.h>
13 #include <time.h>
14
15 #include "common_accounting.h"
16 #include "standard.h"
17 #include "logging.h"
18
19 standard_flow_table standard_table_create(uint32_t (*hash_function)(const key_flow *key), common_accounting_t *common) {
20     standard_flow_table table = malloc(sizeof(struct std_flow_table));
21
22     if (table == NULL) {
23         return NULL;
24     }
25
26     memset(table, 0, sizeof(struct std_flow_table));
27     table->common = common;
28     table->hash_function = hash_function;
29
30     gettimeofday(&table->common->last_update, NULL);
31
32     return table;
33 }
34
35 void standard_table_destroy(standard_flow_table table) {
36     standard_flow *current, *next;
37
38     if ((current = table->flows_head)) {
39         while (current->next) {
40             next = current->next;
41             free(current);
42             current = next;
43         }
44         free(current);
45     }
46
47     free(table);
48 }
49
50 /* Looks for the flow in the table.  If the flow isn't there, it allocates a
51  * place for it. */
52 standard_flow *standard_table_lookup(standard_flow_table table, const key_flow *key) {
53     uint32_t hash;
54     standard_flow *flow;
55     struct in_addr src, dst;
56     char sip[22], dip[22];
57
58     if (table == NULL) {
59         return NULL;
60     }
61
62     hash = table->hash_function(key);
63
64     /* Find the flow, if it's there. */
65     for (flow = table->flows[hash]; flow; flow = flow->nexth) {
66         if (flow->source_ip == key->source_ip &&
67                 flow->dest_ip == key->dest_ip &&
68                 flow->source_port == key->source_port &&
69                 flow->dest_port == key->dest_port &&
70                 flow->protocol == key->protocol) {
71             break;
72         }
73     }
74
75     if (flow == NULL) {
76         flow = malloc(sizeof(standard_flow));
77         if (flow == NULL) {
78             printf("Malloc returned null.\n");
79             printlog(LOG_CRITICAL, "ALLOC: Malloc returned NULL.\n");
80             return NULL;
81         }
82
83         memset(flow, 0, sizeof(standard_flow));
84         flow->protocol = key->protocol;
85         flow->source_ip = key->source_ip;
86         flow->dest_ip = key->dest_ip;
87         flow->source_port = key->source_port;
88         flow->dest_port = key->dest_port;
89         flow->last_packet = key->packet_time;
90         gettimeofday(&flow->last_update, NULL);
91
92         /* Add the flow to the hash list. */
93         flow->nexth = table->flows[hash];
94         table->flows[hash] = flow;
95
96         /* Add the flow to the linked list. */
97         if (table->flows_tail) {
98             flow->prev = table->flows_tail;
99             table->flows_tail->next = flow;
100             table->flows_tail = flow;
101         } else {
102             table->flows_head = table->flows_tail = flow;
103             /* next and prev are already null due to memset above. */
104         }
105
106         src.s_addr = ntohl(flow->source_ip);
107         dst.s_addr = ntohl(flow->dest_ip);
108         strcpy(sip, inet_ntoa(src));
109         strcpy(dip, inet_ntoa(dst));
110         printlog(LOG_DEBUG, "ALLOC:%s:%hd -> %s:%hd\n", sip,
111                 flow->source_port, dip, flow->dest_port);
112     }
113
114     return flow;
115 }
116
117 int standard_table_sample(standard_flow_table table, const key_flow *key) {
118     standard_flow *flow;
119
120     assert(table != NULL);
121     assert(table->common != NULL);
122
123     /* Update aggregate. */
124     table->common->bytes_since += key->packet_size;
125
126     /* Update flow. */
127     flow = standard_table_lookup(table, key);
128     if (flow == NULL) {
129         return 0;
130     }
131
132     flow->bytes_since += key->packet_size;
133     flow->last_packet = key->packet_time;
134
135     return 1;
136 }
137
138 void standard_table_remove(standard_flow_table table, standard_flow *flow) {
139     key_flow key;
140     uint32_t hash;
141     standard_flow *current, *prev;
142
143     assert(flow);
144
145     /* Remove the flow from the hash list. */
146     key.source_ip = flow->source_ip;
147     key.dest_ip = flow->dest_ip;
148     key.source_port = flow->source_port;
149     key.dest_port = flow->dest_port;
150     key.protocol = flow->protocol;
151
152     hash = table->hash_function(&key);
153
154     assert(table->flows[hash]);
155
156     if (table->flows[hash] == flow) {
157         /* It's the head of the hash list. */
158         table->flows[hash] = flow->nexth;
159     } else {
160         prev = table->flows[hash];
161         current = table->flows[hash]->nexth;
162
163         while (current != NULL) {
164             if (current == flow) {
165                 prev->nexth = flow->nexth;
166                 break;
167             } else {
168                 prev = current;
169                 current = current->next;
170             }
171         }
172
173         assert(current != NULL);
174     }
175
176     /* Remove the flow from the linked list. */
177     if (flow->prev == NULL && flow->next == NULL) {
178         /* It's the head, tail, and only element of the list. */
179         assert(table->flows_head == flow);
180         assert(table->flows_tail == flow);
181
182         table->flows_head = NULL;
183         table->flows_tail = NULL;
184     } else if (flow->prev == NULL) {
185         /* It's the head of the list. */
186         assert(table->flows_head == flow);
187
188         table->flows_head = flow->next;
189
190         if (table->flows_head != NULL) {
191             table->flows_head->prev = NULL;
192         }
193     } else if (flow->next == NULL) {
194         /* It's the tail of the list. */
195         assert(table->flows_tail == flow);
196
197         table->flows_tail = flow->prev;
198
199         table->flows_tail->next = NULL;
200     } else {
201         /* Not the head or tail of the list. */
202         assert(table->flows_head != flow);
203
204         flow->prev->next = flow->next;
205
206         if (flow->next != NULL) {
207             flow->next->prev = flow->prev;
208         }
209     }
210
211     memset(flow, 0, sizeof(standard_flow));
212
213     /* Free the flow. */
214     free(flow);
215 }
216
217 int standard_table_cleanup(standard_flow_table table) {
218     standard_flow *current = table->flows_head;
219     standard_flow *remove;
220     time_t now = time(NULL);
221
222     while (current != NULL) {
223         if (current->last_packet + FLOW_IDLE_TIME <= now) {
224             /* Flow hasn't received a packet in the time limit - kill it. */
225             remove = current;
226             current = current->next;
227
228             standard_table_remove(table, remove);
229         } else {
230             current = current->next;
231         }
232     }
233
234     return 0;
235 }
236
237 void standard_table_update_flows(standard_flow_table table, struct timeval now, double ewma_weight) {
238     uint32_t maxflowrate = 0;
239     double time_delta;
240     double unweighted_rate;
241     standard_flow *current;
242     struct in_addr src, dst;
243     char sip[22], dip[22];
244
245     time_delta = timeval_subtract(now, table->common->last_update);
246
247     if (time_delta <= 0) {
248         unweighted_rate = 0;
249     } else {
250         unweighted_rate = table->common->bytes_since / time_delta;
251     }
252
253     table->common->last_inst_rate = table->common->inst_rate;
254     table->common->inst_rate = unweighted_rate;
255
256     table->common->last_rate = table->common->rate;
257
258     /* If the rate is zero, then we don't know anything yet.  Don't apply EWMA
259      * in that case. */
260     if (table->common->rate == 0) {
261         table->common->rate = unweighted_rate;
262     } else {
263         table->common->rate = table->common->rate * ewma_weight +
264                               unweighted_rate * (1 - ewma_weight);
265     }
266
267     table->common->bytes_since = 0;
268     table->common->last_update = now;
269
270     //printf("Flows: ");
271
272     /* Update per-flow information. */
273     for (current = table->flows_head; current; current = current->next) {
274         time_delta = timeval_subtract(now, current->last_update);
275
276         if (time_delta <= 0) {
277             unweighted_rate = 0;
278         } else {
279             unweighted_rate = current->bytes_since / time_delta;
280         }
281
282         current->last_rate = current->rate;
283
284         if (current->rate == 0) {
285             current->rate = unweighted_rate;
286         } else {
287             current->rate = current->rate * ewma_weight +
288                             unweighted_rate * (1 - ewma_weight);
289         }
290
291         current->bytes_since = 0;
292         current->last_update = now;
293
294         if (current->rate > maxflowrate) {
295             maxflowrate = current->rate;
296         }
297
298         //printf("%d, ", current->rate);
299
300         src.s_addr = ntohl(current->source_ip);
301         dst.s_addr = ntohl(current->dest_ip);
302         strcpy(sip, inet_ntoa(src));
303         strcpy(dip, inet_ntoa(dst));
304         printlog(LOG_DEBUG, "FLOW: (%p)  %s:%d -> %s:%d at %d\n", current,
305                 sip, current->source_port,
306                 dip, current->dest_port,
307                 current->rate);
308     }
309
310     //printf("\n");
311     printlog(LOG_DEBUG, "FLOW:--\n--\n");
312
313     table->common->max_flow_rate = maxflowrate;
314 }