diff --git a/bindgen/ir/IR.cpp b/bindgen/ir/IR.cpp index f030923..0d762c9 100644 --- a/bindgen/ir/IR.cpp +++ b/bindgen/ir/IR.cpp @@ -164,7 +164,7 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &s, const IR &ir) { } for (const auto &u : ir.unions) { - if (ir.shouldOutput(u)) { + if (ir.shouldOutput(u) && u->hasHelperMethods()) { s << "\n" << u->generateHelperClass(); } } @@ -184,9 +184,10 @@ void IR::generate(const std::string &excludePrefix) { } bool IR::hasHelperMethods() const { - if (hasOutputtedDeclaration(unions)) { - /* all unions have helper methods */ - return true; + for (const auto &u : unions) { + if (shouldOutput(u) && u->hasHelperMethods()) { + return true; + } } for (const auto &s : structs) { diff --git a/bindgen/ir/Struct.cpp b/bindgen/ir/Struct.cpp index e8cd1d5..4f01b5a 100644 --- a/bindgen/ir/Struct.cpp +++ b/bindgen/ir/Struct.cpp @@ -1,12 +1,18 @@ #include "Struct.h" #include "../Utils.h" #include "types/ArrayType.h" +#include "types/PointerType.h" #include "types/PrimitiveType.h" #include Field::Field(std::string name, std::shared_ptr type) : TypeAndName(std::move(name), std::move(type)) {} +Field::Field(std::string name, std::shared_ptr type, uint64_t offset) + : TypeAndName(std::move(name), std::move(type)), offset(offset) {} + +uint64_t Field::getOffset() const { return offset; } + StructOrUnion::StructOrUnion(std::string name, std::vector> fields, std::shared_ptr location) @@ -41,6 +47,8 @@ std::shared_ptr StructOrUnion::getLocation() const { return location; } +bool StructOrUnion::hasHelperMethods() const { return !fields.empty(); } + Struct::Struct(std::string name, std::vector> fields, uint64_t typeSize, std::shared_ptr location) : StructOrUnion(std::move(name), std::move(fields), std::move(location)), @@ -64,19 +72,15 @@ std::shared_ptr Struct::generateTypeDef() { std::string Struct::generateHelperClass() const { assert(hasHelperMethods()); - /* struct is not empty and not represented as an array */ std::stringstream s; std::string type = getTypeAlias(); s << " implicit class " << type << "_ops(val p: native.Ptr[" << type << "])" << " extends AnyVal {\n"; - unsigned fieldIndex = 0; - for (const auto &field : fields) { - if (!field->getName().empty()) { - s << generateGetter(fieldIndex) << "\n"; - s << generateSetter(fieldIndex) << "\n"; - } - fieldIndex++; + if (fields.size() <= SCALA_NATIVE_MAX_STRUCT_FIELDS) { + s << generateHelperClassMethodsForStructRepresentation(); + } else { + s << generateHelperClassMethodsForArrayRepresentation(); } s << " }\n\n"; @@ -88,8 +92,26 @@ std::string Struct::generateHelperClass() const { return s.str(); } -bool Struct::hasHelperMethods() const { - return !fields.empty() && fields.size() < SCALA_NATIVE_MAX_STRUCT_FIELDS; +std::string Struct::generateHelperClassMethodsForStructRepresentation() const { + std::stringstream s; + for (unsigned fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) { + if (!fields[fieldIndex]->getName().empty()) { + s << generateGetterForStructRepresentation(fieldIndex); + s << generateSetterForStructRepresentation(fieldIndex); + } + } + return s.str(); +} + +std::string Struct::generateHelperClassMethodsForArrayRepresentation() const { + std::stringstream s; + for (unsigned fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) { + if (!fields[fieldIndex]->getName().empty()) { + s << generateGetterForArrayRepresentation(fieldIndex); + s << generateSetterForArrayRepresentation(fieldIndex); + } + } + return s.str(); } std::string Struct::getTypeAlias() const { return "struct_" + name; } @@ -127,7 +149,8 @@ bool Struct::operator==(const Type &other) const { return false; } -std::string Struct::generateSetter(unsigned fieldIndex) const { +std::string +Struct::generateSetterForStructRepresentation(unsigned fieldIndex) const { std::shared_ptr field = fields[fieldIndex]; std::string setter = handleReservedWords(field->getName(), "_="); std::string parameterType = field->getType()->str(); @@ -139,11 +162,12 @@ std::string Struct::generateSetter(unsigned fieldIndex) const { } std::stringstream s; s << " def " << setter << "(value: " + parameterType + "): Unit = !p._" - << std::to_string(fieldIndex + 1) << " = " << value; + << std::to_string(fieldIndex + 1) << " = " << value << "\n"; return s.str(); } -std::string Struct::generateGetter(unsigned fieldIndex) const { +std::string +Struct::generateGetterForStructRepresentation(unsigned fieldIndex) const { std::shared_ptr field = fields[fieldIndex]; std::string getter = handleReservedWords(field->getName()); std::string returnType = field->getType()->str(); @@ -156,10 +180,68 @@ std::string Struct::generateGetter(unsigned fieldIndex) const { methodBody = "!p._" + std::to_string(fieldIndex + 1); } std::stringstream s; - s << " def " << getter << ": " << returnType << " = " << methodBody; + s << " def " << getter << ": " << returnType << " = " << methodBody + << "\n"; return s.str(); } +std::string +Struct::generateSetterForArrayRepresentation(unsigned int fieldIndex) const { + std::shared_ptr field = fields[fieldIndex]; + std::string setter = handleReservedWords(field->getName(), "_="); + std::string parameterType; + std::string value = "value"; + std::string castedField = "p._1"; + + PointerType pointerToFieldType = PointerType(field->getType()); + if (field->getOffset() > 0) { + castedField = "(" + castedField + " + " + + std::to_string(field->getOffset()) + ")"; + } + castedField = "!" + castedField + ".cast[" + pointerToFieldType.str() + "]"; + if (isAliasForType(field->getType().get()) || + isAliasForType(field->getType().get())) { + parameterType = pointerToFieldType.str(); + value = "!" + value; + } else { + parameterType = field->getType()->str(); + } + std::stringstream s; + s << " def " << setter + << "(value: " + parameterType + "): Unit = " << castedField << " = " + << value << "\n"; + return s.str(); +} + +std::string +Struct::generateGetterForArrayRepresentation(unsigned fieldIndex) const { + std::shared_ptr field = fields[fieldIndex]; + std::string getter = handleReservedWords(field->getName()); + std::string returnType; + std::string methodBody; + + PointerType pointerToFieldType = PointerType(field->getType()); + if (field->getOffset() > 0) { + methodBody = "(p._1 + " + std::to_string(field->getOffset()) + ")"; + } else { + methodBody = "p._1"; + } + methodBody = methodBody + ".cast[" + pointerToFieldType.str() + "]"; + + if (isAliasForType(field->getType().get()) || + isAliasForType(field->getType().get())) { + returnType = pointerToFieldType.str(); + } else { + methodBody = "!" + methodBody; + returnType = field->getType()->str(); + } + std::stringstream s; + s << " def " << getter << ": " << returnType << " = " << methodBody + << "\n"; + return s.str(); + return ""; +} + Union::Union(std::string name, std::vector> fields, uint64_t maxSize, std::shared_ptr location) : StructOrUnion(std::move(name), std::move(fields), std::move(location)), @@ -171,6 +253,7 @@ std::shared_ptr Union::generateTypeDef() { } std::string Union::generateHelperClass() const { + assert(hasHelperMethods()); std::stringstream s; std::string type = getTypeAlias(); s << " implicit class " << type << "_pos" diff --git a/bindgen/ir/Struct.h b/bindgen/ir/Struct.h index 517abe3..0a8a1b6 100644 --- a/bindgen/ir/Struct.h +++ b/bindgen/ir/Struct.h @@ -12,6 +12,16 @@ class Field : public TypeAndName { public: Field(std::string name, std::shared_ptr type); + + Field(std::string name, std::shared_ptr type, uint64_t offset); + + uint64_t getOffset() const; + + protected: + /** + * Offset in bytes from address of struct/union. + */ + uint64_t offset = 0; }; class StructOrUnion { @@ -31,6 +41,8 @@ class StructOrUnion { virtual std::string getTypeAlias() const = 0; + bool hasHelperMethods() const; + protected: std::string name; std::vector> fields; @@ -50,11 +62,6 @@ class Struct : public StructOrUnion, std::string getTypeAlias() const override; - /** - * @return true if helper methods will be generated for this struct - */ - bool hasHelperMethods() const; - bool usesType(const std::shared_ptr &type, bool stopOnTypeDefs) const override; @@ -62,13 +69,29 @@ class Struct : public StructOrUnion, bool operator==(const Type &other) const override; - std::string generateSetter(unsigned fieldIndex) const; - - std::string generateGetter(unsigned fieldIndex) const; - private: /* type size is needed if number of fields is bigger than 22 */ uint64_t typeSize; + + /** + * @return helper class methods for struct that is represented as CStruct. + */ + std::string generateHelperClassMethodsForStructRepresentation() const; + + /** + * @return helper class methods for struct that is represented as CArray. + */ + std::string generateHelperClassMethodsForArrayRepresentation() const; + + std::string + generateSetterForStructRepresentation(unsigned fieldIndex) const; + + std::string + generateGetterForStructRepresentation(unsigned fieldIndex) const; + + std::string generateSetterForArrayRepresentation(unsigned fieldIndex) const; + + std::string generateGetterForArrayRepresentation(unsigned fieldIndex) const; }; class Union : public StructOrUnion, diff --git a/bindgen/visitor/TreeVisitor.cpp b/bindgen/visitor/TreeVisitor.cpp index ff2a5c8..1d54270 100644 --- a/bindgen/visitor/TreeVisitor.cpp +++ b/bindgen/visitor/TreeVisitor.cpp @@ -1,4 +1,5 @@ #include "TreeVisitor.h" +#include "clang/AST/RecordLayout.h" #include bool TreeVisitor::VisitFunctionDecl(clang::FunctionDecl *func) { @@ -127,18 +128,20 @@ void TreeVisitor::handleStruct(clang::RecordDecl *record, std::string name) { llvm::errs().flush(); } - int fieldCnt = 0; std::vector> fields; + const clang::ASTRecordLayout &recordLayout = + astContext->getASTRecordLayout(record); for (const clang::FieldDecl *field : record->fields()) { std::shared_ptr ftype = typeTranslator.translate(field->getType(), &name); - fields.push_back( - std::make_shared(field->getNameAsString(), ftype)); + uint64_t recordOffsetInBits = + recordLayout.getFieldOffset(field->getFieldIndex()); + assert(recordOffsetInBits % 8 == 0); + fields.push_back(std::make_shared( + field->getNameAsString(), ftype, recordOffsetInBits / 8)); cycleDetection.AddDependency(newName, field->getType()); - - fieldCnt++; } if (cycleDetection.isCyclic(newName)) { diff --git a/tests/samples/Struct.c b/tests/samples/Struct.c index bc6376f..cce03f8 100644 --- a/tests/samples/Struct.c +++ b/tests/samples/Struct.c @@ -37,3 +37,34 @@ char getCharFromAnonymousStruct(struct structWithAnonymousStruct *s) { char getIntFromAnonymousStruct(struct structWithAnonymousStruct *s) { return s->anonymousStruct.i; } + +int struct_test_long(struct bigStruct *s, enum struct_op op, long value) { + switch (op) { + case STRUCT_SET: + s->one = value; + return 1; + case STRUCT_TEST: + return s->one == value; + } +} + +int struct_test_double(struct bigStruct *s, enum struct_op op, double value) { + switch (op) { + case STRUCT_SET: + s->five = value; + return 1; + case STRUCT_TEST: + return s->five == value; + } +} + +int struct_test_point(struct bigStruct *s, enum struct_op op, + struct point *value) { + switch (op) { + case STRUCT_SET: + s->six = *value; + return 1; + case STRUCT_TEST: + return (s->six.x == value->x) && (s->six.y == value->y); + } +} diff --git a/tests/samples/Struct.h b/tests/samples/Struct.h index 2acd725..408a64d 100644 --- a/tests/samples/Struct.h +++ b/tests/samples/Struct.h @@ -26,8 +26,8 @@ struct bigStruct { int three; float four; double five; - point_s six; - int seven; + struct point six; + struct point *seven; int eight; int nine; int ten; @@ -59,3 +59,10 @@ struct structWithAnonymousStruct { char getCharFromAnonymousStruct(struct structWithAnonymousStruct *s); char getIntFromAnonymousStruct(struct structWithAnonymousStruct *s); + +enum struct_op { STRUCT_SET, STRUCT_TEST }; + +int struct_test_long(struct bigStruct *s, enum struct_op op, long value); +int struct_test_double(struct bigStruct *s, enum struct_op op, double value); +int struct_test_point(struct bigStruct *s, enum struct_op op, + struct point *value); diff --git a/tests/samples/Struct.scala b/tests/samples/Struct.scala index 1255ee7..3006d84 100644 --- a/tests/samples/Struct.scala +++ b/tests/samples/Struct.scala @@ -13,12 +13,16 @@ object Struct { type point_s = native.Ptr[struct_point] type struct_bigStruct = native.CArray[Byte, native.Nat.Digit[native.Nat._1, native.Nat.Digit[native.Nat._1, native.Nat._2]]] type struct_structWithAnonymousStruct = native.CStruct2[native.CInt, native.CArray[Byte, native.Nat._8]] + type enum_struct_op = native.CUnsignedInt def setPoints(points: native.Ptr[struct_points], x1: native.CInt, y1: native.CInt, x2: native.CInt, y2: native.CInt): Unit = native.extern def getPoint(points: native.Ptr[struct_points], pointIndex: enum_pointIndex): native.CInt = native.extern def createPoint(): native.Ptr[struct_point] = native.extern def getBigStructSize(): native.CInt = native.extern def getCharFromAnonymousStruct(s: native.Ptr[struct_structWithAnonymousStruct]): native.CChar = native.extern def getIntFromAnonymousStruct(s: native.Ptr[struct_structWithAnonymousStruct]): native.CChar = native.extern + def struct_test_long(s: native.Ptr[struct_bigStruct], op: enum_struct_op, value: native.CLong): native.CInt = native.extern + def struct_test_double(s: native.Ptr[struct_bigStruct], op: enum_struct_op, value: native.CDouble): native.CInt = native.extern + def struct_test_point(s: native.Ptr[struct_bigStruct], op: enum_struct_op, value: native.Ptr[struct_point]): native.CInt = native.extern } import Struct._ @@ -28,6 +32,9 @@ object StructEnums { final val enum_pointIndex_Y1: enum_pointIndex = 1.toUInt final val enum_pointIndex_X2: enum_pointIndex = 2.toUInt final val enum_pointIndex_Y2: enum_pointIndex = 3.toUInt + + final val enum_struct_op_STRUCT_SET: enum_struct_op = 0.toUInt + final val enum_struct_op_STRUCT_TEST: enum_struct_op = 1.toUInt } object StructHelpers { @@ -50,6 +57,57 @@ object StructHelpers { def struct_points()(implicit z: native.Zone): native.Ptr[struct_points] = native.alloc[struct_points] + implicit class struct_bigStruct_ops(val p: native.Ptr[struct_bigStruct]) extends AnyVal { + def one: native.CLong = !p._1.cast[native.Ptr[native.CLong]] + def one_=(value: native.CLong): Unit = !p._1.cast[native.Ptr[native.CLong]] = value + def two: native.CChar = !(p._1 + 8).cast[native.Ptr[native.CChar]] + def two_=(value: native.CChar): Unit = !(p._1 + 8).cast[native.Ptr[native.CChar]] = value + def three: native.CInt = !(p._1 + 12).cast[native.Ptr[native.CInt]] + def three_=(value: native.CInt): Unit = !(p._1 + 12).cast[native.Ptr[native.CInt]] = value + def four: native.CFloat = !(p._1 + 16).cast[native.Ptr[native.CFloat]] + def four_=(value: native.CFloat): Unit = !(p._1 + 16).cast[native.Ptr[native.CFloat]] = value + def five: native.CDouble = !(p._1 + 24).cast[native.Ptr[native.CDouble]] + def five_=(value: native.CDouble): Unit = !(p._1 + 24).cast[native.Ptr[native.CDouble]] = value + def six: native.Ptr[struct_point] = (p._1 + 32).cast[native.Ptr[struct_point]] + def six_=(value: native.Ptr[struct_point]): Unit = !(p._1 + 32).cast[native.Ptr[struct_point]] = !value + def seven: native.Ptr[struct_point] = !(p._1 + 40).cast[native.Ptr[native.Ptr[struct_point]]] + def seven_=(value: native.Ptr[struct_point]): Unit = !(p._1 + 40).cast[native.Ptr[native.Ptr[struct_point]]] = value + def eight: native.CInt = !(p._1 + 48).cast[native.Ptr[native.CInt]] + def eight_=(value: native.CInt): Unit = !(p._1 + 48).cast[native.Ptr[native.CInt]] = value + def nine: native.CInt = !(p._1 + 52).cast[native.Ptr[native.CInt]] + def nine_=(value: native.CInt): Unit = !(p._1 + 52).cast[native.Ptr[native.CInt]] = value + def ten: native.CInt = !(p._1 + 56).cast[native.Ptr[native.CInt]] + def ten_=(value: native.CInt): Unit = !(p._1 + 56).cast[native.Ptr[native.CInt]] = value + def eleven: native.CInt = !(p._1 + 60).cast[native.Ptr[native.CInt]] + def eleven_=(value: native.CInt): Unit = !(p._1 + 60).cast[native.Ptr[native.CInt]] = value + def twelve: native.CInt = !(p._1 + 64).cast[native.Ptr[native.CInt]] + def twelve_=(value: native.CInt): Unit = !(p._1 + 64).cast[native.Ptr[native.CInt]] = value + def thirteen: native.CInt = !(p._1 + 68).cast[native.Ptr[native.CInt]] + def thirteen_=(value: native.CInt): Unit = !(p._1 + 68).cast[native.Ptr[native.CInt]] = value + def fourteen: native.CInt = !(p._1 + 72).cast[native.Ptr[native.CInt]] + def fourteen_=(value: native.CInt): Unit = !(p._1 + 72).cast[native.Ptr[native.CInt]] = value + def fifteen: native.CInt = !(p._1 + 76).cast[native.Ptr[native.CInt]] + def fifteen_=(value: native.CInt): Unit = !(p._1 + 76).cast[native.Ptr[native.CInt]] = value + def sixteen: native.CInt = !(p._1 + 80).cast[native.Ptr[native.CInt]] + def sixteen_=(value: native.CInt): Unit = !(p._1 + 80).cast[native.Ptr[native.CInt]] = value + def seventeen: native.CInt = !(p._1 + 84).cast[native.Ptr[native.CInt]] + def seventeen_=(value: native.CInt): Unit = !(p._1 + 84).cast[native.Ptr[native.CInt]] = value + def eighteen: native.CInt = !(p._1 + 88).cast[native.Ptr[native.CInt]] + def eighteen_=(value: native.CInt): Unit = !(p._1 + 88).cast[native.Ptr[native.CInt]] = value + def nineteen: native.CInt = !(p._1 + 92).cast[native.Ptr[native.CInt]] + def nineteen_=(value: native.CInt): Unit = !(p._1 + 92).cast[native.Ptr[native.CInt]] = value + def twenty: native.CInt = !(p._1 + 96).cast[native.Ptr[native.CInt]] + def twenty_=(value: native.CInt): Unit = !(p._1 + 96).cast[native.Ptr[native.CInt]] = value + def twentyOne: native.CInt = !(p._1 + 100).cast[native.Ptr[native.CInt]] + def twentyOne_=(value: native.CInt): Unit = !(p._1 + 100).cast[native.Ptr[native.CInt]] = value + def twentyTwo: native.CInt = !(p._1 + 104).cast[native.Ptr[native.CInt]] + def twentyTwo_=(value: native.CInt): Unit = !(p._1 + 104).cast[native.Ptr[native.CInt]] = value + def twentyThree: native.CInt = !(p._1 + 108).cast[native.Ptr[native.CInt]] + def twentyThree_=(value: native.CInt): Unit = !(p._1 + 108).cast[native.Ptr[native.CInt]] = value + } + + def struct_bigStruct()(implicit z: native.Zone): native.Ptr[struct_bigStruct] = native.alloc[struct_bigStruct] + implicit class struct_structWithAnonymousStruct_ops(val p: native.Ptr[struct_structWithAnonymousStruct]) extends AnyVal { def a: native.CInt = !p._1 def a_=(value: native.CInt): Unit = !p._1 = value diff --git a/tests/samples/src/test/scala/org/scalanative/bindgen/samples/StructTests.scala b/tests/samples/src/test/scala/org/scalanative/bindgen/samples/StructTests.scala index a06f5a1..1fccacb 100644 --- a/tests/samples/src/test/scala/org/scalanative/bindgen/samples/StructTests.scala +++ b/tests/samples/src/test/scala/org/scalanative/bindgen/samples/StructTests.scala @@ -60,5 +60,49 @@ object StructTests extends TestSuite { assert(42 == Struct.getIntFromAnonymousStruct(structWithAnonymousStruct)) } } + + 'getFieldOfBigStruct - { + Zone { implicit zone: Zone => + val structPtr = alloc[Struct.struct_bigStruct] + for (value <- Seq(Long.MinValue, -1, 0, 1, Long.MaxValue)) { + Struct.struct_test_long(structPtr, enum_struct_op_STRUCT_SET, value) + assert(structPtr.one == value) + } + + for (value <- Seq(Double.MinValue, -1, 0, 1, Double.MaxValue)) { + Struct.struct_test_double(structPtr, enum_struct_op_STRUCT_SET, value) + assert(structPtr.five == value) + } + + val pointPtr = alloc[Struct.point] + pointPtr.x = 5 + pointPtr.y = 10 + + Struct.struct_test_point(structPtr, enum_struct_op_STRUCT_SET, pointPtr) + assert(structPtr.six.x == pointPtr.x) + assert(structPtr.six.y == pointPtr.y) + } + } + + 'setFieldOfBigStruct - { + Zone { implicit zone: Zone => + val structPtr = alloc[Struct.struct_bigStruct] + for (value <- Seq(Long.MinValue, -1, 0, 1, Long.MaxValue)) { + structPtr.one = value + assert(Struct.struct_test_long(structPtr, enum_struct_op_STRUCT_TEST, value) == 1) + } + + for (value <- Seq(Double.MinValue, -1, 0, 1, Double.MaxValue)) { + structPtr.five = value + assert(Struct.struct_test_double(structPtr, enum_struct_op_STRUCT_TEST, value) == 1) + } + + val pointPtr = alloc[Struct.point] + pointPtr.x = 5 + pointPtr.y = 10 + structPtr.six = pointPtr + assert(Struct.struct_test_point(structPtr, enum_struct_op_STRUCT_TEST, pointPtr) == 1) + } + } } }