diff --git a/utils/broadcast/broadcast.go b/utils/broadcast/broadcast.go new file mode 100644 index 00000000..4d450120 --- /dev/null +++ b/utils/broadcast/broadcast.go @@ -0,0 +1,73 @@ +package broadcast + +type Broadcast struct { + listeners []chan interface{} + reg chan (chan interface{}) + unreg chan (chan interface{}) + in chan interface{} + stop chan int64 + stopStatus bool +} + +func NewBroadcast() *Broadcast { + b := &Broadcast{ + listeners: make([]chan interface{}, 0), + reg: make(chan (chan interface{})), + unreg: make(chan (chan interface{})), + in: make(chan interface{}), + stop: make(chan int64), + stopStatus: false, + } + + go func() { + for { + select { + case l := <-b.unreg: + // remove L from b.listeners + // this operation is slow: O(n) but not used frequently + // unlike iterating over listeners + oldListeners := b.listeners + b.listeners = make([]chan interface{}, 0, len(oldListeners)) + for _, oldL := range oldListeners { + if l != oldL { + b.listeners = append(b.listeners, oldL) + } + } + + case l := <-b.reg: + b.listeners = append(b.listeners, l) + + case item := <-b.in: + for _, l := range b.listeners { + l <- item + } + + case _ = <-b.stop: + b.stopStatus = true + break + } + } + }() + + return b +} + +func (b *Broadcast) In() chan interface{} { + return b.in +} + +func (b *Broadcast) Reg() chan interface{} { + listener := make(chan interface{}) + b.reg <- listener + return listener +} + +func (b *Broadcast) UnReg(listener chan interface{}) { + b.unreg <- listener +} + +func (b *Broadcast) Close() { + if b.stopStatus == false { + b.stop <- 1 + } +} diff --git a/utils/broadcast/broadcast_test.go b/utils/broadcast/broadcast_test.go new file mode 100644 index 00000000..3354adc8 --- /dev/null +++ b/utils/broadcast/broadcast_test.go @@ -0,0 +1,63 @@ +package broadcast + +import ( + "sync" + "testing" + "time" +) + +var ( + totalNum int = 5 + succNum int = 0 + mutex sync.Mutex +) + +func TestBroadcast(t *testing.T) { + b := NewBroadcast() + if b == nil { + t.Errorf("New Broadcast error, nil return") + } + defer b.Close() + + var wait sync.WaitGroup + wait.Add(totalNum) + for i := 0; i < totalNum; i++ { + go worker(b, &wait) + } + + time.Sleep(1e6 * 20) + msg := "test" + b.In() <- msg + + wait.Wait() + if succNum != totalNum { + t.Errorf("TotalNum %d, FailNum(timeout) %d", totalNum, totalNum-succNum) + } +} + +func worker(b *Broadcast, wait *sync.WaitGroup) { + defer wait.Done() + msgChan := b.Reg() + + // exit if nothing got in 2 seconds + timeout := make(chan bool, 1) + go func() { + time.Sleep(time.Duration(2) * time.Second) + timeout <- true + }() + + select { + case item := <-msgChan: + msg := item.(string) + if msg == "test" { + mutex.Lock() + succNum++ + mutex.Unlock() + } else { + break + } + + case <-timeout: + break + } +}