package network import ( "encoding/binary" "errors" "fmt" "github.com/golang/protobuf/proto" "bet24.com/log" "math" "reflect" ) // ------------------------- // | id | protobuf message | // ------------------------- type Processor struct { littleEndian bool msgInfo []*MsgInfo msgID map[reflect.Type]uint16 } type MsgInfo struct { msgType reflect.Type } type MsgRaw struct { msgID uint16 msgRawData []byte } func NewProcessor() *Processor { p := new(Processor) p.littleEndian = false p.msgID = make(map[reflect.Type]uint16) return p } // It's dangerous to call the method on routing or marshaling (unmarshaling) func (p *Processor) SetByteOrder(littleEndian bool) { p.littleEndian = littleEndian } // It's dangerous to call the method on routing or marshaling (unmarshaling) func (p *Processor) Register(msg proto.Message) uint16 { msgType := reflect.TypeOf(msg) if msgType == nil || msgType.Kind() != reflect.Ptr { log.Fatal("protobuf message pointer required") } if _, ok := p.msgID[msgType]; ok { log.Fatal("message %s is already registered", msgType) } if len(p.msgInfo) >= math.MaxUint16 { log.Fatal("too many protobuf messages (max = %v)", math.MaxUint16) } i := new(MsgInfo) i.msgType = msgType p.msgInfo = append(p.msgInfo, i) id := uint16(len(p.msgInfo) - 1) p.msgID[msgType] = id return id } // goroutine safe func (p *Processor) Unmarshal(data []byte) (interface{}, error) { if len(data) < 2 { return nil, errors.New("protobuf data too short") } // id var id uint16 if p.littleEndian { id = binary.LittleEndian.Uint16(data) } else { id = binary.BigEndian.Uint16(data) } if id >= uint16(len(p.msgInfo)) { return nil, fmt.Errorf("message id %v not registered", id) } // msg i := p.msgInfo[id] msg := reflect.New(i.msgType.Elem()).Interface() return msg, proto.UnmarshalMerge(data[2:], msg.(proto.Message)) } // goroutine safe func (p *Processor) Marshal(msg interface{}) ([][]byte, error) { msgType := reflect.TypeOf(msg) // id _id, ok := p.msgID[msgType] if !ok { err := fmt.Errorf("message %s not registered", msgType) return nil, err } id := make([]byte, 2) if p.littleEndian { binary.LittleEndian.PutUint16(id, _id) } else { binary.BigEndian.PutUint16(id, _id) } // data data, err := proto.Marshal(msg.(proto.Message)) return [][]byte{id, data}, err } // goroutine safe func (p *Processor) Range(f func(id uint16, t reflect.Type)) { for id, i := range p.msgInfo { f(uint16(id), i.msgType) } }