Skip to content

Commit

Permalink
Add go-tfdata integration (tests)
Browse files Browse the repository at this point in the history
  • Loading branch information
knopt committed Apr 30, 2020
1 parent 21f6b3b commit b4c0f7e
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 13 deletions.
16 changes: 15 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ bin/tarp: $(cmds) $(datapipes)
bin/tarp -h

test:
cd datapipes && go test -v
cd dpipes && go test -v

test-tfdata:
cd dpipes && go test -v --tags=gitlabnvidia

dtest:
cd datapipes && debug=stdout go test -v | tee ../test.log
Expand All @@ -20,3 +23,14 @@ coverage:
cd datapipes && go test -coverprofile=c.out
cd datapipes && go tool cover -html=c.out -o coverage.html
firefox datapipes/coverage.html

tfdata-access:
# use ssh instead of https for go get gitlab-master.nvidia.com
@git config --global url."ssh://[email protected]:12051/aistorage/go-tfdata".insteadOf "https://github.com/NVIDIA/go-tfdata"
# inform go that go-tfdata is a private repo
@go env -w GOPRIVATE="github.com/NVIDIA/go-tfdata"

tfdata-local:
# use local go-tfdata instead of downloading from gitlab-master.nvidia.com
@grep "replace github.com/NVIDIA/go-tfdata => ../../../NVIDIA/go-tfdata" -q dpipes/go.mod || (echo "" >> dpipes/go.mod && \
echo "replace github.com/NVIDIA/go-tfdata => ../../../NVIDIA/go-tfdata" >>dpipes/go.mod)
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,26 @@ Future work:
- TFRecord/tf.Example interoperability
- add JSON input to "tarp create"
- add separator option to "tarp create"

# Building with private repositories

