Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions natives/share/jdk/internal/misc/Unsafe.c
Original file line number Diff line number Diff line change
Expand Up @@ -172,14 +172,7 @@ DECLARE_NATIVE("jdk/internal/misc", Unsafe, allocateInstance, "(Ljava/lang/Class

DECLARE_NATIVE("jdk/internal/misc", Unsafe, freeMemory0, "(J)V") {
DCHECK(argc == 1);
free((void *)args[0].l);
void **unsafe_allocations = thread->vm->unsafe_allocations;
for (int i = 0; i < arrlen(unsafe_allocations); ++i) {
if (unsafe_allocations[i] == (void *)args[0].l) {
arrdelswap(unsafe_allocations, i);
return value_null();
}
}
free_unsafe_allocation(thread->vm, (void*) args[0].l);
fprintf(stderr, "Attempted to free memory that was not allocated by Unsafe\n");
abort();
}
Expand Down
Binary file added test/test_files/fma/FusedMultiplyAdd.class
Binary file not shown.
20 changes: 20 additions & 0 deletions test/test_files/fma/FusedMultiplyAdd.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@

public class FusedMultiplyAdd {
public static void main(String[] args) {
{
double a = 0x1.0000000001p0;
double b = 0x1.0000000001p0;
double c = -0x1.0000000002p0;
double result = Math.fma(a, b, c);
System.out.println(result);
}

{
float a = 0x1.00001p0f;
float b = 0x1.00001p0f;
float c = -0x1.00002p0f;
float result = Math.fma(a, b, c);
System.out.println(result);
}
}
}
7 changes: 7 additions & 0 deletions test/tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,13 @@ TEST_CASE("Kotlin says hi") {
REQUIRE(result.stdout_ == "First 10 terms: 0, 1, 1, 2, 3, 5, 8, 13, 21, 34, \nHello from Kotlin!");
}

TEST_CASE("Fused multiply-add") {
auto result = run_test_case("test_files/fma/", true, "FusedMultiplyAdd");
REQUIRE(result.stdout_ == R"(8.271806125530277E-25
9.094947E-13
)");
}

#if 0
TEST_CASE("Print useful trampolines") { print_method_sigs(); }
#endif
23 changes: 22 additions & 1 deletion vm/bjvm.c
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include <reflection.h>

#include "cached_classdescs.h"
#include "wasm/wasm_utils.h"

#include <errno.h>
#include <linkage.h>
#include <monitors.h>
Expand Down Expand Up @@ -658,11 +660,27 @@ vm *create_vm(const vm_options options) {
native_ptr->callback);
}

#ifdef EMSCRIPTEN
wasm_init_fma_handles();
#endif

register_native_padding(vm);

return vm;
}

bool free_unsafe_allocation(vm *vm, void *allocation) {
void **unsafe_allocations = vm->unsafe_allocations;
for (int i = 0; i < arrlen(unsafe_allocations); ++i) {
if (unsafe_allocations[i] == allocation) {
free(allocation);
arrdelswap(unsafe_allocations, i);
return true;
}
}
return false;
}

