| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114 |
- package network
- import (
- "encoding/binary"
- "errors"
- "fmt"
- "github.com/golang/protobuf/proto"
- "bet24.com/log"
- "math"
- "reflect"
- )
- // -------------------------
- // | id | protobuf message |
- // -------------------------
- type Processor struct {
- littleEndian bool
- msgInfo []*MsgInfo
- msgID map[reflect.Type]uint16
- }
- type MsgInfo struct {
- msgType reflect.Type
- }
- type MsgRaw struct {
- msgID uint16
- msgRawData []byte
- }
- func NewProcessor() *Processor {
- p := new(Processor)
- p.littleEndian = false
- p.msgID = make(map[reflect.Type]uint16)
- return p
- }
- // It's dangerous to call the method on routing or marshaling (unmarshaling)
- func (p *Processor) SetByteOrder(littleEndian bool) {
- p.littleEndian = littleEndian
- }
- // It's dangerous to call the method on routing or marshaling (unmarshaling)
- func (p *Processor) Register(msg proto.Message) uint16 {
- msgType := reflect.TypeOf(msg)
- if msgType == nil || msgType.Kind() != reflect.Ptr {
- log.Fatal("protobuf message pointer required")
- }
- if _, ok := p.msgID[msgType]; ok {
- log.Fatal("message %s is already registered", msgType)
- }
- if len(p.msgInfo) >= math.MaxUint16 {
- log.Fatal("too many protobuf messages (max = %v)", math.MaxUint16)
- }
- i := new(MsgInfo)
- i.msgType = msgType
- p.msgInfo = append(p.msgInfo, i)
- id := uint16(len(p.msgInfo) - 1)
- p.msgID[msgType] = id
- return id
- }
- // goroutine safe
- func (p *Processor) Unmarshal(data []byte) (interface{}, error) {
- if len(data) < 2 {
- return nil, errors.New("protobuf data too short")
- }
- // id
- var id uint16
- if p.littleEndian {
- id = binary.LittleEndian.Uint16(data)
- } else {
- id = binary.BigEndian.Uint16(data)
- }
- if id >= uint16(len(p.msgInfo)) {
- return nil, fmt.Errorf("message id %v not registered", id)
- }
- // msg
- i := p.msgInfo[id]
- msg := reflect.New(i.msgType.Elem()).Interface()
- return msg, proto.UnmarshalMerge(data[2:], msg.(proto.Message))
- }
- // goroutine safe
- func (p *Processor) Marshal(msg interface{}) ([][]byte, error) {
- msgType := reflect.TypeOf(msg)
- // id
- _id, ok := p.msgID[msgType]
- if !ok {
- err := fmt.Errorf("message %s not registered", msgType)
- return nil, err
- }
- id := make([]byte, 2)
- if p.littleEndian {
- binary.LittleEndian.PutUint16(id, _id)
- } else {
- binary.BigEndian.PutUint16(id, _id)
- }
- // data
- data, err := proto.Marshal(msg.(proto.Message))
- return [][]byte{id, data}, err
- }
- // goroutine safe
- func (p *Processor) Range(f func(id uint16, t reflect.Type)) {
- for id, i := range p.msgInfo {
- f(uint16(id), i.msgType)
- }
- }
|