diff --git a/produce.go b/produce.go index 46c0ec8..6df0da6 100644 --- a/produce.go +++ b/produce.go @@ -16,24 +16,26 @@ import ( ) type produceConfig struct { - topic string - brokers []string - batch int - timeout time.Duration - verbose bool - args struct { - topic string - brokers string - batch int - timeout time.Duration - verbose bool + topic string + brokers []string + batch int + timeout time.Duration + verbose bool + partitioner string + args struct { + topic string + brokers string + batch int + timeout time.Duration + verbose bool + partitioner string } } type message struct { Key *string `json:"key"` Value *string `json:"value"` - Partition int32 `json:"partition"` + Partition *int32 `json:"partition"` } func produceFlags() *flag.FlagSet { @@ -68,6 +70,12 @@ func produceFlags() *flag.FlagSet { false, "Verbose output", ) + flags.StringVar( + &config.produce.args.partitioner, + "partitioner", + "", + "Optional partitioner to use. Available: hashCode", + ) flags.Usage = func() { fmt.Fprintln(os.Stderr, "Usage of produce:") @@ -244,7 +252,7 @@ func produceCommand() command { var wg sync.WaitGroup wg.Add(4) go readInput(&wg, closer, stdin, lines) - go deserializeLines(&wg, lines, messages) + go deserializeLines(&wg, lines, messages, int32(len(leaders))) go batchRecords(&wg, messages, batchedMessages) go produce(&wg, leaders, batchedMessages) @@ -253,7 +261,7 @@ func produceCommand() command { } } -func deserializeLines(wg *sync.WaitGroup, in chan string, out chan message) { +func deserializeLines(wg *sync.WaitGroup, in chan string, out chan message, partitionCount int32) { defer func() { close(out) wg.Done() @@ -268,10 +276,23 @@ func deserializeLines(wg *sync.WaitGroup, in chan string, out chan message) { var msg message if err := json.Unmarshal([]byte(l), &msg); err != nil { if config.produce.verbose { - fmt.Printf("Failed to unmarshal input, falling back to defaults. err=%v\n", err) + fmt.Printf("Failed to unmarshal input [%v], falling back to defaults. err=%v\n", l, err) + } + var v *string = &l + if len(l) == 0 { + v = nil } - msg = message{Key: &l, Value: &l, Partition: 0} + msg = message{Key: nil, Value: v} } + + var p int32 = 0 + if msg.Key != nil && config.produce.partitioner == "hashCode" { + p = hashCodePartition(*msg.Key, partitionCount) + } + if msg.Partition == nil { + msg.Partition = &p + } + out <- msg } } @@ -327,9 +348,9 @@ func (m message) asSaramaMessage() *sarama.Message { func produceBatch(leaders map[int32]*sarama.Broker, batch []message) error { requests := map[*sarama.Broker]*sarama.ProduceRequest{} for _, msg := range batch { - broker, ok := leaders[msg.Partition] + broker, ok := leaders[*msg.Partition] if !ok { - err := fmt.Errorf("Non-configured partition %v", msg.Partition) + err := fmt.Errorf("Non-configured partition %v", *msg.Partition) fmt.Fprintf(os.Stderr, "%v.\n", err) return err } @@ -339,7 +360,7 @@ func produceBatch(leaders map[int32]*sarama.Broker, batch []message) error { requests[broker] = req } - req.AddMessage(config.produce.topic, msg.Partition, msg.asSaramaMessage()) + req.AddMessage(config.produce.topic, *msg.Partition, msg.asSaramaMessage()) } for broker, req := range requests { diff --git a/produce_test.go b/produce_test.go index 43b1ebc..908b75d 100644 --- a/produce_test.go +++ b/produce_test.go @@ -3,7 +3,9 @@ package main import ( "os" "reflect" + "sync" "testing" + "time" ) func TestHashCode(t *testing.T) { @@ -184,3 +186,68 @@ func TestProduceParseArgs(t *testing.T) { return } } + +func newMessage(key, value string, partition int32) message { + var k *string + if key != "" { + k = &key + } + + var v *string + if value != "" { + v = &value + } + + return message{ + Key: k, + Value: v, + Partition: &partition, + } +} + +func TestDeserializeLines(t *testing.T) { + config.produce.partitioner = "hashCode" + data := []struct { + in string + partitionCount int32 + expected message + }{ + { + in: "", + partitionCount: 1, + expected: newMessage("", "", 0), + }, + { + in: `{"key":"hans","value":"123"}`, + partitionCount: 4, + expected: newMessage("hans", "123", hashCodePartition("hans", 4)), + }, + { + in: `{"key":"hans","value":"123","partition":1}`, + partitionCount: 3, + expected: newMessage("hans", "123", 1), + }, + { + in: `so lange schon`, + partitionCount: 3, + expected: newMessage("", "so lange schon", 0), + }, + } + + for _, d := range data { + var wg sync.WaitGroup + in := make(chan string, 1) + out := make(chan message) + go deserializeLines(&wg, in, out, d.partitionCount) + in <- d.in + + select { + case <-time.After(50 * time.Millisecond): + t.Errorf("did not receive output in time") + case actual := <-out: + if !reflect.DeepEqual(d.expected, actual) { + t.Errorf("\nexpected %#v\nactual %#v", d.expected, actual) + } + } + } +}