diff --git a/build.gradle b/build.gradle index ab35c0fc80..ede820b334 100644 --- a/build.gradle +++ b/build.gradle @@ -533,9 +533,10 @@ task generateRustCarExample(type: JavaExec) { classpath = project(':sbe-all').sourceSets.main.runtimeClasspath systemProperties( 'sbe.output.dir': 'rust/car_example/src', + 'sbe.xinclude.aware': 'true', 'sbe.target.language': 'uk.co.real_logic.sbe.generation.rust.Rust', 'sbe.target.namespace': 'car_example_generated_codec') - args = ['sbe-tool/src/test/resources/example-schema.xml'] + args = ['sbe-samples/src/main/resources/example-schema.xml'] } task generateCarExampleDataFile(type: JavaExec) { diff --git a/rust/car_example/src/main.rs b/rust/car_example/src/main.rs index 5562a1eb5a..fcc980edb4 100644 --- a/rust/car_example/src/main.rs +++ b/rust/car_example/src/main.rs @@ -45,11 +45,11 @@ impl std::convert::From for IoError { fn decode_car_and_assert_expected_content(buffer: &[u8]) -> CodecResult<()> { let (h, dec_fields) = start_decoding_car(&buffer).header()?; - assert_eq!(49u16, h.block_length); + assert_eq!(45u16, {h.block_length}); assert_eq!(h.block_length as usize, ::std::mem::size_of::()); - assert_eq!(1u16, h.template_id); - assert_eq!(1u16, h.schema_id); - assert_eq!(0u16, h.version); + assert_eq!(1u16, {h.template_id}); + assert_eq!(1u16, {h.schema_id}); + assert_eq!(0u16, {h.version}); println!("Header read"); assert_eq!(Model::C, CarFields::discounted_model()); @@ -59,16 +59,16 @@ fn decode_car_and_assert_expected_content(buffer: &[u8]) -> CodecResult<()> { let mut found_fuel_figures = Vec::::with_capacity(EXPECTED_FUEL_FIGURES.len()); let (fields, dec_fuel_figures_header) = dec_fields.car_fields()?; - assert_eq!(1234, fields.serial_number); - assert_eq!(2013, fields.model_year); - assert_eq!(BooleanType::T, fields.available); - assert_eq!([97_i8, 98, 99, 100, 101, 102], fields.vehicle_code); // abcdef - assert_eq!([0_u32, 1, 2, 3, 4], fields.some_numbers); + assert_eq!(1234, {fields.serial_number}); + assert_eq!(2013, {fields.model_year}); + assert_eq!(BooleanType::T, {fields.available}); + assert_eq!([97_i8, 98, 99, 100, 101, 102], {fields.vehicle_code}); // abcdef + assert_eq!([1, 2, 3, 4], {fields.some_numbers}); assert_eq!(6, fields.extras.0); assert!(fields.extras.get_cruise_control()); assert!(fields.extras.get_sports_pack()); assert!(!fields.extras.get_sun_roof()); - assert_eq!(2000, fields.engine.capacity); + assert_eq!(2000, {fields.engine.capacity}); assert_eq!(4, fields.engine.num_cylinders); assert_eq!(BoostType::NITROUS, fields.engine.booster.boost_type); assert_eq!(200, fields.engine.booster.horse_power); @@ -82,12 +82,12 @@ fn decode_car_and_assert_expected_content(buffer: &[u8]) -> CodecResult<()> { let (usage_description, next_step) = dec_usage_description.usage_description()?; let usage_str = std::str::from_utf8(usage_description).unwrap(); println!("Fuel Figure: Speed: {0}, MPG: {1}, Usage: {2}", - ff_fields.speed, - ff_fields.mpg, + {ff_fields.speed}, + {ff_fields.mpg}, usage_str); found_fuel_figures.push(FuelFigure { - speed: ff_fields.speed, - mpg: ff_fields.mpg, + speed: {ff_fields.speed}, + mpg: {ff_fields.mpg}, usage_description: usage_str, }); match next_step { @@ -117,8 +117,8 @@ fn decode_car_and_assert_expected_content(buffer: &[u8]) -> CodecResult<()> { let (accel_slice, next_step) = dec_acceleration_header.acceleration_as_slice()?; for accel_fields in accel_slice { println!("Acceleration: MPH: {0}, Seconds: {1}", - accel_fields.mph, - accel_fields.seconds); + {accel_fields.mph}, + {accel_fields.seconds}); } match next_step { Either::Left(more_members) => dec_pf_members = more_members, @@ -165,7 +165,7 @@ fn encode_car_from_scratch() -> CodecResult> { fields.available = BooleanType::T; fields.code = Model::A; fields.vehicle_code = [97_i8, 98, 99, 100, 101, 102]; // abcdef - fields.some_numbers = [0_u32, 1, 2, 3, 4]; + fields.some_numbers = [1, 2, 3, 4]; fields.extras = OptionalExtras::new(); fields.extras.set_cruise_control(true) .set_sports_pack(true) @@ -274,4 +274,4 @@ const EXPECTED_PERF_FIXTURES: &'static [PerfFigure] = &[PerfFigure { mph: 100, seconds: 11.8, }], - }]; \ No newline at end of file + }]; diff --git a/sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/RustCodecType.java b/sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/RustCodecType.java index 9ae2e8139b..55f71d5bc4 100644 --- a/sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/RustCodecType.java +++ b/sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/RustCodecType.java @@ -44,12 +44,18 @@ void appendDirectCodeMethods( final String methodName, final String representationType, final String nextCoderType, - final int numBytes) throws IOException + final int numBytes, + final int trailingBytes) throws IOException { indent(appendable, 1, "pub fn %s(mut self) -> CodecResult<(&%s %s, %s)> {\n", methodName, DATA_LIFETIME, representationType, RustGenerator.withLifetime(nextCoderType)); indent(appendable, 2, "let v = self.%s.read_type::<%s>(%s)?;\n", RustCodecType.Decoder.scratchProperty(), representationType, numBytes); + if (trailingBytes > 0) + { + indent(appendable, 2, "self.%s.skip_bytes(%s)?;\n", + RustCodecType.Decoder.scratchProperty(), trailingBytes); + } indent(appendable, 2, "Ok((v, %s::wrap(self.%s)))\n", nextCoderType, RustCodecType.Decoder.scratchProperty()); indent(appendable).append("}\n"); @@ -78,15 +84,20 @@ void appendDirectCodeMethods( final String methodName, final String representationType, final String nextCoderType, - final int numBytes) throws IOException + final int numBytes, + final int trailingBytes) throws IOException { indent(appendable, 1, "\n/// Create a mutable struct reference overlaid atop the data buffer\n"); indent(appendable, 1, "/// such that changes to the struct directly edit the buffer. \n"); indent(appendable, 1, "/// Note that the initial content of the struct's fields may be garbage.\n"); indent(appendable, 1, "pub fn %s(mut self) -> CodecResult<(&%s mut %s, %s)> {\n", methodName, DATA_LIFETIME, representationType, RustGenerator.withLifetime(nextCoderType)); - indent(appendable, 2, "let v = self.%s.writable_overlay::<%s>(%s)?;\n", - RustCodecType.Encoder.scratchProperty(), representationType, numBytes); + if (trailingBytes > 0) + { + indent(appendable, 2, "// add trailing bytes to extend the end position of the scratch buffer\n"); + } + indent(appendable, 2, "let v = self.%s.writable_overlay::<%s>(%s+%s)?;\n", + RustCodecType.Encoder.scratchProperty(), representationType, numBytes, trailingBytes); indent(appendable, 2, "Ok((v, %s::wrap(self.%s)))\n", nextCoderType, RustCodecType.Encoder.scratchProperty()); indent(appendable).append("}\n\n"); @@ -97,6 +108,12 @@ void appendDirectCodeMethods( indent(appendable, 2) .append(format("self.%s.write_type::<%s>(t, %s)?;\n", RustCodecType.Encoder.scratchProperty(), representationType, numBytes)); + if (trailingBytes > 0) + { + indent(appendable, 2, "// fixed message length > sum of field lengths\n"); + indent(appendable, 2, "self.%s.skip_bytes(%s)?;\n", + RustCodecType.Decoder.scratchProperty(), trailingBytes); + } indent(appendable, 2).append(format("Ok(%s::wrap(self.%s))\n", nextCoderType, RustCodecType.Encoder.scratchProperty())); indent(appendable).append("}\n"); @@ -125,7 +142,8 @@ abstract void appendDirectCodeMethods( String methodName, String representationType, String nextCoderType, - int numBytes) throws IOException; + int numBytes, + int trailingBytes) throws IOException; abstract String gerund(); @@ -175,7 +193,7 @@ String generateMessageHeaderCoder( appendScratchWrappingStruct(writer, headerCoderType); RustGenerator.appendImplWithLifetimeHeader(writer, headerCoderType); appendWrapMethod(writer, headerCoderType); - appendDirectCodeMethods(writer, "header", messageHeaderRepresentation, topType, headerSize); + appendDirectCodeMethods(writer, "header", messageHeaderRepresentation, topType, headerSize, 0); writer.append("}\n"); } diff --git a/sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/RustGenerator.java b/sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/RustGenerator.java index 05f1975ce6..b7e32323fd 100644 --- a/sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/RustGenerator.java +++ b/sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/RustGenerator.java @@ -25,7 +25,6 @@ import java.io.IOException; import java.io.Writer; import java.util.*; -import java.util.stream.Collectors; import java.util.stream.IntStream; import static java.lang.String.format; @@ -69,16 +68,15 @@ public void generate() throws IOException final MessageComponents components = MessageComponents.collectMessageComponents(tokens); final String messageTypeName = formatTypeName(components.messageToken.name()); - final Optional fieldsRepresentation = - generateFieldsRepresentation(messageTypeName, components, outputManager); + final RustStruct fieldStruct = generateMessageFieldStruct(messageTypeName, components, outputManager); generateMessageHeaderDefault(ir, outputManager, components.messageToken); // Avoid the work of recomputing the group tree twice per message final List groupTree = buildGroupTrees(messageTypeName, components.groups); generateGroupFieldRepresentations(outputManager, groupTree); - generateMessageDecoder(outputManager, components, groupTree, fieldsRepresentation, headerSize); - generateMessageEncoder(outputManager, components, groupTree, fieldsRepresentation, headerSize); + generateMessageDecoder(outputManager, components, groupTree, fieldStruct, headerSize); + generateMessageEncoder(outputManager, components, groupTree, fieldStruct, headerSize); } } @@ -106,9 +104,10 @@ private void generateGroupFieldRepresentations( { for (final GroupTreeNode node : groupTree) { - appendStructHeader(appendable, node.contextualName + "Member", true); - appendStructFields(appendable, node.simpleNamedFields); - appendable.append("}\n"); + final RustStruct struct = RustStruct.fromTokens(node.contextualName + "Member", + node.simpleNamedFields, + EnumSet.of(RustStruct.Modifier.PACKED, RustStruct.Modifier.DEFAULT)); + struct.appendDefinitionTo(appendable); generateConstantAccessorImpl(appendable, node.contextualName + "Member", node.rawFields); @@ -116,19 +115,7 @@ private void generateGroupFieldRepresentations( } } - private static final class FieldsRepresentationSummary - { - final String typeName; - final int numBytes; - - private FieldsRepresentationSummary(final String typeName, final int numBytes) - { - this.typeName = typeName; - this.numBytes = numBytes; - } - } - - private static Optional generateFieldsRepresentation( + private static RustStruct generateMessageFieldStruct( final String messageTypeName, final MessageComponents components, final OutputManager outputManager) throws IOException @@ -136,47 +123,18 @@ private static Optional generateFieldsRepresentatio final List namedFieldTokens = NamedToken.gatherNamedNonConstantFieldTokens(components.fields); final String representationStruct = messageTypeName + "Fields"; - try (Writer writer = outputManager.createOutput(messageTypeName + " Fixed-size Fields")) - { - appendStructHeader(writer, representationStruct, true); - appendStructFields(writer, namedFieldTokens); - writer.append("}\n"); - - generateConstantAccessorImpl(writer, representationStruct, components.fields); - } + final RustStruct struct = RustStruct.fromTokens(representationStruct, namedFieldTokens, + EnumSet.of(RustStruct.Modifier.PACKED, RustStruct.Modifier.DEFAULT)); - // Compute the total static size in bytes of the fields representation - int numBytes = 0; - for (int i = 0, size = components.fields.size(); i < size;) + try (Writer writer = outputManager.createOutput( + messageTypeName + " Fixed-size Fields (" + struct.sizeBytes() + " bytes)")) { - final Token fieldToken = components.fields.get(i); - if (fieldToken.signal() == Signal.BEGIN_FIELD) - { - final int fieldEnd = i + fieldToken.componentTokenCount(); - if (!fieldToken.isConstantEncoding()) - { - for (int j = i; j < fieldEnd; j++) - { - final Token t = components.fields.get(j); - if (t.isConstantEncoding()) - { - continue; - } - if (t.signal() == ENCODING || t.signal() == BEGIN_ENUM || t.signal() == BEGIN_SET) - { - numBytes += t.encodedLength(); - } - } - } - i += fieldToken.componentTokenCount(); - } - else - { - throw new IllegalStateException("field tokens must include bounding BEGIN_FIELD and END_FIELD tokens"); - } + struct.appendDefinitionTo(writer); + writer.append("\n"); + generateConstantAccessorImpl(writer, representationStruct, components.fields); } - return Optional.of(new FieldsRepresentationSummary(representationStruct, numBytes)); + return struct; } private static void generateBitSets(final Ir ir, final OutputManager outputManager) throws IOException @@ -238,23 +196,52 @@ private static void generateSingleBitSet(final List tokens, final OutputM writer.append("}\n"); } + + try (Writer writer = outputManager.createOutput(setType + " bit set debug")) + { + indent(writer, 0, "impl core::fmt::Debug for %s {\n", setType); + indent(writer, 1, "fn fmt(&self, fmt: &mut core::fmt::Formatter) -> core::fmt::Result {\n"); + indent(writer, 2, "write!(fmt, \"%s[", setType); + + final StringBuilder string = new StringBuilder(); + final StringBuilder arguments = new StringBuilder(); + for (final Token token : tokens) + { + if (Signal.CHOICE != token.signal()) + { + continue; + } + + final String choiceName = formatMethodName(token.name()); + final String choiceBitIndex = token.encoding().constValue().toString(); + + string.append(choiceName + "(" + choiceBitIndex + ")={},"); + arguments.append("self.get_" + choiceName + "(),"); + } + + writer.append(string.toString() + "]\",\n"); + indent(writer, 3, arguments.toString() + ")\n"); + indent(writer, 1, "}\n"); + writer.append("}\n"); + } } private static void generateMessageEncoder( final OutputManager outputManager, final MessageComponents components, final List groupTree, - final Optional fieldsRepresentation, + final RustStruct fieldStruct, final int headerSize) throws IOException { final Token msgToken = components.messageToken; final String messageTypeName = formatTypeName(msgToken.name()); + final int msgLen = msgToken.encodedLength(); final RustCodecType codecType = RustCodecType.Encoder; String topType = codecType.generateDoneCoderType(outputManager, messageTypeName); topType = generateTopVarDataCoders(messageTypeName, components.varData, outputManager, topType, codecType); topType = generateGroupsCoders(groupTree, outputManager, topType, codecType); - topType = generateFixedFieldCoder(messageTypeName, outputManager, topType, fieldsRepresentation, codecType); + topType = generateFixedFieldCoder(messageTypeName, msgLen, outputManager, topType, fieldStruct, codecType); topType = codecType.generateMessageHeaderCoder(messageTypeName, outputManager, topType, headerSize); generateEntryPoint(messageTypeName, outputManager, topType, codecType); } @@ -263,17 +250,18 @@ private static void generateMessageDecoder( final OutputManager outputManager, final MessageComponents components, final List groupTree, - final Optional fieldsRepresentation, + final RustStruct fieldStruct, final int headerSize) throws IOException { final Token msgToken = components.messageToken; final String messageTypeName = formatTypeName(msgToken.name()); + final int msgLen = msgToken.encodedLength(); final RustCodecType codecType = RustCodecType.Decoder; String topType = codecType.generateDoneCoderType(outputManager, messageTypeName); topType = generateTopVarDataCoders(messageTypeName, components.varData, outputManager, topType, codecType); topType = generateGroupsCoders(groupTree, outputManager, topType, codecType); - topType = generateFixedFieldCoder(messageTypeName, outputManager, topType, fieldsRepresentation, codecType); + topType = generateFixedFieldCoder(messageTypeName, msgLen, outputManager, topType, fieldStruct, codecType); topType = codecType.generateMessageHeaderCoder(messageTypeName, outputManager, topType, headerSize); generateEntryPoint(messageTypeName, outputManager, topType, codecType); } @@ -304,29 +292,22 @@ static String withLifetime(final String typeName) private static String generateFixedFieldCoder( final String messageTypeName, + final int messageEncodedLength, final OutputManager outputManager, final String topType, - final Optional fieldsRepresentationOptional, + final RustStruct fieldStruct, final RustCodecType codecType) throws IOException { - if (!fieldsRepresentationOptional.isPresent()) - { - return topType; - } - - final FieldsRepresentationSummary fieldsRepresentation = fieldsRepresentationOptional.get(); try (Writer writer = outputManager.createOutput(messageTypeName + " Fixed fields " + codecType.name())) { - final String representationStruct = fieldsRepresentation.typeName; - final String decoderName = representationStruct + codecType.name(); + final String decoderName = fieldStruct.name + codecType.name(); codecType.appendScratchWrappingStruct(writer, decoderName); appendImplWithLifetimeHeader(writer, decoderName); codecType.appendWrapMethod(writer, decoderName); codecType.appendDirectCodeMethods(writer, formatMethodName(messageTypeName) + "_fields", - representationStruct, topType, fieldsRepresentation.numBytes); + fieldStruct.name, topType, fieldStruct.sizeBytes(), + messageEncodedLength - fieldStruct.sizeBytes()); writer.append("}\n"); - // TODO - Move read position further if in-message blockLength exceeds fixed fields representation size - // will require piping some data from the previously-read message header return decoderName; } } @@ -457,8 +438,8 @@ private static String writeGroupEncoderTopTypes( final String headerCoderType = node.contextualName + "HeaderEncoder"; try (Writer out = outputManager.createOutput(node.contextualName + " Encoder for fields and header")) { - appendStructHeader(out, withLifetime(memberCoderType), false); - final String rustCountType = rustTypeName(node.numInGroupType); + appendStructHeader(out, withLifetime(memberCoderType)); + final String rustCountType = rustTypeName(node.dimensionsNumInGroupType()); final String contentProperty; final String contentBearingType; if (node.parent.isPresent()) @@ -502,12 +483,12 @@ private static String writeGroupEncoderTopTypes( indent(out, 1, "pub fn done_with_%s(mut self) -> CodecResult<%s> {\n", formatMethodName(node.originalName), withLifetime(afterGroupCoderType)); indent(out, 2, "%s.write_at_position::<%s>(self.count_write_pos, &self.count, %s)?;\n", - scratchChain, rustCountType, node.numInGroupType.size()); + scratchChain, rustCountType, node.dimensionsNumInGroupType().size()); indent(out, 2, "Ok(%s)\n", atEndOfParent ? "self.parent" : format("%s::wrap(self.%s)", afterGroupCoderType, contentProperty)); indent(out).append("}\n").append("}\n"); - appendStructHeader(out, withLifetime(headerCoderType), false); + appendStructHeader(out, withLifetime(headerCoderType)); indent(out, 1, "%s: %s,\n", contentProperty, contentBearingType); out.append("}\n"); @@ -522,12 +503,12 @@ private static String writeGroupEncoderTopTypes( indent(out, 1, "pub fn %s_individually(mut self) -> CodecResult<%s> {\n", formatMethodName(node.originalName), withLifetime(memberCoderType)); indent(out, 2, "%s.write_type::<%s>(&%s, %s)?; // block length\n", - scratchChain, rustTypeName(node.blockLengthType), - generateRustLiteral(node.blockLengthType, Integer.toString(node.blockLength)), - node.blockLengthType.size()); + scratchChain, rustTypeName(node.dimensionsBlockLengthType()), + generateRustLiteral(node.dimensionsBlockLengthType(), Integer.toString(node.blockLength)), + node.dimensionsBlockLengthType().size()); indent(out, 2, "let count_pos = %s.pos;\n", scratchChain); indent(out, 2, "%s.write_type::<%s>(&0, %s)?; // preliminary group member count\n", - scratchChain, rustCountType, node.numInGroupType.size()); + scratchChain, rustCountType, node.dimensionsNumInGroupType().size()); indent(out, 2, "Ok(%s::new(self.%s, count_pos))\n", memberCoderType, contentProperty); indent(out, 1).append("}\n"); @@ -559,11 +540,11 @@ private static void appendFixedSizeMemberGroupEncoderMethods( formatMethodName(node.originalName), rustCountType, DATA_LIFETIME, fieldsType, withLifetime(afterGroupCoderType)); indent(out, 2, "%s.write_type::<%s>(&%s, %s)?; // block length\n", - scratchChain, rustTypeName(node.blockLengthType), - generateRustLiteral(node.blockLengthType, Integer.toString(node.blockLength)), - node.blockLengthType.size()); + scratchChain, rustTypeName(node.dimensionsBlockLengthType()), + generateRustLiteral(node.dimensionsBlockLengthType(), Integer.toString(node.blockLength)), + node.dimensionsBlockLengthType().size()); indent(out, 2, "%s.write_type::<%s>(&count, %s)?; // group count\n", - scratchChain, rustCountType, node.numInGroupType.size()); + scratchChain, rustCountType, node.dimensionsNumInGroupType().size()); indent(out, 2, "let c = count as usize;\n"); indent(out, 2, "let group_slice = %s.writable_slice::<%s>(c, %s)?;\n", scratchChain, fieldsType, node.blockLength); @@ -576,15 +557,15 @@ scratchChain, rustTypeName(node.blockLengthType), formatMethodName(node.originalName), fieldsType, withLifetime(afterGroupCoderType)); indent(out, 2, "%s.write_type::<%s>(&%s, %s)?; // block length\n", - scratchChain, rustTypeName(node.blockLengthType), - generateRustLiteral(node.blockLengthType, Integer.toString(node.blockLength)), - node.blockLengthType.size()); + scratchChain, rustTypeName(node.dimensionsBlockLengthType()), + generateRustLiteral(node.dimensionsBlockLengthType(), Integer.toString(node.blockLength)), + node.dimensionsBlockLengthType().size()); indent(out, 2, "let count = s.len();\n"); - indent(out, 2, "if count > %s {\n", node.numInGroupType.maxValue()); + indent(out, 2, "if count > %s {\n", node.dimensionsNumInGroupType().maxValue()); indent(out, 3).append("return Err(CodecErr::SliceIsLongerThanAllowedBySchema)\n"); indent(out, 2).append("}\n"); indent(out, 2, "%s.write_type::<%s>(&(count as %s), %s)?; // group count\n", - scratchChain, rustCountType, rustCountType, node.numInGroupType.size()); + scratchChain, rustCountType, rustCountType, node.dimensionsNumInGroupType().size()); indent(out, 2, "%s.write_slice_without_count::<%s>(s, %s)?;\n", scratchChain, fieldsType, node.blockLength); indent(out, 2, "Ok(%s)\n", atEndOfParent ? "self.parent" : @@ -605,8 +586,8 @@ private static void writeGroupDecoderTopTypes( { try (Writer out = outputManager.createOutput(node.contextualName + " Decoder for fields and header")) { - appendStructHeader(out, withLifetime(memberDecoderType), false); - final String rustCountType = rustTypeName(node.numInGroupType); + appendStructHeader(out, withLifetime(memberDecoderType)); + final String rustCountType = rustTypeName(node.dimensionsNumInGroupType()); final String contentProperty; final String contentBearingType; if (node.parent.isPresent()) @@ -656,7 +637,7 @@ private static void writeGroupDecoderTopTypes( format("%s::wrap(self.%s)", initialNextDecoderType, contentProperty))); indent(out, 2).append("}\n").append(INDENT).append("}\n").append("}\n"); - appendStructHeader(out, withLifetime(headerDecoderType), false); + appendStructHeader(out, withLifetime(headerDecoderType)); indent(out, 1, "%s: %s,\n", contentProperty, contentBearingType).append("}\n"); appendImplWithLifetimeHeader(out, headerDecoderType); @@ -666,12 +647,12 @@ private static void writeGroupDecoderTopTypes( indent(out, 1, "pub fn %s_individually(mut self) -> CodecResult<%s> {\n", formatMethodName(node.originalName), groupLevelNextDecoderType); - indent(out, 2, "%s.skip_bytes(%s)?; // Skip reading block length for now\n", - toScratchChain(node), node.blockLengthType.size()); - indent(out, 2, "let count = *%s.read_type::<%s>(%s)?;\n", - toScratchChain(node), rustTypeName(node.numInGroupType), node.numInGroupType.size()); - indent(out, 2).append("if count > 0 {\n"); - indent(out, 3, "Ok(Either::Left(%s::new(self.%s, count)))\n", + indent(out, 2, "let dim = %s.read_type::<%s>(%s)?;\n", + toScratchChain(node), + formatTypeName(node.dimensionsTypeName()), + node.dimensionsTypeSize()); + indent(out, 2).append("if dim.num_in_group > 0 {\n"); + indent(out, 3, "Ok(Either::Left(%s::new(self.%s, dim.num_in_group)))\n", memberDecoderType, contentProperty).append(INDENT).append(INDENT).append("} else {\n"); if (atEndOfParent) @@ -707,11 +688,11 @@ private static void appendFixedSizeMemberGroupDecoderMethods( formatMethodName(node.originalName), DATA_LIFETIME, node.contextualName + "Member", initialNextDecoderType.startsWith("Either") ? initialNextDecoderType : withLifetime(initialNextDecoderType)); - indent(out, 2, "%s.skip_bytes(%s)?; // Skip reading block length for now\n", toScratchChain(node), - node.blockLengthType.size()); - indent(out, 2, "let count = *%s.read_type::<%s>(%s)?;\n", - toScratchChain(node), rustTypeName(node.numInGroupType), node.numInGroupType.size()); - indent(out, 2, "let s = %s.read_slice::<%s>(count as usize, %s)?;\n", + indent(out, 2, "let dim = %s.read_type::<%s>(%s)?;\n", + toScratchChain(node), + formatTypeName(node.dimensionsTypeName()), + node.dimensionsTypeSize()); + indent(out, 2, "let s = %s.read_slice::<%s>(dim.num_in_group as usize, %s)?;\n", toScratchChain(node), node.contextualName + "Member", node.blockLength); indent(out, 2, "Ok((s,%s))\n", atEndOfParent ? "self.parent.after_member()" : format("%s::wrap(self.%s)", initialNextDecoderType, contentProperty)); @@ -787,10 +768,8 @@ private static List buildGroupTrees( final Token dimensionsToken = groupsTokens.get(i); final int groupHeaderTokenCount = dimensionsToken.componentTokenCount(); final List dimensionsTokens = groupsTokens.subList(i, i + groupHeaderTokenCount); - final PrimitiveType numInGroupType = findPrimitiveByTokenName(dimensionsTokens, "numInGroup"); - final Token blockLengthToken = findPrimitiveTokenByTokenName(dimensionsTokens, "blockLength"); + final GroupDimensions dimensions = GroupDimensions.ofTokens(dimensionsTokens); final int blockLength = groupToken.encodedLength(); - final PrimitiveType blockLengthType = blockLengthToken.encoding().primitiveType(); i += groupHeaderTokenCount; final List fields = new ArrayList<>(); @@ -807,8 +786,7 @@ private static List buildGroupTrees( parent, originalName, contextualName, - numInGroupType, - blockLengthType, + dimensions, blockLength, fields, varDataSummaries); @@ -840,13 +818,37 @@ private static Token findPrimitiveTokenByTokenName(final List tokens, fin throw new IllegalStateException(format("%s not specified for group", targetName)); } + private static final class GroupDimensions + { + final String typeName; + final int typeSize; + final PrimitiveType numInGroupType; + final PrimitiveType blockLengthType; + + private GroupDimensions(final String typeName, final int typeSize, + final PrimitiveType numInGroupType, final PrimitiveType blockLengthType) + { + this.typeName = typeName; + this.typeSize = typeSize; + this.numInGroupType = numInGroupType; + this.blockLengthType = blockLengthType; + } + + public static GroupDimensions ofTokens(final List dimensionsTokens) + { + final PrimitiveType numInGroupType = findPrimitiveByTokenName(dimensionsTokens, "numInGroup"); + final PrimitiveType blockLengthType = findPrimitiveByTokenName(dimensionsTokens, "blockLength"); + return new GroupDimensions(dimensionsTokens.get(0).name(), dimensionsTokens.get(0).encodedLength(), + numInGroupType, blockLengthType); + } + } + static class GroupTreeNode { final Optional parent; final String originalName; final String contextualName; - final PrimitiveType numInGroupType; - final PrimitiveType blockLengthType; + final GroupDimensions dimensions; final int blockLength; final List rawFields; final List simpleNamedFields; @@ -857,8 +859,7 @@ static class GroupTreeNode final Optional parent, final String originalName, final String contextualName, - final PrimitiveType numInGroupType, - final PrimitiveType blockLengthType, + final GroupDimensions dimensions, final int blockLength, final List fields, final List varData) @@ -866,8 +867,7 @@ static class GroupTreeNode this.parent = parent; this.originalName = originalName; this.contextualName = contextualName; - this.numInGroupType = numInGroupType; - this.blockLengthType = blockLengthType; + this.dimensions = dimensions; this.blockLength = blockLength; this.rawFields = fields; this.simpleNamedFields = NamedToken.gatherNamedNonConstantFieldTokens(fields); @@ -898,6 +898,26 @@ boolean hasFixedSizeMembers() { return groups.isEmpty() && varData.isEmpty(); } + + public PrimitiveType dimensionsNumInGroupType() + { + return dimensions.numInGroupType; + } + + public PrimitiveType dimensionsBlockLengthType() + { + return dimensions.blockLengthType; + } + + public String dimensionsTypeName() + { + return dimensions.typeName; + } + + public int dimensionsTypeSize() + { + return dimensions.typeSize; + } } static class VarDataSummary @@ -930,7 +950,7 @@ String generateVarDataEncoder( final String decoderType = parentContextualName + formatTypeName(name) + codecType.name(); try (Writer writer = outputManager.createOutput(name + " variable-length data")) { - appendStructHeader(writer, withLifetime(decoderType), false); + appendStructHeader(writer, withLifetime(decoderType)); final String contentPropertyName = groupDepth > 0 ? "parent" : codecType.scratchProperty(); indent(writer, 1, "%s: %s,\n", contentPropertyName, withLifetime(contentType)); writer.append("}\n"); @@ -981,7 +1001,7 @@ String generateVarDataDecoder( final String decoderType = parentContextualName + formatTypeName(name) + "Decoder"; try (Writer writer = outputManager.createOutput(name + " variable-length data")) { - appendStructHeader(writer, withLifetime(decoderType), false); + appendStructHeader(writer, withLifetime(decoderType)); final String contentPropertyName = groupDepth > 0 ? "parent" : SCRATCH_DECODER_PROPERTY; indent(writer, 1, "%s: %s,\n", contentPropertyName, withLifetime(contentType)); writer.append("}\n"); @@ -1165,6 +1185,18 @@ static void generateEncoderScratchStruct(final Ir ir, final OutputManager output indent(writer, 2).append("}\n"); indent(writer).append("}\n\n"); + indent(writer, 1, "/// Advances the `pos` index by a set number of bytes.\n"); + indent(writer).append("#[inline]\n"); + indent(writer).append("fn skip_bytes(&mut self, num_bytes: usize) -> CodecResult<()> {\n"); + indent(writer, 2).append("let end = self.pos + num_bytes;\n"); + indent(writer, 2).append("if end <= self.data.len() {\n"); + indent(writer, 3).append("self.pos = end;\n"); + indent(writer, 3).append("Ok(())\n"); + indent(writer, 2).append("} else {\n"); + indent(writer, 3).append("Err(CodecErr::NotEnoughBytes)\n"); + indent(writer, 2).append("}\n"); + indent(writer).append("}\n\n"); + indent(writer, 1, "/// Create a struct reference overlaid atop the data buffer\n"); indent(writer, 1, "/// such that changes to the struct directly edit the buffer. \n"); indent(writer, 1, "/// Note that the initial content of the struct's fields may be garbage.\n"); @@ -1367,6 +1399,11 @@ private static void generateEnum(final List enumTokens, final OutputManag .append(",\n"); writer.append("}\n"); + + // Default implementation to support Default in other structs + indent(writer, 0, "impl Default for %s {\n", enumRustName); + indent(writer, 1, "fn default() -> Self { %s::%s }\n", enumRustName, "NullVal"); + indent(writer, 0, "}\n"); } } @@ -1391,17 +1428,311 @@ private static void generateSingleComposite(final List tokens, final Outp try (Writer writer = outputManager.createOutput(formattedTypeName)) { - appendStructHeader(writer, formattedTypeName, true); - appendStructFields(writer, splitTokens.nonConstantEncodingTokens()); - writer.append("}\n"); + final RustStruct struct = RustStruct.fromTokens(formattedTypeName, + splitTokens.nonConstantEncodingTokens(), + EnumSet.of(RustStruct.Modifier.PACKED, RustStruct.Modifier.DEFAULT)); + struct.appendDefinitionTo(writer); generateConstantAccessorImpl(writer, formattedTypeName, getMessageBody(tokens)); } } - private static void appendStructFields(final Appendable appendable, final List namedTokens) - throws IOException + private interface RustTypeDescriptor + { + String DEFAULT_VALUE = "Default::default()"; + + String name(); + + String literalValue(String valueRep); + + int sizeBytes(); + + default String defaultValue() + { + return DEFAULT_VALUE; + } + } + + private static final class RustArrayType implements RustTypeDescriptor + { + private final RustTypeDescriptor componentType; + private final int length; + + private RustArrayType(final RustTypeDescriptor component, final int length) + { + this.componentType = component; + this.length = length; + } + + @Override + public String name() + { + return getRustStaticArrayString(componentType.name(), length); + } + + @Override + public String literalValue(final String valueRep) + { + return getRustStaticArrayString(valueRep + componentType.name(), length); + } + + @Override + public int sizeBytes() + { + return componentType.sizeBytes() * length; + } + + @Override + public String defaultValue() + { + final String defaultValue = RustTypeDescriptor.super.defaultValue(); + if (length <= 32) + { + return defaultValue; + } + else + { + final StringBuilder result = new StringBuilder(); + result.append('['); + for (int i = 0; i < length; i++) + { + result.append(defaultValue); + result.append(", "); + if (i % 4 == 0) // ~80 char lines + { + result.append('\n'); + } + } + result.append(']'); + return result.toString(); + } + } + } + + private static final class RustPrimitiveType implements RustTypeDescriptor + { + private final String name; + private final int sizeBytes; + + private RustPrimitiveType(final String name, final int sizeBytes) + { + this.name = name; + this.sizeBytes = sizeBytes; + } + + @Override + public String name() + { + return name; + } + + @Override + public String literalValue(final String valueRep) + { + return valueRep + name; + } + + @Override + public int sizeBytes() + { + return sizeBytes; + } + } + + private static final class AnyRustType implements RustTypeDescriptor + { + private final String name; + private final int sizeBytes; + + private AnyRustType(final String name, final int sizeBytes) + { + this.name = name; + this.sizeBytes = sizeBytes; + } + + @Override + public String name() + { + return name; + } + + @Override + public String literalValue(final String valueRep) + { + final String msg = String.format("Cannot produce a literal value %s of type %s!", valueRep, name); + throw new UnsupportedOperationException(msg); + } + + @Override + public int sizeBytes() + { + return sizeBytes; + } + } + + private static final class RustTypes + { + static final RustTypeDescriptor U_8 = new RustPrimitiveType("u8", 1); + + static RustTypeDescriptor ofPrimitiveToken(final Token token) + { + final PrimitiveType primitiveType = token.encoding().primitiveType(); + final String rustPrimitiveType = RustUtil.rustTypeName(primitiveType); + final RustPrimitiveType type = new RustPrimitiveType(rustPrimitiveType, primitiveType.size()); + if (token.arrayLength() > 1) + { + return new RustArrayType(type, token.arrayLength()); + } + return type; + } + + static RustTypeDescriptor ofGeneratedToken(final Token token) + { + return new AnyRustType(formatTypeName(token.applicableTypeName()), token.encodedLength()); + } + + static RustTypeDescriptor arrayOf(final RustTypeDescriptor type, final int len) + { + return new RustArrayType(type, len); + } + } + + private static final class RustStruct { + enum Modifier + { + PACKED, DEFAULT + } + + final String name; + final List fields; + final EnumSet modifiers; + + private RustStruct(final String name, final List fields, final EnumSet modifiers) + { + this.name = name; + this.fields = fields; + this.modifiers = modifiers; + } + + public int sizeBytes() + { + return fields.stream().mapToInt(v -> v.type.sizeBytes()).sum(); + } + + static RustStruct fromHeader(final HeaderStructure header) + { + final List tokens = header.tokens(); + final String originalTypeName = tokens.get(0).applicableTypeName(); + final String formattedTypeName = formatTypeName(originalTypeName); + final SplitCompositeTokens splitTokens = SplitCompositeTokens.splitInnerTokens(tokens); + return RustStruct.fromTokens(formattedTypeName, + splitTokens.nonConstantEncodingTokens(), + EnumSet.of(Modifier.PACKED, Modifier.DEFAULT)); + } + + static RustStruct fromTokens(final String name, final List tokens, + final EnumSet modifiers) + { + return new RustStruct(name, collectStructFields(tokens), modifiers); + } + + // No way to create struct with default values. + // Rust RFC: https://github.com/Centril/rfcs/pull/19 + // Used when struct contains a field which doesn't have a Default impl + void appendDefaultConstructorTo(final Appendable appendable) throws IOException + { + indent(appendable, 0, "impl Default for %s {\n", name); + indent(appendable, 1, "fn default() -> Self {\n"); + + appendInstanceTo(appendable, 2, Collections.emptyMap()); + + indent(appendable, 1, "}\n"); + + appendable.append("}\n"); + } + + void appendDefinitionTo(final Appendable appendable) throws IOException + { + final boolean needsDefault = modifiers.contains(Modifier.DEFAULT); + final boolean canDeriveDefault = fields.stream() + .allMatch(v -> v.type.defaultValue() == RustTypeDescriptor.DEFAULT_VALUE); + + final Set modifiers = this.modifiers.clone(); + if (needsDefault && !canDeriveDefault) + { + modifiers.remove(Modifier.DEFAULT); + } + + appendStructHeader(appendable, name, modifiers); + for (final RustStructField field: fields) + { + indent(appendable); + if (field.modifiers.contains(RustStructField.Modifier.PUBLIC)) + { + appendable.append("pub "); + } + appendable.append(field.name).append(":").append(field.type.name()).append(",\n"); + } + appendable.append("}\n"); + + if (needsDefault && !canDeriveDefault) + { + appendDefaultConstructorTo(appendable); + } + } + + void appendInstanceTo(final Appendable appendable, final int indent, + final Map values) throws IOException + { + indent(appendable, indent, "%s {\n", name); + for (final RustStructField field: fields) + { + final String value; + if (values.containsKey(field.name)) + { + value = field.type.literalValue(values.get(field.name)); + } + else + { + value = field.type.defaultValue(); + } + + indent(appendable, indent + 1, "%s: %s,\n", formatMethodName(field.name), value); + } + indent(appendable, indent, "}\n"); + } + + } + + private static final class RustStructField + { + enum Modifier + { + PUBLIC + } + + final String name; + final RustTypeDescriptor type; + final EnumSet modifiers; + + private RustStructField(final String name, final RustTypeDescriptor type, final EnumSet modifiers) + { + this.name = name; + this.type = type; + this.modifiers = modifiers; + } + + private RustStructField(final String name, final RustTypeDescriptor type) + { + this(name, type, EnumSet.noneOf(Modifier.class)); + } + } + + private static List collectStructFields(final List namedTokens) + { + final List fields = new ArrayList<>(); + int totalSize = 0; for (final NamedToken namedToken : namedTokens) { final Token typeToken = namedToken.typeToken(); @@ -1411,29 +1742,47 @@ private static void appendStructFields(final Appendable appendable, final List 0) + { + // split padding arrays to 32 as larger arrays do not have an `impl Default` + final int padding = Math.min(rem, 32); + final RustTypeDescriptor type = RustTypes.arrayOf(RustTypes.U_8, padding); + fields.add(new RustStructField(propertyName + "_padding_" + idx, type)); + + idx += 1; + rem -= padding; + } + } + totalSize = offset + typeToken.encodedLength(); + + final RustTypeDescriptor type; switch (typeToken.signal()) { case ENCODING: - final String rustPrimitiveType = RustUtil.rustTypeName(typeToken.encoding().primitiveType()); - final String rustFieldType = getRustTypeForPrimitivePossiblyArray(typeToken, rustPrimitiveType); - appendable.append(rustFieldType); + type = RustTypes.ofPrimitiveToken(typeToken); + fields.add(new RustStructField(propertyName, type, EnumSet.of(RustStructField.Modifier.PUBLIC))); break; case BEGIN_ENUM: case BEGIN_SET: case BEGIN_COMPOSITE: - appendable.append(formatTypeName(typeToken.applicableTypeName())); + type = RustTypes.ofGeneratedToken(typeToken); + fields.add(new RustStructField(propertyName, type, EnumSet.of(RustStructField.Modifier.PUBLIC))); break; default: throw new IllegalStateException( format("Unsupported struct property from %s", typeToken.toString())); } - - appendable.append(",\n"); } + return fields; } private void generateMessageHeaderDefault( @@ -1443,46 +1792,32 @@ private void generateMessageHeaderDefault( throws IOException { final HeaderStructure header = ir.headerStructure(); + final RustStruct rustHeader = RustStruct.fromHeader(header); + final String messageTypeName = formatTypeName(messageToken.name()); final String wrapperName = messageTypeName + "MessageHeader"; try (Writer writer = outputManager.createOutput(messageTypeName + " specific Message Header ")) { - appendStructHeader(writer, wrapperName, true); + appendStructHeader(writer, wrapperName, EnumSet.of(RustStruct.Modifier.PACKED)); indent(writer, 1, "pub message_header: MessageHeader\n"); writer.append("}\n"); - indent(writer, 1, "impl Default for %s {\n", wrapperName); + indent(writer, 0, "impl Default for %s {\n", wrapperName); indent(writer, 1, "fn default() -> %s {\n", wrapperName); indent(writer, 2, "%s {\n", wrapperName); - indent(writer, 3, "message_header: MessageHeader {\n"); - indent(writer, 4, "%s: %s,\n", formatMethodName("blockLength"), - generateRustLiteral(header.blockLengthType(), Integer.toString(messageToken.encodedLength()))); - indent(writer, 4, "%s: %s,\n", formatMethodName("templateId"), - generateRustLiteral(header.templateIdType(), Integer.toString(messageToken.id()))); - indent(writer, 4, "%s: %s,\n", formatMethodName("schemaId"), - generateRustLiteral(header.schemaIdType(), Integer.toString(ir.id()))); - indent(writer, 4, "%s: %s,\n", formatMethodName("version"), - generateRustLiteral(header.schemaVersionType(), Integer.toString(ir.version()))); - - // Technically the spec seems to allow non-standard fields in the message header, so we attempt - // to provide some sort of default for them - final Set reserved = new HashSet<>(Arrays.asList("blockLength", "templateId", "schemaId", - "version")); - - final List nonReservedNamedTokens = SplitCompositeTokens.splitInnerTokens(header.tokens()) - .nonConstantEncodingTokens() - .stream() - .filter((namedToken) -> !reserved.contains(namedToken.name())) - .collect(Collectors.toList()); - - for (final NamedToken namedToken : nonReservedNamedTokens) - { - indent(writer, 4, "%s: Default::default(),\n", formatMethodName(namedToken.name())); - } - - indent(writer, 3, "}\n"); + indent(writer, 3, "message_header: "); + rustHeader.appendInstanceTo(writer, 3, + new HashMap() + { + { + put("block_length", Integer.toString(messageToken.encodedLength())); + put("template_id", Integer.toString(messageToken.id())); + put("schema_id", Integer.toString(ir.id())); + put("version", Integer.toString(ir.version())); + } + }); indent(writer, 2, "}\n"); indent(writer, 1, "}\n"); @@ -1491,26 +1826,43 @@ private void generateMessageHeaderDefault( } } + private static void appendStructHeader(final Appendable appendable, final String structName) throws IOException + { + appendStructHeader(appendable, structName, EnumSet.noneOf(RustStruct.Modifier.class)); + } + private static void appendStructHeader( final Appendable appendable, final String structName, - final boolean packedCRepresentation) throws IOException + final Set modifiers) throws IOException { - if (packedCRepresentation) + if (!modifiers.isEmpty()) { - appendable.append("#[repr(C,packed)]\n"); + if (modifiers.contains(RustStruct.Modifier.PACKED)) + { + appendable.append("#[repr(C,packed)]\n"); + } + if (modifiers.contains(RustStruct.Modifier.DEFAULT)) + { + appendable.append("#[derive(Default)]\n"); + } } appendable.append(format("pub struct %s {\n", structName)); } + private static String getRustStaticArrayString(final String rustPrimitiveType, final int length) + { + return format("[%s;%s]", rustPrimitiveType, length); + } + private static String getRustTypeForPrimitivePossiblyArray( final Token encodingToken, final String rustPrimitiveType) { final String rustType; if (encodingToken.arrayLength() > 1) { - rustType = format("[%s;%s]", rustPrimitiveType, encodingToken.arrayLength()); + rustType = getRustStaticArrayString(rustPrimitiveType, encodingToken.arrayLength()); } else { diff --git a/sbe-tool/src/test/java/uk/co/real_logic/sbe/generation/rust/RustGeneratorTest.java b/sbe-tool/src/test/java/uk/co/real_logic/sbe/generation/rust/RustGeneratorTest.java index c4c564c5cb..9727aa35ed 100644 --- a/sbe-tool/src/test/java/uk/co/real_logic/sbe/generation/rust/RustGeneratorTest.java +++ b/sbe-tool/src/test/java/uk/co/real_logic/sbe/generation/rust/RustGeneratorTest.java @@ -187,7 +187,19 @@ private File writeCargoFolderWrapper(final String name, final String generatedRu return folder; } - private static boolean cargoCheckInDirectory(final File folder) throws IOException, InterruptedException + private static final class CargoCheckResult + { + final boolean isSuccess; + final String error; + + private CargoCheckResult(final boolean isSuccess, final String error) + { + this.isSuccess = isSuccess; + this.error = error; + } + } + + private static CargoCheckResult cargoCheckInDirectory(final File folder) throws IOException, InterruptedException { final ProcessBuilder builder = new ProcessBuilder("cargo", "check"); builder.directory(folder); @@ -195,6 +207,7 @@ private static boolean cargoCheckInDirectory(final File folder) throws IOExcepti process.waitFor(30, TimeUnit.SECONDS); final boolean success = process.exitValue() == 0; + final StringBuilder errorString = new StringBuilder(); if (!success) { // Include output as a debugging aid when things go wrong @@ -207,11 +220,16 @@ private static boolean cargoCheckInDirectory(final File folder) throws IOExcepti { break; } + else + { + errorString.append(line); + errorString.append('\n'); + } } } } - return success; + return new CargoCheckResult(success, errorString.toString()); } private static boolean cargoExists() @@ -234,7 +252,9 @@ private void assertRustBuildable(final String generatedRust, final Optional