void free_unsafe_allocations(vm *vm) {
for (int i = 0; i < arrlen(vm->unsafe_allocations); ++i) {
free(vm->unsafe_allocations[i]);
Expand Down Expand Up @@ -2605,11 +2623,14 @@ DEFINE_ASYNC(run_native) {
}

self->native_struct = malloc(hand->async_ctx_bytes);
arrput(thread->vm->unsafe_allocations, self->native_struct);

*self->native_struct = (async_natives_args){{thread, target_handle, native_args, argc}, 0};
AWAIT_FUTURE_EXPR(((native_callback *)frame->method->native_handle)->async(self->native_struct));

// We've laid out the context struct so that the result is always at offset 0
stack_value result = ((async_natives_args *)self->native_struct)->result;
free(self->native_struct);
free_unsafe_allocation(thread->vm, self->native_struct);

ASYNC_END(result);

Expand Down
5 changes: 2 additions & 3 deletions vm/bjvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ typedef struct vm {

s64 next_thread_id; // MUST BE 64 BITS

// Vector of allocations done via Unsafe.allocateMemory0, to be freed in case
// Vector of allocations done via Unsafe.allocateMemory0 or for other purposes in the VM, to be freed in case
// the finalizers aren't run
void **unsafe_allocations;

Expand Down Expand Up @@ -614,10 +614,9 @@ struct native_MethodType *resolve_method_type(vm_thread *thread, method_descript
* check.
*/
void pop_frame(vm_thread *thr, [[maybe_unused]] const stack_frame *reference);

vm_options default_vm_options();

vm *create_vm(vm_options options);
bool free_unsafe_allocation(vm *vm, void *allocation);

typedef struct {
u32 stack_space;
Expand Down
3 changes: 2 additions & 1 deletion vm/classfile.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ typedef enum : u8 {
insn_putstatic_L,

/** intrinsics understood by the interpreter */
insn_sqrt
insn_fma, // fused multiply add (float or double)
insn_sqrt // square root (float or double)
} insn_code_kind;

#define MAX_INSN_KIND (insn_sqrt + 1)
Expand Down
31 changes: 29 additions & 2 deletions vm/interpreter2.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
#include <analysis.h>
#include <debugger.h>
#include <math.h>
#include <tgmath.h>

#include <exceptions.h>
#include <linkage.h>
Expand Down Expand Up @@ -1726,6 +1725,10 @@ static int intrinsify(bytecode_insn *inst) {
inst->kind = insn_sqrt;
return 1;
}
if (utf8_equals(method->name, "fma")) {
inst->kind = insn_fma;
return 1;
}
}
return 0;
}
Expand Down Expand Up @@ -2691,6 +2694,28 @@ static s64 sqrt_impl_float(ARGS_FLOAT) {
NEXT_FLOAT(sqrt(tos))
}

static s64 fma_impl_double(ARGS_DOUBLE) {
DEBUG_CHECK();
#ifdef EMSCRIPTEN
double result = wasm_fma_impl((sp - 3)->d, (sp - 2)->d, tos);
#else
double result = fma((sp - 3)->d, (sp - 2)->d, tos);
#endif
sp -= 2;
NEXT_DOUBLE(result)
}

static s64 fma_impl_float(ARGS_FLOAT) {
DEBUG_CHECK();
#ifdef EMSCRIPTEN
float result = wasm_fmaf_impl((sp - 3)->f, (sp - 2)->f, tos);
#else
float result = fmaf((sp - 3)->f, (sp - 2)->f, tos);
#endif
sp -= 2;
NEXT_FLOAT(result)
}

static s64 frem_impl_float(ARGS_FLOAT) {
DEBUG_CHECK();
float a = (sp - 2)->f, b = tos;
Expand Down Expand Up @@ -3093,7 +3118,7 @@ stack_value interpret_2(future_t *fut, vm_thread *thread, stack_frame *frame) {
return result;
}

/** Jump table definitions. Must be kept in sync with the enum order. */
/** Jump table definitions. */

#define PAGE_ALIGN _Alignas(4096)

Expand Down Expand Up @@ -3212,6 +3237,7 @@ PAGE_ALIGN static s64 (*jmp_table_double[MAX_INSN_KIND])(ARGS_VOID) = {
[insn_getstatic_L] = getstatic_L_impl_double,
[insn_putstatic_D] = putstatic_D_impl_double,
[insn_drem] = drem_impl_double,
[insn_fma] = fma_impl_double,
[insn_sqrt] = sqrt_impl_double};

PAGE_ALIGN static s64 (*jmp_table_int[MAX_INSN_KIND])(ARGS_VOID) = {
Expand Down Expand Up @@ -3446,4 +3472,5 @@ PAGE_ALIGN static s64 (*jmp_table_float[MAX_INSN_KIND])(ARGS_VOID) = {
[insn_getstatic_L] = getstatic_L_impl_float,
[insn_putstatic_F] = putstatic_F_impl_float,
[insn_frem] = frem_impl_float,
[insn_fma] = fma_impl_float,
[insn_sqrt] = sqrt_impl_float};
1 change: 1 addition & 0 deletions vm/pretty_print.c
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ const char *insn_code_to_string(insn_code_kind code) {
CASE(putstatic_L)
CASE(putstatic_Z)
CASE(invokesigpoly)
CASE(fma)
CASE(sqrt)
}
printf("Unknown code: %d\n", code);
Expand Down
80 changes: 80 additions & 0 deletions vm/wasm/fma_support.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Dynamically link in faster implementations of fma and fmaf if supported by the runtime.

#ifdef EMSCRIPTEN

#include "wasm_utils.h"
#include "math.h"
#include <emscripten.h>

/**
Bytes below are:
(module
(func (export "fma") (param f64 f64 f64) (result f64)
local.get 0
f64x2.splat
local.get 1
f64x2.splat
local.get 2
f64x2.splat
f64x2.relaxed_madd
f64x2.extract_lane 0)
(func (export "fmaf") (param f32 f32 f32) (result f32)
local.get 0
f32x4.splat
local.get 1
f32x4.splat
local.get 2
f32x4.splat
f32x4.relaxed_madd
f32x4.extract_lane 0))
*/

wasm_fmaf_handle wasm_fmaf_impl;
wasm_fma_handle wasm_fma_impl;

void wasm_init_fma_handles() {
if (wasm_fma_impl != nullptr)
return;

int failed = EM_ASM_INT({
try {
const module_wasm = new Uint8Array([
0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, 0x01, 0x0f, 0x02, 0x60,
0x03, 0x7c, 0x7c, 0x7c, 0x01, 0x7c, 0x60, 0x03, 0x7d, 0x7d, 0x7d, 0x01,
0x7d, 0x03, 0x03, 0x02, 0x00, 0x01, 0x07, 0x0e, 0x02, 0x03, 0x66, 0x6d,
0x61, 0x00, 0x00, 0x04, 0x66, 0x6d, 0x61, 0x66, 0x00, 0x01, 0x0a, 0x2b,
0x02, 0x14, 0x00, 0x20, 0x00, 0xfd, 0x14, 0x20, 0x01, 0xfd, 0x14, 0x20,
0x02, 0xfd, 0x14, 0xfd, 0x87, 0x02, 0xfd, 0x21, 0x00, 0x0b, 0x14, 0x00,
0x20, 0x00, 0xfd, 0x13, 0x20, 0x01, 0xfd, 0x13, 0x20, 0x02, 0xfd, 0x13,
0xfd, 0x85, 0x02, 0xfd, 0x1f, 0x00, 0x0b, 0x00, 0x0c, 0x04, 0x6e, 0x61,
0x6d, 0x65, 0x02, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00
]);
const module = new WebAssembly.Module(module_wasm);
const instance = new WebAssembly.Instance(module);
const fma = addFunction(instance.exports.fma, 'dddd');
const fmaf = addFunction(instance.exports.fmaf, 'ffff');
HEAP32[$0 >> 2] = fma;
HEAP32[$1 >> 2] = fmaf;
} catch (e) {
return 1;
}
}, &wasm_fma_impl, &wasm_fmaf_impl);

if (!failed) { // Now check that we're on a system which actually supports FMA
DCHECK(wasm_fma_impl)
DCHECK(wasm_fmaf_impl)

double a = 0x1.0000000001p0;
double b = 0x1.0000000001p0;
double c = -0x1.0000000002p0;
if (wasm_fma_impl(a, b, c) != 0x1p-80) { // sad
failed = 1;
}
}

if (failed) { // fall back to software implementation
wasm_fma_impl = &fma;
wasm_fmaf_impl = &fmaf;
}
}
#endif
12 changes: 12 additions & 0 deletions vm/wasm/wasm_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,18 @@ typedef struct {
void free_wasm_instantiation_result(wasm_instantiation_result *result);
wasm_instantiation_result *wasm_instantiate_module(wasm_module *module, const char *debug_name);

/** fma_support.c, WASM only */

#ifdef EMSCRIPTEN
typedef float (*wasm_fmaf_handle)(float a, float b, float c);
typedef double (*wasm_fma_handle)(double a, double b, double c);

extern wasm_fmaf_handle wasm_fmaf_impl;
extern wasm_fma_handle wasm_fma_impl;

void wasm_init_fma_handles();
#endif

#ifdef __cplusplus
}
#endif
Expand Down