From e137b76b0b11ce2fbcd252f196568f6183b90f1f Mon Sep 17 00:00:00 2001 From: milo Date: Fri, 26 Sep 2025 14:18:52 -0400 Subject: [PATCH] write compiler --- Makefile | 6 + input.scm | 22 ++ scm2wasm.c | 624 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 652 insertions(+) create mode 100644 Makefile create mode 100644 input.scm create mode 100644 scm2wasm.c diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..42fb469 --- /dev/null +++ b/Makefile @@ -0,0 +1,6 @@ +CFLAGS = -Wall -Wextra + +all: scm2wasm + +clean: + rm -f scm2wasm diff --git a/input.scm b/input.scm new file mode 100644 index 0000000..b4142ad --- /dev/null +++ b/input.scm @@ -0,0 +1,22 @@ +(define (map f xs) + (if (null? xs) + '() + (cons (f (car xs)) + (map f (cdr xs))))) + +(define (foldl f z xs) + (if (null? xs) + z + (foldl f (f (car xs) z) (cdr xs)))) + +(define (range n) + (define (go i) + (if (> i n) + '() + (cons i (go (+ i 1))))) + (go 1)) + +;; 1² + 2² + 3² + 4² = 30 +(let* ([sqr (lambda (x) (* x x))] + [sqrs (map sqr (range 4))]) + (foldl + 0 sqrs)) diff --git a/scm2wasm.c b/scm2wasm.c new file mode 100644 index 0000000..4c3f6ce --- /dev/null +++ b/scm2wasm.c @@ -0,0 +1,624 @@ +#include +#include +#include +#include +#include + +// == values == + +enum tag { + NIL, PAIR, SYMBOL, NUMBER, + CONSTANT, LAMBDA, IFELSE, BEGIN, SETBANG, +}; + +typedef struct val* Val; + +struct val { + enum tag tag; + union { + const char *name; + long int value; + }; + Val field[0]; +}; + +struct val nil = {.tag = NIL}; + +#define car(v_) (v_)->field[0] +#define cdr(v_) (v_)->field[1] +#define cadr(v_) car(cdr(v_)) +#define cddr(v_) cdr(cdr(v_)) +#define caddr(v_) car(cddr(v_)) +#define cadddr(v_) cadr(cddr(v_)) +#define list1(v_) cons(v_, &nil) +#define list2(v1_,v2_) cons(v1_, list1(v2_)) +#define list3(v1_,v2_,v3_) cons(v1_, list2(v2_,v3_)) + +Val allocVal(enum tag tag, size_t n) { + Val v = malloc(sizeof(struct val) + n * sizeof(Val)); + v->tag = tag; + return v; +} + +Val cons(Val hd, Val tl) { + Val pair = allocVal(PAIR, 2); + car(pair) = hd; + cdr(pair) = tl; + return pair; +} + +Val makeSymbol(const char* name) { + Val s = allocVal(SYMBOL, 0); + s->name = name; + return s; +} + +Val makeNumber(long int value) { + Val s = allocVal(NUMBER, 0); + s->value = value; + return s; +} + +int symbolEq(Val v, const char *name) { + return v->tag == SYMBOL && strcmp(v->name, name) == 0; +} + +Val reverse(Val list) { + Val rev = &nil; + for (; list->tag == PAIR; list = cdr(list)) + rev = cons(car(list), rev); + return rev; +} + +// == reader == + +const char *source; + +Val readSexp(); + +int describeChar(char c) { + switch (c) { + case ' ': case '\t': case '\r': case '\n': return 0; + case '(': case ')': case '[': case ']': case '{': case '}': + case '.': case '\'': case ';': case '\0': return 1; + default: return 2; + } +} + +void skipWhitespace() { + for (;;) { + if (*source == ';') + do source++; while (*source != '\n'); + if (describeChar(*source) == 0) + source++; + else + break; + } +} + +char *readSymbol() { + const char *start = source; + while (describeChar(*source) == 2) + source++; + if (source == start) + return NULL; + return memcpy(calloc(source - start + 1, 1), start, source - start); +} + +Val readSexpList(char eol) { + skipWhitespace(); + if (*source == eol) { + if (eol != '\0') source++; + return &nil; + } + if (*source == '.') { + source++; + return readSexp(); + } + Val hd = readSexp(); + Val tl = readSexpList(eol); + return cons(hd, tl); +} + +Val readSexp() { + skipWhitespace(); + switch (*source) { + case '(': + source++; + return readSexpList(')'); + case '[': + source++; + return readSexpList(']'); + case '{': + source++; + return readSexpList('}'); + case '\'': + { + source++; + Val data = readSexp(); + static struct val quote = {.tag = SYMBOL, .name = "quote"}; + return list2("e, data); + } + } + char *symbol = readSymbol(), *end; + if (symbol == NULL) { + char who[4] = "EOF"; + if (*source) { + who[1] = *source; + who[0] = who[2] = '\''; + } + fprintf(stderr, "error: unexpected %s\n", who); + exit(1); + } + long int value = strtol(symbol, &end, 10); + if (end != symbol) + return makeNumber(value); + return makeSymbol(symbol); +} + +// == expander == + +#define DEFINE_MAP(MAP,F) \ + Val MAP(Val xs_) { \ + return (xs_->tag == NIL) ? xs_ \ + : cons(F(car(xs_)), MAP(cdr(xs_))); \ + } + +DEFINE_MAP(cars, car); +DEFINE_MAP(cadrs, cadr); + +Val expandExpr(Val sexp); + +DEFINE_MAP(expandList, expandExpr); + +Val expandDefinition(Val sexp, Val *names) { + if (sexp->tag == PAIR) { + if (symbolEq(car(sexp), "define")) { + Val name = cadr(sexp); + Val body = caddr(sexp); + if (name->tag == PAIR) { + // (define (f x ...) e) -> (define f (lambda (x ...) e)) + static struct val lambda = {.tag = SYMBOL, .name = "lambda"}; + body = cons(&lambda, cons(cdr(name), cddr(sexp))); + name = car(name); + } + Val ast = allocVal(SETBANG, 2); + ast->field[0] = name; + ast->field[1] = expandExpr(body); + *names = cons(name, *names); + return ast; + } + } + return expandExpr(sexp); +} + +Val expandBlock(Val list) { + if (list->tag != PAIR) + return expandExpr(&nil); + + Val names = &nil; + Val result = NULL; + for (list = reverse(list); list->tag == PAIR; list = cdr(list)) { + Val expr = expandDefinition(car(list), &names); + if (result != NULL) { + Val seq = allocVal(BEGIN, 2); + seq->field[0] = expr; + seq->field[1] = result; + result = seq; + } else + result = expr; + } + + if (names->tag == NIL) + return result; + + // (define x e) ... b -> ((lambda (x ...) (set! x e) ... b) '() ...) + Val args = &nil; + Val initArg = allocVal(CONSTANT, 1); + initArg->field[0] = &nil; + for (list = names; list->tag == PAIR; list = cdr(list)) + args = cons(initArg, args); + + Val lambda = allocVal(LAMBDA, 2); + lambda->field[0] = names; + lambda->field[1] = result; + return cons(lambda, args); +} + +Val expandExpr(Val sexp) { + if (sexp->tag == SYMBOL) { + return sexp; + } else if (sexp->tag != PAIR) { + Val ast = allocVal(CONSTANT, 1); + ast->field[0] = sexp; + return ast; + } else if (symbolEq(car(sexp), "let")) { + static struct val lambda = {.tag = SYMBOL, .name = "lambda"}; + Val xs = cars(cadr(sexp)); + Val vs = cadrs(cadr(sexp)); + Val body = cddr(sexp); + return expandExpr(cons(cons(&lambda, cons(xs, body)), vs)); + } else if (symbolEq(car(sexp), "let*")) { + static struct val let = {.tag = SYMBOL, .name = "let"}; + static struct val begin = {.tag = SYMBOL, .name = "begin"}; + Val body = cons(&begin, cddr(sexp)); + for (Val vars = reverse(cadr(sexp)); vars->tag == PAIR; vars = cdr(vars)) + body = list3(&let, list1(car(vars)), body); + return expandExpr(body); + } else if (symbolEq(car(sexp), "quote")) { + Val ast = allocVal(CONSTANT, 1); + ast->field[0] = cadr(sexp); + return ast; + } else if (symbolEq(car(sexp), "lambda")) { + Val ast = allocVal(LAMBDA, 2); + ast->field[0] = cadr(sexp); + ast->field[1] = expandBlock(cddr(sexp)); + return ast; + } else if (symbolEq(car(sexp), "if")) { + Val ast = allocVal(IFELSE, 3); + ast->field[0] = expandExpr(cadr(sexp)); + ast->field[1] = expandExpr(caddr(sexp)); + ast->field[2] = expandExpr(cadddr(sexp)); + return ast; + } else if (symbolEq(car(sexp), "begin")) { + return expandBlock(cdr(sexp)); + } else if (symbolEq(car(sexp), "set!")) { + Val ast = allocVal(SETBANG, 2); + ast->field[0] = cadr(sexp); + ast->field[1] = expandExpr(caddr(sexp)); + return ast; + } + return expandList(sexp); +} + +// == binary format == + +char *outputBuf, *outputPtr; +size_t outputSize; + +void initOutput() { + outputBuf = outputPtr = malloc(outputSize = 16); +} + +void writeByte(char c) { + if ((size_t) (outputPtr - outputBuf) >= outputSize) { + outputBuf = realloc(outputBuf, outputSize * 2); + outputPtr = outputBuf + outputSize; + outputSize *= 2; + } + *outputPtr++ = c; +} + +void writeNBytes(const char *bs, size_t len) { + for (size_t i = 0; i < len; i++) + writeByte(bs[i]); +} + +#define writeBytes(bs_) writeNBytes(bs_, sizeof(bs_) - 1) + +void writeUint(unsigned long int n) { + while (n >= 0x80) { + writeByte((n & 0x7F) | 0x80); + n >>= 7; + } + writeByte(n); +} + +void writeSint(long int n) { + while (n >= 0x40 || n <= -0x40) { + writeByte((n & 0x7F) | 0x80); + n >>= 7; + } + writeByte(n & 0x7F); +} + +size_t beginLength() { + size_t pos = outputPtr - outputBuf; + writeBytes("\x80\x80\x80\x80\x00"); + return pos; +} + +void endLength(size_t pos) { + size_t len = outputPtr - outputBuf - pos - 5; + while (len != 0) { + outputBuf[pos++] |= len & 0x7F; + len >>= 7; + } +} + +// == compiler == + +#define TYPE_SECTION 1 +#define FUNC_SECTION 3 +#define TABLE_SECTION 4 +#define EXPORT_SECTION 7 +#define ELEM_SECTION 9 +#define CODE_SECTION 10 + +#define DROP "\x1A" +#define I32_CONST "\x41" +#define I31_REF "\xFB\x1C" +#define I31_ONE I32_CONST "\x01" I31_REF +#define CAST_I32 "\xFB\x16\x6C" "\xFB\x1D" +#define GET_ARGS "\x20\x00" +#define SET_ARGS "\x21\x00" +#define GET_ENV "\x20\x01" +#define SET_ENV "\x21\x01" +#define REF_NULL_ENV "\xD0\x01" +#define REF_NEW_ENV "\xFB\x00\x01" +#define REF_ENV_DROP "\xFB\x02\x01\x00" +#define REF_ENV_POP "\xFB\x02\x01\x01" +#define REF_NULL "\xD0\x6E" +#define REF_IS_NULL "\xD1" +#define REF_CONS "\xFB\x00\x00" +#define CAST_CONS "\xFB\x16\x00" +#define REF_CAR CAST_CONS "\xFB\x02\x00\x00" +#define REF_CDR CAST_CONS "\xFB\x02\x00\x01" +#define SET_CAR "\xFB\x05\x00\x00" +#define REF_NEW_PROC "\xFB\x00\x02" +#define REF_PROC_PREPARE_CALL \ + "\xFB\x16\x02" "\x22\x02" "\xFB\x02\x02\x00" "\x20\x02" "\xFB\x02\x02\x01" +#define CALL_INDIRECT "\x11\x03\x00" +#define RETURN_CALL_INDIRECT "\x13\x03\x00" +#define CALL "\x10" +#define IF "\x04\x6E" +#define ELSE "\x05" +#define END "\x0B" + +struct func { + size_t type; + const char *expr; + size_t exprLen; +}; + +struct func funcs[1024]; +size_t funcsLen = 0; + +size_t elems[1024]; +size_t elemsLen = 0; + +Val env[1024]; +Val *envPtr = env; +int tail = 0; + +void compileExpr(Val expr); + +size_t compileProc(Val expr) { + char *prevOutputPtr = outputPtr; + char *prevOutputBuf = outputBuf; + size_t prevOutputSize = outputSize; + initOutput(); + // (local 2 (ref $ENV)) + writeBytes("\x01\x01\x64\x02"); + int prevTail = tail; + tail = 0; + writeBytes(GET_ENV GET_ARGS REF_NEW_ENV SET_ENV); + compileExpr(expr); + writeBytes(END); + tail = prevTail; + size_t funcidx = funcsLen++; + funcs[funcidx].type = 3; + funcs[funcidx].expr = outputBuf; + funcs[funcidx].exprLen = outputPtr - outputBuf; + outputPtr = prevOutputPtr; + outputBuf = prevOutputBuf; + outputSize = prevOutputSize; + return funcidx; +} + +void compileVar(const char *name) { + Val *envPos = envPtr; + writeBytes(GET_ENV); + while (envPos > env) { + Val params = *--envPos; + for (size_t offset = 0; params->tag != NIL; offset++) { + if (symbolEq(car(params), name)) { + writeBytes(REF_ENV_POP); + for (size_t i = 0; i < offset; i++) + writeBytes(REF_CDR); + return; + } + params = cdr(params); + } + writeBytes(REF_ENV_DROP); + } + fprintf(stderr, "variable not found: %s\n", name); + exit(1); +} + +void compileConstant(Val data) { + // TODO(?): symbol + if (data->tag == NUMBER) { + writeBytes(I32_CONST); + writeSint(data->value); + writeBytes(I31_REF); + } else if (data->tag == PAIR) { + compileConstant(car(data)); + compileConstant(cdr(data)); + writeBytes(REF_CONS); + } else + writeBytes(REF_NULL); +} + +void compileList(Val list) { + if (list->tag == PAIR) { + compileExpr(car(list)); + compileList(cdr(list)); + writeBytes(REF_CONS); + } else + writeBytes(REF_NULL); +} + +void compileExpr(Val expr) { + if (expr->tag == CONSTANT) { + compileConstant(expr->field[0]); + } else if (expr->tag == SYMBOL) { + compileVar(expr->name); + writeBytes(REF_CAR); + } else if (expr->tag == SETBANG) { + compileVar(expr->field[0]->name); + writeBytes(CAST_CONS); + tail++; + compileExpr(expr->field[1]); + tail--; + writeBytes(SET_CAR REF_NULL); + } else if (expr->tag == LAMBDA) { + *envPtr++ = expr->field[0]; + size_t funcidx = compileProc(expr->field[1]); + envPtr--; + size_t elemidx = elemsLen++; + elems[elemidx] = funcidx; + writeBytes(GET_ENV I32_CONST); + writeSint(elemidx); + writeBytes(REF_NEW_PROC); + } else if (expr->tag == IFELSE) { + tail++; + compileExpr(expr->field[0]); + tail--; + writeBytes(REF_IS_NULL IF); + compileExpr(expr->field[2]); + writeBytes(ELSE); + compileExpr(expr->field[1]); + writeBytes(END); + } else if (expr->tag == BEGIN) { + tail++; + compileExpr(expr->field[0]); + tail--; + writeBytes(DROP); + compileExpr(expr->field[1]); + } else { + tail++; + compileList(cdr(expr)); + compileExpr(car(expr)); + tail--; + writeBytes(REF_PROC_PREPARE_CALL); + if (tail == 0) + writeBytes(RETURN_CALL_INDIRECT); + else + writeBytes(CALL_INDIRECT); + } +} + +// == builtin environment == + +void builtinN(const char *name, const char *expr, size_t exprLen) { + size_t funcidx = funcsLen++; + funcs[funcidx].type = 3; + funcs[funcidx].expr = expr; + funcs[funcidx].exprLen = exprLen; + size_t elemidx = elemsLen++; + elems[elemidx] = funcidx; + writeBytes(REF_NULL_ENV I32_CONST); writeSint(elemidx); + writeBytes(REF_NEW_PROC GET_ARGS REF_CONS SET_ARGS); + *envPtr = cons(makeSymbol(name), *envPtr); +} + +#define builtin(name_,expr_) builtinN(name_, expr_, sizeof(expr_) - 1) + +void initBuiltins() { +#define CODE_INT_OP2(op_) "\0" \ + GET_ARGS REF_CAR CAST_I32 \ + GET_ARGS REF_CDR REF_CAR CAST_I32 \ + op_ I31_REF END + +#define CODE_INT_CMP(op_) "\0" \ + GET_ARGS REF_CAR CAST_I32 \ + GET_ARGS REF_CDR REF_CAR CAST_I32 \ + op_ IF I31_ONE ELSE REF_NULL END END + + builtin("<", CODE_INT_CMP("\x48")); + builtin(">", CODE_INT_CMP("\x4A")); + builtin("=", CODE_INT_CMP("\x46")); + builtin("*", CODE_INT_OP2("\x6C")); + builtin("-", CODE_INT_OP2("\x6B")); + builtin("+", CODE_INT_OP2("\x6A")); + builtin("car", "\0" GET_ARGS REF_CAR REF_CAR END); + builtin("cdr", "\0" GET_ARGS REF_CAR REF_CDR END); + builtin("cons", "\0" GET_ARGS REF_CAR GET_ARGS REF_CDR REF_CAR REF_CONS END); + builtin("null?", "\0" GET_ARGS REF_CAR REF_IS_NULL IF I31_ONE ELSE REF_NULL END END); +} + +int main() { + char sourceBuf[1048576]; + size_t sourceLen = fread(sourceBuf, 1, 1048575, stdin); + sourceBuf[sourceLen] = '\0'; + source = sourceBuf; + Val main = readSexpList('\0'); + main = expandBlock(main); + + size_t startFuncidx = funcsLen++; + initOutput(); + // (local 0 anyref) + writeBytes("\x01\x01\x6E"); + *envPtr = &nil; + initBuiltins(); + envPtr++; + writeBytes(GET_ARGS REF_NULL_ENV CALL); + writeUint(compileProc(main)); + writeBytes(CAST_I32 END); + funcs[startFuncidx].type = 4; + funcs[startFuncidx].expr = outputBuf; + funcs[startFuncidx].exprLen = outputPtr - outputBuf; + + size_t sectionLen; + initOutput(); + writeBytes("\0asm\1\0\0\0"); + + writeByte(TYPE_SECTION); + sectionLen = beginLength(); + writeUint(5); + // (type $PAIR 0 (struct (field mut anyref) (field mut anyref))) + writeBytes("\x5F\x02\x6E\x01\x6E\x01"); + // (type $ENV 1 (struct (field (ref null $ENV)) (field anyref))) + writeBytes("\x5F\x02\x63\x01\x00\x6E\x00"); + // (type $PROC 2 (struct (field (ref null $ENV)) (field i32))) + writeBytes("\x5F\x02\x63\x01\x00\x7F\x00"); + // (type 3 (func (param anyref) (param (ref null $ENV)) (result anyref))) + writeBytes("\x60\x02\x6E\x63\x01\x01\x6E"); + // (type 4 (func (result i32))) + writeBytes("\x60\x00\x01\x7F"); + endLength(sectionLen); + + writeByte(FUNC_SECTION); + sectionLen = beginLength(); + writeUint(funcsLen); + for (size_t i = 0; i < funcsLen; i++) + writeUint(funcs[i].type); + endLength(sectionLen); + + writeByte(TABLE_SECTION); + sectionLen = beginLength(); + // (table 0 funcref) + writeBytes("\x01\x70\x00"); + writeUint(elemsLen); + endLength(sectionLen); + + writeByte(EXPORT_SECTION); + sectionLen = beginLength(); + // (export "start" (func )) + writeBytes("\x01\x05""start\x00"); + writeUint(startFuncidx); + endLength(sectionLen); + + writeByte(ELEM_SECTION); + sectionLen = beginLength(); + // (elem 0 (i32.const 0) func *) + writeBytes("\x01\x00\x41\x00\x0B"); + writeUint(elemsLen); + for (size_t i = 0; i < elemsLen; i++) + writeUint(elems[i]); + endLength(sectionLen); + + writeByte(CODE_SECTION); + sectionLen = beginLength(); + writeUint(funcsLen); + for (size_t i = 0; i < funcsLen; i++) { + writeUint(funcs[i].exprLen); + writeNBytes(funcs[i].expr, funcs[i].exprLen); + } + endLength(sectionLen); + + char *writePtr = outputBuf; + while (writePtr < outputPtr) + writePtr += fwrite(writePtr, 1, outputPtr - writePtr, stdout); +}