diff --git a/include/linux/mm.h b/include/linux/mm.h index 17b0b3e7a37b..b44e0b2f9975 100644 --- a/include/linux/mm.h +++ b/include/linux/mm.h @@ -1251,6 +1251,7 @@ static inline void INIT_VMA(struct vm_area_struct *vma) INIT_LIST_HEAD(&vma->anon_vma_chain); #ifdef CONFIG_SPECULATIVE_PAGE_FAULT seqcount_init(&vma->vm_sequence); + atomic_set(&vma->vm_ref_count, 1); #endif } diff --git a/include/linux/mm_types.h b/include/linux/mm_types.h index 8f1478057812..121a5796bbe4 100644 --- a/include/linux/mm_types.h +++ b/include/linux/mm_types.h @@ -346,6 +346,7 @@ struct vm_area_struct { struct vm_userfaultfd_ctx vm_userfaultfd_ctx; #ifdef CONFIG_SPECULATIVE_PAGE_FAULT seqcount_t vm_sequence; + atomic_t vm_ref_count; /* see vma_get(), vma_put() */ #endif } __randomize_layout; @@ -364,6 +365,9 @@ struct kioctx_table; struct mm_struct { struct vm_area_struct *mmap; /* list of VMAs */ struct rb_root mm_rb; +#ifdef CONFIG_SPECULATIVE_PAGE_FAULT + rwlock_t mm_rb_lock; +#endif u32 vmacache_seqnum; /* per-thread vmacache */ #ifdef CONFIG_MMU unsigned long (*get_unmapped_area) (struct file *filp, diff --git a/kernel/fork.c b/kernel/fork.c index 9d3eb700dfa2..d323550d6c14 100644 --- a/kernel/fork.c +++ b/kernel/fork.c @@ -811,6 +811,9 @@ static struct mm_struct *mm_init(struct mm_struct *mm, struct task_struct *p, mm->mmap = NULL; mm->mm_rb = RB_ROOT; mm->vmacache_seqnum = 0; +#ifdef CONFIG_SPECULATIVE_PAGE_FAULT + rwlock_init(&mm->mm_rb_lock); +#endif atomic_set(&mm->mm_users, 1); atomic_set(&mm->mm_count, 1); init_rwsem(&mm->mmap_sem); diff --git a/mm/init-mm.c b/mm/init-mm.c index f94d5d15ebc0..e71ac37a98c4 100644 --- a/mm/init-mm.c +++ b/mm/init-mm.c @@ -17,6 +17,9 @@ struct mm_struct init_mm = { .mm_rb = RB_ROOT, +#ifdef CONFIG_SPECULATIVE_PAGE_FAULT + .mm_rb_lock = __RW_LOCK_UNLOCKED(init_mm.mm_rb_lock), +#endif .pgd = swapper_pg_dir, .mm_users = ATOMIC_INIT(2), .mm_count = ATOMIC_INIT(1), diff --git a/mm/internal.h b/mm/internal.h index 1df011f62480..7a2d10376775 100644 --- a/mm/internal.h +++ b/mm/internal.h @@ -40,6 +40,12 @@ void page_writeback_init(void); int do_swap_page(struct vm_fault *vmf); +#ifdef CONFIG_SPECULATIVE_PAGE_FAULT +extern struct vm_area_struct *get_vma(struct mm_struct *mm, + unsigned long addr); +extern void put_vma(struct vm_area_struct *vma); +#endif + void free_pgtables(struct mmu_gather *tlb, struct vm_area_struct *start_vma, unsigned long floor, unsigned long ceiling); diff --git a/mm/mmap.c b/mm/mmap.c index 2855b682f58a..aa5bf9a71e5b 100644 --- a/mm/mmap.c +++ b/mm/mmap.c @@ -160,6 +160,27 @@ void unlink_file_vma(struct vm_area_struct *vma) } } +static void __free_vma(struct vm_area_struct *vma) +{ + if (vma->vm_file) + fput(vma->vm_file); + mpol_put(vma_policy(vma)); + kmem_cache_free(vm_area_cachep, vma); +} + +#ifdef CONFIG_SPECULATIVE_PAGE_FAULT +void put_vma(struct vm_area_struct *vma) +{ + if (atomic_dec_and_test(&vma->vm_ref_count)) + __free_vma(vma); +} +#else +static inline void put_vma(struct vm_area_struct *vma) +{ + __free_vma(vma); +} +#endif + /* * Close a vm structure and free it, returning the next. */ @@ -170,10 +191,7 @@ static struct vm_area_struct *remove_vma(struct vm_area_struct *vma) might_sleep(); if (vma->vm_ops && vma->vm_ops->close) vma->vm_ops->close(vma); - if (vma->vm_file) - fput(vma->vm_file); - mpol_put(vma_policy(vma)); - kmem_cache_free(vm_area_cachep, vma); + put_vma(vma); return next; } @@ -393,6 +411,14 @@ static void validate_mm(struct mm_struct *mm) #define validate_mm(mm) do { } while (0) #endif +#ifdef CONFIG_SPECULATIVE_PAGE_FAULT +#define mm_rb_write_lock(mm) write_lock(&(mm)->mm_rb_lock) +#define mm_rb_write_unlock(mm) write_unlock(&(mm)->mm_rb_lock) +#else +#define mm_rb_write_lock(mm) do { } while (0) +#define mm_rb_write_unlock(mm) do { } while (0) +#endif /* CONFIG_SPECULATIVE_PAGE_FAULT */ + RB_DECLARE_CALLBACKS(static, vma_gap_callbacks, struct vm_area_struct, vm_rb, unsigned long, rb_subtree_gap, vma_compute_subtree_gap) @@ -411,26 +437,37 @@ static void vma_gap_update(struct vm_area_struct *vma) } static inline void vma_rb_insert(struct vm_area_struct *vma, - struct rb_root *root) + struct mm_struct *mm) { + struct rb_root *root = &mm->mm_rb; + /* All rb_subtree_gap values must be consistent prior to insertion */ validate_mm_rb(root, NULL); rb_insert_augmented(&vma->vm_rb, root, &vma_gap_callbacks); } -static void __vma_rb_erase(struct vm_area_struct *vma, struct rb_root *root) +static void __vma_rb_erase(struct vm_area_struct *vma, struct mm_struct *mm) { + struct rb_root *root = &mm->mm_rb; /* * Note rb_erase_augmented is a fairly large inline function, * so make sure we instantiate it only once with our desired * augmented rbtree callbacks. */ + mm_rb_write_lock(mm); rb_erase_augmented(&vma->vm_rb, root, &vma_gap_callbacks); + mm_rb_write_unlock(mm); /* wmb */ + + /* + * Ensure the removal is complete before clearing the node. + * Matched by vma_has_changed()/handle_speculative_fault(). + */ + RB_CLEAR_NODE(&vma->vm_rb); } static __always_inline void vma_rb_erase_ignore(struct vm_area_struct *vma, - struct rb_root *root, + struct mm_struct *mm, struct vm_area_struct *ignore) { /* @@ -438,21 +475,21 @@ static __always_inline void vma_rb_erase_ignore(struct vm_area_struct *vma, * with the possible exception of the "next" vma being erased if * next->vm_start was reduced. */ - validate_mm_rb(root, ignore); + validate_mm_rb(&mm->mm_rb, ignore); - __vma_rb_erase(vma, root); + __vma_rb_erase(vma, mm); } static __always_inline void vma_rb_erase(struct vm_area_struct *vma, - struct rb_root *root) + struct mm_struct *mm) { /* * All rb_subtree_gap values must be consistent prior to erase, * with the possible exception of the vma being erased. */ - validate_mm_rb(root, vma); + validate_mm_rb(&mm->mm_rb, vma); - __vma_rb_erase(vma, root); + __vma_rb_erase(vma, mm); } /* @@ -567,10 +604,12 @@ void __vma_link_rb(struct mm_struct *mm, struct vm_area_struct *vma, * immediately update the gap to the correct value. Finally we * rebalance the rbtree after all augmented values have been set. */ + mm_rb_write_lock(mm); rb_link_node(&vma->vm_rb, rb_parent, rb_link); vma->rb_subtree_gap = 0; vma_gap_update(vma); - vma_rb_insert(vma, &mm->mm_rb); + vma_rb_insert(vma, mm); + mm_rb_write_unlock(mm); } static void __vma_link_file(struct vm_area_struct *vma) @@ -646,7 +685,7 @@ static __always_inline void __vma_unlink_common(struct mm_struct *mm, { struct vm_area_struct *next; - vma_rb_erase_ignore(vma, &mm->mm_rb, ignore); + vma_rb_erase_ignore(vma, mm, ignore); next = vma->vm_next; if (has_prev) prev->vm_next = next; @@ -923,16 +962,13 @@ again: } if (remove_next) { - if (file) { + if (file) uprobe_munmap(next, next->vm_start, next->vm_end); - fput(file); - } if (next->anon_vma) anon_vma_merge(vma, next); mm->map_count--; - mpol_put(vma_policy(next)); vm_raw_write_end(next); - kmem_cache_free(vm_area_cachep, next); + put_vma(next); /* * In mprotect's case 6 (see comments on vma_merge), * we must remove another next too. It would clutter @@ -2165,15 +2201,11 @@ get_unmapped_area(struct file *file, unsigned long addr, unsigned long len, EXPORT_SYMBOL(get_unmapped_area); /* Look up the first VMA which satisfies addr < vm_end, NULL if none. */ -struct vm_area_struct *find_vma(struct mm_struct *mm, unsigned long addr) +static struct vm_area_struct *__find_vma(struct mm_struct *mm, + unsigned long addr) { struct rb_node *rb_node; - struct vm_area_struct *vma; - - /* Check the cache first. */ - vma = vmacache_find(mm, addr); - if (likely(vma)) - return vma; + struct vm_area_struct *vma = NULL; rb_node = mm->mm_rb.rb_node; @@ -2191,13 +2223,40 @@ struct vm_area_struct *find_vma(struct mm_struct *mm, unsigned long addr) rb_node = rb_node->rb_right; } + return vma; +} + +struct vm_area_struct *find_vma(struct mm_struct *mm, unsigned long addr) +{ + struct vm_area_struct *vma; + + /* Check the cache first. */ + vma = vmacache_find(mm, addr); + if (likely(vma)) + return vma; + + vma = __find_vma(mm, addr); if (vma) vmacache_update(addr, vma); return vma; } - EXPORT_SYMBOL(find_vma); +#ifdef CONFIG_SPECULATIVE_PAGE_FAULT +struct vm_area_struct *get_vma(struct mm_struct *mm, unsigned long addr) +{ + struct vm_area_struct *vma = NULL; + + read_lock(&mm->mm_rb_lock); + vma = __find_vma(mm, addr); + if (vma) + atomic_inc(&vma->vm_ref_count); + read_unlock(&mm->mm_rb_lock); + + return vma; +} +#endif + /* * Same as find_vma, but also return a pointer to the previous VMA in *pprev. */ @@ -2565,7 +2624,7 @@ detach_vmas_to_be_unmapped(struct mm_struct *mm, struct vm_area_struct *vma, insertion_point = (prev ? &prev->vm_next : &mm->mmap); vma->vm_prev = NULL; do { - vma_rb_erase(vma, &mm->mm_rb); + vma_rb_erase(vma, mm); mm->map_count--; tail_vma = vma; vma = vma->vm_next;