diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c index f438bd3b7689..15615c94804f 100644 --- a/arch/x86/net/bpf_jit_comp.c +++ b/arch/x86/net/bpf_jit_comp.c @@ -239,6 +239,123 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf) *pprog = prog; } +static int emit_patch(u8 **pprog, void *func, void *ip, u8 opcode) +{ + u8 *prog = *pprog; + int cnt = 0; + s64 offset; + + offset = func - (ip + X86_PATCH_SIZE); + if (!is_simm32(offset)) { + pr_err("Target call %p is out of range\n", func); + return -ERANGE; + } + EMIT1_off32(opcode, offset); + *pprog = prog; + return 0; +} + +static int emit_call(u8 **pprog, void *func, void *ip) +{ + return emit_patch(pprog, func, ip, 0xE8); +} + +static int emit_jump(u8 **pprog, void *func, void *ip) +{ + return emit_patch(pprog, func, ip, 0xE9); +} + +static int __bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t, + void *old_addr, void *new_addr, + const bool text_live) +{ + int (*emit_patch_fn)(u8 **pprog, void *func, void *ip); + const u8 *nop_insn = ideal_nops[NOP_ATOMIC5]; + u8 old_insn[X86_PATCH_SIZE] = {}; + u8 new_insn[X86_PATCH_SIZE] = {}; + u8 *prog; + int ret; + + switch (t) { + case BPF_MOD_NOP_TO_CALL ... BPF_MOD_CALL_TO_NOP: + emit_patch_fn = emit_call; + break; + case BPF_MOD_NOP_TO_JUMP ... BPF_MOD_JUMP_TO_NOP: + emit_patch_fn = emit_jump; + break; + default: + return -ENOTSUPP; + } + + switch (t) { + case BPF_MOD_NOP_TO_CALL: + case BPF_MOD_NOP_TO_JUMP: + if (!old_addr && new_addr) { + memcpy(old_insn, nop_insn, X86_PATCH_SIZE); + + prog = new_insn; + ret = emit_patch_fn(&prog, new_addr, ip); + if (ret) + return ret; + break; + } + return -ENXIO; + case BPF_MOD_CALL_TO_CALL: + case BPF_MOD_JUMP_TO_JUMP: + if (old_addr && new_addr) { + prog = old_insn; + ret = emit_patch_fn(&prog, old_addr, ip); + if (ret) + return ret; + + prog = new_insn; + ret = emit_patch_fn(&prog, new_addr, ip); + if (ret) + return ret; + break; + } + return -ENXIO; + case BPF_MOD_CALL_TO_NOP: + case BPF_MOD_JUMP_TO_NOP: + if (old_addr && !new_addr) { + memcpy(new_insn, nop_insn, X86_PATCH_SIZE); + + prog = old_insn; + ret = emit_patch_fn(&prog, old_addr, ip); + if (ret) + return ret; + break; + } + return -ENXIO; + default: + return -ENOTSUPP; + } + + ret = -EBUSY; + mutex_lock(&text_mutex); + if (memcmp(ip, old_insn, X86_PATCH_SIZE)) + goto out; + if (text_live) + text_poke_bp(ip, new_insn, X86_PATCH_SIZE, NULL); + else + memcpy(ip, new_insn, X86_PATCH_SIZE); + ret = 0; +out: + mutex_unlock(&text_mutex); + return ret; +} + +int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t, + void *old_addr, void *new_addr) +{ + if (!is_kernel_text((long)ip) && + !is_bpf_text_address((long)ip)) + /* BPF poking in modules is not supported */ + return -EINVAL; + + return __bpf_arch_text_poke(ip, t, old_addr, new_addr, true); +} + /* * Generate the following code: * @@ -253,7 +370,7 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf) * goto *(prog->bpf_func + prologue_size); * out: */ -static void emit_bpf_tail_call(u8 **pprog) +static void emit_bpf_tail_call_indirect(u8 **pprog) { u8 *prog = *pprog; int label1, label2, label3; @@ -320,6 +437,69 @@ static void emit_bpf_tail_call(u8 **pprog) *pprog = prog; } +static void emit_bpf_tail_call_direct(struct bpf_jit_poke_descriptor *poke, + u8 **pprog, int addr, u8 *image) +{ + u8 *prog = *pprog; + int cnt = 0; + + /* + * if (tail_call_cnt > MAX_TAIL_CALL_CNT) + * goto out; + */ + EMIT2_off32(0x8B, 0x85, -36 - MAX_BPF_STACK); /* mov eax, dword ptr [rbp - 548] */ + EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */ + EMIT2(X86_JA, 14); /* ja out */ + EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */ + EMIT2_off32(0x89, 0x85, -36 - MAX_BPF_STACK); /* mov dword ptr [rbp -548], eax */ + + poke->ip = image + (addr - X86_PATCH_SIZE); + poke->adj_off = PROLOGUE_SIZE; + + memcpy(prog, ideal_nops[NOP_ATOMIC5], X86_PATCH_SIZE); + prog += X86_PATCH_SIZE; + /* out: */ + + *pprog = prog; +} + +static void bpf_tail_call_direct_fixup(struct bpf_prog *prog) +{ + static const enum bpf_text_poke_type type = BPF_MOD_NOP_TO_JUMP; + struct bpf_jit_poke_descriptor *poke; + struct bpf_array *array; + struct bpf_prog *target; + int i, ret; + + for (i = 0; i < prog->aux->size_poke_tab; i++) { + poke = &prog->aux->poke_tab[i]; + WARN_ON_ONCE(READ_ONCE(poke->ip_stable)); + + if (poke->reason != BPF_POKE_REASON_TAIL_CALL) + continue; + + array = container_of(poke->tail_call.map, struct bpf_array, map); + mutex_lock(&array->aux->poke_mutex); + target = array->ptrs[poke->tail_call.key]; + if (target) { + /* Plain memcpy is used when image is not live yet + * and still not locked as read-only. Once poke + * location is active (poke->ip_stable), any parallel + * bpf_arch_text_poke() might occur still on the + * read-write image until we finally locked it as + * read-only. Both modifications on the given image + * are under text_mutex to avoid interference. + */ + ret = __bpf_arch_text_poke(poke->ip, type, NULL, + (u8 *)target->bpf_func + + poke->adj_off, false); + BUG_ON(ret < 0); + } + WRITE_ONCE(poke->ip_stable, true); + mutex_unlock(&array->aux->poke_mutex); + } +} + static void emit_mov_imm32(u8 **pprog, bool sign_propagate, u32 dst_reg, const u32 imm32) { @@ -481,99 +661,6 @@ static void emit_stx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off) *pprog = prog; } -static int emit_patch(u8 **pprog, void *func, void *ip, u8 opcode) -{ - u8 *prog = *pprog; - int cnt = 0; - s64 offset; - - offset = func - (ip + X86_PATCH_SIZE); - if (!is_simm32(offset)) { - pr_err("Target call %p is out of range\n", func); - return -EINVAL; - } - EMIT1_off32(opcode, offset); - *pprog = prog; - return 0; -} - -static int emit_call(u8 **pprog, void *func, void *ip) -{ - return emit_patch(pprog, func, ip, 0xE8); -} - -static int emit_jump(u8 **pprog, void *func, void *ip) -{ - return emit_patch(pprog, func, ip, 0xE9); -} - -int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t, - void *old_addr, void *new_addr) -{ - int (*emit_patch_fn)(u8 **pprog, void *func, void *ip); - u8 old_insn[X86_PATCH_SIZE] = {}; - u8 new_insn[X86_PATCH_SIZE] = {}; - u8 *prog; - int ret; - - if (!is_kernel_text((long)ip) && - !is_bpf_text_address((long)ip)) - /* BPF poking in modules is not supported */ - return -EINVAL; - - switch (t) { - case BPF_MOD_NOP_TO_CALL ... BPF_MOD_CALL_TO_NOP: - emit_patch_fn = emit_call; - break; - case BPF_MOD_NOP_TO_JUMP ... BPF_MOD_JUMP_TO_NOP: - emit_patch_fn = emit_jump; - break; - default: - return -ENOTSUPP; - } - - if (old_addr) { - prog = old_insn; - ret = emit_patch_fn(&prog, old_addr, (void *)ip); - if (ret) - return ret; - } - if (new_addr) { - prog = new_insn; - ret = emit_patch_fn(&prog, new_addr, (void *)ip); - if (ret) - return ret; - } - - ret = -EBUSY; - mutex_lock(&text_mutex); - switch (t) { - case BPF_MOD_NOP_TO_CALL: - case BPF_MOD_NOP_TO_JUMP: - if (memcmp(ip, ideal_nops[NOP_ATOMIC5], X86_PATCH_SIZE)) - goto out; - text_poke_bp(ip, new_insn, X86_PATCH_SIZE, NULL); - break; - case BPF_MOD_CALL_TO_CALL: - case BPF_MOD_JUMP_TO_JUMP: - if (memcmp(ip, old_insn, X86_PATCH_SIZE)) - goto out; - text_poke_bp(ip, new_insn, X86_PATCH_SIZE, NULL); - break; - case BPF_MOD_CALL_TO_NOP: - case BPF_MOD_JUMP_TO_NOP: - if (memcmp(ip, old_insn, X86_PATCH_SIZE)) - goto out; - text_poke_bp(ip, ideal_nops[NOP_ATOMIC5], X86_PATCH_SIZE, - NULL); - break; - } - ret = 0; -out: - mutex_unlock(&text_mutex); - return ret; -} - static bool ex_handler_bpf(const struct exception_table_entry *x, struct pt_regs *regs, int trapnr, unsigned long error_code, unsigned long fault_addr) @@ -1041,7 +1128,11 @@ xadd: if (is_imm8(insn->off)) break; case BPF_JMP | BPF_TAIL_CALL: - emit_bpf_tail_call(&prog); + if (imm32) + emit_bpf_tail_call_direct(&bpf_prog->aux->poke_tab[imm32 - 1], + &prog, addrs[i], image); + else + emit_bpf_tail_call_indirect(&prog); break; /* cond jump */ @@ -1599,6 +1690,7 @@ out_image: if (image) { if (!prog->is_func || extra_pass) { + bpf_tail_call_direct_fixup(prog); bpf_jit_binary_lock_ro(header); } else { jit_data->addrs = addrs;