diff --git a/src/Stack.zig b/src/Stack.zig index 3529007..68bdb72 100644 --- a/src/Stack.zig +++ b/src/Stack.zig @@ -1,3 +1,4 @@ +const builtin = @import("builtin"); const std = @import("std"); const assert = std.debug.assert; @@ -14,13 +15,34 @@ const f32x4 = def.f32x4; const f64x2 = def.f64x2; const v128 = def.v128; const Instruction = def.Instruction; -const Val = def.Val; const ValType = def.ValType; +const FuncRef = def.FuncRef; +const ExternRef = def.ExternRef; const inst = @import("instance.zig"); const ModuleInstance = inst.ModuleInstance; const TrapError = inst.TrapError; +// V128 values are packed into 2 separate StackVals. This helps reduce wasted memory due to +// alignment requirements since most values are 4 or 8 bytes. +// This is an extern union to avoid the debug error checking Zig adds to unions by default - +// bytebox is aware of the bitwise contents of the stack and often uses types interchangably +const StackVal = extern union { + I32: i32, + I64: i64, + F32: f32, + F64: f64, + FuncRef: FuncRef, + ExternRef: ExternRef, + + comptime { + if (builtin.mode == .ReleaseFast) { + std.debug.assert(@sizeOf(StackVal) == 8); + } + } +}; + +// The number of locals/params/returns in this struct counts V128 as 2 values pub const FunctionInstance = struct { type_def_index: usize, def_index: usize, @@ -57,9 +79,10 @@ pub const FuncCallData = struct { const Stack = @This(); -values: []Val, +values: []StackVal, labels: []Label, frames: []CallFrame, +locals: []StackVal, // references values num_values: u32, num_labels: u16, num_frames: u16, @@ -74,9 +97,10 @@ const AllocOpts = struct { pub fn init(allocator: std.mem.Allocator) Stack { const stack = Stack{ - .values = &[_]Val{}, + .values = &[_]StackVal{}, .labels = &[_]Label{}, .frames = &[_]CallFrame{}, + .locals = &.{}, .num_values = 0, .num_labels = 0, .num_frames = 0, @@ -94,8 +118,8 @@ pub fn deinit(stack: *Stack) void { } pub fn allocMemory(stack: *Stack, opts: AllocOpts) !void { - const alignment = @max(@alignOf(Val), @alignOf(Label), @alignOf(CallFrame)); - const values_alloc_size = std.mem.alignForward(usize, @as(usize, @intCast(opts.max_values)) * @sizeOf(Val), alignment); + const alignment = @max(@alignOf(StackVal), @alignOf(Label), @alignOf(CallFrame)); + const values_alloc_size = std.mem.alignForward(usize, @as(usize, @intCast(opts.max_values)) * @sizeOf(StackVal), alignment); const labels_alloc_size = std.mem.alignForward(usize, @as(usize, @intCast(opts.max_labels)) * @sizeOf(Label), alignment); const frames_alloc_size = std.mem.alignForward(usize, @as(usize, @intCast(opts.max_frames)) * @sizeOf(CallFrame), alignment); const total_alloc_size: usize = values_alloc_size + labels_alloc_size + frames_alloc_size; @@ -104,7 +128,7 @@ pub fn allocMemory(stack: *Stack, opts: AllocOpts) !void { const begin_frames = values_alloc_size + labels_alloc_size; stack.mem = try stack.allocator.alloc(u8, total_alloc_size); - stack.values.ptr = @as([*]Val, @alignCast(@ptrCast(stack.mem.ptr))); + stack.values.ptr = @as([*]StackVal, @alignCast(@ptrCast(stack.mem.ptr))); stack.values.len = opts.max_values; stack.labels.ptr = @as([*]Label, @alignCast(@ptrCast(stack.mem[begin_labels..].ptr))); stack.labels.len = opts.max_labels; @@ -112,44 +136,41 @@ pub fn allocMemory(stack: *Stack, opts: AllocOpts) !void { stack.frames.len = opts.max_frames; } -pub fn pushValue(stack: *Stack, value: Val) void { - stack.values[stack.num_values] = value; - stack.num_values += 1; -} - pub fn pushI32(stack: *Stack, v: i32) void { - stack.values[stack.num_values] = Val{ .I32 = v }; + stack.values[stack.num_values] = .{ .I32 = v }; stack.num_values += 1; } pub fn pushI64(stack: *Stack, v: i64) void { - stack.values[stack.num_values] = Val{ .I64 = v }; + stack.values[stack.num_values] = .{ .I64 = v }; stack.num_values += 1; } pub fn pushF32(stack: *Stack, v: f32) void { - stack.values[stack.num_values] = Val{ .F32 = v }; + stack.values[stack.num_values] = .{ .F32 = v }; stack.num_values += 1; } pub fn pushF64(stack: *Stack, v: f64) void { - stack.values[stack.num_values] = Val{ .F64 = v }; + stack.values[stack.num_values] = .{ .F64 = v }; stack.num_values += 1; } -pub fn pushV128(stack: *Stack, v: v128) void { - stack.values[stack.num_values] = Val{ .V128 = v }; +pub fn pushFuncRef(stack: *Stack, v: FuncRef) void { + stack.values[stack.num_values] = .{ .FuncRef = v }; stack.num_values += 1; } -pub fn popValue(stack: *Stack) Val { - stack.num_values -= 1; - const value: Val = stack.values[stack.num_values]; - return value; +pub fn pushExternRef(stack: *Stack, v: ExternRef) void { + stack.values[stack.num_values] = .{ .ExternRef = v }; + stack.num_values += 1; } -pub fn topValue(stack: *const Stack) Val { - return stack.values[stack.num_values - 1]; +pub fn pushV128(stack: *Stack, v: v128) void { + const vec2 = @as(f64x2, @bitCast(v)); + stack.values[stack.num_values + 0].F64 = vec2[0]; + stack.values[stack.num_values + 1].F64 = vec2[1]; + stack.num_values += 2; } pub fn popI32(stack: *Stack) i32 { @@ -172,9 +193,16 @@ pub fn popF64(stack: *Stack) f64 { return stack.values[stack.num_values].F64; } -pub fn popV128(stack: *Stack) v128 { +pub fn popFuncRef(stack: *Stack) FuncRef { stack.num_values -= 1; - return stack.values[stack.num_values].V128; + return stack.values[stack.num_values].FuncRef; +} + +pub fn popV128(stack: *Stack) v128 { + stack.num_values -= 2; + const f0 = stack.values[stack.num_values + 0].F64; + const f1 = stack.values[stack.num_values + 1].F64; + return @bitCast(@as(f64x2, .{ f0, f1 })); } pub fn popIndexType(stack: *Stack, index_type: ValType) i64 { @@ -225,12 +253,12 @@ pub fn popAllUntilLabelId(stack: *Stack, label_id: u64, pop_final_label: bool, n const dest_begin: usize = label.start_offset_values; const dest_end: usize = label.start_offset_values + num_returns; - const returns_source: []const Val = stack.values[source_begin..source_end]; - const returns_dest: []Val = stack.values[dest_begin..dest_end]; + const returns_source: []const StackVal = stack.values[source_begin..source_end]; + const returns_dest: []StackVal = stack.values[dest_begin..dest_end]; if (dest_begin <= source_begin) { - std.mem.copyForwards(Val, returns_dest, returns_source); + std.mem.copyForwards(StackVal, returns_dest, returns_source); } else { - std.mem.copyBackwards(Val, returns_dest, returns_source); + std.mem.copyBackwards(StackVal, returns_dest, returns_source); } stack.num_values = @as(u32, @intCast(dest_end)); @@ -271,6 +299,7 @@ pub fn pushFrame(stack: *Stack, func: *const FunctionInstance, module_instance: @memset(std.mem.sliceAsBytes(func_locals), 0); stack.num_values = values_index_end; + stack.locals = stack.values[values_index_begin..values_index_end]; stack.frames[stack.num_frames] = CallFrame{ .func = func, @@ -283,7 +312,7 @@ pub fn pushFrame(stack: *Stack, func: *const FunctionInstance, module_instance: } pub fn popFrame(stack: *Stack) ?FuncCallData { - const frame: *CallFrame = stack.topFrame(); + var frame: *CallFrame = stack.topFrame(); const continuation: u32 = stack.labels[frame.start_offset_labels].continuation; const num_returns: usize = frame.num_returns; @@ -296,17 +325,20 @@ pub fn popFrame(stack: *Stack) ?FuncCallData { // Because a function's locals take up stack space, the return values are located // after the locals, so we need to copy them back down to the start of the function's // stack space, where the caller expects them to be. - const returns_source: []const Val = stack.values[source_begin..source_end]; - const returns_dest: []Val = stack.values[dest_begin..dest_end]; - std.mem.copyForwards(Val, returns_dest, returns_source); + const returns_source: []const StackVal = stack.values[source_begin..source_end]; + const returns_dest: []StackVal = stack.values[dest_begin..dest_end]; + std.mem.copyForwards(StackVal, returns_dest, returns_source); stack.num_values = @as(u32, @intCast(dest_end)); stack.num_labels = frame.start_offset_labels; stack.num_frames -= 1; if (stack.num_frames > 0) { + frame = stack.topFrame(); + stack.locals = stack.values[frame.start_offset_values .. stack.num_values + frame.func.num_locals]; + return FuncCallData{ - .code = stack.topFrame().func.code, + .code = frame.func.code, .continuation = continuation, }; } @@ -318,21 +350,71 @@ pub fn topFrame(stack: *const Stack) *CallFrame { return &stack.frames[stack.num_frames - 1]; } -pub fn locals(stack: *const Stack) []Val { - const frame = stack.topFrame(); - return stack.values[frame.start_offset_values..]; +pub fn localGet(stack: *Stack, local_index: usize) void { + stack.values[stack.num_values] = stack.locals[local_index]; + stack.num_values += 1; +} + +pub fn localGetV128(stack: *Stack, local_index: usize) void { + stack.values[stack.num_values + 0] = stack.locals[local_index + 0]; + stack.values[stack.num_values + 1] = stack.locals[local_index + 1]; + stack.num_values += 2; +} + +pub fn localSet(stack: *Stack, local_index: usize) void { + stack.num_values -= 1; + stack.locals[local_index] = stack.values[stack.num_values]; +} + +pub fn localSetV128(stack: *Stack, local_index: usize) void { + stack.num_values -= 2; + stack.locals[local_index + 0] = stack.values[stack.num_values + 0]; + stack.locals[local_index + 1] = stack.values[stack.num_values + 1]; +} + +pub fn localTee(stack: *Stack, local_index: usize) void { + stack.locals[local_index] = stack.values[stack.num_values - 1]; +} + +pub fn localTeeV128(stack: *Stack, local_index: usize) void { + stack.locals[local_index + 0] = stack.values[stack.num_values - 2]; + stack.locals[local_index + 1] = stack.values[stack.num_values - 1]; +} + +pub fn select(stack: *Stack) void { + const boolean: i32 = stack.values[stack.num_values - 1].I32; + if (boolean == 0) { + stack.values[stack.num_values - 3] = stack.values[stack.num_values - 2]; + } + stack.num_values -= 2; +} + +pub fn selectV128(stack: *Stack) void { + const boolean: i32 = stack.values[stack.num_values - 1].I32; + if (boolean == 0) { + stack.values[stack.num_values - 5] = stack.values[stack.num_values - 3]; + stack.values[stack.num_values - 4] = stack.values[stack.num_values - 2]; + } + stack.num_values -= 3; } pub fn popAll(stack: *Stack) void { stack.num_values = 0; stack.num_labels = 0; stack.num_frames = 0; + stack.locals = &.{}; } pub fn debugDump(stack: Stack) void { std.debug.print("===== stack dump =====\n", .{}); for (stack.values[0..stack.num_values]) |val| { - std.debug.print("I32: {}, I64: {}, F32: {}, F64: {}\n", .{ val.I32, val.I64, val.F32, val.F64 }); + std.debug.print("I32: {}, I64: {}, F32: {}, F64: {}, FuncRef: {}\n", .{ + val.I32, + val.I64, + val.F32, + val.F64, + val.FuncRef.func, + }); } std.debug.print("======================\n", .{}); } diff --git a/src/core.zig b/src/core.zig index 517a86d..ffefd74 100644 --- a/src/core.zig +++ b/src/core.zig @@ -36,6 +36,7 @@ pub const ModuleDefinitionOpts = def.ModuleDefinitionOpts; pub const TaggedVal = def.TaggedVal; pub const Val = def.Val; pub const ValType = def.ValType; +pub const ExternRef = def.ExternRef; pub const UnlinkableError = inst.UnlinkableError; pub const UninstantiableError = inst.UninstantiableError; diff --git a/src/definition.zig b/src/definition.zig index d8e01a8..ba33aaf 100644 --- a/src/definition.zig +++ b/src/definition.zig @@ -77,6 +77,8 @@ pub const ValidationError = error{ ValidationFuncRefUndeclared, ValidationIfElseMismatch, ValidationInvalidLaneIndex, + ValidationTooManyFunctionImportParams, + ValidationTooManyFunctionImportReturns, }; pub const i8x16 = @Vector(16, i8); @@ -179,9 +181,12 @@ pub const FuncRef = extern union { comptime { std.debug.assert(@sizeOf(?*const anyopaque) == @sizeOf(usize)); std.debug.assert(@sizeOf(FuncRef) == @sizeOf(usize)); + std.debug.assert(@sizeOf(FuncRef) == @sizeOf(ExternRef)); } }; +pub const ExternRef = usize; + pub const Val = extern union { I32: i32, I64: i64, @@ -189,7 +194,7 @@ pub const Val = extern union { F64: f64, V128: v128, FuncRef: FuncRef, - ExternRef: u64, + ExternRef: ExternRef, pub fn default(valtype: ValType) Val { return switch (valtype) { @@ -386,7 +391,7 @@ pub const BlockTypeValue = extern union { } } - fn getBlocktypeReturnTypes(value: BlockTypeValue, block_type: BlockType, module_def: *const ModuleDefinition) []const ValType { + pub fn getBlocktypeReturnTypes(value: BlockTypeValue, block_type: BlockType, module_def: *const ModuleDefinition) []const ValType { switch (block_type) { .Void => return &.{}, .ValType => return switch (value.ValType) { @@ -481,19 +486,16 @@ pub const ConstantExpression = union(ConstantExpressionType) { }; pub const FunctionTypeDefinition = struct { - types: std.ArrayList(ValType), + types: std.ArrayList(ValType), // TODO replace this with offsets into a single array in the ModuleDefinition num_params: u32, pub fn getParams(self: *const FunctionTypeDefinition) []const ValType { return self.types.items[0..self.num_params]; } + pub fn getReturns(self: *const FunctionTypeDefinition) []const ValType { return self.types.items[self.num_params..]; } - pub fn calcNumReturns(self: *const FunctionTypeDefinition) u32 { - const total: u32 = @as(u32, @intCast(self.types.items.len)); - return total - self.num_params; - } pub const SortContext = struct { const Self = @This(); @@ -698,6 +700,9 @@ pub const DataDefinition = struct { } }; +pub const MAX_FUNCTION_IMPORT_PARAMS = 256; +pub const MAX_FUNCTION_IMPORT_RETURNS = 256; + pub const ImportNames = struct { module_name: []const u8, import_name: []const u8, @@ -806,7 +811,7 @@ pub const InstructionImmediates = union { } }; -const ValidationImmediates = union { +pub const ValidationImmediates = union { Void: void, BlockOrIf: struct { block_type: BlockType, @@ -829,7 +834,7 @@ pub const Instruction = struct { } } - fn decode(reader: anytype, module: *ModuleDefinition) !DecodedInstruction { + fn decode(reader: anytype, module: *ModuleDefinition, func: *FunctionDefinition) !DecodedInstruction { const Helpers = struct { fn decodeBlockType( _reader: anytype, @@ -909,7 +914,8 @@ pub const Instruction = struct { const wasm_op: WasmOpcode = try WasmOpcode.decode(reader); - const opcode: Opcode = wasm_op.toOpcode(); + // note that this opcode can be remapped as we get more information about the instruction + var opcode: Opcode = wasm_op.toOpcode(); var immediate = InstructionImmediates{ .Void = {} }; var validation_immediates = ValidationImmediates{ .Void = {} }; @@ -921,20 +927,45 @@ pub const Instruction = struct { } immediate = InstructionImmediates{ .ValType = try ValType.decode(reader) }; }, - .Local_Get => { - immediate = InstructionImmediates{ .Index = try common.decodeLEB128(u32, reader) }; - }, - .Local_Set => { - immediate = InstructionImmediates{ .Index = try common.decodeLEB128(u32, reader) }; - }, - .Local_Tee => { - immediate = InstructionImmediates{ .Index = try common.decodeLEB128(u32, reader) }; - }, - .Global_Get => { - immediate = InstructionImmediates{ .Index = try common.decodeLEB128(u32, reader) }; + .Local_Get, .Local_Set, .Local_Tee => { + const index = try common.decodeLEB128(u32, reader); + immediate = InstructionImmediates{ .Index = index }; + + const type_def: *const FunctionTypeDefinition = func.typeDefinition(module); + const params = type_def.getParams(); + + // note we don't do validation here, we'll do that after decode + var is_local_v128: bool = false; + if (index < params.len) { + is_local_v128 = params[index] == .V128; + } else { + const locals = func.locals(module); + const func_locals_index = index - params.len; + if (func_locals_index < locals.len) { + is_local_v128 = locals[func_locals_index] == .V128; + } + } + + if (is_local_v128) { + opcode = switch (opcode) { + .Local_Get => .Local_Get_V128, + .Local_Set => .Local_Set_V128, + .Local_Tee => .Local_Tee_V128, + else => unreachable, + }; + } }, - .Global_Set => { + .Global_Get, .Global_Set => { immediate = InstructionImmediates{ .Index = try common.decodeLEB128(u32, reader) }; + if (immediate.Index < module.globals.items.len) { + if (module.globals.items[immediate.Index].valtype == .V128) { + opcode = switch (opcode) { + .Global_Get => .Global_Get_V128, + .Global_Set => .Global_Set_V128, + else => unreachable, + }; + } + } }, .Table_Get => { immediate = InstructionImmediates{ .Index = try common.decodeLEB128(u32, reader) }; @@ -1230,7 +1261,21 @@ pub const Instruction = struct { const memarg = try MemArg.decode(reader, 64); immediate = InstructionImmediates{ .MemoryOffset = memarg.offset }; }, - .I8x16_Extract_Lane_S, .I8x16_Extract_Lane_U, .I8x16_Replace_Lane, .I16x8_Extract_Lane_S, .I16x8_Extract_Lane_U, .I16x8_Replace_Lane, .I32x4_Extract_Lane, .I32x4_Replace_Lane, .I64x2_Extract_Lane, .I64x2_Replace_Lane, .F32x4_Extract_Lane, .F32x4_Replace_Lane, .F64x2_Extract_Lane, .F64x2_Replace_Lane => { + .I8x16_Extract_Lane_S, + .I8x16_Extract_Lane_U, + .I8x16_Replace_Lane, + .I16x8_Extract_Lane_S, + .I16x8_Extract_Lane_U, + .I16x8_Replace_Lane, + .I32x4_Extract_Lane, + .I32x4_Replace_Lane, + .I64x2_Extract_Lane, + .I64x2_Replace_Lane, + .F32x4_Extract_Lane, + .F32x4_Replace_Lane, + .F64x2_Extract_Lane, + .F64x2_Replace_Lane, + => { immediate = InstructionImmediates{ .Index = try reader.readByte() }; // laneidx }, .V128_Store => { @@ -1497,11 +1542,13 @@ const ModuleValidator = struct { try self.pushControl(Opcode.Call_Local, func_type_def.getParams(), func_type_def.getReturns()); } + // Note that validateCode() can modify the instruction depending on if the type information causes a change. + // For example, see Drop which is converted into Drop_V128 if the dropped type is a V128. fn validateCode( self: *ModuleValidator, module: *const ModuleDefinition, func: *const FunctionDefinition, - instruction: Instruction, + instruction: *Instruction, validation_immediates: ValidationImmediates, ) !void { const Helpers = struct { @@ -1621,7 +1668,7 @@ const ModuleValidator = struct { } } - fn validateLoadLaneOp(validator: *ModuleValidator, module_: *const ModuleDefinition, instruction_: Instruction, comptime T: type) !void { + fn validateLoadLaneOp(validator: *ModuleValidator, module_: *const ModuleDefinition, instruction_: *Instruction, comptime T: type) !void { const immediate_index = instruction_.immediate.Index; const immediates: MemoryOffsetAndLaneImmediates = module_.code.memory_offset_and_lane_immediates.items[immediate_index]; try validateVectorLane(T, immediates.laneidx); @@ -1631,7 +1678,7 @@ const ModuleValidator = struct { try validator.pushType(.V128); } - fn validateStoreLaneOp(validator: *ModuleValidator, module_: *const ModuleDefinition, instruction_: Instruction, comptime T: type) !void { + fn validateStoreLaneOp(validator: *ModuleValidator, module_: *const ModuleDefinition, instruction_: *Instruction, comptime T: type) !void { const immediate_index = instruction_.immediate.Index; const immediates: MemoryOffsetAndLaneImmediates = module_.code.memory_offset_and_lane_immediates.items[immediate_index]; try validateVectorLane(T, immediates.laneidx); @@ -1640,14 +1687,14 @@ const ModuleValidator = struct { try validateMemoryIndex(module_); } - fn validateVecExtractLane(comptime T: type, validator: *ModuleValidator, instruction_: Instruction) !void { + fn validateVecExtractLane(comptime T: type, validator: *ModuleValidator, instruction_: *Instruction) !void { try validateVectorLane(T, instruction_.immediate.Index); const lane_valtype = vecLaneTypeToValtype(@typeInfo(T).vector.child); try validator.popType(.V128); try validator.pushType(lane_valtype); } - fn validateVecReplaceLane(comptime T: type, validator: *ModuleValidator, instruction_: Instruction) !void { + fn validateVecReplaceLane(comptime T: type, validator: *ModuleValidator, instruction_: *Instruction) !void { try validateVectorLane(T, instruction_.immediate.Index); const lane_valtype = vecLaneTypeToValtype(@typeInfo(T).vector.child); try validator.popType(lane_valtype); @@ -1686,8 +1733,16 @@ const ModuleValidator = struct { }, .DebugTrap, .Noop => {}, .Drop => { - _ = try self.popAnyType(); + if (try self.popAnyType()) |valtype| { + switch (valtype) { + .V128 => { + instruction.opcode = .Drop_V128; + }, + else => {}, + } + } }, + .Drop_V128 => unreachable, // validation generates this instruction, it shouldn't be generated externally .Block => { try Helpers.enterBlock(self, module, instruction.opcode, validation_immediates); }, @@ -1828,6 +1883,9 @@ const ModuleValidator = struct { } try self.pushType(valtype1); } + + const valtype = self.type_stack.items[self.type_stack.items.len - 1]; + instruction.opcode = if (valtype == .V128) .Select_V128 else .Select; }, .Select_T => { const valtype: ValType = instruction.immediate.ValType; @@ -1835,25 +1893,28 @@ const ModuleValidator = struct { try self.popType(valtype); try self.popType(valtype); try self.pushType(valtype); + + instruction.opcode = if (valtype == .V128) .Select_V128 else .Select; }, - .Local_Get => { + .Select_V128 => unreachable, // this opcode is generated by validation only + .Local_Get, .Local_Get_V128 => { const valtype = try Helpers.getLocalValtype(self, module, func, instruction.immediate.Index); try self.pushType(valtype); }, - .Local_Set => { + .Local_Set, .Local_Set_V128 => { const valtype = try Helpers.getLocalValtype(self, module, func, instruction.immediate.Index); try self.popType(valtype); }, - .Local_Tee => { + .Local_Tee, .Local_Tee_V128 => { const valtype = try Helpers.getLocalValtype(self, module, func, instruction.immediate.Index); try self.popType(valtype); try self.pushType(valtype); }, - .Global_Get => { + .Global_Get, .Global_Get_V128 => { const valtype = try Helpers.getGlobalValtype(module, instruction.immediate.Index, .None); try self.pushType(valtype); }, - .Global_Set => { + .Global_Set, .Global_Set_V128 => { const valtype = try Helpers.getGlobalValtype(module, instruction.immediate.Index, .Mutable); try self.popType(valtype); }, @@ -2626,6 +2687,7 @@ pub const ModuleDefinition = struct { const Code = struct { locals: std.ArrayList(ValType), instructions: std.ArrayList(Instruction), + validation_immediates: std.ArrayList(ValidationImmediates), wasm_address_to_instruction_index: std.AutoHashMap(u32, u32), @@ -2681,6 +2743,7 @@ pub const ModuleDefinition = struct { .allocator = allocator, .code = Code{ .instructions = std.ArrayList(Instruction).init(allocator), + .validation_immediates = std.ArrayList(ValidationImmediates).init(allocator), .locals = std.ArrayList(ValType).init(allocator), .wasm_address_to_instruction_index = std.AutoHashMap(u32, u32).init(allocator), @@ -2720,7 +2783,7 @@ pub const ModuleDefinition = struct { pub fn decode(self: *ModuleDefinition, wasm: []const u8) anyerror!void { std.debug.assert(self.is_decoded == false); - self.decode_internal(wasm) catch |e| { + self.decodeInternal(wasm) catch |e| { const wrapped_error: anyerror = switch (e) { error.EndOfStream => error.MalformedUnexpectedEnd, else => e, @@ -2729,7 +2792,7 @@ pub const ModuleDefinition = struct { }; } - fn decode_internal(self: *ModuleDefinition, wasm: []const u8) anyerror!void { + fn decodeInternal(self: *ModuleDefinition, wasm: []const u8) anyerror!void { const DecodeHelpers = struct { fn readRefValue(valtype: ValType, reader: anytype) !Val { switch (valtype) { @@ -2738,7 +2801,7 @@ pub const ModuleDefinition = struct { return Val.funcrefFromIndex(func_index); }, .ExternRef => { - const ref = try common.decodeLEB128(u64, reader); + const ref = try common.decodeLEB128(usize, reader); return Val{ .ExternRef = ref }; }, else => unreachable, @@ -2878,6 +2941,13 @@ pub const ModuleDefinition = struct { 0x00 => { const type_index = try common.decodeLEB128(u32, reader); try ModuleValidator.validateTypeIndex(type_index, self); + const func_type: *const FunctionTypeDefinition = &self.types.items[type_index]; + if (func_type.num_params >= MAX_FUNCTION_IMPORT_PARAMS) { + return ValidationError.ValidationTooManyFunctionImportParams; + } + if (func_type.getReturns().len >= MAX_FUNCTION_IMPORT_RETURNS) { + return ValidationError.ValidationTooManyFunctionImportReturns; + } try self.imports.functions.append(FunctionImportDefinition{ .names = names, .type_index = type_index, @@ -2937,6 +3007,9 @@ pub const ModuleDefinition = struct { }; const instructions_end = self.code.instructions.items.len; + try self.code.validation_immediates.ensureUnusedCapacity(2); + self.code.validation_immediates.appendNTimesAssumeCapacity(.{ .Void = {} }, 2); + const func = FunctionDefinition{ .type_index = type_index, .instructions_begin = instructions_begin, @@ -3242,6 +3315,8 @@ pub const ModuleDefinition = struct { defer if_to_else_offsets.deinit(); var instructions = &self.code.instructions; + var instruction_validation_immediates = &self.code.validation_immediates; + std.debug.assert(instructions.items.len == instruction_validation_immediates.items.len); const num_codes = try common.decodeLEB128(u32, reader); @@ -3308,7 +3383,7 @@ pub const ModuleDefinition = struct { const wasm_instruction_address = stream.pos - wasm_code_address_begin; - const decoded_instruction: DecodedInstruction = try Instruction.decode(reader, self); + const decoded_instruction: DecodedInstruction = try Instruction.decode(reader, self, func_def); const validation_immediates: ValidationImmediates = decoded_instruction.validation_immediates; var instruction: Instruction = decoded_instruction.instruction; @@ -3363,7 +3438,7 @@ pub const ModuleDefinition = struct { } } - try validator.validateCode(self, func_def, instruction, validation_immediates); + try validator.validateCode(self, func_def, &instruction, validation_immediates); try self.code.wasm_address_to_instruction_index.put(@as(u32, @intCast(wasm_instruction_address)), instruction_index); @@ -3371,6 +3446,8 @@ pub const ModuleDefinition = struct { .Noop => {}, // no need to emit noops since they don't do anything else => { try instructions.append(instruction); + try instruction_validation_immediates.append(validation_immediates); + std.debug.assert(instructions.items.len == instruction_validation_immediates.items.len); }, } } @@ -3432,6 +3509,7 @@ pub const ModuleDefinition = struct { pub fn destroy(self: *ModuleDefinition) void { self.code.instructions.deinit(); + self.code.validation_immediates.deinit(); self.code.locals.deinit(); self.code.wasm_address_to_instruction_index.deinit(); self.code.branch_table_immediates.deinit(); diff --git a/src/instance.zig b/src/instance.zig index 097e1f7..0863670 100644 --- a/src/instance.zig +++ b/src/instance.zig @@ -405,7 +405,7 @@ const HostFunctionCallback = *const fn (userdata: ?*anyopaque, module: *ModuleIn const HostFunction = struct { userdata: ?*anyopaque, - func_def: FunctionTypeDefinition, + func_type_def: FunctionTypeDefinition, callback: HostFunctionCallback, }; @@ -426,12 +426,12 @@ pub const FunctionImport = struct { copy.name = try allocator.dupe(u8, copy.name); switch (copy.data) { .Host => |*data| { - var func_def = FunctionTypeDefinition{ + var func_type_def = FunctionTypeDefinition{ .types = std.ArrayList(ValType).init(allocator), - .num_params = data.func_def.num_params, + .num_params = data.func_type_def.num_params, }; - try func_def.types.appendSlice(data.func_def.types.items); - data.func_def = func_def; + try func_type_def.types.appendSlice(data.func_type_def.types.items); + data.func_type_def = func_type_def; }, .Wasm => {}, } @@ -444,7 +444,7 @@ pub const FunctionImport = struct { switch (import.data) { .Host => |*data| { - data.func_def.types.deinit(); + data.func_type_def.types.deinit(); }, .Wasm => {}, } @@ -454,7 +454,7 @@ pub const FunctionImport = struct { var type_comparer = FunctionTypeDefinition.SortContext{}; switch (import.data) { .Host => |data| { - return type_comparer.eql(&data.func_def, type_signature); + return type_comparer.eql(&data.func_type_def, type_signature); }, .Wasm => |data| { const func_type_def: *const FunctionTypeDefinition = data.module_instance.findFuncTypeDef(data.index); @@ -553,7 +553,7 @@ pub const ModuleImportPackage = struct { .data = .{ .Host = HostFunction{ .userdata = userdata, - .func_def = FunctionTypeDefinition{ + .func_type_def = FunctionTypeDefinition{ .types = type_list, .num_params = @as(u32, @intCast(param_types.len)), }, @@ -569,7 +569,7 @@ pub const ModuleImportPackage = struct { for (self.functions.items) |*item| { self.allocator.free(item.name); switch (item.data) { - .Host => |h| h.func_def.types.deinit(), + .Host => |h| h.func_type_def.types.deinit(), else => {}, } } @@ -982,6 +982,7 @@ pub const ModuleInstance = struct { return error.UnlinkableIncompatibleImportType; } + // NOTE: the try store.imports.functions.append(try import_func.dupe(allocator)); } @@ -1325,18 +1326,7 @@ pub const ModuleInstance = struct { } fn findFuncTypeDef(self: *ModuleInstance, index: usize) *const FunctionTypeDefinition { - // const num_imports: usize = self.store.imports.functions.items.len; - // if (index >= num_imports) { - // const local_func_index: usize = index - num_imports; return self.vm.findFuncTypeDef(self, index); - // } else { - // const import: *const FunctionImport = &self.store.imports.functions.items[index]; - // const func_type_def: *const FunctionTypeDefinition = switch (import.data) { - // .Host => |data| &data.func_def, - // .Wasm => |data| data.module_instance.findFuncTypeDef(data.index), - // }; - // return func_type_def; - // } } fn getGlobalWithIndex(self: *ModuleInstance, index: usize) *GlobalInstance { diff --git a/src/opcode.zig b/src/opcode.zig index aa1eb32..851170f 100644 --- a/src/opcode.zig +++ b/src/opcode.zig @@ -21,13 +21,20 @@ pub const Opcode = enum(u16) { Call_Import, // Has no corresponding mapping in WasmOpcode, only calls imported functions Call_Indirect, Drop, + Drop_V128, // Has no corresponding mapping in WasmOpcode Select, Select_T, + Select_V128, Local_Get, Local_Set, Local_Tee, + Local_Get_V128, // Has no corresponding mapping in WasmOpcode + Local_Set_V128, // Has no corresponding mapping in WasmOpcode + Local_Tee_V128, // Has no corresponding mapping in WasmOpcode Global_Get, Global_Set, + Global_Get_V128, // Has no corresponding mapping in WasmOpcode + Global_Set_V128, // Has no corresponding mapping in WasmOpcode Table_Get, Table_Set, I32_Load, diff --git a/src/stack_ops.zig b/src/stack_ops.zig index 5cddc72..8a43e00 100644 --- a/src/stack_ops.zig +++ b/src/stack_ops.zig @@ -40,6 +40,10 @@ const TableDefinition = def.TableDefinition; const TablePairImmediates = def.TablePairImmediates; const Val = def.Val; const ValType = def.ValType; +const FuncRef = def.FuncRef; +const ExternRef = def.ExternRef; +const MAX_FUNCTION_IMPORT_PARAMS = def.MAX_FUNCTION_IMPORT_PARAMS; +const MAX_FUNCTION_IMPORT_RETURNS = def.MAX_FUNCTION_IMPORT_RETURNS; const inst = @import("instance.zig"); const UnlinkableError = inst.UnlinkableError; @@ -69,6 +73,11 @@ const Label = Stack.Label; const StackVM = @import("vm_stack.zig").StackVM; +pub const HostFunctionData = struct { + num_param_values: u16 = 0, + num_return_values: u16 = 0, +}; + pub fn traceInstruction(instruction_name: []const u8, pc: u32, stack: *const Stack) void { if (config.enable_debug_trace and DebugTrace.shouldTraceInstructions()) { const frame: *const CallFrame = stack.topFrame(); @@ -165,7 +174,7 @@ pub inline fn end(pc: u32, code: [*]const Instruction, stack: *Stack) ?FuncCallD pub inline fn branch(pc: u32, code: [*]const Instruction, stack: *Stack) ?FuncCallData { const label_id: u32 = code[pc].immediate.LabelId; - return _branch(code, stack, label_id); + return branchToLabel(code, stack, label_id); } pub inline fn branchIf(pc: u32, code: [*]const Instruction, stack: *Stack) ?FuncCallData { @@ -173,7 +182,7 @@ pub inline fn branchIf(pc: u32, code: [*]const Instruction, stack: *Stack) ?Func const v = stack.popI32(); if (v != 0) { const label_id: u32 = code[pc].immediate.LabelId; - next = _branch(code, stack, label_id); + next = branchToLabel(code, stack, label_id); } else { next = FuncCallData{ .code = code, @@ -193,7 +202,7 @@ pub inline fn branchTable(pc: u32, code: [*]const Instruction, stack: *Stack) ?F const label_index = stack.popI32(); const label_id: u32 = if (label_index >= 0 and label_index < table.len) table[@as(usize, @intCast(label_index))] else immediates.fallback_id; - return _branch(code, stack, label_id); + return branchToLabel(code, stack, label_id); } pub inline fn @"return"(stack: *Stack) ?FuncCallData { @@ -208,7 +217,7 @@ pub inline fn callLocal(pc: u32, code: [*]const Instruction, stack: *Stack) !Fun std.debug.assert(func_index < stack_vm.functions.items.len); const func: *const FunctionInstance = &stack_vm.functions.items[@as(usize, @intCast(func_index))]; - return call(pc, stack, module_instance, func); + return @call(.always_inline, call, .{ pc, stack, module_instance, func }); } pub inline fn callImport(pc: u32, code: [*]const Instruction, stack: *Stack) !FuncCallData { @@ -221,28 +230,69 @@ pub inline fn callImport(pc: u32, code: [*]const Instruction, stack: *Stack) !Fu const func_import = &store.imports.functions.items[func_index]; switch (func_import.data) { .Host => |data| { - const params_len: u32 = @as(u32, @intCast(data.func_def.getParams().len)); - const returns_len: u32 = @as(u32, @intCast(data.func_def.calcNumReturns())); + const vm: *const StackVM = StackVM.fromVM(module_instance.vm); + const host_function_data: *const HostFunctionData = &vm.host_function_import_data.items[func_index]; + const num_params = host_function_data.num_param_values; + const num_returns = host_function_data.num_return_values; - std.debug.assert(stack.num_values + returns_len < stack.values.len); + std.debug.assert(num_params < MAX_FUNCTION_IMPORT_PARAMS); + std.debug.assert(num_params < MAX_FUNCTION_IMPORT_PARAMS); - const module: *ModuleInstance = stack.topFrame().module_instance; - const params = stack.values[stack.num_values - params_len .. stack.num_values]; - const returns_temp = stack.values[stack.num_values .. stack.num_values + returns_len]; + std.debug.assert(stack.num_values >= num_params); + std.debug.assert(stack.num_values - num_params + num_returns < stack.values.len); - DebugTrace.traceHostFunction(module, stack.num_frames + 1, func_import.name); + const module: *ModuleInstance = stack.topFrame().module_instance; + const stack_params = stack.values[stack.num_values - num_params .. stack.num_values]; + + // because StackVal is not compatible with Val, we have to marshal the values + var vals_memory: [MAX_FUNCTION_IMPORT_PARAMS + MAX_FUNCTION_IMPORT_RETURNS]Val = undefined; + const params: []Val = vals_memory[0..num_params]; + { + const param_types: []const ValType = data.func_type_def.getParams(); + var stack_index: u32 = 0; + for (param_types, 0..) |valtype, param_index| { + switch (valtype) { + .V128 => { + const f0 = stack_params[stack_index + 0].F64; + const f1 = stack_params[stack_index + 1].F64; + params[param_index].V128 = @bitCast(f64x2{ f0, f1 }); + stack_index += 2; + }, + else => { + params[param_index].I64 = stack_params[stack_index].I64; + stack_index += 1; + }, + } + } + } - try data.callback(data.userdata, module, params.ptr, returns_temp.ptr); + const returns: []Val = vals_memory[num_params .. num_params + num_returns]; - stack.num_values = (stack.num_values - params_len) + returns_len; - const returns_dest = stack.values[stack.num_values - returns_len .. stack.num_values]; + DebugTrace.traceHostFunction(module, stack.num_frames + 1, func_import.name); - if (params_len > 0) { - std.debug.assert(@intFromPtr(returns_dest.ptr) < @intFromPtr(returns_temp.ptr)); - std.mem.copyForwards(Val, returns_dest, returns_temp); - } else { - // no copy needed in this case since the return values will go into the same location - std.debug.assert(returns_dest.ptr == returns_temp.ptr); + try data.callback(data.userdata, module, params.ptr, returns.ptr); + + const stack_returns = stack.values[stack.num_values - num_params .. stack.num_values - num_params + num_returns]; + stack.num_values = stack.num_values - num_params + num_returns; + + // marshalling back into StackVal from Val + { + const return_types: []const ValType = data.func_type_def.getReturns(); + var stack_index: u32 = 0; + for (return_types, 0..) |valtype, return_index| { + switch (valtype) { + .V128 => { + const vec2: f64x2 = @bitCast(returns[return_index].V128); + stack_returns[stack_index + 0].F64 = vec2[0]; + stack_returns[stack_index + 1].F64 = vec2[1]; + stack_index += 2; + }, + else => { + stack_returns[stack_index].I64 = returns[return_index].I64; + stack_index += 1; + }, + } + } } return FuncCallData{ @@ -251,8 +301,8 @@ pub inline fn callImport(pc: u32, code: [*]const Instruction, stack: *Stack) !Fu }; }, .Wasm => |data| { - var stack_vm: *StackVM = StackVM.fromVM(data.module_instance.vm); - const func_instance: *const FunctionInstance = &stack_vm.functions.items[data.index]; + const import_vm: *const StackVM = StackVM.fromVM(data.module_instance.vm); + const func_instance: *const FunctionInstance = &import_vm.functions.items[data.index]; return call(pc, stack, data.module_instance, func_instance); }, } @@ -292,64 +342,73 @@ pub inline fn callIndirect(pc: u32, code: [*]const Instruction, stack: *Stack) ! } pub inline fn drop(stack: *Stack) void { - _ = stack.popValue(); + _ = stack.popI64(); } -pub inline fn select(stack: *Stack) void { - const boolean: i32 = stack.popI32(); - const v2: Val = stack.popValue(); - const v1: Val = stack.popValue(); - - if (boolean != 0) { - stack.pushValue(v1); - } else { - stack.pushValue(v2); - } +pub inline fn dropV128(stack: *Stack) void { + _ = stack.popV128(); } -pub inline fn selectT(stack: *Stack) void { - const boolean: i32 = stack.popI32(); - const v2: Val = stack.popValue(); - const v1: Val = stack.popValue(); +pub inline fn select(stack: *Stack) void { + stack.select(); +} - if (boolean != 0) { - stack.pushValue(v1); - } else { - stack.pushValue(v2); - } +pub inline fn selectV128(stack: *Stack) void { + stack.selectV128(); } pub inline fn localGet(pc: u32, code: [*]const Instruction, stack: *Stack) void { const locals_index: u32 = code[pc].immediate.Index; - const locals = stack.locals(); - const v: Val = locals[locals_index]; - stack.pushValue(v); + stack.localGet(locals_index); +} + +pub inline fn localGetV128(pc: u32, code: [*]const Instruction, stack: *Stack) void { + const locals_index: u32 = code[pc].immediate.Index; + stack.localGetV128(locals_index); } pub inline fn localSet(pc: u32, code: [*]const Instruction, stack: *Stack) void { const locals_index: u32 = code[pc].immediate.Index; - const locals = stack.locals(); - const v: Val = stack.popValue(); - locals[locals_index] = v; + stack.localSet(locals_index); +} + +pub inline fn localSetV128(pc: u32, code: [*]const Instruction, stack: *Stack) void { + const locals_index: u32 = code[pc].immediate.Index; + stack.localSetV128(locals_index); } pub inline fn localTee(pc: u32, code: [*]const Instruction, stack: *Stack) void { const locals_index: u32 = code[pc].immediate.Index; - const locals = stack.locals(); - const v: Val = stack.topValue(); - locals[locals_index] = v; + stack.localTee(locals_index); +} + +pub inline fn localTeeV128(pc: u32, code: [*]const Instruction, stack: *Stack) void { + const locals_index: u32 = code[pc].immediate.Index; + stack.localTeeV128(locals_index); } pub inline fn globalGet(pc: u32, code: [*]const Instruction, stack: *Stack) void { const global_index: u32 = code[pc].immediate.Index; const global: *GlobalInstance = getStore(stack).getGlobal(global_index); - stack.pushValue(global.value); + stack.pushI64(global.value.I64); } pub inline fn globalSet(pc: u32, code: [*]const Instruction, stack: *Stack) void { const global_index: u32 = code[pc].immediate.Index; const global: *GlobalInstance = getStore(stack).getGlobal(global_index); - global.value = stack.popValue(); + global.value.I64 = stack.popI64(); +} + +pub inline fn globalGetV128(pc: u32, code: [*]const Instruction, stack: *Stack) void { + const global_index: u32 = code[pc].immediate.Index; + const global: *GlobalInstance = getStore(stack).getGlobal(global_index); + stack.pushV128(global.value.V128); +} + +pub inline fn globalSetV128(pc: u32, code: [*]const Instruction, stack: *Stack) void { + const global_index: u32 = code[pc].immediate.Index; + const global: *GlobalInstance = getStore(stack).getGlobal(global_index); + global.value.V128 = stack.popV128(); } pub inline fn tableGet(pc: u32, code: [*]const Instruction, stack: *Stack) !void { @@ -359,19 +418,19 @@ pub inline fn tableGet(pc: u32, code: [*]const Instruction, stack: *Stack) !void if (table.refs.items.len <= index or index < 0) { return error.TrapOutOfBoundsTableAccess; } - const ref = table.refs.items[@as(usize, @intCast(index))]; - stack.pushValue(ref); + const ref: Val = table.refs.items[@as(usize, @intCast(index))]; + stack.pushFuncRef(ref.FuncRef); } pub inline fn tableSet(pc: u32, code: [*]const Instruction, stack: *Stack) !void { const table_index: u32 = code[pc].immediate.Index; var table: *TableInstance = getStore(stack).getTable(table_index); - const ref = stack.popValue(); + const ref = stack.popFuncRef(); const index: i32 = stack.popI32(); if (table.refs.items.len <= index or index < 0) { return error.TrapOutOfBoundsTableAccess; } - table.refs.items[@as(usize, @intCast(index))] = ref; + table.refs.items[@as(usize, @intCast(index))].FuncRef = ref; } pub inline fn i32Load(pc: u32, code: [*]const Instruction, stack: *Stack) anyerror!void { @@ -1461,24 +1520,22 @@ pub inline fn i64Extend32S(stack: *Stack) void { stack.pushI64(v_extended); } -pub inline fn refNull(pc: u32, code: [*]const Instruction, stack: *Stack) anyerror!void { - const valtype = code[pc].immediate.ValType; - const val: ?Val = Val.nullRef(valtype); - std.debug.assert(val != null); // should have been validated in debug - stack.pushValue(val.?); +pub inline fn refNull(stack: *Stack) anyerror!void { + const ref = FuncRef.nullRef(); + stack.pushFuncRef(ref); } pub inline fn refIsNull(stack: *Stack) void { - const val: Val = stack.popValue(); - const boolean: i32 = if (val.isNull()) 1 else 0; + const ref: FuncRef = stack.popFuncRef(); + const boolean: i32 = if (ref.isNull()) 1 else 0; stack.pushI32(boolean); } pub inline fn refFunc(pc: u32, code: [*]const Instruction, stack: *Stack) void { const stack_vm = StackVM.fromVM(stack.topFrame().module_instance.vm); const func_index: u32 = code[pc].immediate.Index; - const val = Val{ .FuncRef = .{ .func = &stack_vm.functions.items[func_index] } }; - stack.pushValue(val); + const ref = FuncRef{ .func = &stack_vm.functions.items[func_index] }; + stack.pushFuncRef(ref); } pub inline fn i32TruncSatF32S(stack: *Stack) void { @@ -1702,7 +1759,7 @@ pub inline fn tableGrow(pc: u32, code: [*]const Instruction, stack: *Stack) void const table_index: u32 = code[pc].immediate.Index; const table: *TableInstance = getStore(stack).getTable(table_index); const length = @as(u32, @bitCast(stack.popI32())); - const init_value = stack.popValue(); + const init_value: Val = .{ .FuncRef = stack.popFuncRef() }; const old_length = @as(i32, @intCast(table.refs.items.len)); const return_value: i32 = if (table.grow(length, init_value)) old_length else -1; stack.pushI32(return_value); @@ -1720,7 +1777,7 @@ pub inline fn tableFill(pc: u32, code: [*]const Instruction, stack: *Stack) !voi const table: *TableInstance = getStore(stack).getTable(table_index); const length_i32 = stack.popI32(); - const funcref = stack.popValue(); + const funcref = Val{ .FuncRef = stack.popFuncRef() }; const dest_table_index = stack.popI32(); if (dest_table_index + length_i32 > table.refs.items.len or length_i32 < 0) { @@ -3003,10 +3060,10 @@ fn call(pc: u32, stack: *Stack, module_instance: *ModuleInstance, func: *const F }; } -fn _branch(code: [*]const Instruction, stack: *Stack, label_id: u32) ?FuncCallData { +inline fn branchToLabel(code: [*]const Instruction, stack: *Stack, label_id: u32) ?FuncCallData { const label: *const Label = stack.findLabel(@as(u32, @intCast(label_id))); const frame_label: *const Label = stack.frameLabel(); - // TODO generate BranchToFunctionEnd if this can be statically determined at decode time (or just generate a Return?) + // TODO generate Return opcode at decode time since this should be able to be statically determined for some opcodes (e.g. unconditional branch) if (label == frame_label) { return stack.popFrame(); } diff --git a/src/vm_stack.zig b/src/vm_stack.zig index d70e96b..a341b9a 100644 --- a/src/vm_stack.zig +++ b/src/vm_stack.zig @@ -26,6 +26,8 @@ pub const v128 = def.v128; const BlockImmediates = def.BlockImmediates; const BranchTableImmediates = def.BranchTableImmediates; const CallIndirectImmediates = def.CallIndirectImmediates; +const IfImmediates = def.IfImmediates; +const ValidationImmediates = def.ValidationImmediates; const ConstantExpression = def.ConstantExpression; const DataDefinition = def.DataDefinition; const ElementDefinition = def.ElementDefinition; @@ -37,7 +39,6 @@ const FunctionHandleType = def.FunctionHandleType; const FunctionTypeDefinition = def.FunctionTypeDefinition; const GlobalDefinition = def.GlobalDefinition; const GlobalMut = def.GlobalMut; -const IfImmediates = def.IfImmediates; const ImportNames = def.ImportNames; const Instruction = def.Instruction; const Limits = def.Limits; @@ -82,6 +83,7 @@ const FuncCallData = Stack.FuncCallData; const FunctionInstance = Stack.FunctionInstance; const OpHelpers = @import("stack_ops.zig"); +const HostFunctionData = OpHelpers.HostFunctionData; fn preamble(name: []const u8, pc: u32, code: [*]const Instruction, stack: *Stack) TrapError!void { if (metering.enabled) { @@ -149,13 +151,20 @@ const InstructionFuncs = struct { &op_Call_Import, &op_Call_Indirect, &op_Drop, + &op_Drop_V128, &op_Select, - &op_Select_T, + &op_Invalid, // Opcode.SelectT should have been replaced with either .Select or .SelectV128 + &op_Select_V128, &op_Local_Get, &op_Local_Set, &op_Local_Tee, + &op_Local_Get_V128, + &op_Local_Set_V128, + &op_Local_Tee_V128, &op_Global_Get, &op_Global_Set, + &op_Global_Get_V128, + &op_Global_Set_V128, &op_Table_Get, &op_Table_Set, &op_I32_Load, @@ -629,14 +638,12 @@ const InstructionFuncs = struct { fn op_Else(pc: u32, code: [*]const Instruction, stack: *Stack) anyerror!void { try preamble("Else", pc, code, stack); - // getting here means we reached the end of the if opcode chain, so skip to the true end opcode const next_pc = OpHelpers.@"else"(pc, code); try @call(.always_tail, InstructionFuncs.lookup(code[next_pc].opcode), .{ next_pc, code, stack }); } fn op_End(pc: u32, code: [*]const Instruction, stack: *Stack) anyerror!void { try preamble("End", pc, code, stack); - const next = OpHelpers.end(pc, code, stack) orelse return; try @call(.always_tail, InstructionFuncs.lookup(next.code[next.continuation].opcode), .{ next.continuation, next.code, stack }); } @@ -667,25 +674,19 @@ const InstructionFuncs = struct { fn op_Call_Local(pc: u32, code: [*]const Instruction, stack: *Stack) anyerror!void { try preamble("Call", pc, code, stack); - const next = try OpHelpers.callLocal(pc, code, stack); - try @call(.always_tail, InstructionFuncs.lookup(next.code[next.continuation].opcode), .{ next.continuation, next.code, stack }); } fn op_Call_Import(pc: u32, code: [*]const Instruction, stack: *Stack) anyerror!void { try preamble("Call", pc, code, stack); - const next = try OpHelpers.callImport(pc, code, stack); - try @call(.always_tail, InstructionFuncs.lookup(next.code[next.continuation].opcode), .{ next.continuation, next.code, stack }); } fn op_Call_Indirect(pc: u32, code: [*]const Instruction, stack: *Stack) anyerror!void { try preamble("Call_Indirect", pc, code, stack); - const next = try OpHelpers.callIndirect(pc, code, stack); - try @call(.always_tail, InstructionFuncs.lookup(next.code[next.continuation].opcode), .{ next.continuation, next.code, stack }); } @@ -695,19 +696,21 @@ const InstructionFuncs = struct { try @call(.always_tail, InstructionFuncs.lookup(code[pc + 1].opcode), .{ pc + 1, code, stack }); } + fn op_Drop_V128(pc: u32, code: [*]const Instruction, stack: *Stack) anyerror!void { + try preamble("Drop_V128", pc, code, stack); + OpHelpers.dropV128(stack); + try @call(.always_tail, InstructionFuncs.lookup(code[pc + 1].opcode), .{ pc + 1, code, stack }); + } + fn op_Select(pc: u32, code: [*]const Instruction, stack: *Stack) anyerror!void { try preamble("Select", pc, code, stack); - OpHelpers.select(stack); - try @call(.always_tail, InstructionFuncs.lookup(code[pc + 1].opcode), .{ pc + 1, code, stack }); } - fn op_Select_T(pc: u32, code: [*]const Instruction, stack: *Stack) anyerror!void { - try preamble("Select_T", pc, code, stack); - - OpHelpers.selectT(stack); - + fn op_Select_V128(pc: u32, code: [*]const Instruction, stack: *Stack) anyerror!void { + try preamble("Select_V128", pc, code, stack); + OpHelpers.selectV128(stack); try @call(.always_tail, InstructionFuncs.lookup(code[pc + 1].opcode), .{ pc + 1, code, stack }); } @@ -719,7 +722,6 @@ const InstructionFuncs = struct { fn op_Local_Set(pc: u32, code: [*]const Instruction, stack: *Stack) anyerror!void { try preamble("Local_Set", pc, code, stack); - OpHelpers.localSet(pc, code, stack); try @call(.always_tail, InstructionFuncs.lookup(code[pc + 1].opcode), .{ pc + 1, code, stack }); } @@ -730,6 +732,24 @@ const InstructionFuncs = struct { try @call(.always_tail, InstructionFuncs.lookup(code[pc + 1].opcode), .{ pc + 1, code, stack }); } + fn op_Local_Get_V128(pc: u32, code: [*]const Instruction, stack: *Stack) anyerror!void { + try preamble("Local_Get_V128", pc, code, stack); + OpHelpers.localGetV128(pc, code, stack); + try @call(.always_tail, InstructionFuncs.lookup(code[pc + 1].opcode), .{ pc + 1, code, stack }); + } + + fn op_Local_Set_V128(pc: u32, code: [*]const Instruction, stack: *Stack) anyerror!void { + try preamble("Local_Set_V128", pc, code, stack); + OpHelpers.localSetV128(pc, code, stack); + try @call(.always_tail, InstructionFuncs.lookup(code[pc + 1].opcode), .{ pc + 1, code, stack }); + } + + fn op_Local_Tee_V128(pc: u32, code: [*]const Instruction, stack: *Stack) anyerror!void { + try preamble("Local_Tee_V128", pc, code, stack); + OpHelpers.localTeeV128(pc, code, stack); + try @call(.always_tail, InstructionFuncs.lookup(code[pc + 1].opcode), .{ pc + 1, code, stack }); + } + fn op_Global_Get(pc: u32, code: [*]const Instruction, stack: *Stack) anyerror!void { try preamble("Global_Get", pc, code, stack); OpHelpers.globalGet(pc, code, stack); @@ -742,6 +762,18 @@ const InstructionFuncs = struct { try @call(.always_tail, InstructionFuncs.lookup(code[pc + 1].opcode), .{ pc + 1, code, stack }); } + fn op_Global_Get_V128(pc: u32, code: [*]const Instruction, stack: *Stack) anyerror!void { + try preamble("Global_Get_V128", pc, code, stack); + OpHelpers.globalGetV128(pc, code, stack); + try @call(.always_tail, InstructionFuncs.lookup(code[pc + 1].opcode), .{ pc + 1, code, stack }); + } + + fn op_Global_Set_V128(pc: u32, code: [*]const Instruction, stack: *Stack) anyerror!void { + try preamble("Global_Set_V128", pc, code, stack); + OpHelpers.globalSetV128(pc, code, stack); + try @call(.always_tail, InstructionFuncs.lookup(code[pc + 1].opcode), .{ pc + 1, code, stack }); + } + fn op_Table_Get(pc: u32, code: [*]const Instruction, stack: *Stack) anyerror!void { try preamble("Table_Get", pc, code, stack); try OpHelpers.tableGet(pc, code, stack); @@ -1698,7 +1730,7 @@ const InstructionFuncs = struct { fn op_Ref_Null(pc: u32, code: [*]const Instruction, stack: *Stack) anyerror!void { try preamble("Ref_Null", pc, code, stack); - try OpHelpers.refNull(pc, code, stack); + try OpHelpers.refNull(stack); try @call(.always_tail, InstructionFuncs.lookup(code[pc + 1].opcode), .{ pc + 1, code, stack }); } @@ -3235,6 +3267,17 @@ const InstructionFuncs = struct { } }; +fn calcNumValues(types: []const ValType) u16 { + var num: u16 = 0; + for (types) |valtype| { + num += switch (valtype) { + .V128 => 2, + else => 1, + }; + } + return num; +} + pub const StackVM = struct { const TrapType = enum { Step, @@ -3272,7 +3315,9 @@ pub const StackVM = struct { } else void; stack: Stack, + instructions: std.ArrayList(Instruction), functions: std.ArrayList(FunctionInstance), + host_function_import_data: std.ArrayList(HostFunctionData), debug_state: ?DebugState, meter_state: MeterState, @@ -3283,7 +3328,9 @@ pub const StackVM = struct { pub fn init(vm: *VM) void { var self: *StackVM = fromVM(vm); self.stack = Stack.init(vm.allocator); + self.instructions = std.ArrayList(Instruction).init(vm.allocator); self.functions = std.ArrayList(FunctionInstance).init(vm.allocator); + self.host_function_import_data = std.ArrayList(HostFunctionData).init(vm.allocator); self.debug_state = null; } @@ -3291,7 +3338,8 @@ pub const StackVM = struct { var self: *StackVM = fromVM(vm); self.functions.deinit(); - + self.host_function_import_data.deinit(); + self.instructions.deinit(); self.stack.deinit(); if (self.debug_state) |*debug_state| { debug_state.trapped_opcodes.deinit(); @@ -3317,32 +3365,112 @@ pub const StackVM = struct { .max_frames = @as(u16, @intFromFloat(stack_size_f * 0.01)), }); + // vm keeps a copy of the instructions to mutate some of them + try self.instructions.appendSlice(module.module_def.code.instructions.items); + + var locals_remap: std.ArrayList(u32) = .init(vm.allocator); + defer locals_remap.deinit(); + try locals_remap.ensureTotalCapacity(1024); + try self.functions.ensureTotalCapacity(module.module_def.functions.items.len); for (module.module_def.functions.items, 0..) |*def_func, i| { const func_type: *const FunctionTypeDefinition = &module.module_def.types.items[def_func.type_index]; const param_types: []const ValType = func_type.getParams(); + const local_types: []const ValType = def_func.locals(module.module_def); + + var num_params: u16 = 0; + var num_locals: u32 = 0; + + // remap local indices to ensure v128 gets 2 local slots + try locals_remap.resize(0); + { + for (param_types) |valtype| { + const num_values: u16 = switch (valtype) { + .V128 => 2, + else => 1, + }; + try locals_remap.append(num_params); + num_params += num_values; + } + + for (local_types) |valtype| { + const num_values: u16 = switch (valtype) { + .V128 => 2, + else => 1, + }; + try locals_remap.append(num_params + num_locals); + num_locals += num_values; + } + } + + const return_types: []const ValType = func_type.getReturns(); + const num_returns: u16 = calcNumValues(return_types); - const locals: []const ValType = def_func.locals(module.module_def); - const num_locals: u32 = @intCast(locals.len); - const num_params: u16 = @intCast(param_types.len); - const num_values: u32 = @intCast(def_func.stack_stats.values); + const max_values_on_stack: u32 = @intCast(def_func.stack_stats.values); const f = FunctionInstance{ .type_def_index = def_func.type_index, .def_index = @as(u32, @intCast(i)), - .code = module.module_def.code.instructions.items.ptr, + .code = self.instructions.items.ptr, .instructions_begin = def_func.instructions_begin, .num_locals = num_locals, .num_params = num_params, - .num_returns = @intCast(func_type.getReturns().len), + .num_returns = num_returns, // maximum number of values that can be on the stack for this function - .max_values = num_values + num_locals + num_params, + .max_values = max_values_on_stack + num_locals + num_params, .max_labels = @intCast(def_func.stack_stats.labels), .module = module, }; try self.functions.append(f); + + // fixup immediates + std.debug.assert(self.instructions.items.len == module.module_def.code.validation_immediates.items.len); + const func_code: []Instruction = self.instructions.items[def_func.instructions_begin..def_func.instructions_end]; + const func_validation_immediates: []ValidationImmediates = module.module_def.code.validation_immediates.items[def_func.instructions_begin..def_func.instructions_end]; + for (func_code, func_validation_immediates) |*instruction, validation_immediates| { + switch (instruction.opcode) { + .Local_Get, .Local_Set, .Local_Tee, .Local_Get_V128, .Local_Set_V128, .Local_Tee_V128 => { + const remapped_index = locals_remap.items[instruction.immediate.Index]; + instruction.immediate.Index = remapped_index; + }, + .Block, .Loop, .If => { + const immediates = validation_immediates.BlockOrIf; + const block_return_types = immediates.block_value.getBlocktypeReturnTypes(immediates.block_type, module.module_def); + const num_block_returns: u16 = calcNumValues(block_return_types); + + if (instruction.opcode == .If) { + instruction.immediate.If.num_returns = num_block_returns; + } else { + instruction.immediate.Block.num_returns = num_block_returns; + } + }, + else => {}, + } + } + } + + // precalculate some data for function imports to avoid having to do this at runtime + try self.host_function_import_data.ensureTotalCapacity(module.store.imports.functions.items.len); + for (module.store.imports.functions.items) |import| { + const data: HostFunctionData = switch (import.data) { + .Host => |host_import| blk: { + const params: []const ValType = host_import.func_type_def.getParams(); + const returns: []const ValType = host_import.func_type_def.getReturns(); + + const num_param_values: u16 = calcNumValues(params); + const num_return_values: u16 = calcNumValues(returns); + + const data: HostFunctionData = .{ + .num_param_values = num_param_values, + .num_return_values = num_return_values, + }; + break :blk data; + }, + .Wasm => .{}, + }; + self.host_function_import_data.appendAssumeCapacity(data); } } @@ -3367,7 +3495,69 @@ pub const StackVM = struct { } } - try self.invokeInternal(module, handle.index, params, returns); + const func: *const FunctionInstance = &self.functions.items[handle.index]; + const func_def: *const FunctionDefinition = &module.module_def.functions.items[func.def_index]; + const type_def: *const FunctionTypeDefinition = func_def.typeDefinition(module.module_def); + const param_types: []const ValType = type_def.getParams(); + const return_types: []const ValType = type_def.getReturns(); + + // use the count of params/returns from the type since it corresponds to the number of Vals. The function instances' param/return + // counts double count V128 as 2 parameters. + const params_slice = params[0..param_types.len]; + var returns_slice = returns[0..return_types.len]; + + // Ensure any leftover stack state doesn't pollute this invoke. Can happen if the previous invoke returned an error. + self.stack.popAll(); + + // pushFrame() assumes the stack already contains the params to the function, so ensure they exist + // on the value stack + for (params_slice, param_types) |v, valtype| { + switch (valtype) { + .V128 => { + const vec2: f64x2 = @bitCast(v.V128); + self.stack.pushF64(vec2[0]); + self.stack.pushF64(vec2[1]); + }, + else => self.stack.pushI64(v.I64), + } + } + + try self.stack.pushFrame(func, module); + try self.stack.pushLabel(func.num_returns, @intCast(func_def.continuation)); + + DebugTrace.traceFunction(module, self.stack.num_frames, func.def_index); + + if (config.vm_kind == .tailcall) { + try InstructionFuncs.run(@intCast(func.instructions_begin), func.code, &self.stack); + } else { + try self.run(@intCast(func.instructions_begin), func.code); + } + + if (returns_slice.len > 0) { + std.debug.assert(returns_slice.len == return_types.len); + for (0..returns_slice.len) |i| { + const index = returns_slice.len - 1 - i; + switch (return_types[index]) { + .V128 => { + var vec2: f64x2 = undefined; + vec2[1] = self.stack.popF64(); + vec2[0] = self.stack.popF64(); + returns_slice[index].V128 = @bitCast(vec2); + }, + else => { + returns_slice[index].I64 = self.stack.popI64(); + }, + } + } + } + + if (self.debug_state) |*debug_state| { + debug_state.onInvokeFinished(); + } + + if (metering.enabled and self.meter_state.enabled) { + self.meter_state.onInvokeFinished(); + } } pub fn resumeInvoke(vm: *VM, module: *ModuleInstance, returns: []Val, opts: ResumeInvokeOpts) anyerror!void { @@ -3398,14 +3588,33 @@ pub const StackVM = struct { return error.TrapInvalidResume; } - const op_func = InstructionFuncs.lookup(opcode); - try op_func(pc, module.module_def.code.instructions.items.ptr, &self.stack); + const func: *const FunctionInstance = self.stack.topFrame().func; + + if (config.vm_kind == .tailcall) { + try InstructionFuncs.run(pc, func.code, &self.stack); + } else { + try self.run(pc, func.code); + } if (returns.len > 0) { - var index: i32 = @as(i32, @intCast(returns.len - 1)); - while (index >= 0) { - returns[@as(usize, @intCast(index))] = self.stack.popValue(); - index -= 1; + const func_def: *const FunctionDefinition = &module.module_def.functions.items[func.def_index]; + const type_def: *const FunctionTypeDefinition = func_def.typeDefinition(module.module_def); + const return_types: []const ValType = type_def.getReturns(); + std.debug.assert(returns.len == return_types.len); + + for (0..returns.len) |i| { + const index = returns.len - 1 - i; + switch (return_types[index]) { + .V128 => { + var vec2: f64x2 = undefined; + vec2[1] = self.stack.popF64(); + vec2[0] = self.stack.popF64(); + returns[index].V128 = @bitCast(vec2); + }, + else => { + returns[index].I64 = self.stack.popI64(); + }, + } } } @@ -3487,7 +3696,7 @@ pub const StackVM = struct { const name_section: *const NameCustomSection = &frame.func.module.module_def.name_section; const module_name = name_section.getModuleName(); - const func_name_index: usize = frame.func.def_index + frame.func.module.module_def.imports.functions.items.len; + const func_name_index: usize = frame.func.def_index; const function_name = name_section.findFunctionName(func_name_index); try writer.print("{}: {s}!{s}\n", .{ reverse_index, module_name, function_name }); @@ -3509,50 +3718,6 @@ pub const StackVM = struct { return if (func.isNull()) func else FuncRef{ .func = &self.functions.items[func.index] }; } - fn invokeInternal(self: *StackVM, module: *ModuleInstance, func_instance_index: usize, params: [*]const Val, returns: [*]Val) !void { - const func: FunctionInstance = self.functions.items[func_instance_index]; - const func_def: FunctionDefinition = module.module_def.functions.items[func.def_index]; - - const params_slice = params[0..func.num_params]; - var returns_slice = returns[0..func.num_returns]; - - // Ensure any leftover stack state doesn't pollute this invoke. Can happen if the previous invoke returned an error. - self.stack.popAll(); - - // pushFrame() assumes the stack already contains the params to the function, so ensure they exist - // on the value stack - for (params_slice) |v| { - self.stack.pushValue(v); - } - - try self.stack.pushFrame(&func, module); - try self.stack.pushLabel(func.num_returns, @intCast(func_def.continuation)); - - DebugTrace.traceFunction(module, self.stack.num_frames, func.def_index); - - if (config.vm_kind == .tailcall) { - try InstructionFuncs.run(@intCast(func.instructions_begin), func.code, &self.stack); - } else { - try self.run(@intCast(func.instructions_begin), func.code); - } - - if (returns_slice.len > 0) { - var index: i32 = @as(i32, @intCast(returns_slice.len - 1)); - while (index >= 0) { - returns_slice[@as(usize, @intCast(index))] = self.stack.popValue(); - index -= 1; - } - } - - if (self.debug_state) |*debug_state| { - debug_state.onInvokeFinished(); - } - - if (metering.enabled and self.meter_state.enabled) { - self.meter_state.onInvokeFinished(); - } - } - fn run(self: *StackVM, start_pc: u32, start_code: [*]const Instruction) anyerror!void { var pc: u32 = start_pc; var code: [*]const Instruction = start_code; @@ -3609,14 +3774,12 @@ pub const StackVM = struct { Opcode.Else => { try preamble("Else", pc, code, stack); - // getting here means we reached the end of the if opcode chain, so skip to the true end opcode pc = OpHelpers.@"else"(pc, code); continue :interpret code[pc].opcode; }, Opcode.End => { try preamble("End", pc, code, stack); - const next = OpHelpers.end(pc, code, stack) orelse return; pc = next.continuation; code = next.code; @@ -3657,7 +3820,6 @@ pub const StackVM = struct { Opcode.Call_Local => { try preamble("Call_Local", pc, code, stack); - const next = try OpHelpers.callLocal(pc, code, stack); pc = next.continuation; code = next.code; @@ -3666,9 +3828,7 @@ pub const StackVM = struct { Opcode.Call_Import => { try preamble("Call_Import", pc, code, stack); - const next = try OpHelpers.callImport(pc, code, stack); - pc = next.continuation; code = next.code; continue :interpret code[pc].opcode; @@ -3676,9 +3836,7 @@ pub const StackVM = struct { Opcode.Call_Indirect => { try preamble("Call_Indirect", pc, code, stack); - const next = try OpHelpers.callIndirect(pc, code, stack); - pc = next.continuation; code = next.code; continue :interpret code[pc].opcode; @@ -3691,20 +3849,28 @@ pub const StackVM = struct { continue :interpret code[pc].opcode; }, + Opcode.Drop_V128 => { + try preamble("Drop_V128", pc, code, stack); + OpHelpers.dropV128(stack); + pc += 1; + continue :interpret code[pc].opcode; + }, + Opcode.Select => { try preamble("Select", pc, code, stack); - OpHelpers.select(stack); - pc += 1; continue :interpret code[pc].opcode; }, Opcode.Select_T => { - try preamble("Select_T", pc, code, stack); - - OpHelpers.selectT(stack); + // should have been switched to Select in validation + unreachable; + }, + Opcode.Select_V128 => { + try preamble("SelectV128", pc, code, stack); + OpHelpers.selectV128(stack); pc += 1; continue :interpret code[pc].opcode; }, @@ -3718,7 +3884,6 @@ pub const StackVM = struct { Opcode.Local_Set => { try preamble("Local_Set", pc, code, stack); - OpHelpers.localSet(pc, code, stack); pc += 1; continue :interpret code[pc].opcode; @@ -3731,6 +3896,27 @@ pub const StackVM = struct { continue :interpret code[pc].opcode; }, + Opcode.Local_Get_V128 => { + try preamble("Local_Get_V128", pc, code, stack); + OpHelpers.localGetV128(pc, code, stack); + pc += 1; + continue :interpret code[pc].opcode; + }, + + Opcode.Local_Set_V128 => { + try preamble("Local_Set_V128", pc, code, stack); + OpHelpers.localSetV128(pc, code, stack); + pc += 1; + continue :interpret code[pc].opcode; + }, + + Opcode.Local_Tee_V128 => { + try preamble("Local_Tee_V128", pc, code, stack); + OpHelpers.localTeeV128(pc, code, stack); + pc += 1; + continue :interpret code[pc].opcode; + }, + Opcode.Global_Get => { try preamble("Global_Get", pc, code, stack); OpHelpers.globalGet(pc, code, stack); @@ -3745,6 +3931,20 @@ pub const StackVM = struct { continue :interpret code[pc].opcode; }, + Opcode.Global_Get_V128 => { + try preamble("Global_Get_V128", pc, code, stack); + OpHelpers.globalGetV128(pc, code, stack); + pc += 1; + continue :interpret code[pc].opcode; + }, + + Opcode.Global_Set_V128 => { + try preamble("Global_Set_V128", pc, code, stack); + OpHelpers.globalSetV128(pc, code, stack); + pc += 1; + continue :interpret code[pc].opcode; + }, + Opcode.Table_Get => { try preamble("Table_Get", pc, code, stack); try OpHelpers.tableGet(pc, code, stack); @@ -4860,7 +5060,7 @@ pub const StackVM = struct { Opcode.Ref_Null => { try preamble("Ref_Null", pc, code, stack); - try OpHelpers.refNull(pc, code, stack); + try OpHelpers.refNull(stack); pc += 1; continue :interpret code[pc].opcode; },