From 308688565549f1e5874cca97c8592a4a1fa71eaf Mon Sep 17 00:00:00 2001 From: Matthew Rich Date: Sat, 17 Aug 2024 18:16:02 -0700 Subject: [PATCH] add protobuf support --- internal/codec/decoder_test.go | 27 ++++- internal/codec/encoder.go | 26 +++++ internal/codec/encoder_test.go | 52 ++++++++++ internal/codec/testuser.pb.go | 179 +++++++++++++++++++++++++++++++++ internal/codec/testuser.proto | 13 +++ internal/codec/types.go | 9 +- internal/codec/types_test.go | 36 +++++++ 7 files changed, 337 insertions(+), 5 deletions(-) create mode 100644 internal/codec/testuser.pb.go create mode 100644 internal/codec/testuser.proto diff --git a/internal/codec/decoder_test.go b/internal/codec/decoder_test.go index d4c9192..4e98623 100644 --- a/internal/codec/decoder_test.go +++ b/internal/codec/decoder_test.go @@ -83,7 +83,7 @@ func TestNewJSONStringDecoder(t *testing.T) { } func TestNewDecoder(t *testing.T) { - pbData, err := proto.Marshal(&TestUser{ Name: "pb", Uid: "15001", Group: "15005", Home: "/home/pb", State: "present" }) + pbData, err := proto.Marshal(&TestPBUser{ Name: "pb", Uid: "15001", Group: "15005", Home: "/home/pb", State: "present" }) assert.Nil(t, err) for _, v := range []struct{ reader io.Reader; format Format; expectedhome string } { { reader: strings.NewReader(`{ @@ -104,8 +104,31 @@ func TestNewDecoder(t *testing.T) { decoder := NewDecoder(v.reader, v.format) assert.NotNil(t, decoder) - u := &TestUser{} + u := &TestPBUser{} assert.Nil(t, decoder.Decode(u)) assert.Equal(t, v.expectedhome, u.Home ) } } + +func TestNewDecoderError(t *testing.T) { + pbData, err := proto.Marshal(&TestPBUser{ Name: "pb", Uid: "15001", Group: "15005", Home: "/home/pb", State: "present" }) + assert.Nil(t, err) + + decoder := NewDecoder(bytes.NewReader(pbData), Format("foo")) + assert.Nil(t, decoder) +} + +func TestNewStringDecoder(t *testing.T) { + jsonDoc := `{ + "name": "testuser", + "uid": "12001", + "group": "12001", + "home": "/home/testuser", + "state": "present" }` + decoder := NewStringDecoder(jsonDoc, FormatJson) + assert.NotNil(t, decoder) + u := &TestUser{} + assert.Nil(t, decoder.Decode(u)) + assert.Equal(t, "testuser", u.Name) + +} diff --git a/internal/codec/encoder.go b/internal/codec/encoder.go index 80c09bd..e12e05c 100644 --- a/internal/codec/encoder.go +++ b/internal/codec/encoder.go @@ -9,8 +9,12 @@ _ "github.com/xeipuuv/gojsonschema" "gopkg.in/yaml.v3" "io" _ "log" + "errors" + "google.golang.org/protobuf/proto" ) +var ErrInvalidWriter error = errors.New("Invalid writer") + type JSONEncoder json.Encoder type Encoder interface { @@ -38,7 +42,29 @@ func NewYAMLEncoder(w io.Writer) Encoder { return yaml.NewEncoder(w) } +type ProtoEncoder struct { + writer io.Writer +} + +func (p *ProtoEncoder) Encode(v any) (err error) { + var encoded []byte + encoded, err = proto.Marshal(v.(proto.Message)) + if err != nil { + return + } + + _, err = p.writer.Write(encoded) + return +} + +func (p *ProtoEncoder) Close() error { + return nil +} + func NewProtoBufEncoder(w io.Writer) Encoder { + if w != nil { + return &ProtoEncoder{ writer: w } + } return nil } diff --git a/internal/codec/encoder_test.go b/internal/codec/encoder_test.go index 4367db7..4fddea2 100644 --- a/internal/codec/encoder_test.go +++ b/internal/codec/encoder_test.go @@ -9,6 +9,9 @@ import ( "strings" "testing" "github.com/xeipuuv/gojsonschema" + "io" + "bytes" + "google.golang.org/protobuf/proto" ) type TestFile struct { @@ -56,3 +59,52 @@ schema:=` assert.True(t, result.Valid()) } + +func TestNewEncoder(t *testing.T) { + + pb := &TestPBUser{ Name: "pb", Uid: "15001", Group: "15005", Home: "/home/pb", State: "present" } + jx := &TestUser{ Name: "jx", Uid: "17001", Group: "17005", Home: "/home/jx", State: "present" } + + pbData, pbErr := proto.Marshal(pb) + assert.Nil(t, pbErr) + + for _, v := range []struct{ writer io.Writer; testuser any; format Format; expected []byte} { + { writer: &bytes.Buffer{}, testuser: jx, expected: []byte(`{"name":"jx","uid":"17001","group":"17005","home":"/home/jx","state":"present"} +`), format: FormatJson }, + { writer: &bytes.Buffer{}, testuser: jx, expected: []byte(`name: jx +uid: "17001" +group: "17005" +home: /home/jx +state: present +`), format: FormatYaml }, + { writer: &bytes.Buffer{}, testuser: pb, expected: pbData , format: FormatProtoBuf }, + } { + encoder := NewEncoder(v.writer, v.format) + assert.NotNil(t, encoder) + assert.Nil(t, encoder.Encode(v.testuser)) + assert.Equal(t, string(v.expected), string(v.writer.(*bytes.Buffer).Bytes())) + assert.Equal(t, v.expected, v.writer.(*bytes.Buffer).Bytes()) + assert.Nil(t, encoder.Close()) + } +} + +func TestNewEncoderError(t *testing.T) { + encoder := NewEncoder(&strings.Builder{}, Format("foo")) + assert.Nil(t, encoder) +} + +func TestNewProtobufError(t *testing.T) { + encoder := NewProtoBufEncoder(nil) + assert.Nil(t, encoder) +} + +/* +func TestProtobufEncodeError(t *testing.T) { + buf := &bytes.Buffer{} + buf.Write([]byte("broken input")) + + encoder := NewProtoBufEncoder(buf) + assert.NotNil(t, encoder) + assert.NotNil(t, encoder.Encode(&TestPBUser{})) +} +*/ diff --git a/internal/codec/testuser.pb.go b/internal/codec/testuser.pb.go new file mode 100644 index 0000000..b4a04d9 --- /dev/null +++ b/internal/codec/testuser.pb.go @@ -0,0 +1,179 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.34.2 +// protoc v3.12.4 +// source: testuser.proto + +package codec + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type TestPBUser struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Name string `protobuf:"bytes,1,opt,name=Name,proto3" json:"Name,omitempty"` + Uid string `protobuf:"bytes,2,opt,name=Uid,proto3" json:"Uid,omitempty"` + Group string `protobuf:"bytes,3,opt,name=Group,proto3" json:"Group,omitempty"` + Home string `protobuf:"bytes,4,opt,name=Home,proto3" json:"Home,omitempty"` + State string `protobuf:"bytes,5,opt,name=State,proto3" json:"State,omitempty"` +} + +func (x *TestPBUser) Reset() { + *x = TestPBUser{} + if protoimpl.UnsafeEnabled { + mi := &file_testuser_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *TestPBUser) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TestPBUser) ProtoMessage() {} + +func (x *TestPBUser) ProtoReflect() protoreflect.Message { + mi := &file_testuser_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TestPBUser.ProtoReflect.Descriptor instead. +func (*TestPBUser) Descriptor() ([]byte, []int) { + return file_testuser_proto_rawDescGZIP(), []int{0} +} + +func (x *TestPBUser) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +func (x *TestPBUser) GetUid() string { + if x != nil { + return x.Uid + } + return "" +} + +func (x *TestPBUser) GetGroup() string { + if x != nil { + return x.Group + } + return "" +} + +func (x *TestPBUser) GetHome() string { + if x != nil { + return x.Home + } + return "" +} + +func (x *TestPBUser) GetState() string { + if x != nil { + return x.State + } + return "" +} + +var File_testuser_proto protoreflect.FileDescriptor + +var file_testuser_proto_rawDesc = []byte{ + 0x0a, 0x0e, 0x74, 0x65, 0x73, 0x74, 0x75, 0x73, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x12, 0x05, 0x63, 0x6f, 0x64, 0x65, 0x63, 0x22, 0x72, 0x0a, 0x0a, 0x54, 0x65, 0x73, 0x74, 0x50, + 0x42, 0x55, 0x73, 0x65, 0x72, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x69, 0x64, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x69, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x47, + 0x72, 0x6f, 0x75, 0x70, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x47, 0x72, 0x6f, 0x75, + 0x70, 0x12, 0x12, 0x0a, 0x04, 0x48, 0x6f, 0x6d, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x04, 0x48, 0x6f, 0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x53, 0x74, 0x61, 0x74, 0x65, 0x42, 0x15, 0x5a, 0x13, 0x64, + 0x65, 0x63, 0x6c, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x63, 0x6f, 0x64, + 0x65, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_testuser_proto_rawDescOnce sync.Once + file_testuser_proto_rawDescData = file_testuser_proto_rawDesc +) + +func file_testuser_proto_rawDescGZIP() []byte { + file_testuser_proto_rawDescOnce.Do(func() { + file_testuser_proto_rawDescData = protoimpl.X.CompressGZIP(file_testuser_proto_rawDescData) + }) + return file_testuser_proto_rawDescData +} + +var file_testuser_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_testuser_proto_goTypes = []any{ + (*TestPBUser)(nil), // 0: codec.TestPBUser +} +var file_testuser_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_testuser_proto_init() } +func file_testuser_proto_init() { + if File_testuser_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_testuser_proto_msgTypes[0].Exporter = func(v any, i int) any { + switch v := v.(*TestPBUser); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_testuser_proto_rawDesc, + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_testuser_proto_goTypes, + DependencyIndexes: file_testuser_proto_depIdxs, + MessageInfos: file_testuser_proto_msgTypes, + }.Build() + File_testuser_proto = out.File + file_testuser_proto_rawDesc = nil + file_testuser_proto_goTypes = nil + file_testuser_proto_depIdxs = nil +} diff --git a/internal/codec/testuser.proto b/internal/codec/testuser.proto new file mode 100644 index 0000000..dbcdf9b --- /dev/null +++ b/internal/codec/testuser.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; +package codec; +option go_package = "decl/internal/codec"; + + +message TestPBUser { + string Name = 1; + string Uid = 2; + string Group = 3; + string Home = 4; + string State = 5; +} + diff --git a/internal/codec/types.go b/internal/codec/types.go index 232beb3..55e5f8e 100644 --- a/internal/codec/types.go +++ b/internal/codec/types.go @@ -11,6 +11,7 @@ import ( ) const ( + FormatYml Format = "yml" FormatYaml Format = "yaml" FormatJson Format = "json" FormatProtoBuf Format = "protobuf" @@ -22,7 +23,7 @@ type Format string func (f *Format) Validate() error { switch *f { - case FormatYaml, FormatJson, FormatProtoBuf: + case FormatYml, FormatYaml, FormatJson, FormatProtoBuf: return nil default: return fmt.Errorf("%w: %s", ErrInvalidFormat, *f) @@ -31,19 +32,21 @@ func (f *Format) Validate() error { func (f *Format) Set(value string) (err error) { if err = (*Format)(&value).Validate(); err == nil { - *f = Format(value) + err = f.UnmarshalValue(value) } return } func (f *Format) UnmarshalValue(value string) error { switch value { + case string(FormatYml): + *f = FormatYaml case string(FormatYaml), string(FormatJson), string(FormatProtoBuf): *f = Format(value) - return nil default: return ErrInvalidFormat } + return nil } func (f *Format) UnmarshalJSON(data []byte) error { diff --git a/internal/codec/types_test.go b/internal/codec/types_test.go index 7d15461..017095d 100644 --- a/internal/codec/types_test.go +++ b/internal/codec/types_test.go @@ -7,6 +7,8 @@ _ "fmt" "github.com/stretchr/testify/assert" _ "log" "testing" + "strings" + "encoding/json" ) type TestDec struct { @@ -39,3 +41,37 @@ formattype: foo assert.ErrorIs(t, ErrInvalidFormat, e) } + +func TestFormatValidate(t *testing.T) { + f := FormatYaml + assert.Nil(t, f.Validate()) + + var fail Format = Format("foo") + assert.ErrorIs(t, fail.Validate(), ErrInvalidFormat) + + var testFormatSet Format + assert.Nil(t, testFormatSet.Set("yaml")) + + assert.ErrorIs(t, testFormatSet.Set("yamlv3"), ErrInvalidFormat) +} + +func TestFormatCodec(t *testing.T) { + var output map[string]Format = make(map[string]Format) + var writer strings.Builder + encoder := FormatYaml.Encoder(&writer) + assert.NotNil(t, encoder) + + decoder := FormatYaml.Decoder(strings.NewReader("formattype: json")) + assert.Nil(t, decoder.Decode(output)) + assert.Equal(t, FormatJson, output["formattype"]) +} + +func TestFormatUnmarshal(t *testing.T) { + var f Format + assert.Nil(t, json.Unmarshal([]byte("\"yaml\""), &f)) + assert.Equal(t, FormatYaml, f) + assert.NotNil(t, json.Unmarshal([]byte("\"yaml"), &f)) + + assert.Nil(t, json.Unmarshal([]byte("\"yml\""), &f)) + assert.Equal(t, FormatYaml, f) +}