protobuf.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. package network
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "fmt"
  6. "github.com/golang/protobuf/proto"
  7. "bet24.com/log"
  8. "math"
  9. "reflect"
  10. )
  11. // -------------------------
  12. // | id | protobuf message |
  13. // -------------------------
  14. type Processor struct {
  15. littleEndian bool
  16. msgInfo []*MsgInfo
  17. msgID map[reflect.Type]uint16
  18. }
  19. type MsgInfo struct {
  20. msgType reflect.Type
  21. }
  22. type MsgRaw struct {
  23. msgID uint16
  24. msgRawData []byte
  25. }
  26. func NewProcessor() *Processor {
  27. p := new(Processor)
  28. p.littleEndian = false
  29. p.msgID = make(map[reflect.Type]uint16)
  30. return p
  31. }
  32. // It's dangerous to call the method on routing or marshaling (unmarshaling)
  33. func (p *Processor) SetByteOrder(littleEndian bool) {
  34. p.littleEndian = littleEndian
  35. }
  36. // It's dangerous to call the method on routing or marshaling (unmarshaling)
  37. func (p *Processor) Register(msg proto.Message) uint16 {
  38. msgType := reflect.TypeOf(msg)
  39. if msgType == nil || msgType.Kind() != reflect.Ptr {
  40. log.Fatal("protobuf message pointer required")
  41. }
  42. if _, ok := p.msgID[msgType]; ok {
  43. log.Fatal("message %s is already registered", msgType)
  44. }
  45. if len(p.msgInfo) >= math.MaxUint16 {
  46. log.Fatal("too many protobuf messages (max = %v)", math.MaxUint16)
  47. }
  48. i := new(MsgInfo)
  49. i.msgType = msgType
  50. p.msgInfo = append(p.msgInfo, i)
  51. id := uint16(len(p.msgInfo) - 1)
  52. p.msgID[msgType] = id
  53. return id
  54. }
  55. // goroutine safe
  56. func (p *Processor) Unmarshal(data []byte) (interface{}, error) {
  57. if len(data) < 2 {
  58. return nil, errors.New("protobuf data too short")
  59. }
  60. // id
  61. var id uint16
  62. if p.littleEndian {
  63. id = binary.LittleEndian.Uint16(data)
  64. } else {
  65. id = binary.BigEndian.Uint16(data)
  66. }
  67. if id >= uint16(len(p.msgInfo)) {
  68. return nil, fmt.Errorf("message id %v not registered", id)
  69. }
  70. // msg
  71. i := p.msgInfo[id]
  72. msg := reflect.New(i.msgType.Elem()).Interface()
  73. return msg, proto.UnmarshalMerge(data[2:], msg.(proto.Message))
  74. }
  75. // goroutine safe
  76. func (p *Processor) Marshal(msg interface{}) ([][]byte, error) {
  77. msgType := reflect.TypeOf(msg)
  78. // id
  79. _id, ok := p.msgID[msgType]
  80. if !ok {
  81. err := fmt.Errorf("message %s not registered", msgType)
  82. return nil, err
  83. }
  84. id := make([]byte, 2)
  85. if p.littleEndian {
  86. binary.LittleEndian.PutUint16(id, _id)
  87. } else {
  88. binary.BigEndian.PutUint16(id, _id)
  89. }
  90. // data
  91. data, err := proto.Marshal(msg.(proto.Message))
  92. return [][]byte{id, data}, err
  93. }
  94. // goroutine safe
  95. func (p *Processor) Range(f func(id uint16, t reflect.Type)) {
  96. for id, i := range p.msgInfo {
  97. f(uint16(id), i.msgType)
  98. }
  99. }