Merge branch 'viafb-next' of git://git.lwn.net/linux-2.6
[linux-drm-fsl-dcu.git] / net / xfrm / xfrm_state.c
1 /*
2  * xfrm_state.c
3  *
4  * Changes:
5  *      Mitsuru KANDA @USAGI
6  *      Kazunori MIYAZAWA @USAGI
7  *      Kunihiro Ishiguro <kunihiro@ipinfusion.com>
8  *              IPv6 support
9  *      YOSHIFUJI Hideaki @USAGI
10  *              Split up af-specific functions
11  *      Derek Atkins <derek@ihtfp.com>
12  *              Add UDP Encapsulation
13  *
14  */
15
16 #include <linux/workqueue.h>
17 #include <net/xfrm.h>
18 #include <linux/pfkeyv2.h>
19 #include <linux/ipsec.h>
20 #include <linux/module.h>
21 #include <linux/cache.h>
22 #include <linux/audit.h>
23 #include <asm/uaccess.h>
24 #include <linux/ktime.h>
25 #include <linux/slab.h>
26 #include <linux/interrupt.h>
27 #include <linux/kernel.h>
28
29 #include "xfrm_hash.h"
30
31 /* Each xfrm_state may be linked to two tables:
32
33    1. Hash table by (spi,daddr,ah/esp) to find SA by SPI. (input,ctl)
34    2. Hash table by (daddr,family,reqid) to find what SAs exist for given
35       destination/tunnel endpoint. (output)
36  */
37
38 static DEFINE_SPINLOCK(xfrm_state_lock);
39
40 static unsigned int xfrm_state_hashmax __read_mostly = 1 * 1024 * 1024;
41 static unsigned int xfrm_state_genid;
42
43 static struct xfrm_state_afinfo *xfrm_state_get_afinfo(unsigned int family);
44 static void xfrm_state_put_afinfo(struct xfrm_state_afinfo *afinfo);
45
46 #ifdef CONFIG_AUDITSYSCALL
47 static void xfrm_audit_state_replay(struct xfrm_state *x,
48                                     struct sk_buff *skb, __be32 net_seq);
49 #else
50 #define xfrm_audit_state_replay(x, s, sq)       do { ; } while (0)
51 #endif /* CONFIG_AUDITSYSCALL */
52
53 static inline unsigned int xfrm_dst_hash(struct net *net,
54                                          xfrm_address_t *daddr,
55                                          xfrm_address_t *saddr,
56                                          u32 reqid,
57                                          unsigned short family)
58 {
59         return __xfrm_dst_hash(daddr, saddr, reqid, family, net->xfrm.state_hmask);
60 }
61
62 static inline unsigned int xfrm_src_hash(struct net *net,
63                                          xfrm_address_t *daddr,
64                                          xfrm_address_t *saddr,
65                                          unsigned short family)
66 {
67         return __xfrm_src_hash(daddr, saddr, family, net->xfrm.state_hmask);
68 }
69
70 static inline unsigned int
71 xfrm_spi_hash(struct net *net, xfrm_address_t *daddr, __be32 spi, u8 proto, unsigned short family)
72 {
73         return __xfrm_spi_hash(daddr, spi, proto, family, net->xfrm.state_hmask);
74 }
75
76 static void xfrm_hash_transfer(struct hlist_head *list,
77                                struct hlist_head *ndsttable,
78                                struct hlist_head *nsrctable,
79                                struct hlist_head *nspitable,
80                                unsigned int nhashmask)
81 {
82         struct hlist_node *entry, *tmp;
83         struct xfrm_state *x;
84
85         hlist_for_each_entry_safe(x, entry, tmp, list, bydst) {
86                 unsigned int h;
87
88                 h = __xfrm_dst_hash(&x->id.daddr, &x->props.saddr,
89                                     x->props.reqid, x->props.family,
90                                     nhashmask);
91                 hlist_add_head(&x->bydst, ndsttable+h);
92
93                 h = __xfrm_src_hash(&x->id.daddr, &x->props.saddr,
94                                     x->props.family,
95                                     nhashmask);
96                 hlist_add_head(&x->bysrc, nsrctable+h);
97
98                 if (x->id.spi) {
99                         h = __xfrm_spi_hash(&x->id.daddr, x->id.spi,
100                                             x->id.proto, x->props.family,
101                                             nhashmask);
102                         hlist_add_head(&x->byspi, nspitable+h);
103                 }
104         }
105 }
106
107 static unsigned long xfrm_hash_new_size(unsigned int state_hmask)
108 {
109         return ((state_hmask + 1) << 1) * sizeof(struct hlist_head);
110 }
111
112 static DEFINE_MUTEX(hash_resize_mutex);
113
114 static void xfrm_hash_resize(struct work_struct *work)
115 {
116         struct net *net = container_of(work, struct net, xfrm.state_hash_work);
117         struct hlist_head *ndst, *nsrc, *nspi, *odst, *osrc, *ospi;
118         unsigned long nsize, osize;
119         unsigned int nhashmask, ohashmask;
120         int i;
121
122         mutex_lock(&hash_resize_mutex);
123
124         nsize = xfrm_hash_new_size(net->xfrm.state_hmask);
125         ndst = xfrm_hash_alloc(nsize);
126         if (!ndst)
127                 goto out_unlock;
128         nsrc = xfrm_hash_alloc(nsize);
129         if (!nsrc) {
130                 xfrm_hash_free(ndst, nsize);
131                 goto out_unlock;
132         }
133         nspi = xfrm_hash_alloc(nsize);
134         if (!nspi) {
135                 xfrm_hash_free(ndst, nsize);
136                 xfrm_hash_free(nsrc, nsize);
137                 goto out_unlock;
138         }
139
140         spin_lock_bh(&xfrm_state_lock);
141
142         nhashmask = (nsize / sizeof(struct hlist_head)) - 1U;
143         for (i = net->xfrm.state_hmask; i >= 0; i--)
144                 xfrm_hash_transfer(net->xfrm.state_bydst+i, ndst, nsrc, nspi,
145                                    nhashmask);
146
147         odst = net->xfrm.state_bydst;
148         osrc = net->xfrm.state_bysrc;
149         ospi = net->xfrm.state_byspi;
150         ohashmask = net->xfrm.state_hmask;
151
152         net->xfrm.state_bydst = ndst;
153         net->xfrm.state_bysrc = nsrc;
154         net->xfrm.state_byspi = nspi;
155         net->xfrm.state_hmask = nhashmask;
156
157         spin_unlock_bh(&xfrm_state_lock);
158
159         osize = (ohashmask + 1) * sizeof(struct hlist_head);
160         xfrm_hash_free(odst, osize);
161         xfrm_hash_free(osrc, osize);
162         xfrm_hash_free(ospi, osize);
163
164 out_unlock:
165         mutex_unlock(&hash_resize_mutex);
166 }
167
168 static DEFINE_RWLOCK(xfrm_state_afinfo_lock);
169 static struct xfrm_state_afinfo *xfrm_state_afinfo[NPROTO];
170
171 static DEFINE_SPINLOCK(xfrm_state_gc_lock);
172
173 int __xfrm_state_delete(struct xfrm_state *x);
174
175 int km_query(struct xfrm_state *x, struct xfrm_tmpl *t, struct xfrm_policy *pol);
176 void km_state_expired(struct xfrm_state *x, int hard, u32 pid);
177
178 static struct xfrm_state_afinfo *xfrm_state_lock_afinfo(unsigned int family)
179 {
180         struct xfrm_state_afinfo *afinfo;
181         if (unlikely(family >= NPROTO))
182                 return NULL;
183         write_lock_bh(&xfrm_state_afinfo_lock);
184         afinfo = xfrm_state_afinfo[family];
185         if (unlikely(!afinfo))
186                 write_unlock_bh(&xfrm_state_afinfo_lock);
187         return afinfo;
188 }
189
190 static void xfrm_state_unlock_afinfo(struct xfrm_state_afinfo *afinfo)
191         __releases(xfrm_state_afinfo_lock)
192 {
193         write_unlock_bh(&xfrm_state_afinfo_lock);
194 }
195
196 int xfrm_register_type(const struct xfrm_type *type, unsigned short family)
197 {
198         struct xfrm_state_afinfo *afinfo = xfrm_state_lock_afinfo(family);
199         const struct xfrm_type **typemap;
200         int err = 0;
201
202         if (unlikely(afinfo == NULL))
203                 return -EAFNOSUPPORT;
204         typemap = afinfo->type_map;
205
206         if (likely(typemap[type->proto] == NULL))
207                 typemap[type->proto] = type;
208         else
209                 err = -EEXIST;
210         xfrm_state_unlock_afinfo(afinfo);
211         return err;
212 }
213 EXPORT_SYMBOL(xfrm_register_type);
214
215 int xfrm_unregister_type(const struct xfrm_type *type, unsigned short family)
216 {
217         struct xfrm_state_afinfo *afinfo = xfrm_state_lock_afinfo(family);
218         const struct xfrm_type **typemap;
219         int err = 0;
220
221         if (unlikely(afinfo == NULL))
222                 return -EAFNOSUPPORT;
223         typemap = afinfo->type_map;
224
225         if (unlikely(typemap[type->proto] != type))
226                 err = -ENOENT;
227         else
228                 typemap[type->proto] = NULL;
229         xfrm_state_unlock_afinfo(afinfo);
230         return err;
231 }
232 EXPORT_SYMBOL(xfrm_unregister_type);
233
234 static const struct xfrm_type *xfrm_get_type(u8 proto, unsigned short family)
235 {
236         struct xfrm_state_afinfo *afinfo;
237         const struct xfrm_type **typemap;
238         const struct xfrm_type *type;
239         int modload_attempted = 0;
240
241 retry:
242         afinfo = xfrm_state_get_afinfo(family);
243         if (unlikely(afinfo == NULL))
244                 return NULL;
245         typemap = afinfo->type_map;
246
247         type = typemap[proto];
248         if (unlikely(type && !try_module_get(type->owner)))
249                 type = NULL;
250         if (!type && !modload_attempted) {
251                 xfrm_state_put_afinfo(afinfo);
252                 request_module("xfrm-type-%d-%d", family, proto);
253                 modload_attempted = 1;
254                 goto retry;
255         }
256
257         xfrm_state_put_afinfo(afinfo);
258         return type;
259 }
260
261 static void xfrm_put_type(const struct xfrm_type *type)
262 {
263         module_put(type->owner);
264 }
265
266 int xfrm_register_mode(struct xfrm_mode *mode, int family)
267 {
268         struct xfrm_state_afinfo *afinfo;
269         struct xfrm_mode **modemap;
270         int err;
271
272         if (unlikely(mode->encap >= XFRM_MODE_MAX))
273                 return -EINVAL;
274
275         afinfo = xfrm_state_lock_afinfo(family);
276         if (unlikely(afinfo == NULL))
277                 return -EAFNOSUPPORT;
278
279         err = -EEXIST;
280         modemap = afinfo->mode_map;
281         if (modemap[mode->encap])
282                 goto out;
283
284         err = -ENOENT;
285         if (!try_module_get(afinfo->owner))
286                 goto out;
287
288         mode->afinfo = afinfo;
289         modemap[mode->encap] = mode;
290         err = 0;
291
292 out:
293         xfrm_state_unlock_afinfo(afinfo);
294         return err;
295 }
296 EXPORT_SYMBOL(xfrm_register_mode);
297
298 int xfrm_unregister_mode(struct xfrm_mode *mode, int family)
299 {
300         struct xfrm_state_afinfo *afinfo;
301         struct xfrm_mode **modemap;
302         int err;
303
304         if (unlikely(mode->encap >= XFRM_MODE_MAX))
305                 return -EINVAL;
306
307         afinfo = xfrm_state_lock_afinfo(family);
308         if (unlikely(afinfo == NULL))
309                 return -EAFNOSUPPORT;
310
311         err = -ENOENT;
312         modemap = afinfo->mode_map;
313         if (likely(modemap[mode->encap] == mode)) {
314                 modemap[mode->encap] = NULL;
315                 module_put(mode->afinfo->owner);
316                 err = 0;
317         }
318
319         xfrm_state_unlock_afinfo(afinfo);
320         return err;
321 }
322 EXPORT_SYMBOL(xfrm_unregister_mode);
323
324 static struct xfrm_mode *xfrm_get_mode(unsigned int encap, int family)
325 {
326         struct xfrm_state_afinfo *afinfo;
327         struct xfrm_mode *mode;
328         int modload_attempted = 0;
329
330         if (unlikely(encap >= XFRM_MODE_MAX))
331                 return NULL;
332
333 retry:
334         afinfo = xfrm_state_get_afinfo(family);
335         if (unlikely(afinfo == NULL))
336                 return NULL;
337
338         mode = afinfo->mode_map[encap];
339         if (unlikely(mode && !try_module_get(mode->owner)))
340                 mode = NULL;
341         if (!mode && !modload_attempted) {
342                 xfrm_state_put_afinfo(afinfo);
343                 request_module("xfrm-mode-%d-%d", family, encap);
344                 modload_attempted = 1;
345                 goto retry;
346         }
347
348         xfrm_state_put_afinfo(afinfo);
349         return mode;
350 }
351
352 static void xfrm_put_mode(struct xfrm_mode *mode)
353 {
354         module_put(mode->owner);
355 }
356
357 static void xfrm_state_gc_destroy(struct xfrm_state *x)
358 {
359         tasklet_hrtimer_cancel(&x->mtimer);
360         del_timer_sync(&x->rtimer);
361         kfree(x->aalg);
362         kfree(x->ealg);
363         kfree(x->calg);
364         kfree(x->encap);
365         kfree(x->coaddr);
366         if (x->inner_mode)
367                 xfrm_put_mode(x->inner_mode);
368         if (x->inner_mode_iaf)
369                 xfrm_put_mode(x->inner_mode_iaf);
370         if (x->outer_mode)
371                 xfrm_put_mode(x->outer_mode);
372         if (x->type) {
373                 x->type->destructor(x);
374                 xfrm_put_type(x->type);
375         }
376         security_xfrm_state_free(x);
377         kfree(x);
378 }
379
380 static void xfrm_state_gc_task(struct work_struct *work)
381 {
382         struct net *net = container_of(work, struct net, xfrm.state_gc_work);
383         struct xfrm_state *x;
384         struct hlist_node *entry, *tmp;
385         struct hlist_head gc_list;
386
387         spin_lock_bh(&xfrm_state_gc_lock);
388         hlist_move_list(&net->xfrm.state_gc_list, &gc_list);
389         spin_unlock_bh(&xfrm_state_gc_lock);
390
391         hlist_for_each_entry_safe(x, entry, tmp, &gc_list, gclist)
392                 xfrm_state_gc_destroy(x);
393
394         wake_up(&net->xfrm.km_waitq);
395 }
396
397 static inline unsigned long make_jiffies(long secs)
398 {
399         if (secs >= (MAX_SCHEDULE_TIMEOUT-1)/HZ)
400                 return MAX_SCHEDULE_TIMEOUT-1;
401         else
402                 return secs*HZ;
403 }
404
405 static enum hrtimer_restart xfrm_timer_handler(struct hrtimer * me)
406 {
407         struct tasklet_hrtimer *thr = container_of(me, struct tasklet_hrtimer, timer);
408         struct xfrm_state *x = container_of(thr, struct xfrm_state, mtimer);
409         struct net *net = xs_net(x);
410         unsigned long now = get_seconds();
411         long next = LONG_MAX;
412         int warn = 0;
413         int err = 0;
414
415         spin_lock(&x->lock);
416         if (x->km.state == XFRM_STATE_DEAD)
417                 goto out;
418         if (x->km.state == XFRM_STATE_EXPIRED)
419                 goto expired;
420         if (x->lft.hard_add_expires_seconds) {
421                 long tmo = x->lft.hard_add_expires_seconds +
422                         x->curlft.add_time - now;
423                 if (tmo <= 0)
424                         goto expired;
425                 if (tmo < next)
426                         next = tmo;
427         }
428         if (x->lft.hard_use_expires_seconds) {
429                 long tmo = x->lft.hard_use_expires_seconds +
430                         (x->curlft.use_time ? : now) - now;
431                 if (tmo <= 0)
432                         goto expired;
433                 if (tmo < next)
434                         next = tmo;
435         }
436         if (x->km.dying)
437                 goto resched;
438         if (x->lft.soft_add_expires_seconds) {
439                 long tmo = x->lft.soft_add_expires_seconds +
440                         x->curlft.add_time - now;
441                 if (tmo <= 0)
442                         warn = 1;
443                 else if (tmo < next)
444                         next = tmo;
445         }
446         if (x->lft.soft_use_expires_seconds) {
447                 long tmo = x->lft.soft_use_expires_seconds +
448                         (x->curlft.use_time ? : now) - now;
449                 if (tmo <= 0)
450                         warn = 1;
451                 else if (tmo < next)
452                         next = tmo;
453         }
454
455         x->km.dying = warn;
456         if (warn)
457                 km_state_expired(x, 0, 0);
458 resched:
459         if (next != LONG_MAX){
460                 tasklet_hrtimer_start(&x->mtimer, ktime_set(next, 0), HRTIMER_MODE_REL);
461         }
462
463         goto out;
464
465 expired:
466         if (x->km.state == XFRM_STATE_ACQ && x->id.spi == 0) {
467                 x->km.state = XFRM_STATE_EXPIRED;
468                 wake_up(&net->xfrm.km_waitq);
469                 next = 2;
470                 goto resched;
471         }
472
473         err = __xfrm_state_delete(x);
474         if (!err && x->id.spi)
475                 km_state_expired(x, 1, 0);
476
477         xfrm_audit_state_delete(x, err ? 0 : 1,
478                                 audit_get_loginuid(current),
479                                 audit_get_sessionid(current), 0);
480
481 out:
482         spin_unlock(&x->lock);
483         return HRTIMER_NORESTART;
484 }
485
486 static void xfrm_replay_timer_handler(unsigned long data);
487
488 struct xfrm_state *xfrm_state_alloc(struct net *net)
489 {
490         struct xfrm_state *x;
491
492         x = kzalloc(sizeof(struct xfrm_state), GFP_ATOMIC);
493
494         if (x) {
495                 write_pnet(&x->xs_net, net);
496                 atomic_set(&x->refcnt, 1);
497                 atomic_set(&x->tunnel_users, 0);
498                 INIT_LIST_HEAD(&x->km.all);
499                 INIT_HLIST_NODE(&x->bydst);
500                 INIT_HLIST_NODE(&x->bysrc);
501                 INIT_HLIST_NODE(&x->byspi);
502                 tasklet_hrtimer_init(&x->mtimer, xfrm_timer_handler, CLOCK_REALTIME, HRTIMER_MODE_ABS);
503                 setup_timer(&x->rtimer, xfrm_replay_timer_handler,
504                                 (unsigned long)x);
505                 x->curlft.add_time = get_seconds();
506                 x->lft.soft_byte_limit = XFRM_INF;
507                 x->lft.soft_packet_limit = XFRM_INF;
508                 x->lft.hard_byte_limit = XFRM_INF;
509                 x->lft.hard_packet_limit = XFRM_INF;
510                 x->replay_maxage = 0;
511                 x->replay_maxdiff = 0;
512                 x->inner_mode = NULL;
513                 x->inner_mode_iaf = NULL;
514                 spin_lock_init(&x->lock);
515         }
516         return x;
517 }
518 EXPORT_SYMBOL(xfrm_state_alloc);
519
520 void __xfrm_state_destroy(struct xfrm_state *x)
521 {
522         struct net *net = xs_net(x);
523
524         WARN_ON(x->km.state != XFRM_STATE_DEAD);
525
526         spin_lock_bh(&xfrm_state_gc_lock);
527         hlist_add_head(&x->gclist, &net->xfrm.state_gc_list);
528         spin_unlock_bh(&xfrm_state_gc_lock);
529         schedule_work(&net->xfrm.state_gc_work);
530 }
531 EXPORT_SYMBOL(__xfrm_state_destroy);
532
533 int __xfrm_state_delete(struct xfrm_state *x)
534 {
535         struct net *net = xs_net(x);
536         int err = -ESRCH;
537
538         if (x->km.state != XFRM_STATE_DEAD) {
539                 x->km.state = XFRM_STATE_DEAD;
540                 spin_lock(&xfrm_state_lock);
541                 list_del(&x->km.all);
542                 hlist_del(&x->bydst);
543                 hlist_del(&x->bysrc);
544                 if (x->id.spi)
545                         hlist_del(&x->byspi);
546                 net->xfrm.state_num--;
547                 spin_unlock(&xfrm_state_lock);
548
549                 /* All xfrm_state objects are created by xfrm_state_alloc.
550                  * The xfrm_state_alloc call gives a reference, and that
551                  * is what we are dropping here.
552                  */
553                 xfrm_state_put(x);
554                 err = 0;
555         }
556
557         return err;
558 }
559 EXPORT_SYMBOL(__xfrm_state_delete);
560
561 int xfrm_state_delete(struct xfrm_state *x)
562 {
563         int err;
564
565         spin_lock_bh(&x->lock);
566         err = __xfrm_state_delete(x);
567         spin_unlock_bh(&x->lock);
568
569         return err;
570 }
571 EXPORT_SYMBOL(xfrm_state_delete);
572
573 #ifdef CONFIG_SECURITY_NETWORK_XFRM
574 static inline int
575 xfrm_state_flush_secctx_check(struct net *net, u8 proto, struct xfrm_audit *audit_info)
576 {
577         int i, err = 0;
578
579         for (i = 0; i <= net->xfrm.state_hmask; i++) {
580                 struct hlist_node *entry;
581                 struct xfrm_state *x;
582
583                 hlist_for_each_entry(x, entry, net->xfrm.state_bydst+i, bydst) {
584                         if (xfrm_id_proto_match(x->id.proto, proto) &&
585                            (err = security_xfrm_state_delete(x)) != 0) {
586                                 xfrm_audit_state_delete(x, 0,
587                                                         audit_info->loginuid,
588                                                         audit_info->sessionid,
589                                                         audit_info->secid);
590                                 return err;
591                         }
592                 }
593         }
594
595         return err;
596 }
597 #else
598 static inline int
599 xfrm_state_flush_secctx_check(struct net *net, u8 proto, struct xfrm_audit *audit_info)
600 {
601         return 0;
602 }
603 #endif
604
605 int xfrm_state_flush(struct net *net, u8 proto, struct xfrm_audit *audit_info)
606 {
607         int i, err = 0, cnt = 0;
608
609         spin_lock_bh(&xfrm_state_lock);
610         err = xfrm_state_flush_secctx_check(net, proto, audit_info);
611         if (err)
612                 goto out;
613
614         err = -ESRCH;
615         for (i = 0; i <= net->xfrm.state_hmask; i++) {
616                 struct hlist_node *entry;
617                 struct xfrm_state *x;
618 restart:
619                 hlist_for_each_entry(x, entry, net->xfrm.state_bydst+i, bydst) {
620                         if (!xfrm_state_kern(x) &&
621                             xfrm_id_proto_match(x->id.proto, proto)) {
622                                 xfrm_state_hold(x);
623                                 spin_unlock_bh(&xfrm_state_lock);
624
625                                 err = xfrm_state_delete(x);
626                                 xfrm_audit_state_delete(x, err ? 0 : 1,
627                                                         audit_info->loginuid,
628                                                         audit_info->sessionid,
629                                                         audit_info->secid);
630                                 xfrm_state_put(x);
631                                 if (!err)
632                                         cnt++;
633
634                                 spin_lock_bh(&xfrm_state_lock);
635                                 goto restart;
636                         }
637                 }
638         }
639         if (cnt)
640                 err = 0;
641
642 out:
643         spin_unlock_bh(&xfrm_state_lock);
644         wake_up(&net->xfrm.km_waitq);
645         return err;
646 }
647 EXPORT_SYMBOL(xfrm_state_flush);
648
649 void xfrm_sad_getinfo(struct net *net, struct xfrmk_sadinfo *si)
650 {
651         spin_lock_bh(&xfrm_state_lock);
652         si->sadcnt = net->xfrm.state_num;
653         si->sadhcnt = net->xfrm.state_hmask;
654         si->sadhmcnt = xfrm_state_hashmax;
655         spin_unlock_bh(&xfrm_state_lock);
656 }
657 EXPORT_SYMBOL(xfrm_sad_getinfo);
658
659 static int
660 xfrm_init_tempsel(struct xfrm_state *x, struct flowi *fl,
661                   struct xfrm_tmpl *tmpl,
662                   xfrm_address_t *daddr, xfrm_address_t *saddr,
663                   unsigned short family)
664 {
665         struct xfrm_state_afinfo *afinfo = xfrm_state_get_afinfo(family);
666         if (!afinfo)
667                 return -1;
668         afinfo->init_tempsel(x, fl, tmpl, daddr, saddr);
669         xfrm_state_put_afinfo(afinfo);
670         return 0;
671 }
672
673 static struct xfrm_state *__xfrm_state_lookup(struct net *net, u32 mark, xfrm_address_t *daddr, __be32 spi, u8 proto, unsigned short family)
674 {
675         unsigned int h = xfrm_spi_hash(net, daddr, spi, proto, family);
676         struct xfrm_state *x;
677         struct hlist_node *entry;
678
679         hlist_for_each_entry(x, entry, net->xfrm.state_byspi+h, byspi) {
680                 if (x->props.family != family ||
681                     x->id.spi       != spi ||
682                     x->id.proto     != proto ||
683                     xfrm_addr_cmp(&x->id.daddr, daddr, family))
684                         continue;
685
686                 if ((mark & x->mark.m) != x->mark.v)
687                         continue;
688                 xfrm_state_hold(x);
689                 return x;
690         }
691
692         return NULL;
693 }
694
695 static struct xfrm_state *__xfrm_state_lookup_byaddr(struct net *net, u32 mark, xfrm_address_t *daddr, xfrm_address_t *saddr, u8 proto, unsigned short family)
696 {
697         unsigned int h = xfrm_src_hash(net, daddr, saddr, family);
698         struct xfrm_state *x;
699         struct hlist_node *entry;
700
701         hlist_for_each_entry(x, entry, net->xfrm.state_bysrc+h, bysrc) {
702                 if (x->props.family != family ||
703                     x->id.proto     != proto ||
704                     xfrm_addr_cmp(&x->id.daddr, daddr, family) ||
705                     xfrm_addr_cmp(&x->props.saddr, saddr, family))
706                         continue;
707
708                 if ((mark & x->mark.m) != x->mark.v)
709                         continue;
710                 xfrm_state_hold(x);
711                 return x;
712         }
713
714         return NULL;
715 }
716
717 static inline struct xfrm_state *
718 __xfrm_state_locate(struct xfrm_state *x, int use_spi, int family)
719 {
720         struct net *net = xs_net(x);
721         u32 mark = x->mark.v & x->mark.m;
722
723         if (use_spi)
724                 return __xfrm_state_lookup(net, mark, &x->id.daddr,
725                                            x->id.spi, x->id.proto, family);
726         else
727                 return __xfrm_state_lookup_byaddr(net, mark,
728                                                   &x->id.daddr,
729                                                   &x->props.saddr,
730                                                   x->id.proto, family);
731 }
732
733 static void xfrm_hash_grow_check(struct net *net, int have_hash_collision)
734 {
735         if (have_hash_collision &&
736             (net->xfrm.state_hmask + 1) < xfrm_state_hashmax &&
737             net->xfrm.state_num > net->xfrm.state_hmask)
738                 schedule_work(&net->xfrm.state_hash_work);
739 }
740
741 static void xfrm_state_look_at(struct xfrm_policy *pol, struct xfrm_state *x,
742                                struct flowi *fl, unsigned short family,
743                                xfrm_address_t *daddr, xfrm_address_t *saddr,
744                                struct xfrm_state **best, int *acq_in_progress,
745                                int *error)
746 {
747         /* Resolution logic:
748          * 1. There is a valid state with matching selector. Done.
749          * 2. Valid state with inappropriate selector. Skip.
750          *
751          * Entering area of "sysdeps".
752          *
753          * 3. If state is not valid, selector is temporary, it selects
754          *    only session which triggered previous resolution. Key
755          *    manager will do something to install a state with proper
756          *    selector.
757          */
758         if (x->km.state == XFRM_STATE_VALID) {
759                 if ((x->sel.family &&
760                      !xfrm_selector_match(&x->sel, fl, x->sel.family)) ||
761                     !security_xfrm_state_pol_flow_match(x, pol, fl))
762                         return;
763
764                 if (!*best ||
765                     (*best)->km.dying > x->km.dying ||
766                     ((*best)->km.dying == x->km.dying &&
767                      (*best)->curlft.add_time < x->curlft.add_time))
768                         *best = x;
769         } else if (x->km.state == XFRM_STATE_ACQ) {
770                 *acq_in_progress = 1;
771         } else if (x->km.state == XFRM_STATE_ERROR ||
772                    x->km.state == XFRM_STATE_EXPIRED) {
773                 if (xfrm_selector_match(&x->sel, fl, x->sel.family) &&
774                     security_xfrm_state_pol_flow_match(x, pol, fl))
775                         *error = -ESRCH;
776         }
777 }
778
779 struct xfrm_state *
780 xfrm_state_find(xfrm_address_t *daddr, xfrm_address_t *saddr,
781                 struct flowi *fl, struct xfrm_tmpl *tmpl,
782                 struct xfrm_policy *pol, int *err,
783                 unsigned short family)
784 {
785         static xfrm_address_t saddr_wildcard = { };
786         struct net *net = xp_net(pol);
787         unsigned int h, h_wildcard;
788         struct hlist_node *entry;
789         struct xfrm_state *x, *x0, *to_put;
790         int acquire_in_progress = 0;
791         int error = 0;
792         struct xfrm_state *best = NULL;
793         u32 mark = pol->mark.v & pol->mark.m;
794
795         to_put = NULL;
796
797         spin_lock_bh(&xfrm_state_lock);
798         h = xfrm_dst_hash(net, daddr, saddr, tmpl->reqid, family);
799         hlist_for_each_entry(x, entry, net->xfrm.state_bydst+h, bydst) {
800                 if (x->props.family == family &&
801                     x->props.reqid == tmpl->reqid &&
802                     (mark & x->mark.m) == x->mark.v &&
803                     !(x->props.flags & XFRM_STATE_WILDRECV) &&
804                     xfrm_state_addr_check(x, daddr, saddr, family) &&
805                     tmpl->mode == x->props.mode &&
806                     tmpl->id.proto == x->id.proto &&
807                     (tmpl->id.spi == x->id.spi || !tmpl->id.spi))
808                         xfrm_state_look_at(pol, x, fl, family, daddr, saddr,
809                                            &best, &acquire_in_progress, &error);
810         }
811         if (best)
812                 goto found;
813
814         h_wildcard = xfrm_dst_hash(net, daddr, &saddr_wildcard, tmpl->reqid, family);
815         hlist_for_each_entry(x, entry, net->xfrm.state_bydst+h_wildcard, bydst) {
816                 if (x->props.family == family &&
817                     x->props.reqid == tmpl->reqid &&
818                     (mark & x->mark.m) == x->mark.v &&
819                     !(x->props.flags & XFRM_STATE_WILDRECV) &&
820                     xfrm_state_addr_check(x, daddr, saddr, family) &&
821                     tmpl->mode == x->props.mode &&
822                     tmpl->id.proto == x->id.proto &&
823                     (tmpl->id.spi == x->id.spi || !tmpl->id.spi))
824                         xfrm_state_look_at(pol, x, fl, family, daddr, saddr,
825                                            &best, &acquire_in_progress, &error);
826         }
827
828 found:
829         x = best;
830         if (!x && !error && !acquire_in_progress) {
831                 if (tmpl->id.spi &&
832                     (x0 = __xfrm_state_lookup(net, mark, daddr, tmpl->id.spi,
833                                               tmpl->id.proto, family)) != NULL) {
834                         to_put = x0;
835                         error = -EEXIST;
836                         goto out;
837                 }
838                 x = xfrm_state_alloc(net);
839                 if (x == NULL) {
840                         error = -ENOMEM;
841                         goto out;
842                 }
843                 /* Initialize temporary selector matching only
844                  * to current session. */
845                 xfrm_init_tempsel(x, fl, tmpl, daddr, saddr, family);
846                 memcpy(&x->mark, &pol->mark, sizeof(x->mark));
847
848                 error = security_xfrm_state_alloc_acquire(x, pol->security, fl->secid);
849                 if (error) {
850                         x->km.state = XFRM_STATE_DEAD;
851                         to_put = x;
852                         x = NULL;
853                         goto out;
854                 }
855
856                 if (km_query(x, tmpl, pol) == 0) {
857                         x->km.state = XFRM_STATE_ACQ;
858                         list_add(&x->km.all, &net->xfrm.state_all);
859                         hlist_add_head(&x->bydst, net->xfrm.state_bydst+h);
860                         h = xfrm_src_hash(net, daddr, saddr, family);
861                         hlist_add_head(&x->bysrc, net->xfrm.state_bysrc+h);
862                         if (x->id.spi) {
863                                 h = xfrm_spi_hash(net, &x->id.daddr, x->id.spi, x->id.proto, family);
864                                 hlist_add_head(&x->byspi, net->xfrm.state_byspi+h);
865                         }
866                         x->lft.hard_add_expires_seconds = net->xfrm.sysctl_acq_expires;
867                         tasklet_hrtimer_start(&x->mtimer, ktime_set(net->xfrm.sysctl_acq_expires, 0), HRTIMER_MODE_REL);
868                         net->xfrm.state_num++;
869                         xfrm_hash_grow_check(net, x->bydst.next != NULL);
870                 } else {
871                         x->km.state = XFRM_STATE_DEAD;
872                         to_put = x;
873                         x = NULL;
874                         error = -ESRCH;
875                 }
876         }
877 out:
878         if (x)
879                 xfrm_state_hold(x);
880         else
881                 *err = acquire_in_progress ? -EAGAIN : error;
882         spin_unlock_bh(&xfrm_state_lock);
883         if (to_put)
884                 xfrm_state_put(to_put);
885         return x;
886 }
887
888 struct xfrm_state *
889 xfrm_stateonly_find(struct net *net, u32 mark,
890                     xfrm_address_t *daddr, xfrm_address_t *saddr,
891                     unsigned short family, u8 mode, u8 proto, u32 reqid)
892 {
893         unsigned int h;
894         struct xfrm_state *rx = NULL, *x = NULL;
895         struct hlist_node *entry;
896
897         spin_lock(&xfrm_state_lock);
898         h = xfrm_dst_hash(net, daddr, saddr, reqid, family);
899         hlist_for_each_entry(x, entry, net->xfrm.state_bydst+h, bydst) {
900                 if (x->props.family == family &&
901                     x->props.reqid == reqid &&
902                     (mark & x->mark.m) == x->mark.v &&
903                     !(x->props.flags & XFRM_STATE_WILDRECV) &&
904                     xfrm_state_addr_check(x, daddr, saddr, family) &&
905                     mode == x->props.mode &&
906                     proto == x->id.proto &&
907                     x->km.state == XFRM_STATE_VALID) {
908                         rx = x;
909                         break;
910                 }
911         }
912
913         if (rx)
914                 xfrm_state_hold(rx);
915         spin_unlock(&xfrm_state_lock);
916
917
918         return rx;
919 }
920 EXPORT_SYMBOL(xfrm_stateonly_find);
921
922 static void __xfrm_state_insert(struct xfrm_state *x)
923 {
924         struct net *net = xs_net(x);
925         unsigned int h;
926
927         x->genid = ++xfrm_state_genid;
928
929         list_add(&x->km.all, &net->xfrm.state_all);
930
931         h = xfrm_dst_hash(net, &x->id.daddr, &x->props.saddr,
932                           x->props.reqid, x->props.family);
933         hlist_add_head(&x->bydst, net->xfrm.state_bydst+h);
934
935         h = xfrm_src_hash(net, &x->id.daddr, &x->props.saddr, x->props.family);
936         hlist_add_head(&x->bysrc, net->xfrm.state_bysrc+h);
937
938         if (x->id.spi) {
939                 h = xfrm_spi_hash(net, &x->id.daddr, x->id.spi, x->id.proto,
940                                   x->props.family);
941
942                 hlist_add_head(&x->byspi, net->xfrm.state_byspi+h);
943         }
944
945         tasklet_hrtimer_start(&x->mtimer, ktime_set(1, 0), HRTIMER_MODE_REL);
946         if (x->replay_maxage)
947                 mod_timer(&x->rtimer, jiffies + x->replay_maxage);
948
949         wake_up(&net->xfrm.km_waitq);
950
951         net->xfrm.state_num++;
952
953         xfrm_hash_grow_check(net, x->bydst.next != NULL);
954 }
955
956 /* xfrm_state_lock is held */
957 static void __xfrm_state_bump_genids(struct xfrm_state *xnew)
958 {
959         struct net *net = xs_net(xnew);
960         unsigned short family = xnew->props.family;
961         u32 reqid = xnew->props.reqid;
962         struct xfrm_state *x;
963         struct hlist_node *entry;
964         unsigned int h;
965         u32 mark = xnew->mark.v & xnew->mark.m;
966
967         h = xfrm_dst_hash(net, &xnew->id.daddr, &xnew->props.saddr, reqid, family);
968         hlist_for_each_entry(x, entry, net->xfrm.state_bydst+h, bydst) {
969                 if (x->props.family     == family &&
970                     x->props.reqid      == reqid &&
971                     (mark & x->mark.m) == x->mark.v &&
972                     !xfrm_addr_cmp(&x->id.daddr, &xnew->id.daddr, family) &&
973                     !xfrm_addr_cmp(&x->props.saddr, &xnew->props.saddr, family))
974                         x->genid = xfrm_state_genid;
975         }
976 }
977
978 void xfrm_state_insert(struct xfrm_state *x)
979 {
980         spin_lock_bh(&xfrm_state_lock);
981         __xfrm_state_bump_genids(x);
982         __xfrm_state_insert(x);
983         spin_unlock_bh(&xfrm_state_lock);
984 }
985 EXPORT_SYMBOL(xfrm_state_insert);
986
987 /* xfrm_state_lock is held */
988 static struct xfrm_state *__find_acq_core(struct net *net, struct xfrm_mark *m, unsigned short family, u8 mode, u32 reqid, u8 proto, xfrm_address_t *daddr, xfrm_address_t *saddr, int create)
989 {
990         unsigned int h = xfrm_dst_hash(net, daddr, saddr, reqid, family);
991         struct hlist_node *entry;
992         struct xfrm_state *x;
993         u32 mark = m->v & m->m;
994
995         hlist_for_each_entry(x, entry, net->xfrm.state_bydst+h, bydst) {
996                 if (x->props.reqid  != reqid ||
997                     x->props.mode   != mode ||
998                     x->props.family != family ||
999                     x->km.state     != XFRM_STATE_ACQ ||
1000                     x->id.spi       != 0 ||
1001                     x->id.proto     != proto ||
1002                     (mark & x->mark.m) != x->mark.v ||
1003                     xfrm_addr_cmp(&x->id.daddr, daddr, family) ||
1004                     xfrm_addr_cmp(&x->props.saddr, saddr, family))
1005                         continue;
1006
1007                 xfrm_state_hold(x);
1008                 return x;
1009         }
1010
1011         if (!create)
1012                 return NULL;
1013
1014         x = xfrm_state_alloc(net);
1015         if (likely(x)) {
1016                 switch (family) {
1017                 case AF_INET:
1018                         x->sel.daddr.a4 = daddr->a4;
1019                         x->sel.saddr.a4 = saddr->a4;
1020                         x->sel.prefixlen_d = 32;
1021                         x->sel.prefixlen_s = 32;
1022                         x->props.saddr.a4 = saddr->a4;
1023                         x->id.daddr.a4 = daddr->a4;
1024                         break;
1025
1026                 case AF_INET6:
1027                         ipv6_addr_copy((struct in6_addr *)x->sel.daddr.a6,
1028                                        (struct in6_addr *)daddr);
1029                         ipv6_addr_copy((struct in6_addr *)x->sel.saddr.a6,
1030                                        (struct in6_addr *)saddr);
1031                         x->sel.prefixlen_d = 128;
1032                         x->sel.prefixlen_s = 128;
1033                         ipv6_addr_copy((struct in6_addr *)x->props.saddr.a6,
1034                                        (struct in6_addr *)saddr);
1035                         ipv6_addr_copy((struct in6_addr *)x->id.daddr.a6,
1036                                        (struct in6_addr *)daddr);
1037                         break;
1038                 }
1039
1040                 x->km.state = XFRM_STATE_ACQ;
1041                 x->id.proto = proto;
1042                 x->props.family = family;
1043                 x->props.mode = mode;
1044                 x->props.reqid = reqid;
1045                 x->mark.v = m->v;
1046                 x->mark.m = m->m;
1047                 x->lft.hard_add_expires_seconds = net->xfrm.sysctl_acq_expires;
1048                 xfrm_state_hold(x);
1049                 tasklet_hrtimer_start(&x->mtimer, ktime_set(net->xfrm.sysctl_acq_expires, 0), HRTIMER_MODE_REL);
1050                 list_add(&x->km.all, &net->xfrm.state_all);
1051                 hlist_add_head(&x->bydst, net->xfrm.state_bydst+h);
1052                 h = xfrm_src_hash(net, daddr, saddr, family);
1053                 hlist_add_head(&x->bysrc, net->xfrm.state_bysrc+h);
1054
1055                 net->xfrm.state_num++;
1056
1057                 xfrm_hash_grow_check(net, x->bydst.next != NULL);
1058         }
1059
1060         return x;
1061 }
1062
1063 static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq);
1064
1065 int xfrm_state_add(struct xfrm_state *x)
1066 {
1067         struct net *net = xs_net(x);
1068         struct xfrm_state *x1, *to_put;
1069         int family;
1070         int err;
1071         u32 mark = x->mark.v & x->mark.m;
1072         int use_spi = xfrm_id_proto_match(x->id.proto, IPSEC_PROTO_ANY);
1073
1074         family = x->props.family;
1075
1076         to_put = NULL;
1077
1078         spin_lock_bh(&xfrm_state_lock);
1079
1080         x1 = __xfrm_state_locate(x, use_spi, family);
1081         if (x1) {
1082                 to_put = x1;
1083                 x1 = NULL;
1084                 err = -EEXIST;
1085                 goto out;
1086         }
1087
1088         if (use_spi && x->km.seq) {
1089                 x1 = __xfrm_find_acq_byseq(net, mark, x->km.seq);
1090                 if (x1 && ((x1->id.proto != x->id.proto) ||
1091                     xfrm_addr_cmp(&x1->id.daddr, &x->id.daddr, family))) {
1092                         to_put = x1;
1093                         x1 = NULL;
1094                 }
1095         }
1096
1097         if (use_spi && !x1)
1098                 x1 = __find_acq_core(net, &x->mark, family, x->props.mode,
1099                                      x->props.reqid, x->id.proto,
1100                                      &x->id.daddr, &x->props.saddr, 0);
1101
1102         __xfrm_state_bump_genids(x);
1103         __xfrm_state_insert(x);
1104         err = 0;
1105
1106 out:
1107         spin_unlock_bh(&xfrm_state_lock);
1108
1109         if (x1) {
1110                 xfrm_state_delete(x1);
1111                 xfrm_state_put(x1);
1112         }
1113
1114         if (to_put)
1115                 xfrm_state_put(to_put);
1116
1117         return err;
1118 }
1119 EXPORT_SYMBOL(xfrm_state_add);
1120
1121 #ifdef CONFIG_XFRM_MIGRATE
1122 static struct xfrm_state *xfrm_state_clone(struct xfrm_state *orig, int *errp)
1123 {
1124         struct net *net = xs_net(orig);
1125         int err = -ENOMEM;
1126         struct xfrm_state *x = xfrm_state_alloc(net);
1127         if (!x)
1128                 goto out;
1129
1130         memcpy(&x->id, &orig->id, sizeof(x->id));
1131         memcpy(&x->sel, &orig->sel, sizeof(x->sel));
1132         memcpy(&x->lft, &orig->lft, sizeof(x->lft));
1133         x->props.mode = orig->props.mode;
1134         x->props.replay_window = orig->props.replay_window;
1135         x->props.reqid = orig->props.reqid;
1136         x->props.family = orig->props.family;
1137         x->props.saddr = orig->props.saddr;
1138
1139         if (orig->aalg) {
1140                 x->aalg = xfrm_algo_auth_clone(orig->aalg);
1141                 if (!x->aalg)
1142                         goto error;
1143         }
1144         x->props.aalgo = orig->props.aalgo;
1145
1146         if (orig->ealg) {
1147                 x->ealg = xfrm_algo_clone(orig->ealg);
1148                 if (!x->ealg)
1149                         goto error;
1150         }
1151         x->props.ealgo = orig->props.ealgo;
1152
1153         if (orig->calg) {
1154                 x->calg = xfrm_algo_clone(orig->calg);
1155                 if (!x->calg)
1156                         goto error;
1157         }
1158         x->props.calgo = orig->props.calgo;
1159
1160         if (orig->encap) {
1161                 x->encap = kmemdup(orig->encap, sizeof(*x->encap), GFP_KERNEL);
1162                 if (!x->encap)
1163                         goto error;
1164         }
1165
1166         if (orig->coaddr) {
1167                 x->coaddr = kmemdup(orig->coaddr, sizeof(*x->coaddr),
1168                                     GFP_KERNEL);
1169                 if (!x->coaddr)
1170                         goto error;
1171         }
1172
1173         memcpy(&x->mark, &orig->mark, sizeof(x->mark));
1174
1175         err = xfrm_init_state(x);
1176         if (err)
1177                 goto error;
1178
1179         x->props.flags = orig->props.flags;
1180
1181         x->curlft.add_time = orig->curlft.add_time;
1182         x->km.state = orig->km.state;
1183         x->km.seq = orig->km.seq;
1184
1185         return x;
1186
1187  error:
1188         xfrm_state_put(x);
1189 out:
1190         if (errp)
1191                 *errp = err;
1192         return NULL;
1193 }
1194
1195 /* xfrm_state_lock is held */
1196 struct xfrm_state * xfrm_migrate_state_find(struct xfrm_migrate *m)
1197 {
1198         unsigned int h;
1199         struct xfrm_state *x;
1200         struct hlist_node *entry;
1201
1202         if (m->reqid) {
1203                 h = xfrm_dst_hash(&init_net, &m->old_daddr, &m->old_saddr,
1204                                   m->reqid, m->old_family);
1205                 hlist_for_each_entry(x, entry, init_net.xfrm.state_bydst+h, bydst) {
1206                         if (x->props.mode != m->mode ||
1207                             x->id.proto != m->proto)
1208                                 continue;
1209                         if (m->reqid && x->props.reqid != m->reqid)
1210                                 continue;
1211                         if (xfrm_addr_cmp(&x->id.daddr, &m->old_daddr,
1212                                           m->old_family) ||
1213                             xfrm_addr_cmp(&x->props.saddr, &m->old_saddr,
1214                                           m->old_family))
1215                                 continue;
1216                         xfrm_state_hold(x);
1217                         return x;
1218                 }
1219         } else {
1220                 h = xfrm_src_hash(&init_net, &m->old_daddr, &m->old_saddr,
1221                                   m->old_family);
1222                 hlist_for_each_entry(x, entry, init_net.xfrm.state_bysrc+h, bysrc) {
1223                         if (x->props.mode != m->mode ||
1224                             x->id.proto != m->proto)
1225                                 continue;
1226                         if (xfrm_addr_cmp(&x->id.daddr, &m->old_daddr,
1227                                           m->old_family) ||
1228                             xfrm_addr_cmp(&x->props.saddr, &m->old_saddr,
1229                                           m->old_family))
1230                                 continue;
1231                         xfrm_state_hold(x);
1232                         return x;
1233                 }
1234         }
1235
1236         return NULL;
1237 }
1238 EXPORT_SYMBOL(xfrm_migrate_state_find);
1239
1240 struct xfrm_state * xfrm_state_migrate(struct xfrm_state *x,
1241                                        struct xfrm_migrate *m)
1242 {
1243         struct xfrm_state *xc;
1244         int err;
1245
1246         xc = xfrm_state_clone(x, &err);
1247         if (!xc)
1248                 return NULL;
1249
1250         memcpy(&xc->id.daddr, &m->new_daddr, sizeof(xc->id.daddr));
1251         memcpy(&xc->props.saddr, &m->new_saddr, sizeof(xc->props.saddr));
1252
1253         /* add state */
1254         if (!xfrm_addr_cmp(&x->id.daddr, &m->new_daddr, m->new_family)) {
1255                 /* a care is needed when the destination address of the
1256                    state is to be updated as it is a part of triplet */
1257                 xfrm_state_insert(xc);
1258         } else {
1259                 if ((err = xfrm_state_add(xc)) < 0)
1260                         goto error;
1261         }
1262
1263         return xc;
1264 error:
1265         kfree(xc);
1266         return NULL;
1267 }
1268 EXPORT_SYMBOL(xfrm_state_migrate);
1269 #endif
1270
1271 int xfrm_state_update(struct xfrm_state *x)
1272 {
1273         struct xfrm_state *x1, *to_put;
1274         int err;
1275         int use_spi = xfrm_id_proto_match(x->id.proto, IPSEC_PROTO_ANY);
1276
1277         to_put = NULL;
1278
1279         spin_lock_bh(&xfrm_state_lock);
1280         x1 = __xfrm_state_locate(x, use_spi, x->props.family);
1281
1282         err = -ESRCH;
1283         if (!x1)
1284                 goto out;
1285
1286         if (xfrm_state_kern(x1)) {
1287                 to_put = x1;
1288                 err = -EEXIST;
1289                 goto out;
1290         }
1291
1292         if (x1->km.state == XFRM_STATE_ACQ) {
1293                 __xfrm_state_insert(x);
1294                 x = NULL;
1295         }
1296         err = 0;
1297
1298 out:
1299         spin_unlock_bh(&xfrm_state_lock);
1300
1301         if (to_put)
1302                 xfrm_state_put(to_put);
1303
1304         if (err)
1305                 return err;
1306
1307         if (!x) {
1308                 xfrm_state_delete(x1);
1309                 xfrm_state_put(x1);
1310                 return 0;
1311         }
1312
1313         err = -EINVAL;
1314         spin_lock_bh(&x1->lock);
1315         if (likely(x1->km.state == XFRM_STATE_VALID)) {
1316                 if (x->encap && x1->encap)
1317                         memcpy(x1->encap, x->encap, sizeof(*x1->encap));
1318                 if (x->coaddr && x1->coaddr) {
1319                         memcpy(x1->coaddr, x->coaddr, sizeof(*x1->coaddr));
1320                 }
1321                 if (!use_spi && memcmp(&x1->sel, &x->sel, sizeof(x1->sel)))
1322                         memcpy(&x1->sel, &x->sel, sizeof(x1->sel));
1323                 memcpy(&x1->lft, &x->lft, sizeof(x1->lft));
1324                 x1->km.dying = 0;
1325
1326                 tasklet_hrtimer_start(&x1->mtimer, ktime_set(1, 0), HRTIMER_MODE_REL);
1327                 if (x1->curlft.use_time)
1328                         xfrm_state_check_expire(x1);
1329
1330                 err = 0;
1331         }
1332         spin_unlock_bh(&x1->lock);
1333
1334         xfrm_state_put(x1);
1335
1336         return err;
1337 }
1338 EXPORT_SYMBOL(xfrm_state_update);
1339
1340 int xfrm_state_check_expire(struct xfrm_state *x)
1341 {
1342         if (!x->curlft.use_time)
1343                 x->curlft.use_time = get_seconds();
1344
1345         if (x->km.state != XFRM_STATE_VALID)
1346                 return -EINVAL;
1347
1348         if (x->curlft.bytes >= x->lft.hard_byte_limit ||
1349             x->curlft.packets >= x->lft.hard_packet_limit) {
1350                 x->km.state = XFRM_STATE_EXPIRED;
1351                 tasklet_hrtimer_start(&x->mtimer, ktime_set(0,0), HRTIMER_MODE_REL);
1352                 return -EINVAL;
1353         }
1354
1355         if (!x->km.dying &&
1356             (x->curlft.bytes >= x->lft.soft_byte_limit ||
1357              x->curlft.packets >= x->lft.soft_packet_limit)) {
1358                 x->km.dying = 1;
1359                 km_state_expired(x, 0, 0);
1360         }
1361         return 0;
1362 }
1363 EXPORT_SYMBOL(xfrm_state_check_expire);
1364
1365 struct xfrm_state *
1366 xfrm_state_lookup(struct net *net, u32 mark, xfrm_address_t *daddr, __be32 spi,
1367                   u8 proto, unsigned short family)
1368 {
1369         struct xfrm_state *x;
1370
1371         spin_lock_bh(&xfrm_state_lock);
1372         x = __xfrm_state_lookup(net, mark, daddr, spi, proto, family);
1373         spin_unlock_bh(&xfrm_state_lock);
1374         return x;
1375 }
1376 EXPORT_SYMBOL(xfrm_state_lookup);
1377
1378 struct xfrm_state *
1379 xfrm_state_lookup_byaddr(struct net *net, u32 mark,
1380                          xfrm_address_t *daddr, xfrm_address_t *saddr,
1381                          u8 proto, unsigned short family)
1382 {
1383         struct xfrm_state *x;
1384
1385         spin_lock_bh(&xfrm_state_lock);
1386         x = __xfrm_state_lookup_byaddr(net, mark, daddr, saddr, proto, family);
1387         spin_unlock_bh(&xfrm_state_lock);
1388         return x;
1389 }
1390 EXPORT_SYMBOL(xfrm_state_lookup_byaddr);
1391
1392 struct xfrm_state *
1393 xfrm_find_acq(struct net *net, struct xfrm_mark *mark, u8 mode, u32 reqid, u8 proto,
1394               xfrm_address_t *daddr, xfrm_address_t *saddr,
1395               int create, unsigned short family)
1396 {
1397         struct xfrm_state *x;
1398
1399         spin_lock_bh(&xfrm_state_lock);
1400         x = __find_acq_core(net, mark, family, mode, reqid, proto, daddr, saddr, create);
1401         spin_unlock_bh(&xfrm_state_lock);
1402
1403         return x;
1404 }
1405 EXPORT_SYMBOL(xfrm_find_acq);
1406
1407 #ifdef CONFIG_XFRM_SUB_POLICY
1408 int
1409 xfrm_tmpl_sort(struct xfrm_tmpl **dst, struct xfrm_tmpl **src, int n,
1410                unsigned short family)
1411 {
1412         int err = 0;
1413         struct xfrm_state_afinfo *afinfo = xfrm_state_get_afinfo(family);
1414         if (!afinfo)
1415                 return -EAFNOSUPPORT;
1416
1417         spin_lock_bh(&xfrm_state_lock);
1418         if (afinfo->tmpl_sort)
1419                 err = afinfo->tmpl_sort(dst, src, n);
1420         spin_unlock_bh(&xfrm_state_lock);
1421         xfrm_state_put_afinfo(afinfo);
1422         return err;
1423 }
1424 EXPORT_SYMBOL(xfrm_tmpl_sort);
1425
1426 int
1427 xfrm_state_sort(struct xfrm_state **dst, struct xfrm_state **src, int n,
1428                 unsigned short family)
1429 {
1430         int err = 0;
1431         struct xfrm_state_afinfo *afinfo = xfrm_state_get_afinfo(family);
1432         if (!afinfo)
1433                 return -EAFNOSUPPORT;
1434
1435         spin_lock_bh(&xfrm_state_lock);
1436         if (afinfo->state_sort)
1437                 err = afinfo->state_sort(dst, src, n);
1438         spin_unlock_bh(&xfrm_state_lock);
1439         xfrm_state_put_afinfo(afinfo);
1440         return err;
1441 }
1442 EXPORT_SYMBOL(xfrm_state_sort);
1443 #endif
1444
1445 /* Silly enough, but I'm lazy to build resolution list */
1446
1447 static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq)
1448 {
1449         int i;
1450
1451         for (i = 0; i <= net->xfrm.state_hmask; i++) {
1452                 struct hlist_node *entry;
1453                 struct xfrm_state *x;
1454
1455                 hlist_for_each_entry(x, entry, net->xfrm.state_bydst+i, bydst) {
1456                         if (x->km.seq == seq &&
1457                             (mark & x->mark.m) == x->mark.v &&
1458                             x->km.state == XFRM_STATE_ACQ) {
1459                                 xfrm_state_hold(x);
1460                                 return x;
1461                         }
1462                 }
1463         }
1464         return NULL;
1465 }
1466
1467 struct xfrm_state *xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq)
1468 {
1469         struct xfrm_state *x;
1470
1471         spin_lock_bh(&xfrm_state_lock);
1472         x = __xfrm_find_acq_byseq(net, mark, seq);
1473         spin_unlock_bh(&xfrm_state_lock);
1474         return x;
1475 }
1476 EXPORT_SYMBOL(xfrm_find_acq_byseq);
1477
1478 u32 xfrm_get_acqseq(void)
1479 {
1480         u32 res;
1481         static atomic_t acqseq;
1482
1483         do {
1484                 res = atomic_inc_return(&acqseq);
1485         } while (!res);
1486
1487         return res;
1488 }
1489 EXPORT_SYMBOL(xfrm_get_acqseq);
1490
1491 int xfrm_alloc_spi(struct xfrm_state *x, u32 low, u32 high)
1492 {
1493         struct net *net = xs_net(x);
1494         unsigned int h;
1495         struct xfrm_state *x0;
1496         int err = -ENOENT;
1497         __be32 minspi = htonl(low);
1498         __be32 maxspi = htonl(high);
1499         u32 mark = x->mark.v & x->mark.m;
1500
1501         spin_lock_bh(&x->lock);
1502         if (x->km.state == XFRM_STATE_DEAD)
1503                 goto unlock;
1504
1505         err = 0;
1506         if (x->id.spi)
1507                 goto unlock;
1508
1509         err = -ENOENT;
1510
1511         if (minspi == maxspi) {
1512                 x0 = xfrm_state_lookup(net, mark, &x->id.daddr, minspi, x->id.proto, x->props.family);
1513                 if (x0) {
1514                         xfrm_state_put(x0);
1515                         goto unlock;
1516                 }
1517                 x->id.spi = minspi;
1518         } else {
1519                 u32 spi = 0;
1520                 for (h=0; h<high-low+1; h++) {
1521                         spi = low + net_random()%(high-low+1);
1522                         x0 = xfrm_state_lookup(net, mark, &x->id.daddr, htonl(spi), x->id.proto, x->props.family);
1523                         if (x0 == NULL) {
1524                                 x->id.spi = htonl(spi);
1525                                 break;
1526                         }
1527                         xfrm_state_put(x0);
1528                 }
1529         }
1530         if (x->id.spi) {
1531                 spin_lock_bh(&xfrm_state_lock);
1532                 h = xfrm_spi_hash(net, &x->id.daddr, x->id.spi, x->id.proto, x->props.family);
1533                 hlist_add_head(&x->byspi, net->xfrm.state_byspi+h);
1534                 spin_unlock_bh(&xfrm_state_lock);
1535
1536                 err = 0;
1537         }
1538
1539 unlock:
1540         spin_unlock_bh(&x->lock);
1541
1542         return err;
1543 }
1544 EXPORT_SYMBOL(xfrm_alloc_spi);
1545
1546 int xfrm_state_walk(struct net *net, struct xfrm_state_walk *walk,
1547                     int (*func)(struct xfrm_state *, int, void*),
1548                     void *data)
1549 {
1550         struct xfrm_state *state;
1551         struct xfrm_state_walk *x;
1552         int err = 0;
1553
1554         if (walk->seq != 0 && list_empty(&walk->all))
1555                 return 0;
1556
1557         spin_lock_bh(&xfrm_state_lock);
1558         if (list_empty(&walk->all))
1559                 x = list_first_entry(&net->xfrm.state_all, struct xfrm_state_walk, all);
1560         else
1561                 x = list_entry(&walk->all, struct xfrm_state_walk, all);
1562         list_for_each_entry_from(x, &net->xfrm.state_all, all) {
1563                 if (x->state == XFRM_STATE_DEAD)
1564                         continue;
1565                 state = container_of(x, struct xfrm_state, km);
1566                 if (!xfrm_id_proto_match(state->id.proto, walk->proto))
1567                         continue;
1568                 err = func(state, walk->seq, data);
1569                 if (err) {
1570                         list_move_tail(&walk->all, &x->all);
1571                         goto out;
1572                 }
1573                 walk->seq++;
1574         }
1575         if (walk->seq == 0) {
1576                 err = -ENOENT;
1577                 goto out;
1578         }
1579         list_del_init(&walk->all);
1580 out:
1581         spin_unlock_bh(&xfrm_state_lock);
1582         return err;
1583 }
1584 EXPORT_SYMBOL(xfrm_state_walk);
1585
1586 void xfrm_state_walk_init(struct xfrm_state_walk *walk, u8 proto)
1587 {
1588         INIT_LIST_HEAD(&walk->all);
1589         walk->proto = proto;
1590         walk->state = XFRM_STATE_DEAD;
1591         walk->seq = 0;
1592 }
1593 EXPORT_SYMBOL(xfrm_state_walk_init);
1594
1595 void xfrm_state_walk_done(struct xfrm_state_walk *walk)
1596 {
1597         if (list_empty(&walk->all))
1598                 return;
1599
1600         spin_lock_bh(&xfrm_state_lock);
1601         list_del(&walk->all);
1602         spin_unlock_bh(&xfrm_state_lock);
1603 }
1604 EXPORT_SYMBOL(xfrm_state_walk_done);
1605
1606
1607 void xfrm_replay_notify(struct xfrm_state *x, int event)
1608 {
1609         struct km_event c;
1610         /* we send notify messages in case
1611          *  1. we updated on of the sequence numbers, and the seqno difference
1612          *     is at least x->replay_maxdiff, in this case we also update the
1613          *     timeout of our timer function
1614          *  2. if x->replay_maxage has elapsed since last update,
1615          *     and there were changes
1616          *
1617          *  The state structure must be locked!
1618          */
1619
1620         switch (event) {
1621         case XFRM_REPLAY_UPDATE:
1622                 if (x->replay_maxdiff &&
1623                     (x->replay.seq - x->preplay.seq < x->replay_maxdiff) &&
1624                     (x->replay.oseq - x->preplay.oseq < x->replay_maxdiff)) {
1625                         if (x->xflags & XFRM_TIME_DEFER)
1626                                 event = XFRM_REPLAY_TIMEOUT;
1627                         else
1628                                 return;
1629                 }
1630
1631                 break;
1632
1633         case XFRM_REPLAY_TIMEOUT:
1634                 if ((x->replay.seq == x->preplay.seq) &&
1635                     (x->replay.bitmap == x->preplay.bitmap) &&
1636                     (x->replay.oseq == x->preplay.oseq)) {
1637                         x->xflags |= XFRM_TIME_DEFER;
1638                         return;
1639                 }
1640
1641                 break;
1642         }
1643
1644         memcpy(&x->preplay, &x->replay, sizeof(struct xfrm_replay_state));
1645         c.event = XFRM_MSG_NEWAE;
1646         c.data.aevent = event;
1647         km_state_notify(x, &c);
1648
1649         if (x->replay_maxage &&
1650             !mod_timer(&x->rtimer, jiffies + x->replay_maxage))
1651                 x->xflags &= ~XFRM_TIME_DEFER;
1652 }
1653
1654 static void xfrm_replay_timer_handler(unsigned long data)
1655 {
1656         struct xfrm_state *x = (struct xfrm_state*)data;
1657
1658         spin_lock(&x->lock);
1659
1660         if (x->km.state == XFRM_STATE_VALID) {
1661                 if (xfrm_aevent_is_on(xs_net(x)))
1662                         xfrm_replay_notify(x, XFRM_REPLAY_TIMEOUT);
1663                 else
1664                         x->xflags |= XFRM_TIME_DEFER;
1665         }
1666
1667         spin_unlock(&x->lock);
1668 }
1669
1670 int xfrm_replay_check(struct xfrm_state *x,
1671                       struct sk_buff *skb, __be32 net_seq)
1672 {
1673         u32 diff;
1674         u32 seq = ntohl(net_seq);
1675
1676         if (unlikely(seq == 0))
1677                 goto err;
1678
1679         if (likely(seq > x->replay.seq))
1680                 return 0;
1681
1682         diff = x->replay.seq - seq;
1683         if (diff >= min_t(unsigned int, x->props.replay_window,
1684                           sizeof(x->replay.bitmap) * 8)) {
1685                 x->stats.replay_window++;
1686                 goto err;
1687         }
1688
1689         if (x->replay.bitmap & (1U << diff)) {
1690                 x->stats.replay++;
1691                 goto err;
1692         }
1693         return 0;
1694
1695 err:
1696         xfrm_audit_state_replay(x, skb, net_seq);
1697         return -EINVAL;
1698 }
1699
1700 void xfrm_replay_advance(struct xfrm_state *x, __be32 net_seq)
1701 {
1702         u32 diff;
1703         u32 seq = ntohl(net_seq);
1704
1705         if (seq > x->replay.seq) {
1706                 diff = seq - x->replay.seq;
1707                 if (diff < x->props.replay_window)
1708                         x->replay.bitmap = ((x->replay.bitmap) << diff) | 1;
1709                 else
1710                         x->replay.bitmap = 1;
1711                 x->replay.seq = seq;
1712         } else {
1713                 diff = x->replay.seq - seq;
1714                 x->replay.bitmap |= (1U << diff);
1715         }
1716
1717         if (xfrm_aevent_is_on(xs_net(x)))
1718                 xfrm_replay_notify(x, XFRM_REPLAY_UPDATE);
1719 }
1720
1721 static LIST_HEAD(xfrm_km_list);
1722 static DEFINE_RWLOCK(xfrm_km_lock);
1723
1724 void km_policy_notify(struct xfrm_policy *xp, int dir, struct km_event *c)
1725 {
1726         struct xfrm_mgr *km;
1727
1728         read_lock(&xfrm_km_lock);
1729         list_for_each_entry(km, &xfrm_km_list, list)
1730                 if (km->notify_policy)
1731                         km->notify_policy(xp, dir, c);
1732         read_unlock(&xfrm_km_lock);
1733 }
1734
1735 void km_state_notify(struct xfrm_state *x, struct km_event *c)
1736 {
1737         struct xfrm_mgr *km;
1738         read_lock(&xfrm_km_lock);
1739         list_for_each_entry(km, &xfrm_km_list, list)
1740                 if (km->notify)
1741                         km->notify(x, c);
1742         read_unlock(&xfrm_km_lock);
1743 }
1744
1745 EXPORT_SYMBOL(km_policy_notify);
1746 EXPORT_SYMBOL(km_state_notify);
1747
1748 void km_state_expired(struct xfrm_state *x, int hard, u32 pid)
1749 {
1750         struct net *net = xs_net(x);
1751         struct km_event c;
1752
1753         c.data.hard = hard;
1754         c.pid = pid;
1755         c.event = XFRM_MSG_EXPIRE;
1756         km_state_notify(x, &c);
1757
1758         if (hard)
1759                 wake_up(&net->xfrm.km_waitq);
1760 }
1761
1762 EXPORT_SYMBOL(km_state_expired);
1763 /*
1764  * We send to all registered managers regardless of failure
1765  * We are happy with one success
1766 */
1767 int km_query(struct xfrm_state *x, struct xfrm_tmpl *t, struct xfrm_policy *pol)
1768 {
1769         int err = -EINVAL, acqret;
1770         struct xfrm_mgr *km;
1771
1772         read_lock(&xfrm_km_lock);
1773         list_for_each_entry(km, &xfrm_km_list, list) {
1774                 acqret = km->acquire(x, t, pol, XFRM_POLICY_OUT);
1775                 if (!acqret)
1776                         err = acqret;
1777         }
1778         read_unlock(&xfrm_km_lock);
1779         return err;
1780 }
1781 EXPORT_SYMBOL(km_query);
1782
1783 int km_new_mapping(struct xfrm_state *x, xfrm_address_t *ipaddr, __be16 sport)
1784 {
1785         int err = -EINVAL;
1786         struct xfrm_mgr *km;
1787
1788         read_lock(&xfrm_km_lock);
1789         list_for_each_entry(km, &xfrm_km_list, list) {
1790                 if (km->new_mapping)
1791                         err = km->new_mapping(x, ipaddr, sport);
1792                 if (!err)
1793                         break;
1794         }
1795         read_unlock(&xfrm_km_lock);
1796         return err;
1797 }
1798 EXPORT_SYMBOL(km_new_mapping);
1799
1800 void km_policy_expired(struct xfrm_policy *pol, int dir, int hard, u32 pid)
1801 {
1802         struct net *net = xp_net(pol);
1803         struct km_event c;
1804
1805         c.data.hard = hard;
1806         c.pid = pid;
1807         c.event = XFRM_MSG_POLEXPIRE;
1808         km_policy_notify(pol, dir, &c);
1809
1810         if (hard)
1811                 wake_up(&net->xfrm.km_waitq);
1812 }
1813 EXPORT_SYMBOL(km_policy_expired);
1814
1815 #ifdef CONFIG_XFRM_MIGRATE
1816 int km_migrate(struct xfrm_selector *sel, u8 dir, u8 type,
1817                struct xfrm_migrate *m, int num_migrate,
1818                struct xfrm_kmaddress *k)
1819 {
1820         int err = -EINVAL;
1821         int ret;
1822         struct xfrm_mgr *km;
1823
1824         read_lock(&xfrm_km_lock);
1825         list_for_each_entry(km, &xfrm_km_list, list) {
1826                 if (km->migrate) {
1827                         ret = km->migrate(sel, dir, type, m, num_migrate, k);
1828                         if (!ret)
1829                                 err = ret;
1830                 }
1831         }
1832         read_unlock(&xfrm_km_lock);
1833         return err;
1834 }
1835 EXPORT_SYMBOL(km_migrate);
1836 #endif
1837
1838 int km_report(struct net *net, u8 proto, struct xfrm_selector *sel, xfrm_address_t *addr)
1839 {
1840         int err = -EINVAL;
1841         int ret;
1842         struct xfrm_mgr *km;
1843
1844         read_lock(&xfrm_km_lock);
1845         list_for_each_entry(km, &xfrm_km_list, list) {
1846                 if (km->report) {
1847                         ret = km->report(net, proto, sel, addr);
1848                         if (!ret)
1849                                 err = ret;
1850                 }
1851         }
1852         read_unlock(&xfrm_km_lock);
1853         return err;
1854 }
1855 EXPORT_SYMBOL(km_report);
1856
1857 int xfrm_user_policy(struct sock *sk, int optname, u8 __user *optval, int optlen)
1858 {
1859         int err;
1860         u8 *data;
1861         struct xfrm_mgr *km;
1862         struct xfrm_policy *pol = NULL;
1863
1864         if (optlen <= 0 || optlen > PAGE_SIZE)
1865                 return -EMSGSIZE;
1866
1867         data = kmalloc(optlen, GFP_KERNEL);
1868         if (!data)
1869                 return -ENOMEM;
1870
1871         err = -EFAULT;
1872         if (copy_from_user(data, optval, optlen))
1873                 goto out;
1874
1875         err = -EINVAL;
1876         read_lock(&xfrm_km_lock);
1877         list_for_each_entry(km, &xfrm_km_list, list) {
1878                 pol = km->compile_policy(sk, optname, data,
1879                                          optlen, &err);
1880                 if (err >= 0)
1881                         break;
1882         }
1883         read_unlock(&xfrm_km_lock);
1884
1885         if (err >= 0) {
1886                 xfrm_sk_policy_insert(sk, err, pol);
1887                 xfrm_pol_put(pol);
1888                 err = 0;
1889         }
1890
1891 out:
1892         kfree(data);
1893         return err;
1894 }
1895 EXPORT_SYMBOL(xfrm_user_policy);
1896
1897 int xfrm_register_km(struct xfrm_mgr *km)
1898 {
1899         write_lock_bh(&xfrm_km_lock);
1900         list_add_tail(&km->list, &xfrm_km_list);
1901         write_unlock_bh(&xfrm_km_lock);
1902         return 0;
1903 }
1904 EXPORT_SYMBOL(xfrm_register_km);
1905
1906 int xfrm_unregister_km(struct xfrm_mgr *km)
1907 {
1908         write_lock_bh(&xfrm_km_lock);
1909         list_del(&km->list);
1910         write_unlock_bh(&xfrm_km_lock);
1911         return 0;
1912 }
1913 EXPORT_SYMBOL(xfrm_unregister_km);
1914
1915 int xfrm_state_register_afinfo(struct xfrm_state_afinfo *afinfo)
1916 {
1917         int err = 0;
1918         if (unlikely(afinfo == NULL))
1919                 return -EINVAL;
1920         if (unlikely(afinfo->family >= NPROTO))
1921                 return -EAFNOSUPPORT;
1922         write_lock_bh(&xfrm_state_afinfo_lock);
1923         if (unlikely(xfrm_state_afinfo[afinfo->family] != NULL))
1924                 err = -ENOBUFS;
1925         else
1926                 xfrm_state_afinfo[afinfo->family] = afinfo;
1927         write_unlock_bh(&xfrm_state_afinfo_lock);
1928         return err;
1929 }
1930 EXPORT_SYMBOL(xfrm_state_register_afinfo);
1931
1932 int xfrm_state_unregister_afinfo(struct xfrm_state_afinfo *afinfo)
1933 {
1934         int err = 0;
1935         if (unlikely(afinfo == NULL))
1936                 return -EINVAL;
1937         if (unlikely(afinfo->family >= NPROTO))
1938                 return -EAFNOSUPPORT;
1939         write_lock_bh(&xfrm_state_afinfo_lock);
1940         if (likely(xfrm_state_afinfo[afinfo->family] != NULL)) {
1941                 if (unlikely(xfrm_state_afinfo[afinfo->family] != afinfo))
1942                         err = -EINVAL;
1943                 else
1944                         xfrm_state_afinfo[afinfo->family] = NULL;
1945         }
1946         write_unlock_bh(&xfrm_state_afinfo_lock);
1947         return err;
1948 }
1949 EXPORT_SYMBOL(xfrm_state_unregister_afinfo);
1950
1951 static struct xfrm_state_afinfo *xfrm_state_get_afinfo(unsigned int family)
1952 {
1953         struct xfrm_state_afinfo *afinfo;
1954         if (unlikely(family >= NPROTO))
1955                 return NULL;
1956         read_lock(&xfrm_state_afinfo_lock);
1957         afinfo = xfrm_state_afinfo[family];
1958         if (unlikely(!afinfo))
1959                 read_unlock(&xfrm_state_afinfo_lock);
1960         return afinfo;
1961 }
1962
1963 static void xfrm_state_put_afinfo(struct xfrm_state_afinfo *afinfo)
1964         __releases(xfrm_state_afinfo_lock)
1965 {
1966         read_unlock(&xfrm_state_afinfo_lock);
1967 }
1968
1969 /* Temporarily located here until net/xfrm/xfrm_tunnel.c is created */
1970 void xfrm_state_delete_tunnel(struct xfrm_state *x)
1971 {
1972         if (x->tunnel) {
1973                 struct xfrm_state *t = x->tunnel;
1974
1975                 if (atomic_read(&t->tunnel_users) == 2)
1976                         xfrm_state_delete(t);
1977                 atomic_dec(&t->tunnel_users);
1978                 xfrm_state_put(t);
1979                 x->tunnel = NULL;
1980         }
1981 }
1982 EXPORT_SYMBOL(xfrm_state_delete_tunnel);
1983
1984 int xfrm_state_mtu(struct xfrm_state *x, int mtu)
1985 {
1986         int res;
1987
1988         spin_lock_bh(&x->lock);
1989         if (x->km.state == XFRM_STATE_VALID &&
1990             x->type && x->type->get_mtu)
1991                 res = x->type->get_mtu(x, mtu);
1992         else
1993                 res = mtu - x->props.header_len;
1994         spin_unlock_bh(&x->lock);
1995         return res;
1996 }
1997
1998 int xfrm_init_state(struct xfrm_state *x)
1999 {
2000         struct xfrm_state_afinfo *afinfo;
2001         struct xfrm_mode *inner_mode;
2002         int family = x->props.family;
2003         int err;
2004
2005         err = -EAFNOSUPPORT;
2006         afinfo = xfrm_state_get_afinfo(family);
2007         if (!afinfo)
2008                 goto error;
2009
2010         err = 0;
2011         if (afinfo->init_flags)
2012                 err = afinfo->init_flags(x);
2013
2014         xfrm_state_put_afinfo(afinfo);
2015
2016         if (err)
2017                 goto error;
2018
2019         err = -EPROTONOSUPPORT;
2020
2021         if (x->sel.family != AF_UNSPEC) {
2022                 inner_mode = xfrm_get_mode(x->props.mode, x->sel.family);
2023                 if (inner_mode == NULL)
2024                         goto error;
2025
2026                 if (!(inner_mode->flags & XFRM_MODE_FLAG_TUNNEL) &&
2027                     family != x->sel.family) {
2028                         xfrm_put_mode(inner_mode);
2029                         goto error;
2030                 }
2031
2032                 x->inner_mode = inner_mode;
2033         } else {
2034                 struct xfrm_mode *inner_mode_iaf;
2035                 int iafamily = AF_INET;
2036
2037                 inner_mode = xfrm_get_mode(x->props.mode, x->props.family);
2038                 if (inner_mode == NULL)
2039                         goto error;
2040
2041                 if (!(inner_mode->flags & XFRM_MODE_FLAG_TUNNEL)) {
2042                         xfrm_put_mode(inner_mode);
2043                         goto error;
2044                 }
2045                 x->inner_mode = inner_mode;
2046
2047                 if (x->props.family == AF_INET)
2048                         iafamily = AF_INET6;
2049
2050                 inner_mode_iaf = xfrm_get_mode(x->props.mode, iafamily);
2051                 if (inner_mode_iaf) {
2052                         if (inner_mode_iaf->flags & XFRM_MODE_FLAG_TUNNEL)
2053                                 x->inner_mode_iaf = inner_mode_iaf;
2054                         else
2055                                 xfrm_put_mode(inner_mode_iaf);
2056                 }
2057         }
2058
2059         x->type = xfrm_get_type(x->id.proto, family);
2060         if (x->type == NULL)
2061                 goto error;
2062
2063         err = x->type->init_state(x);
2064         if (err)
2065                 goto error;
2066
2067         x->outer_mode = xfrm_get_mode(x->props.mode, family);
2068         if (x->outer_mode == NULL)
2069                 goto error;
2070
2071         x->km.state = XFRM_STATE_VALID;
2072
2073 error:
2074         return err;
2075 }
2076
2077 EXPORT_SYMBOL(xfrm_init_state);
2078
2079 int __net_init xfrm_state_init(struct net *net)
2080 {
2081         unsigned int sz;
2082
2083         INIT_LIST_HEAD(&net->xfrm.state_all);
2084
2085         sz = sizeof(struct hlist_head) * 8;
2086
2087         net->xfrm.state_bydst = xfrm_hash_alloc(sz);
2088         if (!net->xfrm.state_bydst)
2089                 goto out_bydst;
2090         net->xfrm.state_bysrc = xfrm_hash_alloc(sz);
2091         if (!net->xfrm.state_bysrc)
2092                 goto out_bysrc;
2093         net->xfrm.state_byspi = xfrm_hash_alloc(sz);
2094         if (!net->xfrm.state_byspi)
2095                 goto out_byspi;
2096         net->xfrm.state_hmask = ((sz / sizeof(struct hlist_head)) - 1);
2097
2098         net->xfrm.state_num = 0;
2099         INIT_WORK(&net->xfrm.state_hash_work, xfrm_hash_resize);
2100         INIT_HLIST_HEAD(&net->xfrm.state_gc_list);
2101         INIT_WORK(&net->xfrm.state_gc_work, xfrm_state_gc_task);
2102         init_waitqueue_head(&net->xfrm.km_waitq);
2103         return 0;
2104
2105 out_byspi:
2106         xfrm_hash_free(net->xfrm.state_bysrc, sz);
2107 out_bysrc:
2108         xfrm_hash_free(net->xfrm.state_bydst, sz);
2109 out_bydst:
2110         return -ENOMEM;
2111 }
2112
2113 void xfrm_state_fini(struct net *net)
2114 {
2115         struct xfrm_audit audit_info;
2116         unsigned int sz;
2117
2118         flush_work(&net->xfrm.state_hash_work);
2119         audit_info.loginuid = -1;
2120         audit_info.sessionid = -1;
2121         audit_info.secid = 0;
2122         xfrm_state_flush(net, IPSEC_PROTO_ANY, &audit_info);
2123         flush_work(&net->xfrm.state_gc_work);
2124
2125         WARN_ON(!list_empty(&net->xfrm.state_all));
2126
2127         sz = (net->xfrm.state_hmask + 1) * sizeof(struct hlist_head);
2128         WARN_ON(!hlist_empty(net->xfrm.state_byspi));
2129         xfrm_hash_free(net->xfrm.state_byspi, sz);
2130         WARN_ON(!hlist_empty(net->xfrm.state_bysrc));
2131         xfrm_hash_free(net->xfrm.state_bysrc, sz);
2132         WARN_ON(!hlist_empty(net->xfrm.state_bydst));
2133         xfrm_hash_free(net->xfrm.state_bydst, sz);
2134 }
2135
2136 #ifdef CONFIG_AUDITSYSCALL
2137 static void xfrm_audit_helper_sainfo(struct xfrm_state *x,
2138                                      struct audit_buffer *audit_buf)
2139 {
2140         struct xfrm_sec_ctx *ctx = x->security;
2141         u32 spi = ntohl(x->id.spi);
2142
2143         if (ctx)
2144                 audit_log_format(audit_buf, " sec_alg=%u sec_doi=%u sec_obj=%s",
2145                                  ctx->ctx_alg, ctx->ctx_doi, ctx->ctx_str);
2146
2147         switch(x->props.family) {
2148         case AF_INET:
2149                 audit_log_format(audit_buf, " src=%pI4 dst=%pI4",
2150                                  &x->props.saddr.a4, &x->id.daddr.a4);
2151                 break;
2152         case AF_INET6:
2153                 audit_log_format(audit_buf, " src=%pI6 dst=%pI6",
2154                                  x->props.saddr.a6, x->id.daddr.a6);
2155                 break;
2156         }
2157
2158         audit_log_format(audit_buf, " spi=%u(0x%x)", spi, spi);
2159 }
2160
2161 static void xfrm_audit_helper_pktinfo(struct sk_buff *skb, u16 family,
2162                                       struct audit_buffer *audit_buf)
2163 {
2164         struct iphdr *iph4;
2165         struct ipv6hdr *iph6;
2166
2167         switch (family) {
2168         case AF_INET:
2169                 iph4 = ip_hdr(skb);
2170                 audit_log_format(audit_buf, " src=%pI4 dst=%pI4",
2171                                  &iph4->saddr, &iph4->daddr);
2172                 break;
2173         case AF_INET6:
2174                 iph6 = ipv6_hdr(skb);
2175                 audit_log_format(audit_buf,
2176                                  " src=%pI6 dst=%pI6 flowlbl=0x%x%02x%02x",
2177                                  &iph6->saddr,&iph6->daddr,
2178                                  iph6->flow_lbl[0] & 0x0f,
2179                                  iph6->flow_lbl[1],
2180                                  iph6->flow_lbl[2]);
2181                 break;
2182         }
2183 }
2184
2185 void xfrm_audit_state_add(struct xfrm_state *x, int result,
2186                           uid_t auid, u32 sessionid, u32 secid)
2187 {
2188         struct audit_buffer *audit_buf;
2189
2190         audit_buf = xfrm_audit_start("SAD-add");
2191         if (audit_buf == NULL)
2192                 return;
2193         xfrm_audit_helper_usrinfo(auid, sessionid, secid, audit_buf);
2194         xfrm_audit_helper_sainfo(x, audit_buf);
2195         audit_log_format(audit_buf, " res=%u", result);
2196         audit_log_end(audit_buf);
2197 }
2198 EXPORT_SYMBOL_GPL(xfrm_audit_state_add);
2199
2200 void xfrm_audit_state_delete(struct xfrm_state *x, int result,
2201                              uid_t auid, u32 sessionid, u32 secid)
2202 {
2203         struct audit_buffer *audit_buf;
2204
2205         audit_buf = xfrm_audit_start("SAD-delete");
2206         if (audit_buf == NULL)
2207                 return;
2208         xfrm_audit_helper_usrinfo(auid, sessionid, secid, audit_buf);
2209         xfrm_audit_helper_sainfo(x, audit_buf);
2210         audit_log_format(audit_buf, " res=%u", result);
2211         audit_log_end(audit_buf);
2212 }
2213 EXPORT_SYMBOL_GPL(xfrm_audit_state_delete);
2214
2215 void xfrm_audit_state_replay_overflow(struct xfrm_state *x,
2216                                       struct sk_buff *skb)
2217 {
2218         struct audit_buffer *audit_buf;
2219         u32 spi;
2220
2221         audit_buf = xfrm_audit_start("SA-replay-overflow");
2222         if (audit_buf == NULL)
2223                 return;
2224         xfrm_audit_helper_pktinfo(skb, x->props.family, audit_buf);
2225         /* don't record the sequence number because it's inherent in this kind
2226          * of audit message */
2227         spi = ntohl(x->id.spi);
2228         audit_log_format(audit_buf, " spi=%u(0x%x)", spi, spi);
2229         audit_log_end(audit_buf);
2230 }
2231 EXPORT_SYMBOL_GPL(xfrm_audit_state_replay_overflow);
2232
2233 static void xfrm_audit_state_replay(struct xfrm_state *x,
2234                              struct sk_buff *skb, __be32 net_seq)
2235 {
2236         struct audit_buffer *audit_buf;
2237         u32 spi;
2238
2239         audit_buf = xfrm_audit_start("SA-replayed-pkt");
2240         if (audit_buf == NULL)
2241                 return;
2242         xfrm_audit_helper_pktinfo(skb, x->props.family, audit_buf);
2243         spi = ntohl(x->id.spi);
2244         audit_log_format(audit_buf, " spi=%u(0x%x) seqno=%u",
2245                          spi, spi, ntohl(net_seq));
2246         audit_log_end(audit_buf);
2247 }
2248
2249 void xfrm_audit_state_notfound_simple(struct sk_buff *skb, u16 family)
2250 {
2251         struct audit_buffer *audit_buf;
2252
2253         audit_buf = xfrm_audit_start("SA-notfound");
2254         if (audit_buf == NULL)
2255                 return;
2256         xfrm_audit_helper_pktinfo(skb, family, audit_buf);
2257         audit_log_end(audit_buf);
2258 }
2259 EXPORT_SYMBOL_GPL(xfrm_audit_state_notfound_simple);
2260
2261 void xfrm_audit_state_notfound(struct sk_buff *skb, u16 family,
2262                                __be32 net_spi, __be32 net_seq)
2263 {
2264         struct audit_buffer *audit_buf;
2265         u32 spi;
2266
2267         audit_buf = xfrm_audit_start("SA-notfound");
2268         if (audit_buf == NULL)
2269                 return;
2270         xfrm_audit_helper_pktinfo(skb, family, audit_buf);
2271         spi = ntohl(net_spi);
2272         audit_log_format(audit_buf, " spi=%u(0x%x) seqno=%u",
2273                          spi, spi, ntohl(net_seq));
2274         audit_log_end(audit_buf);
2275 }
2276 EXPORT_SYMBOL_GPL(xfrm_audit_state_notfound);
2277
2278 void xfrm_audit_state_icvfail(struct xfrm_state *x,
2279                               struct sk_buff *skb, u8 proto)
2280 {
2281         struct audit_buffer *audit_buf;
2282         __be32 net_spi;
2283         __be32 net_seq;
2284
2285         audit_buf = xfrm_audit_start("SA-icv-failure");
2286         if (audit_buf == NULL)
2287                 return;
2288         xfrm_audit_helper_pktinfo(skb, x->props.family, audit_buf);
2289         if (xfrm_parse_spi(skb, proto, &net_spi, &net_seq) == 0) {
2290                 u32 spi = ntohl(net_spi);
2291                 audit_log_format(audit_buf, " spi=%u(0x%x) seqno=%u",
2292                                  spi, spi, ntohl(net_seq));
2293         }
2294         audit_log_end(audit_buf);
2295 }
2296 EXPORT_SYMBOL_GPL(xfrm_audit_state_icvfail);
2297 #endif /* CONFIG_AUDITSYSCALL */