From f454526c5a79944853669b0213c5ce623c3a1bcf Mon Sep 17 00:00:00 2001 From: Moritz Gmeiner Date: Sun, 30 Mar 2025 15:38:38 +0200 Subject: [PATCH] deserialiser --- src/deserialise.zig | 667 ++++++++++++++++++++++++++++++++++++++++++++ src/root.zig | 52 +++- src/utils.zig | 111 ++++++++ 3 files changed, 827 insertions(+), 3 deletions(-) create mode 100644 src/deserialise.zig create mode 100644 src/utils.zig diff --git a/src/deserialise.zig b/src/deserialise.zig new file mode 100644 index 0000000..9fbc739 --- /dev/null +++ b/src/deserialise.zig @@ -0,0 +1,667 @@ +const std = @import("std"); + +const utils = @import("utils.zig"); + +const Object = @import("root.zig").Object; +const MapEntry = @import("root.zig").MapEntry; + +const ObjectLen = struct { + bytes_read: usize, + obj: Object, +}; + +const DeserialiseError = error{ + BadLength, + InvalidTag, + IntOutOfRange, + InvalidUtf8, + NotImplemented, + OutOfMemory, +}; + +fn deserialise_float(bytes: []const u8) DeserialiseError!ObjectLen { + std.debug.assert(bytes.len != 0); + + const tag = bytes[0]; + + const payload = bytes[1..]; + + std.debug.assert(tag >> 1 == 0b1100101); + + switch (@as(u1, @intCast(tag & 0b1))) { + 0b0 => { // float 32 + if (payload.len != 4) { + return error.BadLength; + } + + const f: f32 = @bitCast(std.mem.readInt(u32, payload[0..4], .big)); + + return .{ .bytes_read = 5, .obj = .{ .float = f } }; + }, + 0b1 => { // float 64 + if (payload.len != 8) { + return error.BadLength; + } + + const f: f64 = @bitCast(std.mem.readInt(u64, payload[0..8], .big)); + + return .{ .bytes_read = 9, .obj = .{ .float = f } }; + }, + } +} + +fn deserialise_uint(bytes: []const u8) DeserialiseError!ObjectLen { + std.debug.assert(bytes.len != 0); + + const tag = bytes[0]; + + const payload = bytes[1..]; + + std.debug.assert(tag >> 2 == 0b110011); + + switch (@as(u2, @intCast(tag & 0b11))) { + 0b00 => { + if (payload.len != 1) { + return error.BadLength; + } + + const obj = Object{ .integer = @as(i64, payload[0]) }; + + return .{ .bytes_read = 2, .obj = obj }; + }, + + 0b01 => { + if (payload.len != 2) { + return error.BadLength; + } + + const obj = Object{ .integer = std.mem.readInt(u16, payload[0..2], .big) }; + + return .{ .bytes_read = 3, .obj = obj }; + }, + + 0b10 => { + if (payload.len != 4) { + return error.BadLength; + } + + const obj = Object{ .integer = std.mem.readInt(u32, payload[0..4], .big) }; + + return .{ .bytes_read = 5, .obj = obj }; + }, + + 0b11 => { + if (payload.len != 8) { + return error.BadLength; + } + + const i = std.mem.readInt(u64, payload[0..8], .big); + + if (i > std.math.maxInt(i64)) { + return error.IntOutOfRange; + } + + const obj = Object{ .integer = @intCast(i) }; + + return .{ .bytes_read = 9, .obj = obj }; + }, + } +} + +fn deserialise_int(bytes: []const u8) DeserialiseError!ObjectLen { + std.debug.assert(bytes.len != 0); + + const tag = bytes[0]; + + const payload = bytes[1..]; + + std.debug.assert(tag >> 2 == 110100); + + switch (@as(u2, @intCast(tag & 0b11))) { + 0b00 => { + if (payload.len != 1) { + return error.BadLength; + } + + const obj = Object{ .integer = @as(i64, @as(i8, @bitCast(payload[0]))) }; + + return .{ .bytes_read = 2, .obj = obj }; + }, + + 0b01 => { + if (payload.len != 2) { + return error.BadLength; + } + + const obj = Object{ .integer = std.mem.readInt(i16, payload[0..2], .big) }; + + return .{ .bytes_read = 3, .obj = obj }; + }, + + 0b10 => { + if (payload.len != 4) { + return error.BadLength; + } + + const obj = Object{ .integer = std.mem.readInt(i32, payload[0..4], .big) }; + + return .{ .bytes_read = 5, .obj = obj }; + }, + + 0b11 => { + if (payload.len != 8) { + return error.BadLength; + } + + const obj = Object{ .integer = std.mem.readInt(i64, payload[0..8], .big) }; + + return .{ .bytes_read = 9, .obj = obj }; + }, + } +} + +fn deserialise_raw(alloc: std.mem.Allocator, bytes: []const u8, len: usize, comptime kind: enum { str, binary }) DeserialiseError!Object { + std.debug.assert(bytes.len == len); + + if (!utils.validateUtf8(bytes)) { + return error.InvalidUtf8; + } + + const s = try alloc.dupe(u8, bytes); + + const obj = switch (kind) { + .str => Object{ .raw = .{ .string = s } }, + .binary => Object{ .raw = .{ .binary = s } }, + }; + + return obj; +} + +fn deserialise_array(alloc: std.mem.Allocator, bytes: []const u8, len: usize) DeserialiseError!ObjectLen { + var array = try alloc.alloc(Object, len); + errdefer alloc.free(array); + + var bytes_read: usize = 0; + + for (0..len, array) |i, *obj| { + if (bytes_read > bytes.len) { + return error.BadLength; + } + + const r = deserialise_with_count(alloc, bytes[bytes_read..]) catch |err| { + // on error: deinit previous objects (up to i), then return the error + for (0..i) |j| { + array[j].deinit(alloc); + } + + return err; + }; + + bytes_read += r.bytes_read; + + obj.* = r.obj; + } + + return .{ .bytes_read = bytes_read, .obj = .{ .array = array } }; +} + +fn deserialise_map(alloc: std.mem.Allocator, bytes: []const u8, len: usize) DeserialiseError!ObjectLen { + var array = try alloc.alloc(MapEntry, len); + errdefer alloc.free(array); + + var bytes_read: usize = 0; + + for (0..len, array) |i, *entry| { + if (bytes_read > bytes.len) { + return error.BadLength; + } + + const r1 = deserialise_with_count(alloc, bytes[bytes_read..]) catch |err| { + // on error: deinit previous objects (up to i), then return the error + for (0..i) |j| { + array[j].deinit(alloc); + } + + return err; + }; + + bytes_read += r1.bytes_read; + + entry.*.key = r1.obj; + + const r2 = deserialise_with_count(alloc, bytes[bytes_read..]) catch |err| { + // on error: deinit previous objects (up to i) and current key, then return the error + for (0..i) |j| { + array[j].deinit(alloc); + } + + entry.key.deinit(alloc); + + return err; + }; + + bytes_read += r2.bytes_read; + + entry.*.value = r2.obj; + } + + return .{ .bytes_read = bytes_read, .obj = .{ .map = array } }; +} + +fn deserialise_ext(alloc: std.mem.Allocator, type_: u8, data: []const u8, len: usize) error{OutOfMemory}!Object { + std.debug.assert(data.len == len); + + const bytes = try alloc.dupe(u8, data); + + return .{ .extension = .{ .type = type_, .bytes = bytes } }; +} + +pub fn deserialise_with_count(alloc: std.mem.Allocator, bytes: []const u8) DeserialiseError!ObjectLen { + if (bytes.len == 0) { + return error.BadLength; + } + + const tag = bytes[0]; + + if (tag >> 7 == 0) { + // positive fixint + return .{ .bytes_read = 1, .obj = .{ .integer = @as(i64, tag) } }; + } + + if (tag >> 5 == 0b111) { + // negative fixint + return .{ .bytes_read = 1, .obj = .{ .integer = @as(i64, @as(i8, @bitCast(tag))) } }; + } + + const payload = bytes[1..]; + + if (tag >> 6 == 0b10) { + // these are the remaining tags that store information inside the tag, + // do them all inside this branch so we can use a fixed switch on the rest + + if (tag >> 5 == 0b101) { // fixstr + const len: usize = tag & 0b11111; + + if (payload.len < len) { + return error.BadLength; + } + + const obj = try deserialise_raw(alloc, payload[0..len], len, .str); + + return .{ .bytes_read = 1 + len, .obj = obj }; + } + + if (tag >> 4 == 0b1001) { // fixarray + const len: usize = tag & 0b1111; + + var r = try deserialise_array(alloc, payload, len); + + // add 1 byte for tag to length + r.bytes_read += 1; + + return r; + } + + if (tag >> 4 == 0b1000) { // fixmap + const len: usize = tag & 0b1111; + + var r = try deserialise_map(alloc, payload, len); + + // add 1 bytes for tag to length + r.bytes_read += 1; + + return r; + } + + return error.NotImplemented; + } + + if (tag >> 1 == 0b1100101) { + return deserialise_float(bytes); + } + + if (tag >> 2 == 0b110011) { + return deserialise_uint(bytes); + } + + if (tag >> 2 == 0b110100) { + return deserialise_int(bytes); + } + + // from this point on the tag should be fixed, i.e. not contain additional information + + switch (tag) { + 0b11000000 => return .{ .bytes_read = 1, .obj = .nil }, + 0b11000001 => return error.InvalidTag, + + 0b11000010 => return .{ .bytes_read = 1, .obj = .{ .bool = false } }, + 0b11000011 => return .{ .bytes_read = 1, .obj = .{ .bool = true } }, + + 0b11000100 => { // bin 8 + if (payload.len == 0) { + return error.BadLength; + } + + const len: usize = payload[0]; + + const obj = try deserialise_raw(alloc, payload[1..][0..len], len, .binary); + + return .{ .bytes_read = 1 + 1 + len, .obj = obj }; + }, + 0b11000101 => { // bin 16 + if (payload.len < 2) { + return error.BadLength; + } + + const len: usize = std.mem.readInt(u16, payload[0..2], .big); + + const obj = try deserialise_raw(alloc, payload[2..][0..len], len, .binary); + + return .{ .bytes_read = 1 + 2 + len, .obj = obj }; + }, + 0b11000110 => { // bin 32 + if (payload.len < 4) { + return error.BadLength; + } + + const len: usize = std.mem.readInt(u32, payload[0..4], .big); + + const obj = try deserialise_raw(alloc, payload[4..][0..len], len, .binary); + + return .{ .bytes_read = 1 + 4 + len, .obj = obj }; + }, + + 0b11000111 => { // ext 8 + if (payload.len < 1 + 1) { + return error.BadLength; + } + + const len: usize = payload[0]; + + const type_: u8 = payload[1]; + + const obj = try deserialise_ext(alloc, type_, payload[2..][0..len], len); + + return .{ .bytes_read = 1 + 1 + 1 + len, .obj = obj }; + }, + 0b11001000 => { // ext 16 + if (payload.len < 1 + 2) { + return error.BadLength; + } + + const len: usize = std.mem.readInt(u16, payload[0..2], .big); + + const type_: u8 = payload[2]; + + const obj = try deserialise_ext(alloc, type_, payload[3..][0..len], len); + + return .{ .bytes_read = 1 + 2 + 1 + len, .obj = obj }; + }, + 0b11001001 => { // ext 32 + if (payload.len < 1 + 4) { + return error.BadLength; + } + + const len: usize = std.mem.readInt(u32, payload[0..4], .big); + + const type_: u8 = payload[4]; + + const obj = try deserialise_ext(alloc, type_, payload[5..][0..len], len); + + return .{ .bytes_read = 1 + 4 + 1 + len, .obj = obj }; + }, + + 0b11010100 => { // fixext 1 + if (payload.len < 1 + 1) { + return error.BadLength; + } + + const type_: u8 = payload[0]; + + const obj = try deserialise_ext(alloc, type_, payload[1..][0..1], 1); + + return .{ .bytes_read = 1 + 1 + 1, .obj = obj }; + }, + 0b11010101 => { // fixext 2 + if (payload.len < 1 + 2) { + return error.BadLength; + } + + const type_: u8 = payload[0]; + + const obj = try deserialise_ext(alloc, type_, payload[1..][0..2], 2); + + return .{ .bytes_read = 1 + 1 + 2, .obj = obj }; + }, + 0b11010110 => { // fixext 4 + if (payload.len < 1 + 4) { + return error.BadLength; + } + + const type_: u8 = payload[0]; + + const obj = try deserialise_ext(alloc, type_, payload[1..][0..4], 4); + + return .{ .bytes_read = 1 + 1 + 4, .obj = obj }; + }, + 0b11010111 => { // fixext 8 + if (payload.len < 1 + 8) { + return error.BadLength; + } + + const type_: u8 = payload[0]; + + const obj = try deserialise_ext(alloc, type_, payload[1..][0..8], 8); + + return .{ .bytes_read = 1 + 1 + 8, .obj = obj }; + }, + 0b11011000 => { // fixext 16 + if (payload.len < 1 + 16) { + return error.BadLength; + } + + const type_: u8 = payload[0]; + + const obj = try deserialise_ext(alloc, type_, payload[1..][0..16], 8); + + return .{ .bytes_read = 1 + 1 + 16, .obj = obj }; + }, + + 0b11011001 => { // string 8 + if (payload.len == 0) { + return error.BadLength; + } + + const len: usize = payload[0]; + + const obj = try deserialise_raw(alloc, payload[1..], len, .str); + + return .{ .bytes_read = 1 + 1 + len, .obj = obj }; + }, + 0b11011010 => { // string 16 + if (payload.len < 2) { + return error.BadLength; + } + + const len: usize = std.mem.readInt(u16, payload[0..2], .big); + + const obj = try deserialise_raw(alloc, payload[2..], len, .str); + + return .{ .bytes_read = 1 + 2 + len, .obj = obj }; + }, + 0b11011011 => { // string 32 + if (payload.len < 4) { + return error.BadLength; + } + + const len: usize = std.mem.readInt(u32, payload[0..4], .big); + + const obj = try deserialise_raw(alloc, payload[4..], len, .str); + + return .{ .bytes_read = 1 + 4 + len, .obj = obj }; + }, + + 0b11011100 => { // array 16 + const len: usize = std.mem.readInt(u16, payload[0..2], .big); + + var r = try deserialise_array(alloc, payload[2..], len); + + // add 1 byte for tag and 2 bytes for length to length + r.bytes_read += 1 + 2; + + return r; + }, + 0b11011101 => { // array 32 + const len: usize = std.mem.readInt(u32, payload[0..4], .big); + + var r = try deserialise_array(alloc, payload[4..], len); + + // add 1 byte for tag and 2 bytes for length to length + r.bytes_read += 1 + 4; + + return r; + }, + + 0b11011110 => { // map 16 + const len: usize = std.mem.readInt(u16, payload[0..2], .big); + + var r = try deserialise_map(alloc, payload[2..], len); + + // add 1 byte for tag and 2 bytes for length to length + r.bytes_read += 1 + 2; + + return r; + }, + 0b11011111 => { // map 32 + const len: usize = std.mem.readInt(u32, payload[0..4], .big); + + var r = try deserialise_map(alloc, payload[4..], len); + + // add 1 byte for tag and 2 bytes for length to length + r.bytes_read += 1 + 4; + + return r; + }, + + else => |_| { + var buf: [1024]u8 = undefined; + + @panic(std.fmt.bufPrint(&buf, "Unexpected tag {b} in switch", .{tag}) catch unreachable); + }, + } +} + +pub fn deserialise(alloc: std.mem.Allocator, bytes: []const u8) DeserialiseError!Object { + const r = try deserialise_with_count(alloc, bytes); + + std.debug.assert(r.bytes_read == bytes.len); + + return r.obj; +} + +test "pos fixint" { + const alloc = std.testing.allocator; + + const bytes = [1]u8{0x07}; + + const obj = try deserialise(alloc, &bytes); + + try std.testing.expectEqual(Object{ .integer = 0x07 }, obj); +} + +test "neg fixint" { + const alloc = std.testing.allocator; + + const bytes = [1]u8{0b111_11111}; + + const obj = try deserialise(alloc, &bytes); + + try std.testing.expectEqual(Object{ .integer = -1 }, obj); +} + +test "raw" { + const alloc = std.testing.allocator; + + { + const bytes = [_]u8{ 0xc4, 0x03, 'A', 'B', 'C' }; + + const obj = try deserialise(alloc, &bytes); + defer obj.deinit(alloc); + + switch (obj) { + .raw => |raw| { + switch (raw) { + .binary => |s| try std.testing.expectEqualStrings("ABC", s), + .string => return error.TestExpectedEqual, + } + }, + else => return error.TestExpectedEqual, + } + } + + { + const bytes = [_]u8{ 0xc4, 0x00 }; + + const obj = try deserialise(alloc, &bytes); + defer obj.deinit(alloc); + + switch (obj) { + .raw => |raw| { + switch (raw) { + .binary => |s| try std.testing.expectEqualStrings("", s), + .string => return error.TestExpectedEqual, + } + }, + else => return error.TestExpectedEqual, + } + } + + { + var bytes: [2 + 255]u8 = undefined; + + bytes[0] = 0xc4; + bytes[1] = 0xff; + + for (bytes[2..]) |*c| { + c.* = 'A'; + } + + const obj = try deserialise(alloc, &bytes); + defer obj.deinit(alloc); + + switch (obj) { + .raw => |raw| { + switch (raw) { + .binary => |s| try std.testing.expectEqualStrings("A" ** 255, s), + .string => return error.TestExpectedEqual, + } + }, + else => return error.TestExpectedEqual, + } + } + + { + var bytes: [3 + 256]u8 = undefined; + + bytes[0] = 0xc5; + bytes[1] = 0x01; + bytes[2] = 0x00; + + for (bytes[3..]) |*c| { + c.* = 'A'; + } + + const obj = try deserialise(alloc, &bytes); + defer obj.deinit(alloc); + + switch (obj) { + .raw => |raw| { + switch (raw) { + .binary => |s| try std.testing.expectEqualStrings("A" ** 256, s), + .string => return error.TestExpectedEqual, + } + }, + else => return error.TestExpectedEqual, + } + } +} diff --git a/src/root.zig b/src/root.zig index 4ea96eb..2f19975 100644 --- a/src/root.zig +++ b/src/root.zig @@ -1,21 +1,67 @@ const std = @import("std"); -const Raw = union(enum) { +pub const deserialise = @import("deserialise.zig"); + +pub const Raw = union(enum) { string: []u8, binary: []u8, + + pub fn deinit(self: Raw, alloc: std.mem.Allocator) void { + switch (self) { + .string => |s| alloc.free(s), + .binary => |b| alloc.free(b), + } + } }; -const Object = union(enum) { +pub const MapEntry = struct { + key: Object, + value: Object, + + pub fn deinit(self: MapEntry, alloc: std.mem.Allocator) void { + self.key.deinit(alloc); + self.value.deinit(alloc); + } +}; + +pub const Object = union(enum) { nil, bool: bool, integer: i64, float: f64, raw: Raw, array: []Object, - map: []struct { key: Object, value: Object }, + map: []MapEntry, extension: struct { type: u8, bytes: []u8 }, + + pub fn deinit(self: Object, alloc: std.mem.Allocator) void { + switch (self) { + .raw => |raw| raw.deinit(alloc), + .array => |array| { + for (array) |x| { + x.deinit(alloc); + } + + alloc.free(array); + }, + .map => |map| { + for (map) |elem| { + elem.key.deinit(alloc); + elem.value.deinit(alloc); + } + + alloc.free(map); + }, + .extension => |ext| alloc.free(ext.bytes), + else => {}, + } + } }; +test { + std.testing.refAllDecls(@This()); +} + test { const o: Object = .nil; diff --git a/src/utils.zig b/src/utils.zig new file mode 100644 index 0000000..f5fea6a --- /dev/null +++ b/src/utils.zig @@ -0,0 +1,111 @@ +const std = @import("std"); + +// https://tools.ietf.org/html/rfc3629 +// 0 means invalid +const UTF8_CHAR_WIDTH: [256]u8 = [256]u8{ + // 1 2 3 4 5 6 7 8 9 A B C D E F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0 + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 1 + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 2 + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 3 + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 4 + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 5 + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 6 + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 7 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 8 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 9 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // A + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // B + 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // C + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // D + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, // E + 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // F +}; + +pub fn validateUtf8(bytes: []const u8) bool { + // TODO: implement this + // _ = bytes; + + // return true; + + var idx: usize = 0; + + while (idx < bytes.len) { + if (bytes[idx] < 128) { + // fast-path the ASCII case + // TODO: use SIMD here + + idx += 1; + while (idx < bytes.len and bytes[idx] < 128) { + idx += 1; + } + + continue; + } + + const width = UTF8_CHAR_WIDTH[bytes[idx]]; + + std.debug.assert(width != 1); // should be already handled + + if (width == 0) { + return false; + } + + // check we have at least w-1 bytes remaining after idx + if (bytes.len < idx + width - 1) { + return false; + } + + const first = bytes[idx]; + const second = bytes[idx + 1]; + + const max_continuation = 0b10_111111; + + switch (width) { + 2 => { + if (second > max_continuation) { + return false; + } + + idx += 2; + }, + 3 => { + const third = bytes[idx + 2]; + + if (second > max_continuation or third > max_continuation) { + return false; + } + + // overlong encoding + if (first == 0xE0 and second < 0xA0) { + return false; + } + + // surrogates + if (first == 0xED and second >= 0xA0) { + return false; + } + + idx += 3; + }, + 4 => { + const third = bytes[idx + 2]; + const fourth = bytes[idx + 3]; + + if (second > max_continuation or third > max_continuation or fourth > max_continuation) { + return false; + } + + // overlong encoding + if (first == 0xF0 and second < 0x90) { + return false; + } + + idx += 4; + }, + else => unreachable, // should be handled above + } + } + + return true; +}