diff --git a/core/UserMessagePBHelpers.h b/core/UserMessagePBHelpers.h index 71d2f9570..eb7fa4ef4 100644 --- a/core/UserMessagePBHelpers.h +++ b/core/UserMessagePBHelpers.h @@ -56,6 +56,15 @@ return false; \ } +#define CHECK_FIELD_TYPE3(type1, type2, type3) \ + protobuf::FieldDescriptor::CppType fieldType = field->cpp_type(); \ + if (fieldType != protobuf::FieldDescriptor::CPPTYPE_##type1 \ + && fieldType != protobuf::FieldDescriptor::CPPTYPE_##type2 \ + && fieldType != protobuf::FieldDescriptor::CPPTYPE_##type3) \ + { \ + return false; \ + } + #define CHECK_FIELD_REPEATED() \ if (field->label() != protobuf::FieldDescriptor::LABEL_REPEATED) \ { \ @@ -316,74 +325,114 @@ public: return true; } - inline bool GetInt32OrUnsigned(const char *pszFieldName, int32 *out) + inline bool GetInt32OrUnsignedOrEnum(const char *pszFieldName, int32 *out) { GETCHECK_FIELD(); - CHECK_FIELD_TYPE2(INT32, UINT32); + CHECK_FIELD_TYPE3(INT32, UINT32, ENUM); CHECK_FIELD_NOT_REPEATED(); if (fieldType == protobuf::FieldDescriptor::CPPTYPE_UINT32) *out = (int32)msg->GetReflection()->GetUInt32(*msg, field); - else + else if (fieldType == protobuf::FieldDescriptor::CPPTYPE_INT32) *out = msg->GetReflection()->GetInt32(*msg, field); + else // CPPTYPE_ENUM + *out = msg->GetReflection()->GetEnum(*msg, field)->number(); return true; } - inline bool SetInt32OrUnsigned(const char *pszFieldName, int32 value) + inline bool SetInt32OrUnsignedOrEnum(const char *pszFieldName, int32 value) { GETCHECK_FIELD(); - CHECK_FIELD_TYPE2(INT32, UINT32); + CHECK_FIELD_TYPE3(INT32, UINT32, ENUM); CHECK_FIELD_NOT_REPEATED(); if (fieldType == protobuf::FieldDescriptor::CPPTYPE_UINT32) + { msg->GetReflection()->SetUInt32(msg, field, (uint32)value); - else + } + else if (fieldType == protobuf::FieldDescriptor::CPPTYPE_INT32) + { msg->GetReflection()->SetInt32(msg, field, value); + } + else // CPPTYPE_ENUM + { + const protobuf::EnumValueDescriptor *pEnumValue = field->enum_type()->FindValueByNumber(value); + if (!pEnumValue) + return false; + + msg->GetReflection()->SetEnum(msg, field, pEnumValue); + } return true; } - inline bool GetRepeatedInt32OrUnsigned(const char *pszFieldName, int index, int32 *out) + inline bool GetRepeatedInt32OrUnsignedOrEnum(const char *pszFieldName, int index, int32 *out) { GETCHECK_FIELD(); - CHECK_FIELD_TYPE2(INT32, UINT32); + CHECK_FIELD_TYPE3(INT32, UINT32, ENUM); CHECK_FIELD_REPEATED(); CHECK_REPEATED_ELEMENT(index); if (fieldType == protobuf::FieldDescriptor::CPPTYPE_UINT32) *out = (int32)msg->GetReflection()->GetRepeatedUInt32(*msg, field, index); - else + else if (fieldType == protobuf::FieldDescriptor::CPPTYPE_INT32) *out = msg->GetReflection()->GetRepeatedInt32(*msg, field, index); + else // CPPTYPE_ENUM + *out = msg->GetReflection()->GetRepeatedEnum(*msg, field, index)->number(); return true; } - inline bool SetRepeatedInt32OrUnsigned(const char *pszFieldName, int index, int32 value) + inline bool SetRepeatedInt32OrUnsignedOrEnum(const char *pszFieldName, int index, int32 value) { GETCHECK_FIELD(); - CHECK_FIELD_TYPE2(INT32, UINT32); + CHECK_FIELD_TYPE3(INT32, UINT32, ENUM); CHECK_FIELD_REPEATED(); CHECK_REPEATED_ELEMENT(index); if (fieldType == protobuf::FieldDescriptor::CPPTYPE_UINT32) + { msg->GetReflection()->SetRepeatedUInt32(msg, field, index, (uint32)value); - else + } + else if (fieldType == protobuf::FieldDescriptor::CPPTYPE_INT32) + { msg->GetReflection()->SetRepeatedInt32(msg, field, index, value); + } + else // CPPTYPE_ENUM + { + const protobuf::EnumValueDescriptor *pEnumValue = field->enum_type()->FindValueByNumber(value); + if (!pEnumValue) + return false; + + msg->GetReflection()->SetRepeatedEnum(msg, field, index, pEnumValue); + } return true; } - inline bool AddInt32OrUnsigned(const char *pszFieldName, int32 value) + inline bool AddInt32OrUnsignedOrEnum(const char *pszFieldName, int32 value) { GETCHECK_FIELD(); - CHECK_FIELD_TYPE2(INT32, UINT32); + CHECK_FIELD_TYPE3(INT32, UINT32, ENUM); CHECK_FIELD_REPEATED(); if (fieldType == protobuf::FieldDescriptor::CPPTYPE_UINT32) + { msg->GetReflection()->AddUInt32(msg, field, (uint32)value); - else + } + else if (fieldType == protobuf::FieldDescriptor::CPPTYPE_INT32) + { msg->GetReflection()->AddInt32(msg, field, value); + } + else // CPPTYPE_ENUM + { + const protobuf::EnumValueDescriptor *pEnumValue = field->enum_type()->FindValueByNumber(value); + if (!pEnumValue) + return false; + + msg->GetReflection()->AddEnum(msg, field, pEnumValue); + } return true; } diff --git a/core/smn_protobuf.cpp b/core/smn_protobuf.cpp index b282bd404..081414b2a 100644 --- a/core/smn_protobuf.cpp +++ b/core/smn_protobuf.cpp @@ -70,14 +70,14 @@ static cell_t smn_PbReadInt(IPluginContext *pCtx, const cell_t *params) int index = params[0] >= 3 ? params[3] : -1; if (index < 0) { - if (!msg->GetInt32OrUnsigned(strField, &ret)) + if (!msg->GetInt32OrUnsignedOrEnum(strField, &ret)) { return pCtx->ThrowNativeError("Invalid field \"%s\" for message \"%s\"", strField, msg->GetProtobufMessage()->GetTypeName().c_str()); } } else { - if (!msg->GetRepeatedInt32OrUnsigned(strField, index, &ret)) + if (!msg->GetRepeatedInt32OrUnsignedOrEnum(strField, index, &ret)) { return pCtx->ThrowNativeError("Invalid field \"%s\"[%d] for message \"%s\"", strField, index, msg->GetProtobufMessage()->GetTypeName().c_str()); } @@ -314,7 +314,7 @@ static cell_t smn_PbReadRepeatedInt(IPluginContext *pCtx, const cell_t *params) GET_FIELD_NAME_OR_ERR(); int ret; - if (!msg->GetRepeatedInt32OrUnsigned(strField, params[3], &ret)) + if (!msg->GetRepeatedInt32OrUnsignedOrEnum(strField, params[3], &ret)) { return pCtx->ThrowNativeError("Invalid field \"%s\"[%d] for message \"%s\"", strField, params[3], msg->GetProtobufMessage()->GetTypeName().c_str()); } @@ -458,14 +458,14 @@ static cell_t smn_PbSetInt(IPluginContext *pCtx, const cell_t *params) int index = params[0] >= 4 ? params[4] : -1; if (index < 0) { - if (!msg->SetInt32OrUnsigned(strField, params[3])) + if (!msg->SetInt32OrUnsignedOrEnum(strField, params[3])) { return pCtx->ThrowNativeError("Invalid field \"%s\" for message \"%s\"", strField, msg->GetProtobufMessage()->GetTypeName().c_str()); } } else { - if (!msg->SetRepeatedInt32OrUnsigned(strField, index, params[3])) + if (!msg->SetRepeatedInt32OrUnsignedOrEnum(strField, index, params[3])) { return pCtx->ThrowNativeError("Invalid field \"%s\"[%d] for message \"%s\"", strField, index, msg->GetProtobufMessage()->GetTypeName().c_str()); } @@ -703,7 +703,7 @@ static cell_t smn_PbAddInt(IPluginContext *pCtx, const cell_t *params) GET_MSG_FROM_HANDLE_OR_ERR(); GET_FIELD_NAME_OR_ERR(); - if (!msg->AddInt32OrUnsigned(strField, params[3])) + if (!msg->AddInt32OrUnsignedOrEnum(strField, params[3])) { return pCtx->ThrowNativeError("Invalid field \"%s\" for message \"%s\"", strField, msg->GetProtobufMessage()->GetTypeName().c_str()); } diff --git a/plugins/include/protobuf.inc b/plugins/include/protobuf.inc index a24e80592..4cb06dbd3 100644 --- a/plugins/include/protobuf.inc +++ b/plugins/include/protobuf.inc @@ -38,7 +38,7 @@ #define PB_FIELD_NOT_REPEATED -1 /** - * Reads an int32, uint32, sint32, fixed32, or sfixed32 from a protobuf message. + * Reads an int32, uint32, sint32, fixed32, sfixed32, or enum value from a protobuf message. * * @param pb protobuf handle. * @param field Field name. @@ -244,7 +244,7 @@ native PbReadRepeatedVector(Handle:pb, const String:field[], index, Float:buffer native PbReadRepeatedVector2D(Handle:pb, const String:field[], index, Float:buffer[2]); /** - * Sets an int32, uint32, sint32, fixed32, or sfixed32 on a protobuf message. + * Sets an int32, uint32, sint32, fixed32, sfixed32, or enum value on a protobuf message. * * @param pb protobuf handle. * @param field Field name. @@ -340,7 +340,7 @@ native PbSetVector(Handle:pb, const String:field[], const Float:vec[3], index=PB native PbSetVector2D(Handle:pb, const String:field[], const Float:vec[2], index=PB_FIELD_NOT_REPEATED); /** - * Add an int32, uint32, sint32, fixed32, or sfixed32 to a protobuf message repeated field. + * Add an int32, uint32, sint32, fixed32, sfixed32, or enum value to a protobuf message repeated field. * * @param pb protobuf handle. * @param field Field name.