The `dpipes` module uses external dependencies from private repositories hosted on `gitlab-master.nvidia.com`.
Access to them is not required, but enables additional tarp features. At the moment tarp integrates with
`go-tfdata` - a Go library helping to work with tar/tgz archives and files in
[TFRecord and tf.Example formats](https://www.tensorflow.org/tutorials/load_data/tfrecord).

If you believe that you have the access to `gitlab-master.nvidia.com` and you wish to enable additional features
you should build tarp with [build tags](https://golang.org/pkg/go/build/#hdr-Build_Constraints) and do the following:

```console
make tfdata-access
```

>Make sure *not* to include private dependencies in requirements in go.mod files. `go mod tidy` will include them
>by default, so they have to be excluded (deleted) explicitly. More context in [this github issue](https://github.com/golang/go/issues/35832).
If you wish `tarp` to use local version of `go-tfdata` which you downloaded beforehand, do the following:
```console
make tfdata-local
```
It will modify `dpipes/go.mod`, by adding a `replace` directive.
2 changes: 1 addition & 1 deletion dpipes/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ require (
github.com/Masterminds/squirrel v1.2.0
github.com/mattn/go-sqlite3 v2.0.3+incompatible
github.com/shamaton/msgpack v1.1.1
github.com/stretchr/testify v1.2.2
github.com/stretchr/testify v1.3.0
gopkg.in/zeromq/goczmq.v4 v4.1.0
)
145 changes: 145 additions & 0 deletions dpipes/gotfdata_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
// +build gitlabnvidia

package dpipes

import (
"encoding/json"
"fmt"
"io"
"io/ioutil"
"os"
"testing"

"github.com/NVIDIA/go-tfdata/tfdata/core"
"github.com/NVIDIA/go-tfdata/tfdata/transform"
"github.com/stretchr/testify/assert"
)

type (
SamplesReader struct {
pipe Pipe
}
)

func (r *SamplesReader) Read() (sample *core.Sample, err error) {
s, ok := <-r.pipe
if !ok {
return nil, io.EOF
}

return tarpSampleToTfDataSample(s), nil
}

func TFRecordSink(t *testing.T, writer io.Writer) Sink {
return func(pipe Pipe) {
w := core.NewTFRecordWriter(writer)
samplesReader := &SamplesReader{pipe}
tfExamplesReader := transform.NewSamplesToTFExample(samplesReader)
err := w.WriteMessages(tfExamplesReader)

assert.NoError(t, err)
}
}

func TFRecordSource(t *testing.T, reader io.Reader) Source {
return func(pipe Pipe) {
defer close(pipe)
var (
ex *core.TFExample
err error
r core.TFExampleReader
)
r = core.NewTFRecordReader(reader)
for ex, err = r.Read(); err == nil; ex, err = r.Read() {
pipe <- tfExampleTarpSample(ex)
}
if err != io.EOF {
assert.Fail(t, "expected to get io.EOF, got %v instead", err)
}
}
}

func SamplesChecker(t *testing.T, target int) Process {
return func(in, out Pipe) {
total := 0
for s := range in {
assert.Equal(t, s["txt"], Bytes(fmt.Sprintf("%d", total)))
assert.Equal(t, s["__key__"], Bytes(fmt.Sprintf("%06d", total)))
total++
out <- s
}
close(out)
assert.Equal(t, target, total)
}
}

func tarpSampleToTfDataSample(sample Sample) *core.Sample {
s := core.NewSample()
for k, v := range sample {
s.Entries[k] = v
}
return s
}

func tfExampleTarpSample(example *core.TFExample) Sample {
s := make(map[string]Bytes, len(example.GetFeatures().Feature))
for k, v := range example.GetFeatures().Feature {
var b Bytes
err := json.Unmarshal(v.GetBytesList().Value[0], &b)
if err != nil {
panic(err)
}
s[k] = b // assume that all TFExample features are just a list of bytes
}
return s
}

func PrepareTarSource() Source {
return func(pipe Pipe) {
for i := 0; i < 1; i++ {
pipe <- Sample{
"__key__": Bytes(fmt.Sprintf("%06d", i)),
"txt": Bytes(fmt.Sprintf("%d", i)),
}
}
close(pipe)
}
}

func prepareTar(t *testing.T) *os.File {
var (
sinkFd *os.File
err error
)
sinkFd, err = ioutil.TempFile("", "go-tfdata-*.tar")
assert.NoError(t, err)

sink := TarSink(sinkFd)
Processing(PrepareTarSource(), nil, sink)
return sinkFd
}

func TestGoTfData(t *testing.T) {
var (
sourceFd = prepareTar(t)
sinkFd *os.File
err error
)

defer os.RemoveAll(sourceFd.Name())
sourceFd, err = os.Open(sourceFd.Name())
assert.NoError(t, err)

sinkFd, err = ioutil.TempFile("", "go-tfdata-*.tfrecord")
assert.NoError(t, err)
defer os.RemoveAll(sinkFd.Name())

Processing(TarSource(sourceFd), nil, TFRecordSink(t, sinkFd))
sinkFd.Close()
sourceFd, err = os.Open(sinkFd.Name())
assert.NoError(t, err)
sinkFd, err = os.OpenFile(os.DevNull, os.O_RDWR, os.ModeAppend)
assert.NoError(t, err)

Processing(TFRecordSource(t, sourceFd), SamplesChecker(t, 1), TFRecordSink(t, sinkFd))
}
12 changes: 8 additions & 4 deletions dpipes/rawtario.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ package dpipes

import (
"archive/tar"
"time"
"bytes"
"fmt"
"io"
"regexp"
"time"
)

// Raw is a struct representing unaggregated data items (e.g., from a tar file).
Expand Down Expand Up @@ -97,15 +97,19 @@ func TarRawSource(stream io.Reader) func(RawPipe) {
tr := tar.NewReader(stream)
for {
header, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
panic(err)
}
if header == nil {
break
}
if header.Typeflag != tar.TypeReg {
continue
}
if err != nil {
panic(err)
}

var buffer bytes.Buffer
io.Copy(&buffer, tr)
data := buffer.Bytes()
Expand Down
9 changes: 2 additions & 7 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,6 @@ module github.com/tmbdev/tarp

replace github.com/tmbdev/tarp/dpipes => ./dpipes

go 1.14
replace github.com/NVIDIA/go-tfdata => gitlab-master.nvidia.com/aistorage/go-tfdata v0.0.0-20200427194410-c20d8f9980a7

require (
github.com/bcicen/ctop v0.7.3 // indirect
github.com/jessevdk/go-flags v1.4.0
github.com/maruel/panicparse v1.3.0 // indirect
github.com/tmbdev/tarp/dpipes v0.0.0-20200330012711-53823ac810b9
)
go 1.14

0 comments on commit b4c0f7e

Please sign in to comment.