diff --git a/arch/riscv/Kconfig b/arch/riscv/Kconfig index 5352932badd88..3e332ecc4c791 100644 --- a/arch/riscv/Kconfig +++ b/arch/riscv/Kconfig @@ -715,7 +715,6 @@ config TOOLCHAIN_HAS_ZACAS config RISCV_ISA_ZACAS bool "Zacas extension support for atomic CAS" - depends on TOOLCHAIN_HAS_ZACAS depends on RISCV_ALTERNATIVE default y help diff --git a/arch/riscv/include/asm/cmpxchg.h b/arch/riscv/include/asm/cmpxchg.h index 0b749e7102162..4f4f389282b8c 100644 --- a/arch/riscv/include/asm/cmpxchg.h +++ b/arch/riscv/include/asm/cmpxchg.h @@ -133,6 +133,7 @@ ({ \ if (IS_ENABLED(CONFIG_RISCV_ISA_ZABHA) && \ IS_ENABLED(CONFIG_RISCV_ISA_ZACAS) && \ + IS_ENABLED(CONFIG_TOOLCHAIN_HAS_ZACAS) && \ riscv_has_extension_unlikely(RISCV_ISA_EXT_ZABHA) && \ riscv_has_extension_unlikely(RISCV_ISA_EXT_ZACAS)) { \ r = o; \ @@ -180,6 +181,7 @@ r, p, co, o, n) \ ({ \ if (IS_ENABLED(CONFIG_RISCV_ISA_ZACAS) && \ + IS_ENABLED(CONFIG_TOOLCHAIN_HAS_ZACAS) && \ riscv_has_extension_unlikely(RISCV_ISA_EXT_ZACAS)) { \ r = o; \ \ @@ -315,7 +317,7 @@ arch_cmpxchg_release((ptr), (o), (n)); \ }) -#if defined(CONFIG_64BIT) && defined(CONFIG_RISCV_ISA_ZACAS) +#if defined(CONFIG_64BIT) && defined(CONFIG_RISCV_ISA_ZACAS) && defined(CONFIG_TOOLCHAIN_HAS_ZACAS) #define system_has_cmpxchg128() riscv_has_extension_unlikely(RISCV_ISA_EXT_ZACAS) @@ -351,7 +353,7 @@ union __u128_halves { #define arch_cmpxchg128_local(ptr, o, n) \ __arch_cmpxchg128((ptr), (o), (n), "") -#endif /* CONFIG_64BIT && CONFIG_RISCV_ISA_ZACAS */ +#endif /* CONFIG_64BIT && CONFIG_RISCV_ISA_ZACAS && CONFIG_TOOLCHAIN_HAS_ZACAS */ #ifdef CONFIG_RISCV_ISA_ZAWRS /* diff --git a/arch/riscv/kernel/setup.c b/arch/riscv/kernel/setup.c index 14888e5ea19ab..348b924bbbcab 100644 --- a/arch/riscv/kernel/setup.c +++ b/arch/riscv/kernel/setup.c @@ -288,6 +288,7 @@ static void __init riscv_spinlock_init(void) if (IS_ENABLED(CONFIG_RISCV_ISA_ZABHA) && IS_ENABLED(CONFIG_RISCV_ISA_ZACAS) && + IS_ENABLED(CONFIG_TOOLCHAIN_HAS_ZACAS) && riscv_isa_extension_available(NULL, ZABHA) && riscv_isa_extension_available(NULL, ZACAS)) { using_ext = "using Zabha"; diff --git a/arch/riscv/net/bpf_jit.h b/arch/riscv/net/bpf_jit.h index e7b032dfd17f0..632ced07bca44 100644 --- a/arch/riscv/net/bpf_jit.h +++ b/arch/riscv/net/bpf_jit.h @@ -13,21 +13,15 @@ #include #include +/* verify runtime detection extension status */ +#define rv_ext_enabled(ext) \ + (IS_ENABLED(CONFIG_RISCV_ISA_##ext) && riscv_has_extension_likely(RISCV_ISA_EXT_##ext)) + static inline bool rvc_enabled(void) { return IS_ENABLED(CONFIG_RISCV_ISA_C); } -static inline bool rvzba_enabled(void) -{ - return IS_ENABLED(CONFIG_RISCV_ISA_ZBA) && riscv_has_extension_likely(RISCV_ISA_EXT_ZBA); -} - -static inline bool rvzbb_enabled(void) -{ - return IS_ENABLED(CONFIG_RISCV_ISA_ZBB) && riscv_has_extension_likely(RISCV_ISA_EXT_ZBB); -} - enum { RV_REG_ZERO = 0, /* The constant value 0 */ RV_REG_RA = 1, /* Return address */ @@ -84,6 +78,8 @@ struct rv_jit_context { int epilogue_offset; int *offset; /* BPF to RV */ int nexentries; + int ex_insn_off; + int ex_jmp_off; unsigned long flags; int stack_size; u64 arena_vm_start; @@ -757,6 +753,17 @@ static inline u16 rvc_swsp(u32 imm8, u8 rs2) return rv_css_insn(0x6, imm, rs2, 0x2); } +/* RVZACAS instructions. */ +static inline u32 rvzacas_amocas_w(u8 rd, u8 rs2, u8 rs1, u8 aq, u8 rl) +{ + return rv_amo_insn(0x5, aq, rl, rs2, rs1, 2, rd, 0x2f); +} + +static inline u32 rvzacas_amocas_d(u8 rd, u8 rs2, u8 rs1, u8 aq, u8 rl) +{ + return rv_amo_insn(0x5, aq, rl, rs2, rs1, 3, rd, 0x2f); +} + /* RVZBA instructions. */ static inline u32 rvzba_sh2add(u8 rd, u8 rs1, u8 rs2) { @@ -1123,7 +1130,7 @@ static inline void emit_sw(u8 rs1, s32 off, u8 rs2, struct rv_jit_context *ctx) static inline void emit_sh2add(u8 rd, u8 rs1, u8 rs2, struct rv_jit_context *ctx) { - if (rvzba_enabled()) { + if (rv_ext_enabled(ZBA)) { emit(rvzba_sh2add(rd, rs1, rs2), ctx); return; } @@ -1134,7 +1141,7 @@ static inline void emit_sh2add(u8 rd, u8 rs1, u8 rs2, struct rv_jit_context *ctx static inline void emit_sh3add(u8 rd, u8 rs1, u8 rs2, struct rv_jit_context *ctx) { - if (rvzba_enabled()) { + if (rv_ext_enabled(ZBA)) { emit(rvzba_sh3add(rd, rs1, rs2), ctx); return; } @@ -1184,7 +1191,7 @@ static inline void emit_subw(u8 rd, u8 rs1, u8 rs2, struct rv_jit_context *ctx) static inline void emit_sextb(u8 rd, u8 rs, struct rv_jit_context *ctx) { - if (rvzbb_enabled()) { + if (rv_ext_enabled(ZBB)) { emit(rvzbb_sextb(rd, rs), ctx); return; } @@ -1195,7 +1202,7 @@ static inline void emit_sextb(u8 rd, u8 rs, struct rv_jit_context *ctx) static inline void emit_sexth(u8 rd, u8 rs, struct rv_jit_context *ctx) { - if (rvzbb_enabled()) { + if (rv_ext_enabled(ZBB)) { emit(rvzbb_sexth(rd, rs), ctx); return; } @@ -1211,7 +1218,7 @@ static inline void emit_sextw(u8 rd, u8 rs, struct rv_jit_context *ctx) static inline void emit_zexth(u8 rd, u8 rs, struct rv_jit_context *ctx) { - if (rvzbb_enabled()) { + if (rv_ext_enabled(ZBB)) { emit(rvzbb_zexth(rd, rs), ctx); return; } @@ -1222,7 +1229,7 @@ static inline void emit_zexth(u8 rd, u8 rs, struct rv_jit_context *ctx) static inline void emit_zextw(u8 rd, u8 rs, struct rv_jit_context *ctx) { - if (rvzba_enabled()) { + if (rv_ext_enabled(ZBA)) { emit(rvzba_zextw(rd, rs), ctx); return; } @@ -1233,7 +1240,7 @@ static inline void emit_zextw(u8 rd, u8 rs, struct rv_jit_context *ctx) static inline void emit_bswap(u8 rd, s32 imm, struct rv_jit_context *ctx) { - if (rvzbb_enabled()) { + if (rv_ext_enabled(ZBB)) { int bits = 64 - imm; emit(rvzbb_rev8(rd, rd), ctx); @@ -1289,6 +1296,35 @@ static inline void emit_bswap(u8 rd, s32 imm, struct rv_jit_context *ctx) emit_mv(rd, RV_REG_T2, ctx); } +static inline void emit_cmpxchg(u8 rd, u8 rs, u8 r0, bool is64, struct rv_jit_context *ctx) +{ + int jmp_offset; + + if (rv_ext_enabled(ZACAS)) { + ctx->ex_insn_off = ctx->ninsns; + emit(is64 ? rvzacas_amocas_d(r0, rs, rd, 1, 1) : + rvzacas_amocas_w(r0, rs, rd, 1, 1), ctx); + ctx->ex_jmp_off = ctx->ninsns; + if (!is64) + emit_zextw(r0, r0, ctx); + return; + } + + if (is64) + emit_mv(RV_REG_T2, r0, ctx); + else + emit_addiw(RV_REG_T2, r0, 0, ctx); + emit(is64 ? rv_lr_d(r0, 0, rd, 0, 0) : + rv_lr_w(r0, 0, rd, 0, 0), ctx); + jmp_offset = ninsns_rvoff(8); + emit(rv_bne(RV_REG_T2, r0, jmp_offset >> 1), ctx); + emit(is64 ? rv_sc_d(RV_REG_T3, rs, rd, 0, 1) : + rv_sc_w(RV_REG_T3, rs, rd, 0, 1), ctx); + jmp_offset = ninsns_rvoff(-6); + emit(rv_bne(RV_REG_T3, 0, jmp_offset >> 1), ctx); + emit_fence_rw_rw(ctx); +} + #endif /* __riscv_xlen == 64 */ void bpf_jit_build_prologue(struct rv_jit_context *ctx, bool is_subprog); diff --git a/arch/riscv/net/bpf_jit_comp64.c b/arch/riscv/net/bpf_jit_comp64.c index 10e01ff06312d..549c3063c7f11 100644 --- a/arch/riscv/net/bpf_jit_comp64.c +++ b/arch/riscv/net/bpf_jit_comp64.c @@ -473,138 +473,92 @@ static inline void emit_kcfi(u32 hash, struct rv_jit_context *ctx) emit(hash, ctx); } -static int emit_load_8(bool sign_ext, u8 rd, s32 off, u8 rs, struct rv_jit_context *ctx) +static void emit_ldx_insn(u8 rd, s16 off, u8 rs, u8 size, bool sign_ext, + struct rv_jit_context *ctx) { - int insns_start; - - if (is_12b_int(off)) { - insns_start = ctx->ninsns; - if (sign_ext) - emit(rv_lb(rd, off, rs), ctx); - else - emit(rv_lbu(rd, off, rs), ctx); - return ctx->ninsns - insns_start; - } - - emit_imm(RV_REG_T1, off, ctx); - emit_add(RV_REG_T1, RV_REG_T1, rs, ctx); - insns_start = ctx->ninsns; - if (sign_ext) - emit(rv_lb(rd, 0, RV_REG_T1), ctx); - else - emit(rv_lbu(rd, 0, RV_REG_T1), ctx); - return ctx->ninsns - insns_start; -} - -static int emit_load_16(bool sign_ext, u8 rd, s32 off, u8 rs, struct rv_jit_context *ctx) -{ - int insns_start; - - if (is_12b_int(off)) { - insns_start = ctx->ninsns; - if (sign_ext) - emit(rv_lh(rd, off, rs), ctx); - else - emit(rv_lhu(rd, off, rs), ctx); - return ctx->ninsns - insns_start; - } - - emit_imm(RV_REG_T1, off, ctx); - emit_add(RV_REG_T1, RV_REG_T1, rs, ctx); - insns_start = ctx->ninsns; - if (sign_ext) - emit(rv_lh(rd, 0, RV_REG_T1), ctx); - else - emit(rv_lhu(rd, 0, RV_REG_T1), ctx); - return ctx->ninsns - insns_start; -} - -static int emit_load_32(bool sign_ext, u8 rd, s32 off, u8 rs, struct rv_jit_context *ctx) -{ - int insns_start; - - if (is_12b_int(off)) { - insns_start = ctx->ninsns; - if (sign_ext) - emit(rv_lw(rd, off, rs), ctx); - else - emit(rv_lwu(rd, off, rs), ctx); - return ctx->ninsns - insns_start; - } - - emit_imm(RV_REG_T1, off, ctx); - emit_add(RV_REG_T1, RV_REG_T1, rs, ctx); - insns_start = ctx->ninsns; - if (sign_ext) - emit(rv_lw(rd, 0, RV_REG_T1), ctx); - else - emit(rv_lwu(rd, 0, RV_REG_T1), ctx); - return ctx->ninsns - insns_start; -} - -static int emit_load_64(bool sign_ext, u8 rd, s32 off, u8 rs, struct rv_jit_context *ctx) -{ - int insns_start; - - if (is_12b_int(off)) { - insns_start = ctx->ninsns; + switch (size) { + case BPF_B: + emit(sign_ext ? rv_lb(rd, off, rs) : rv_lbu(rd, off, rs), ctx); + break; + case BPF_H: + emit(sign_ext ? rv_lh(rd, off, rs) : rv_lhu(rd, off, rs), ctx); + break; + case BPF_W: + emit(sign_ext ? rv_lw(rd, off, rs) : rv_lwu(rd, off, rs), ctx); + break; + case BPF_DW: emit_ld(rd, off, rs, ctx); - return ctx->ninsns - insns_start; + break; } - emit_imm(RV_REG_T1, off, ctx); - emit_add(RV_REG_T1, RV_REG_T1, rs, ctx); - insns_start = ctx->ninsns; - emit_ld(rd, 0, RV_REG_T1, ctx); - return ctx->ninsns - insns_start; } -static void emit_store_8(u8 rd, s32 off, u8 rs, struct rv_jit_context *ctx) +static void emit_stx_insn(u8 rd, s16 off, u8 rs, u8 size, struct rv_jit_context *ctx) { - if (is_12b_int(off)) { + switch (size) { + case BPF_B: emit(rv_sb(rd, off, rs), ctx); - return; + break; + case BPF_H: + emit(rv_sh(rd, off, rs), ctx); + break; + case BPF_W: + emit_sw(rd, off, rs, ctx); + break; + case BPF_DW: + emit_sd(rd, off, rs, ctx); + break; } - - emit_imm(RV_REG_T1, off, ctx); - emit_add(RV_REG_T1, RV_REG_T1, rd, ctx); - emit(rv_sb(RV_REG_T1, 0, rs), ctx); } -static void emit_store_16(u8 rd, s32 off, u8 rs, struct rv_jit_context *ctx) +static void emit_ldx(u8 rd, s16 off, u8 rs, u8 size, bool sign_ext, + struct rv_jit_context *ctx) { if (is_12b_int(off)) { - emit(rv_sh(rd, off, rs), ctx); + ctx->ex_insn_off = ctx->ninsns; + emit_ldx_insn(rd, off, rs, size, sign_ext, ctx); + ctx->ex_jmp_off = ctx->ninsns; return; } emit_imm(RV_REG_T1, off, ctx); - emit_add(RV_REG_T1, RV_REG_T1, rd, ctx); - emit(rv_sh(RV_REG_T1, 0, rs), ctx); + emit_add(RV_REG_T1, RV_REG_T1, rs, ctx); + ctx->ex_insn_off = ctx->ninsns; + emit_ldx_insn(rd, 0, RV_REG_T1, size, sign_ext, ctx); + ctx->ex_jmp_off = ctx->ninsns; } -static void emit_store_32(u8 rd, s32 off, u8 rs, struct rv_jit_context *ctx) +static void emit_st(u8 rd, s16 off, s32 imm, u8 size, struct rv_jit_context *ctx) { + emit_imm(RV_REG_T1, imm, ctx); if (is_12b_int(off)) { - emit_sw(rd, off, rs, ctx); + ctx->ex_insn_off = ctx->ninsns; + emit_stx_insn(rd, off, RV_REG_T1, size, ctx); + ctx->ex_jmp_off = ctx->ninsns; return; } - emit_imm(RV_REG_T1, off, ctx); - emit_add(RV_REG_T1, RV_REG_T1, rd, ctx); - emit_sw(RV_REG_T1, 0, rs, ctx); + emit_imm(RV_REG_T2, off, ctx); + emit_add(RV_REG_T2, RV_REG_T2, rd, ctx); + ctx->ex_insn_off = ctx->ninsns; + emit_stx_insn(RV_REG_T2, 0, RV_REG_T1, size, ctx); + ctx->ex_jmp_off = ctx->ninsns; } -static void emit_store_64(u8 rd, s32 off, u8 rs, struct rv_jit_context *ctx) +static void emit_stx(u8 rd, s16 off, u8 rs, u8 size, struct rv_jit_context *ctx) { if (is_12b_int(off)) { - emit_sd(rd, off, rs, ctx); + ctx->ex_insn_off = ctx->ninsns; + emit_stx_insn(rd, off, rs, size, ctx); + ctx->ex_jmp_off = ctx->ninsns; return; } emit_imm(RV_REG_T1, off, ctx); emit_add(RV_REG_T1, RV_REG_T1, rd, ctx); - emit_sd(RV_REG_T1, 0, rs, ctx); + ctx->ex_insn_off = ctx->ninsns; + emit_stx_insn(RV_REG_T1, 0, rs, size, ctx); + ctx->ex_jmp_off = ctx->ninsns; } static int emit_atomic_ld_st(u8 rd, u8 rs, const struct bpf_insn *insn, @@ -617,20 +571,12 @@ static int emit_atomic_ld_st(u8 rd, u8 rs, const struct bpf_insn *insn, switch (imm) { /* dst_reg = load_acquire(src_reg + off16) */ case BPF_LOAD_ACQ: - switch (BPF_SIZE(code)) { - case BPF_B: - emit_load_8(false, rd, off, rs, ctx); - break; - case BPF_H: - emit_load_16(false, rd, off, rs, ctx); - break; - case BPF_W: - emit_load_32(false, rd, off, rs, ctx); - break; - case BPF_DW: - emit_load_64(false, rd, off, rs, ctx); - break; + if (BPF_MODE(code) == BPF_PROBE_ATOMIC) { + emit_add(RV_REG_T2, rs, RV_REG_ARENA, ctx); + rs = RV_REG_T2; } + + emit_ldx(rd, off, rs, BPF_SIZE(code), false, ctx); emit_fence_r_rw(ctx); /* If our next insn is a redundant zext, return 1 to tell @@ -641,21 +587,13 @@ static int emit_atomic_ld_st(u8 rd, u8 rs, const struct bpf_insn *insn, break; /* store_release(dst_reg + off16, src_reg) */ case BPF_STORE_REL: - emit_fence_rw_w(ctx); - switch (BPF_SIZE(code)) { - case BPF_B: - emit_store_8(rd, off, rs, ctx); - break; - case BPF_H: - emit_store_16(rd, off, rs, ctx); - break; - case BPF_W: - emit_store_32(rd, off, rs, ctx); - break; - case BPF_DW: - emit_store_64(rd, off, rs, ctx); - break; + if (BPF_MODE(code) == BPF_PROBE_ATOMIC) { + emit_add(RV_REG_T2, rd, RV_REG_ARENA, ctx); + rd = RV_REG_T2; } + + emit_fence_rw_w(ctx); + emit_stx(rd, off, rs, BPF_SIZE(code), ctx); break; default: pr_err_once("bpf-jit: invalid atomic load/store opcode %02x\n", imm); @@ -668,17 +606,15 @@ static int emit_atomic_ld_st(u8 rd, u8 rs, const struct bpf_insn *insn, static int emit_atomic_rmw(u8 rd, u8 rs, const struct bpf_insn *insn, struct rv_jit_context *ctx) { - u8 r0, code = insn->code; + u8 code = insn->code; s16 off = insn->off; s32 imm = insn->imm; - int jmp_offset; - bool is64; + bool is64 = BPF_SIZE(code) == BPF_DW; if (BPF_SIZE(code) != BPF_W && BPF_SIZE(code) != BPF_DW) { pr_err_once("bpf-jit: 1- and 2-byte RMW atomics are not supported\n"); return -EINVAL; } - is64 = BPF_SIZE(code) == BPF_DW; if (off) { if (is_12b_int(off)) { @@ -690,72 +626,82 @@ static int emit_atomic_rmw(u8 rd, u8 rs, const struct bpf_insn *insn, rd = RV_REG_T1; } + if (BPF_MODE(code) == BPF_PROBE_ATOMIC) { + emit_add(RV_REG_T1, rd, RV_REG_ARENA, ctx); + rd = RV_REG_T1; + } + switch (imm) { /* lock *(u32/u64 *)(dst_reg + off16) = src_reg */ case BPF_ADD: + ctx->ex_insn_off = ctx->ninsns; emit(is64 ? rv_amoadd_d(RV_REG_ZERO, rs, rd, 0, 0) : rv_amoadd_w(RV_REG_ZERO, rs, rd, 0, 0), ctx); + ctx->ex_jmp_off = ctx->ninsns; break; case BPF_AND: + ctx->ex_insn_off = ctx->ninsns; emit(is64 ? rv_amoand_d(RV_REG_ZERO, rs, rd, 0, 0) : rv_amoand_w(RV_REG_ZERO, rs, rd, 0, 0), ctx); + ctx->ex_jmp_off = ctx->ninsns; break; case BPF_OR: + ctx->ex_insn_off = ctx->ninsns; emit(is64 ? rv_amoor_d(RV_REG_ZERO, rs, rd, 0, 0) : rv_amoor_w(RV_REG_ZERO, rs, rd, 0, 0), ctx); + ctx->ex_jmp_off = ctx->ninsns; break; case BPF_XOR: + ctx->ex_insn_off = ctx->ninsns; emit(is64 ? rv_amoxor_d(RV_REG_ZERO, rs, rd, 0, 0) : rv_amoxor_w(RV_REG_ZERO, rs, rd, 0, 0), ctx); + ctx->ex_jmp_off = ctx->ninsns; break; /* src_reg = atomic_fetch_(dst_reg + off16, src_reg) */ case BPF_ADD | BPF_FETCH: + ctx->ex_insn_off = ctx->ninsns; emit(is64 ? rv_amoadd_d(rs, rs, rd, 1, 1) : rv_amoadd_w(rs, rs, rd, 1, 1), ctx); + ctx->ex_jmp_off = ctx->ninsns; if (!is64) emit_zextw(rs, rs, ctx); break; case BPF_AND | BPF_FETCH: + ctx->ex_insn_off = ctx->ninsns; emit(is64 ? rv_amoand_d(rs, rs, rd, 1, 1) : rv_amoand_w(rs, rs, rd, 1, 1), ctx); + ctx->ex_jmp_off = ctx->ninsns; if (!is64) emit_zextw(rs, rs, ctx); break; case BPF_OR | BPF_FETCH: + ctx->ex_insn_off = ctx->ninsns; emit(is64 ? rv_amoor_d(rs, rs, rd, 1, 1) : rv_amoor_w(rs, rs, rd, 1, 1), ctx); + ctx->ex_jmp_off = ctx->ninsns; if (!is64) emit_zextw(rs, rs, ctx); break; case BPF_XOR | BPF_FETCH: + ctx->ex_insn_off = ctx->ninsns; emit(is64 ? rv_amoxor_d(rs, rs, rd, 1, 1) : rv_amoxor_w(rs, rs, rd, 1, 1), ctx); + ctx->ex_jmp_off = ctx->ninsns; if (!is64) emit_zextw(rs, rs, ctx); break; /* src_reg = atomic_xchg(dst_reg + off16, src_reg); */ case BPF_XCHG: + ctx->ex_insn_off = ctx->ninsns; emit(is64 ? rv_amoswap_d(rs, rs, rd, 1, 1) : rv_amoswap_w(rs, rs, rd, 1, 1), ctx); + ctx->ex_jmp_off = ctx->ninsns; if (!is64) emit_zextw(rs, rs, ctx); break; /* r0 = atomic_cmpxchg(dst_reg + off16, r0, src_reg); */ case BPF_CMPXCHG: - r0 = bpf_to_rv_reg(BPF_REG_0, ctx); - if (is64) - emit_mv(RV_REG_T2, r0, ctx); - else - emit_addiw(RV_REG_T2, r0, 0, ctx); - emit(is64 ? rv_lr_d(r0, 0, rd, 0, 0) : - rv_lr_w(r0, 0, rd, 0, 0), ctx); - jmp_offset = ninsns_rvoff(8); - emit(rv_bne(RV_REG_T2, r0, jmp_offset >> 1), ctx); - emit(is64 ? rv_sc_d(RV_REG_T3, rs, rd, 0, 1) : - rv_sc_w(RV_REG_T3, rs, rd, 0, 1), ctx); - jmp_offset = ninsns_rvoff(-6); - emit(rv_bne(RV_REG_T3, 0, jmp_offset >> 1), ctx); - emit_fence_rw_rw(ctx); + emit_cmpxchg(rd, rs, regmap[BPF_REG_0], is64, ctx); break; default: pr_err_once("bpf-jit: invalid atomic RMW opcode %02x\n", imm); @@ -783,9 +729,8 @@ bool ex_handler_bpf(const struct exception_table_entry *ex, } /* For accesses to BTF pointers, add an entry to the exception table */ -static int add_exception_handler(const struct bpf_insn *insn, - struct rv_jit_context *ctx, - int dst_reg, int insn_len) +static int add_exception_handler(const struct bpf_insn *insn, int dst_reg, + struct rv_jit_context *ctx) { struct exception_table_entry *ex; unsigned long pc; @@ -793,21 +738,23 @@ static int add_exception_handler(const struct bpf_insn *insn, off_t fixup_offset; if (!ctx->insns || !ctx->ro_insns || !ctx->prog->aux->extable || - (BPF_MODE(insn->code) != BPF_PROBE_MEM && BPF_MODE(insn->code) != BPF_PROBE_MEMSX && - BPF_MODE(insn->code) != BPF_PROBE_MEM32)) + ctx->ex_insn_off <= 0 || ctx->ex_jmp_off <= 0) return 0; - if (WARN_ON_ONCE(ctx->nexentries >= ctx->prog->aux->num_exentries)) - return -EINVAL; + if (BPF_MODE(insn->code) != BPF_PROBE_MEM && + BPF_MODE(insn->code) != BPF_PROBE_MEMSX && + BPF_MODE(insn->code) != BPF_PROBE_MEM32 && + BPF_MODE(insn->code) != BPF_PROBE_ATOMIC) + return 0; - if (WARN_ON_ONCE(insn_len > ctx->ninsns)) + if (WARN_ON_ONCE(ctx->nexentries >= ctx->prog->aux->num_exentries)) return -EINVAL; - if (WARN_ON_ONCE(!rvc_enabled() && insn_len == 1)) + if (WARN_ON_ONCE(ctx->ex_insn_off > ctx->ninsns || ctx->ex_jmp_off > ctx->ninsns)) return -EINVAL; ex = &ctx->prog->aux->extable[ctx->nexentries]; - pc = (unsigned long)&ctx->ro_insns[ctx->ninsns - insn_len]; + pc = (unsigned long)&ctx->ro_insns[ctx->ex_insn_off]; /* * This is the relative offset of the instruction that may fault from @@ -831,7 +778,7 @@ static int add_exception_handler(const struct bpf_insn *insn, * that may fault. The execution will jump to this after handling the * fault. */ - fixup_offset = (long)&ex->fixup - (pc + insn_len * sizeof(u16)); + fixup_offset = (long)&ex->fixup - (long)&ctx->ro_insns[ctx->ex_jmp_off]; if (!FIELD_FIT(BPF_FIXUP_OFFSET_MASK, fixup_offset)) return -ERANGE; @@ -848,6 +795,8 @@ static int add_exception_handler(const struct bpf_insn *insn, FIELD_PREP(BPF_FIXUP_REG_MASK, dst_reg); ex->type = EX_TYPE_BPF; + ctx->ex_insn_off = 0; + ctx->ex_jmp_off = 0; ctx->nexentries++; return 0; } @@ -1857,7 +1806,6 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx, case BPF_LDX | BPF_PROBE_MEM32 | BPF_DW: { bool sign_ext; - int insn_len; sign_ext = BPF_MODE(insn->code) == BPF_MEMSX || BPF_MODE(insn->code) == BPF_PROBE_MEMSX; @@ -1867,22 +1815,9 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx, rs = RV_REG_T2; } - switch (BPF_SIZE(code)) { - case BPF_B: - insn_len = emit_load_8(sign_ext, rd, off, rs, ctx); - break; - case BPF_H: - insn_len = emit_load_16(sign_ext, rd, off, rs, ctx); - break; - case BPF_W: - insn_len = emit_load_32(sign_ext, rd, off, rs, ctx); - break; - case BPF_DW: - insn_len = emit_load_64(sign_ext, rd, off, rs, ctx); - break; - } + emit_ldx(rd, off, rs, BPF_SIZE(code), sign_ext, ctx); - ret = add_exception_handler(insn, ctx, rd, insn_len); + ret = add_exception_handler(insn, rd, ctx); if (ret) return ret; @@ -1890,238 +1825,73 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx, return 1; break; } + /* speculation barrier */ case BPF_ST | BPF_NOSPEC: break; /* ST: *(size *)(dst + off) = imm */ case BPF_ST | BPF_MEM | BPF_B: - emit_imm(RV_REG_T1, imm, ctx); - if (is_12b_int(off)) { - emit(rv_sb(rd, off, RV_REG_T1), ctx); - break; - } - - emit_imm(RV_REG_T2, off, ctx); - emit_add(RV_REG_T2, RV_REG_T2, rd, ctx); - emit(rv_sb(RV_REG_T2, 0, RV_REG_T1), ctx); - break; - case BPF_ST | BPF_MEM | BPF_H: - emit_imm(RV_REG_T1, imm, ctx); - if (is_12b_int(off)) { - emit(rv_sh(rd, off, RV_REG_T1), ctx); - break; - } - - emit_imm(RV_REG_T2, off, ctx); - emit_add(RV_REG_T2, RV_REG_T2, rd, ctx); - emit(rv_sh(RV_REG_T2, 0, RV_REG_T1), ctx); - break; case BPF_ST | BPF_MEM | BPF_W: - emit_imm(RV_REG_T1, imm, ctx); - if (is_12b_int(off)) { - emit_sw(rd, off, RV_REG_T1, ctx); - break; - } - - emit_imm(RV_REG_T2, off, ctx); - emit_add(RV_REG_T2, RV_REG_T2, rd, ctx); - emit_sw(RV_REG_T2, 0, RV_REG_T1, ctx); - break; case BPF_ST | BPF_MEM | BPF_DW: - emit_imm(RV_REG_T1, imm, ctx); - if (is_12b_int(off)) { - emit_sd(rd, off, RV_REG_T1, ctx); - break; - } - - emit_imm(RV_REG_T2, off, ctx); - emit_add(RV_REG_T2, RV_REG_T2, rd, ctx); - emit_sd(RV_REG_T2, 0, RV_REG_T1, ctx); - break; - + /* ST | PROBE_MEM32: *(size *)(dst + RV_REG_ARENA + off) = imm */ case BPF_ST | BPF_PROBE_MEM32 | BPF_B: case BPF_ST | BPF_PROBE_MEM32 | BPF_H: case BPF_ST | BPF_PROBE_MEM32 | BPF_W: case BPF_ST | BPF_PROBE_MEM32 | BPF_DW: - { - int insn_len, insns_start; - - emit_add(RV_REG_T3, rd, RV_REG_ARENA, ctx); - rd = RV_REG_T3; - - /* Load imm to a register then store it */ - emit_imm(RV_REG_T1, imm, ctx); - - switch (BPF_SIZE(code)) { - case BPF_B: - if (is_12b_int(off)) { - insns_start = ctx->ninsns; - emit(rv_sb(rd, off, RV_REG_T1), ctx); - insn_len = ctx->ninsns - insns_start; - break; - } - - emit_imm(RV_REG_T2, off, ctx); - emit_add(RV_REG_T2, RV_REG_T2, rd, ctx); - insns_start = ctx->ninsns; - emit(rv_sb(RV_REG_T2, 0, RV_REG_T1), ctx); - insn_len = ctx->ninsns - insns_start; - break; - case BPF_H: - if (is_12b_int(off)) { - insns_start = ctx->ninsns; - emit(rv_sh(rd, off, RV_REG_T1), ctx); - insn_len = ctx->ninsns - insns_start; - break; - } - - emit_imm(RV_REG_T2, off, ctx); - emit_add(RV_REG_T2, RV_REG_T2, rd, ctx); - insns_start = ctx->ninsns; - emit(rv_sh(RV_REG_T2, 0, RV_REG_T1), ctx); - insn_len = ctx->ninsns - insns_start; - break; - case BPF_W: - if (is_12b_int(off)) { - insns_start = ctx->ninsns; - emit_sw(rd, off, RV_REG_T1, ctx); - insn_len = ctx->ninsns - insns_start; - break; - } - - emit_imm(RV_REG_T2, off, ctx); - emit_add(RV_REG_T2, RV_REG_T2, rd, ctx); - insns_start = ctx->ninsns; - emit_sw(RV_REG_T2, 0, RV_REG_T1, ctx); - insn_len = ctx->ninsns - insns_start; - break; - case BPF_DW: - if (is_12b_int(off)) { - insns_start = ctx->ninsns; - emit_sd(rd, off, RV_REG_T1, ctx); - insn_len = ctx->ninsns - insns_start; - break; - } - - emit_imm(RV_REG_T2, off, ctx); - emit_add(RV_REG_T2, RV_REG_T2, rd, ctx); - insns_start = ctx->ninsns; - emit_sd(RV_REG_T2, 0, RV_REG_T1, ctx); - insn_len = ctx->ninsns - insns_start; - break; + if (BPF_MODE(insn->code) == BPF_PROBE_MEM32) { + emit_add(RV_REG_T3, rd, RV_REG_ARENA, ctx); + rd = RV_REG_T3; } - ret = add_exception_handler(insn, ctx, REG_DONT_CLEAR_MARKER, - insn_len); + emit_st(rd, off, imm, BPF_SIZE(code), ctx); + + ret = add_exception_handler(insn, REG_DONT_CLEAR_MARKER, ctx); if (ret) return ret; - break; - } /* STX: *(size *)(dst + off) = src */ case BPF_STX | BPF_MEM | BPF_B: - emit_store_8(rd, off, rs, ctx); - break; case BPF_STX | BPF_MEM | BPF_H: - emit_store_16(rd, off, rs, ctx); - break; case BPF_STX | BPF_MEM | BPF_W: - emit_store_32(rd, off, rs, ctx); - break; case BPF_STX | BPF_MEM | BPF_DW: - emit_store_64(rd, off, rs, ctx); + /* STX | PROBE_MEM32: *(size *)(dst + RV_REG_ARENA + off) = src */ + case BPF_STX | BPF_PROBE_MEM32 | BPF_B: + case BPF_STX | BPF_PROBE_MEM32 | BPF_H: + case BPF_STX | BPF_PROBE_MEM32 | BPF_W: + case BPF_STX | BPF_PROBE_MEM32 | BPF_DW: + if (BPF_MODE(insn->code) == BPF_PROBE_MEM32) { + emit_add(RV_REG_T2, rd, RV_REG_ARENA, ctx); + rd = RV_REG_T2; + } + + emit_stx(rd, off, rs, BPF_SIZE(code), ctx); + + ret = add_exception_handler(insn, REG_DONT_CLEAR_MARKER, ctx); + if (ret) + return ret; break; + + /* Atomics */ case BPF_STX | BPF_ATOMIC | BPF_B: case BPF_STX | BPF_ATOMIC | BPF_H: case BPF_STX | BPF_ATOMIC | BPF_W: case BPF_STX | BPF_ATOMIC | BPF_DW: + case BPF_STX | BPF_PROBE_ATOMIC | BPF_B: + case BPF_STX | BPF_PROBE_ATOMIC | BPF_H: + case BPF_STX | BPF_PROBE_ATOMIC | BPF_W: + case BPF_STX | BPF_PROBE_ATOMIC | BPF_DW: if (bpf_atomic_is_load_store(insn)) ret = emit_atomic_ld_st(rd, rs, insn, ctx); else ret = emit_atomic_rmw(rd, rs, insn, ctx); - if (ret) - return ret; - break; - - case BPF_STX | BPF_PROBE_MEM32 | BPF_B: - case BPF_STX | BPF_PROBE_MEM32 | BPF_H: - case BPF_STX | BPF_PROBE_MEM32 | BPF_W: - case BPF_STX | BPF_PROBE_MEM32 | BPF_DW: - { - int insn_len, insns_start; - - emit_add(RV_REG_T2, rd, RV_REG_ARENA, ctx); - rd = RV_REG_T2; - - switch (BPF_SIZE(code)) { - case BPF_B: - if (is_12b_int(off)) { - insns_start = ctx->ninsns; - emit(rv_sb(rd, off, rs), ctx); - insn_len = ctx->ninsns - insns_start; - break; - } - - emit_imm(RV_REG_T1, off, ctx); - emit_add(RV_REG_T1, RV_REG_T1, rd, ctx); - insns_start = ctx->ninsns; - emit(rv_sb(RV_REG_T1, 0, rs), ctx); - insn_len = ctx->ninsns - insns_start; - break; - case BPF_H: - if (is_12b_int(off)) { - insns_start = ctx->ninsns; - emit(rv_sh(rd, off, rs), ctx); - insn_len = ctx->ninsns - insns_start; - break; - } - - emit_imm(RV_REG_T1, off, ctx); - emit_add(RV_REG_T1, RV_REG_T1, rd, ctx); - insns_start = ctx->ninsns; - emit(rv_sh(RV_REG_T1, 0, rs), ctx); - insn_len = ctx->ninsns - insns_start; - break; - case BPF_W: - if (is_12b_int(off)) { - insns_start = ctx->ninsns; - emit_sw(rd, off, rs, ctx); - insn_len = ctx->ninsns - insns_start; - break; - } - - emit_imm(RV_REG_T1, off, ctx); - emit_add(RV_REG_T1, RV_REG_T1, rd, ctx); - insns_start = ctx->ninsns; - emit_sw(RV_REG_T1, 0, rs, ctx); - insn_len = ctx->ninsns - insns_start; - break; - case BPF_DW: - if (is_12b_int(off)) { - insns_start = ctx->ninsns; - emit_sd(rd, off, rs, ctx); - insn_len = ctx->ninsns - insns_start; - break; - } - emit_imm(RV_REG_T1, off, ctx); - emit_add(RV_REG_T1, RV_REG_T1, rd, ctx); - insns_start = ctx->ninsns; - emit_sd(RV_REG_T1, 0, rs, ctx); - insn_len = ctx->ninsns - insns_start; - break; - } - - ret = add_exception_handler(insn, ctx, REG_DONT_CLEAR_MARKER, - insn_len); + ret = ret ?: add_exception_handler(insn, REG_DONT_CLEAR_MARKER, ctx); if (ret) return ret; - break; - } default: pr_err("bpf-jit: unknown opcode %02x\n", code); @@ -2249,6 +2019,20 @@ bool bpf_jit_supports_arena(void) return true; } +bool bpf_jit_supports_insn(struct bpf_insn *insn, bool in_arena) +{ + if (in_arena) { + switch (insn->code) { + case BPF_STX | BPF_ATOMIC | BPF_W: + case BPF_STX | BPF_ATOMIC | BPF_DW: + if (insn->imm == BPF_CMPXCHG) + return rv_ext_enabled(ZACAS); + } + } + + return true; +} + bool bpf_jit_supports_percpu_insn(void) { return true; diff --git a/tools/testing/selftests/bpf/progs/arena_atomics.c b/tools/testing/selftests/bpf/progs/arena_atomics.c index a52feff981126..d1841aac94a22 100644 --- a/tools/testing/selftests/bpf/progs/arena_atomics.c +++ b/tools/testing/selftests/bpf/progs/arena_atomics.c @@ -28,7 +28,8 @@ bool skip_all_tests = true; #if defined(ENABLE_ATOMICS_TESTS) && \ defined(__BPF_FEATURE_ADDR_SPACE_CAST) && \ - (defined(__TARGET_ARCH_arm64) || defined(__TARGET_ARCH_x86)) + (defined(__TARGET_ARCH_arm64) || defined(__TARGET_ARCH_x86) || \ + (defined(__TARGET_ARCH_riscv) && __riscv_xlen == 64)) bool skip_lacq_srel_tests __attribute((__section__(".data"))) = false; #else bool skip_lacq_srel_tests = true; @@ -314,7 +315,8 @@ int load_acquire(const void *ctx) { #if defined(ENABLE_ATOMICS_TESTS) && \ defined(__BPF_FEATURE_ADDR_SPACE_CAST) && \ - (defined(__TARGET_ARCH_arm64) || defined(__TARGET_ARCH_x86)) + (defined(__TARGET_ARCH_arm64) || defined(__TARGET_ARCH_x86) || \ + (defined(__TARGET_ARCH_riscv) && __riscv_xlen == 64)) #define LOAD_ACQUIRE_ARENA(SIZEOP, SIZE, SRC, DST) \ { asm volatile ( \ @@ -365,7 +367,8 @@ int store_release(const void *ctx) { #if defined(ENABLE_ATOMICS_TESTS) && \ defined(__BPF_FEATURE_ADDR_SPACE_CAST) && \ - (defined(__TARGET_ARCH_arm64) || defined(__TARGET_ARCH_x86)) + (defined(__TARGET_ARCH_arm64) || defined(__TARGET_ARCH_x86) || \ + (defined(__TARGET_ARCH_riscv) && __riscv_xlen == 64)) #define STORE_RELEASE_ARENA(SIZEOP, DST, VAL) \ { asm volatile ( \