Skip to content

Implement type imports and exports #7330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions scripts/test/fuzzing.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@
'stack_switching_resume.wast',
'stack_switching_resume_throw.wast',
'stack_switching_switch.wast',
# TODO: fuzzer support for type imports
'type-imports.wast',
'type-imports.wat'
# TODO: fuzzer support for exact references
'exact-references.wast',
'optimize-instructions-exact.wast',
Expand Down
10 changes: 10 additions & 0 deletions src/binaryen-c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4940,6 +4940,13 @@ BinaryenExportRef BinaryenAddTagExport(BinaryenModuleRef module,
((Module*)module)->addExport(ret);
return ret;
}
BinaryenExportRef BinaryenAddTypeExport(BinaryenModuleRef module,
BinaryenHeapType type,
const char* externalName) {
auto* ret = new Export(externalName, ExternalKind::Type, HeapType(type));
((Module*)module)->addExport(ret);
return ret;
}
BinaryenExportRef BinaryenGetExport(BinaryenModuleRef module,
const char* externalName) {
return ((Module*)module)->getExportOrNull(externalName);
Expand Down Expand Up @@ -5924,6 +5931,9 @@ const char* BinaryenExportGetName(BinaryenExportRef export_) {
const char* BinaryenExportGetValue(BinaryenExportRef export_) {
return ((Export*)export_)->getInternalName()->str.data();
}
BinaryenHeapType BinaryenExportGetHeapType(BinaryenExportRef export_) {
return ((Export*)export_)->getHeapType()->getID();
}

//
// ========= Custom sections =========
Expand Down
7 changes: 7 additions & 0 deletions src/binaryen-c.h
Original file line number Diff line number Diff line change
Expand Up @@ -2717,6 +2717,10 @@ BINARYEN_API BinaryenExportRef BinaryenAddGlobalExport(
BINARYEN_API BinaryenExportRef BinaryenAddTagExport(BinaryenModuleRef module,
const char* internalName,
const char* externalName);
// Adds a type export to the module.
BINARYEN_API BinaryenExportRef BinaryenAddTypeExport(BinaryenModuleRef module,
BinaryenHeapType type,
const char* externalName);
// Gets an export reference by external name. Returns NULL if the export does
// not exist.
BINARYEN_API BinaryenExportRef BinaryenGetExport(BinaryenModuleRef module,
Expand Down Expand Up @@ -3319,6 +3323,9 @@ BinaryenExportGetKind(BinaryenExportRef export_);
BINARYEN_API const char* BinaryenExportGetName(BinaryenExportRef export_);
// Gets the internal name of the specified export.
BINARYEN_API const char* BinaryenExportGetValue(BinaryenExportRef export_);
// Gets the heap type of the specified type export.
BINARYEN_API BinaryenHeapType
BinaryenTypeExportGetHeapType(BinaryenExportRef export_);

//
// ========= Custom sections =========
Expand Down
1 change: 1 addition & 0 deletions src/ir/gc-type-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ inline std::optional<Field> getField(HeapType type, Index index = 0) {
return type.getArray().element;
case HeapTypeKind::Func:
case HeapTypeKind::Cont:
case HeapTypeKind::Import:
case HeapTypeKind::Basic:
break;
}
Expand Down
32 changes: 28 additions & 4 deletions src/ir/module-utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,11 @@ InsertOrderedMap<HeapType, HeapTypeInfo> collectHeapTypeInfo(
for (auto& curr : wasm.elementSegments) {
info.note(curr->type);
}
for (auto& curr : wasm.exports) {
if (auto* heapType = curr->getHeapType()) {
info.note(*heapType);
}
}

// Collect info from functions in parallel.
ModuleUtils::ParallelFunctionAnalysis<TypeInfos, Immutable, InsertOrderedMap>
Expand Down Expand Up @@ -655,6 +660,9 @@ void classifyTypeVisibility(Module& wasm,
case ExternalKind::Tag:
notePublic(wasm.getTag(*ex->getInternalName())->type);
continue;
case ExternalKind::Type:
notePublic(*ex->getHeapType());
continue;
case ExternalKind::Invalid:
break;
}
Expand Down Expand Up @@ -730,9 +738,9 @@ std::vector<HeapType> getPrivateHeapTypes(Module& wasm) {
return types;
}

IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm) {
auto counts = collectHeapTypeInfo(wasm, TypeInclusion::BinaryTypes);

IndexedHeapTypes sortHeapTypes(Module& wasm,
InsertOrderedMap<HeapType, HeapTypeInfo>& counts,
std::function<HeapType(HeapType)> map) {
// Collect the rec groups.
std::unordered_map<RecGroup, size_t> groupIndices;
std::vector<RecGroup> groups;
Expand All @@ -759,6 +767,7 @@ IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm) {
for (size_t i = 0; i < groups.size(); ++i) {
for (auto type : groups[i]) {
for (auto child : type.getReferencedHeapTypes()) {
child = map(child);
if (child.isBasic()) {
continue;
}
Expand Down Expand Up @@ -793,6 +802,11 @@ IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm) {
}
}

std::vector<bool> isImport(groups.size());
for (size_t i = 0; i < groups.size(); ++i) {
isImport[i] = groups[i][0].isImport();
}

// If we've preserved the input type order on the module, we have to respect
// that first. Use the index of the first type from each group. In principle
// we could try to do something more robust like take the minimum index of all
Expand All @@ -812,7 +826,11 @@ IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm) {
}
}

auto order = TopologicalSort::minSort(deps, [&](size_t a, size_t b) {
auto order = TopologicalSort::minSort(deps, [&](size_t a, size_t b) -> bool {
// Imports should be first
if (isImport[a] != isImport[b]) {
return isImport[a];
}
auto indexA = groupTypeIndices[a];
auto indexB = groupTypeIndices[b];
// Groups with indices must be sorted before groups without indices to
Expand Down Expand Up @@ -846,4 +864,10 @@ IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm) {
return indexedTypes;
}

IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm) {
auto counts = collectHeapTypeInfo(wasm, TypeInclusion::BinaryTypes);
return sortHeapTypes(
wasm, counts, [](HeapType type) -> HeapType { return type; });
}

} // namespace wasm::ModuleUtils
6 changes: 6 additions & 0 deletions src/ir/module-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,12 @@ struct IndexedHeapTypes {
std::unordered_map<HeapType, Index> indices;
};

// Orders the types to be valid (after renaming by the map function)
// and sorts the types by frequency of use to minimize code size.
IndexedHeapTypes sortHeapTypes(Module& wasm,
InsertOrderedMap<HeapType, HeapTypeInfo>& counts,
std::function<HeapType(HeapType)> map);

// Similar to `collectHeapTypes`, but provides fast lookup of the index for each
// type as well. Also orders the types to be valid and sorts the types by
// frequency of use to minimize code size.
Expand Down
3 changes: 3 additions & 0 deletions src/ir/subtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ struct SubTypes {
break;
case HeapTypeKind::Cont:
WASM_UNREACHABLE("TODO: cont");
case HeapTypeKind::Import:
basic = type.getImport().bound;
break;
case HeapTypeKind::Basic:
WASM_UNREACHABLE("unexpected kind");
}
Expand Down
8 changes: 8 additions & 0 deletions src/ir/type-updating.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ GlobalTypeRewriter::TypeMap GlobalTypeRewriter::rebuildTypes(
}
case HeapTypeKind::Cont:
WASM_UNREACHABLE("TODO: cont");
case HeapTypeKind::Import: {
break;
}
case HeapTypeKind::Basic:
WASM_UNREACHABLE("unexpected kind");
}
Expand Down Expand Up @@ -302,6 +305,11 @@ void GlobalTypeRewriter::mapTypes(const TypeMap& oldToNewTypes) {
for (auto& tag : wasm.tags) {
tag->type = updater.getNew(tag->type);
}
for (auto& exp : wasm.exports) {
if (auto* heapType = exp->getHeapType()) {
*heapType = updater.getNew(*heapType);
}
}
}

void GlobalTypeRewriter::mapTypeNamesAndIndices(const TypeMap& oldToNewTypes) {
Expand Down
68 changes: 65 additions & 3 deletions src/parser/contexts.h
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,10 @@ struct ParseDeclsCtx : NullTypeParserCtx, NullInstrParserCtx {
std::vector<DefPos> dataDefs;
std::vector<DefPos> tagDefs;

// Type imports: name, export names, import names, and positions.
std::vector<std::tuple<Name, std::vector<Name>, ImportNames, Index>>
typeImports;

// Positions of export definitions.
std::vector<Index> exportDefs;

Expand All @@ -962,6 +966,7 @@ struct ParseDeclsCtx : NullTypeParserCtx, NullInstrParserCtx {

// Used to verify that all imports come before all non-imports.
bool hasNonImport = false;
bool hasTypeDefinition = false;

Result<> checkImport(Index pos, ImportNames* import) {
if (import) {
Expand All @@ -983,9 +988,10 @@ struct ParseDeclsCtx : NullTypeParserCtx, NullInstrParserCtx {
void setOpen() {}
void setShared() {}
Result<> addSubtype(HeapTypeT) { return Ok{}; }
void finishTypeDef(Name name, Index pos) {
void finishTypeDef(Name name, const std::vector<Name>& exports, Index pos) {
// TODO: type annotations
typeDefs.push_back({name, pos, Index(typeDefs.size()), {}});
hasTypeDefinition = true;
}
size_t getRecGroupStartIndex() { return 0; }
void addRecGroup(Index, size_t) {}
Expand Down Expand Up @@ -1097,10 +1103,28 @@ struct ParseDeclsCtx : NullTypeParserCtx, NullInstrParserCtx {
TypeUseT type,
Index pos);

Result<> addTypeImport(Name name,
const std::vector<Name>& exports,
ImportNames* import,
Index pos,
Index typePos) {
if (hasTypeDefinition) {
return in.err(pos, "type import after type definitions");
}
typeDefs.push_back({name, pos, Index(typeDefs.size()), {}});
typeImports.push_back({name, exports, *import, typePos});
return Ok{};
}

Result<> addExport(Index pos, Ok, Name, ExternalKind) {
exportDefs.push_back(pos);
return Ok{};
}

Result<> addTypeExport(Index pos, Ok, Name) {
exportDefs.push_back(pos);
return Ok{};
}
};

// Phase 2: Parse type definitions into a TypeBuilder.
Expand All @@ -1113,12 +1137,15 @@ struct ParseTypeDefsCtx : TypeParserCtx<ParseTypeDefsCtx> {
// Parse the names of types and fields as we go.
std::vector<TypeNames> names;

// Keep track of type exports
std::vector<std::vector<Name>> typeExports;

// The index of the subtype definition we are parsing.
Index index = 0;

ParseTypeDefsCtx(Lexer& in, TypeBuilder& builder, const IndexMap& typeIndices)
: TypeParserCtx<ParseTypeDefsCtx>(typeIndices), in(in), builder(builder),
names(builder.size()) {}
names(builder.size()), typeExports(builder.size()) {}

TypeT
makeRefType(HeapTypeT ht, Nullability nullability, Exactness exactness) {
Expand Down Expand Up @@ -1162,7 +1189,18 @@ struct ParseTypeDefsCtx : TypeParserCtx<ParseTypeDefsCtx> {
return Ok{};
}

void finishTypeDef(Name name, Index pos) { names[index++].name = name; }
void finishTypeDef(Name name, const std::vector<Name>& exports, Index pos) {
typeExports[index] = exports;
names[index++].name = name;
}

Result<> addTypeImport(Name name,
const std::vector<Name>& exports,
ImportNames* import,
Index pos,
Index typePos) {
return Ok{};
}

size_t getRecGroupStartIndex() { return index; }

Expand Down Expand Up @@ -1411,6 +1449,14 @@ struct ParseModuleTypesCtx : TypeParserCtx<ParseModuleTypesCtx>,
t->type = use.type;
return Ok{};
}

Result<> addTypeImport(Name name,
const std::vector<Name>& exports,
ImportNames* import,
Index pos,
Index typePos) {
return Ok{};
}
};

// Phase 5: Parse module element definitions, including instructions.
Expand Down Expand Up @@ -1765,6 +1811,14 @@ struct ParseDefsCtx : TypeParserCtx<ParseDefsCtx> {
return Ok{};
}

Result<> addTypeImport(Name,
const std::vector<Name> exports,
ImportNames* import,
Index pos,
Index typePos) {
return Ok{};
}

Result<> addExport(Index pos, Name value, Name name, ExternalKind kind) {
if (wasm.getExportOrNull(name)) {
return in.err(pos, "duplicate export");
Expand All @@ -1773,6 +1827,14 @@ struct ParseDefsCtx : TypeParserCtx<ParseDefsCtx> {
return Ok{};
}

Result<> addTypeExport(Index pos, HeapType heaptype, Name name) {
if (wasm.getExportOrNull(name)) {
return in.err(pos, "duplicate export");
}
wasm.addExport(builder.makeExport(name, heaptype, ExternalKind::Type));
return Ok{};
}

Result<Index> addScratchLocal(Index pos, Type type) {
if (!func) {
return in.err(pos,
Expand Down
18 changes: 18 additions & 0 deletions src/parser/parse-2-typedefs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ Result<> parseTypeDefs(
std::unordered_map<HeapType, std::unordered_map<Name, Index>>& typeNames) {
TypeBuilder builder(decls.typeDefs.size());
ParseTypeDefsCtx ctx(input, builder, typeIndices);
for (auto& [name, exports, importNames, pos] : decls.typeImports) {
WithPosition with(ctx, pos);
auto heaptype = typetype(ctx);
CHECK_ERR(heaptype);
builder[ctx.index] = TypeImport(importNames.mod, importNames.nm, *heaptype);
ctx.typeExports[ctx.index] = exports;
ctx.names[ctx.index++].name = name;
}
for (auto& recType : decls.recTypeDefs) {
WithPosition with(ctx, recType.pos);
CHECK_ERR(rectype(ctx));
Expand All @@ -49,6 +57,16 @@ Result<> parseTypeDefs(
}
}
}
for (size_t i = 0; i < types.size(); ++i) {
for (Name& name : ctx.typeExports[i]) {
if (decls.wasm.getExportOrNull(name)) {
// TODO: Fix error location
return ctx.in.err("repeated export name");
}
decls.wasm.addExport(
Builder(decls.wasm).makeExport(name, types[i], ExternalKind::Type));
}
}
return Ok{};
}

Expand Down
Loading