diff --git a/internal/transport/file.go b/internal/transport/file.go index 335fcbb..c16b159 100644 --- a/internal/transport/file.go +++ b/internal/transport/file.go @@ -6,11 +6,13 @@ import ( _ "errors" "path/filepath" "io" + "io/fs" "os" "net/url" "strings" "fmt" "compress/gzip" + "log/slog" ) type File struct { @@ -25,13 +27,93 @@ type File struct { gzipReader io.ReadCloser } +type FileReader struct { + *File + readHandle *os.File + gzipReader io.ReadCloser +} + +type FileWriter struct { + *File + writeHandle *os.File + gzipWriter io.WriteCloser +} + func FilePath(u *url.URL) string { return filepath.Join(u.Hostname(), u.Path) } func FileExists(u *url.URL) bool { _, err := os.Stat(FilePath(u)) - return err == nil + return !os.IsNotExist(err) +} + +func NewFileReader(u *url.URL) (f *FileReader, err error) { + f = &FileReader { + File: &File { + uri: u, + path: FilePath(u), + }, + } + f.extension() + f.DetectGzip() + exists := FileExists(u) + slog.Info("transport.NewFileReader()", "uri", u, "path", f.Path(), "file", f, "error", err, "exists", exists) + + if f.Path() == "" || f.Path() == "-" { + f.readHandle = os.Stdin + } else { + if f.readHandle, err = os.Open(f.Path()); err == nil { + var fi fs.FileInfo + fi, err = f.readHandle.Stat() + if fi.IsDir() { + f.readHandle.Close() + f.readHandle = nil + err = fmt.Errorf("is a directory") + } + } + + if err != nil { + slog.Info("transport.NewFileReader()", "file", f, "path", f.Path(), "error", err) + return + } + } + if f.Gzip() { + if exists { + if f.gzipReader, err = gzip.NewReader(f.readHandle); err != nil { + return + } + } + } + slog.Info("transport.NewFileReader() - created reader transport", "uri", u, "file", f, "error", err) + return +} + +func NewFileWriter(u *url.URL) (f *FileWriter, err error) { + f = &FileWriter { + File: &File { + uri: u, + path: FilePath(u), + }, + } + f.extension() + f.DetectGzip() + exists := FileExists(u) + slog.Info("transport.NewFileWriter()", "file", f, "error", err, "exists", exists) + + if f.Path() == "" || f.Path() == "-" { + f.writeHandle = os.Stdout + } else { + if f.writeHandle, err = os.OpenFile(f.Path(), os.O_RDWR|os.O_CREATE, 0644); err != nil { + slog.Info("transport.NewFileWriter()", "file", f, "path", f.Path(), "error", err) + return + } + } + if f.Gzip() { + f.gzipWriter = gzip.NewWriter(f.writeHandle) + } + slog.Info("transport.NewFileWriter()", "file", f, "error", err) + return } func NewFile(u *url.URL) (f *File, err error) { @@ -41,12 +123,16 @@ func NewFile(u *url.URL) (f *File, err error) { } f.extension() f.DetectGzip() + exists := FileExists(u) + slog.Info("transport.NewFile()", "file", f, "error", err, "exists", exists) if f.path == "" || f.path == "-" { f.readHandle = os.Stdin f.writeHandle = os.Stdout } else { + if f.readHandle, err = os.OpenFile(f.Path(), os.O_RDWR|os.O_CREATE, 0644); err != nil { + slog.Info("transport.NewFile()", "file", f, "path", f.Path(), "error", err) return } f.writeHandle = f.readHandle @@ -54,10 +140,13 @@ func NewFile(u *url.URL) (f *File, err error) { if f.Gzip() { f.gzipWriter = gzip.NewWriter(f.writeHandle) - if f.gzipReader, err = gzip.NewReader(f.readHandle); err != nil { - return + if exists { + if f.gzipReader, err = gzip.NewReader(f.readHandle); err != nil { + return + } } } + slog.Info("transport.NewFile()", "file", f, "error", err) return } @@ -68,8 +157,9 @@ func (f *File) extension() { if numberOfElements > 2 { f.exttype = elements[numberOfElements - 2] f.fileext = elements[numberOfElements - 1] + } else { + f.exttype = elements[numberOfElements - 1] } - f.exttype = elements[numberOfElements - 1] } } @@ -101,10 +191,25 @@ func (f *File) Signature() (documentSignature string) { } func (f *File) ContentType() string { + var ext strings.Builder if f.uri.Scheme != "file" { return f.uri.Scheme } - return f.exttype + if f.fileext == "" { + return f.exttype + } + ext.WriteString(f.exttype) + ext.WriteRune('.') + ext.WriteString(f.fileext) + return ext.String() +} + +func (f *File) Stat() (fs.FileInfo, error) { + return f.FileInfo() +} + +func (f *File) FileInfo() (info fs.FileInfo, err error) { + return os.Lstat(f.Path()) } func (f *File) SetGzip(gzip bool) { @@ -115,16 +220,25 @@ func (f *File) Gzip() bool { return f.gzip } -func (f *File) Reader() io.ReadCloser { +func (f *FileReader) Reader() io.ReadCloser { if f.Gzip() { + var err error + if f.gzipReader, err = gzip.NewReader(f.readHandle); err != nil { + panic(err) + } return f.gzipReader } return f.readHandle } -func (f *File) Writer() io.WriteCloser { +func (f *FileWriter) Writer() io.WriteCloser { if f.Gzip() { + f.gzipWriter = gzip.NewWriter(f.writeHandle) return f.gzipWriter } return f.writeHandle } + +func (f *File) ReadWriter() io.ReadWriteCloser { + return f.writeHandle +} diff --git a/internal/transport/http.go b/internal/transport/http.go index 8ef26dc..2fe34ea 100644 --- a/internal/transport/http.go +++ b/internal/transport/http.go @@ -10,26 +10,34 @@ _ "os" "net/http" "strings" "fmt" - "bytes" "context" "path/filepath" + "log/slog" + "io/fs" ) -type BufferCloser struct { - stream io.Closer - *bytes.Buffer +type Pipe struct { + Reader io.ReadCloser + Writer io.WriteCloser +} + +type HTTPConnection struct { + stream *Pipe + request *http.Request + response *http.Response + Client *http.Client } type HTTP struct { uri *url.URL path string + gzip bool exttype string fileext string - buffer BufferCloser - getRequest *http.Request - getResponse *http.Response - postRequest *http.Request - postResponse *http.Response + + ctx context.Context + get *HTTPConnection + post *HTTPConnection Client *http.Client } @@ -37,26 +45,99 @@ func HTTPExists(u *url.URL) bool { return false } -func (b BufferCloser) Close() error { - if b.stream != nil { - return b.stream.Close() +func NewPipe() *Pipe { + r,w := io.Pipe() + return &Pipe{ Reader: r, Writer: w } +} + +func NewHTTPConnection(client *http.Client) *HTTPConnection { + return &HTTPConnection { + Client: client, } - return nil +} + +func (h *HTTPConnection) NewPostRequest(ctx context.Context, uri string) (err error) { + h.stream = NewPipe() + h.request, err = http.NewRequestWithContext(ctx, "POST", uri, h.Reader()) + return +} + +func (h *HTTPConnection) NewGetRequest(ctx context.Context, uri string) (err error) { + h.request, err = http.NewRequestWithContext(ctx, "GET", uri, nil) + return +} + +func (h *HTTPConnection) Request() *http.Request { + return h.request +} + +func (h *HTTPConnection) Response() *http.Response { + return h.response +} + +func (h *HTTPConnection) Writer() io.WriteCloser { + return h.stream.Writer +} + +func (h *HTTPConnection) Reader() io.ReadCloser { + return h.stream.Reader +} + +func (h *HTTPConnection) Do() (err error) { + slog.Info("transport.HTTPConnection.Do()", "connection", h) + h.response, err = h.Client.Do(h.request) + return +} + +func (h *HTTPConnection) Read(p []byte) (n int, err error) { + if h.response == nil { + if err = h.Do(); err != nil { + return + } + } + return h.response.Body.Read(p) +} + +func (h *HTTPConnection) Write(p []byte) (n int, err error) { + if h.response == nil { + if err = h.Do(); err != nil { + return + } + } + slog.Info("transport.HTTPConnection.Write()", "data", p, "connection", h) + return h.Writer().Write(p) +} + +func (h *HTTPConnection) ReadFrom(r io.Reader) (n int64, err error) { + h.request.Body = r.(io.ReadCloser) + if h.response == nil { + if err = h.Do(); err != nil { + return + } + } + return h.request.ContentLength, nil +} + +func (h *HTTPConnection) Close() (err error) { + if h.response != nil { + defer h.response.Body.Close() + } + if h.stream != nil { + err = h.Writer().Close() + } + return } func NewHTTP(u *url.URL, ctx context.Context) (h *HTTP, err error) { h = &HTTP { + ctx: ctx, uri: u, path: filepath.Join(u.Hostname(), u.RequestURI()), Client: http.DefaultClient, } - h.extension() - h.postRequest, err = http.NewRequestWithContext(ctx, "POST", u.String(), h.buffer) - if err != nil { - return - } - h.getRequest, err = http.NewRequestWithContext(ctx, "GET", u.String(), nil) + h.extension() + h.DetectGzip() return } @@ -70,6 +151,10 @@ func (h *HTTP) extension() { h.exttype = elements[numberOfElements - 1] } +func (h *HTTP) DetectGzip() { + h.gzip = (h.uri.Query().Get("gzip") == "true" || h.fileext == "gz") +} + func (h *HTTP) URI() *url.URL { return h.uri } @@ -79,8 +164,8 @@ func (h *HTTP) Path() string { } func (h *HTTP) Signature() (documentSignature string) { - if h.getResponse != nil { - documentSignature = h.getResponse.Header.Get("Signature") + if h.get.Response() != nil { + documentSignature = h.get.Response().Header.Get("Signature") if documentSignature == "" { signatureResp, signatureErr := h.Client.Get(fmt.Sprintf("%s.sig", h.uri.String())) if signatureErr == nil { @@ -95,8 +180,12 @@ func (h *HTTP) Signature() (documentSignature string) { return documentSignature } +func (h *HTTP) Stat() (info fs.FileInfo, err error) { + return +} + func (h *HTTP) ContentType() (contenttype string) { - contenttype = h.getResponse.Header.Get("Content-Type") + contenttype = h.get.Response().Header.Get("Content-Type") switch contenttype { case "application/octet-stream": return h.exttype @@ -105,22 +194,50 @@ func (h *HTTP) ContentType() (contenttype string) { return } +func (h *HTTP) SetGzip(gzip bool) { + h.gzip = gzip +} + func (h *HTTP) Gzip() bool { - return h.fileext == "gz" + return h.gzip } func (h *HTTP) Reader() io.ReadCloser { - var err error - if h.getResponse, err = h.Client.Do(h.getRequest); err != nil { - panic(err) + if h.get == nil { + h.get = NewHTTPConnection(h.Client) + if err := h.get.NewGetRequest(h.ctx, h.uri.String()); err != nil { + panic(err) + } } - return h.getResponse.Body + return h.get } func (h *HTTP) Writer() io.WriteCloser { - var err error - if h.postResponse, err = h.Client.Do(h.postRequest); err != nil { - panic(err) + if h.post == nil { + h.post = NewHTTPConnection(h.Client) + if err := h.post.NewPostRequest(h.ctx, h.uri.String()); err != nil { + panic(err) + } } - return h.buffer + return h.post +} + +func (h *HTTP) ReadWriter() io.ReadWriteCloser { + return nil +} + +func (h *HTTP) GetRequest() *http.Request { + return h.get.Request() +} + +func (h *HTTP) GetResponse() *http.Response { + return h.get.Response() +} + +func (h *HTTP) PostRequest() *http.Request { + return h.post.Request() +} + +func (h *HTTP) PostResponse() *http.Response { + return h.post.Response() } diff --git a/internal/transport/http_test.go b/internal/transport/http_test.go index 99298a3..7e161ef 100644 --- a/internal/transport/http_test.go +++ b/internal/transport/http_test.go @@ -5,17 +5,102 @@ package transport import ( "github.com/stretchr/testify/assert" "testing" -_ "fmt" + "fmt" _ "os" + "io" "net/url" _ "path/filepath" "context" + "net/http" + "net/http/httptest" ) func TestNewTransportHTTPReader(t *testing.T) { - u, urlErr := url.Parse("https://localhost/resource") + //ctx := context.Background() + + body := []byte(` +type: "user" +attributes: + name: "foo" + gecos: "foo user" +`) + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + assert.Equal(t, req.URL.String(), "/resource/user") + n,e := rw.Write(body) + assert.Nil(t, e) + assert.Greater(t, n, 0) + assert.Equal(t, "bar", req.Header.Get("foo")) + })) + defer server.Close() + + u, urlErr := url.Parse(fmt.Sprintf("%s/resource/user", server.URL)) assert.Nil(t, urlErr) h, err := NewHTTP(u, context.Background()) assert.Nil(t, err) assert.NotNil(t, h) + h.Reader() + h.GetRequest().Header.Add("foo", "bar") + + resData, readErr := io.ReadAll(h.Reader()) + assert.Nil(t, readErr) + assert.Greater(t, len(resData), 0) + assert.Equal(t, body, resData) +} + +func TestNewTransportHTTPWriter(t *testing.T) { + body := []byte(` +type: "user" +attributes: + name: "foo" + gecos: "foo user" +`) + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + assert.Equal(t, req.URL.String(), "/resource/user") + n, postBody := io.ReadAll(req.Body) + + assert.Greater(t, n, 0) + assert.Equal(t, "bar", req.Header.Get("foo")) + assert.Equal(t, body, postBody) + })) + defer server.Close() + + u, urlErr := url.Parse(fmt.Sprintf("%s/resource/user", server.URL)) + assert.Nil(t, urlErr) + h, err := NewHTTP(u, context.Background()) + assert.Nil(t, err) + assert.NotNil(t, h) + h.Writer() + h.PostRequest().Header.Add("foo", "bar") + +// _, writeErr := h.Writer().Write(body) +// assert.Nil(t, writeErr) +} + +func TestNewHTTPConnection(t *testing.T) { + ctx := context.Background() + h := NewHTTPConnection(http.DefaultClient) + assert.NotNil(t, h) + + body := []byte(` +type: "user" +attributes: + name: "foo" + gecos: "foo user" +`) + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + assert.Equal(t, req.URL.String(), "/resource/user") + n,e := rw.Write(body) + assert.Nil(t, e) + assert.Greater(t, n, 0) + assert.Equal(t, "bar", req.Header.Get("foo")) + })) + defer server.Close() + + uri := fmt.Sprintf("%s/resource/user", server.URL) + + assert.Nil(t, h.NewGetRequest(ctx, uri)) + h.Request().Header.Add("foo", "bar") + responseData, responseErr := io.ReadAll(h) + assert.Nil(t, responseErr) + assert.Equal(t, body, responseData) } diff --git a/internal/transport/transport.go b/internal/transport/transport.go index f7bf40f..c1f0070 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -10,6 +10,7 @@ _ "net/http" _ "strings" _ "path/filepath" "io" + "io/fs" _ "os" "context" ) @@ -17,9 +18,17 @@ _ "os" type Handler interface { URI() *url.URL ContentType() string + SetGzip(bool) Gzip() bool Signature() string + Stat() (fs.FileInfo, error) +} + +type HandlerReader interface { Reader() io.ReadCloser +} + +type HandlerWriter interface { Writer() io.WriteCloser } @@ -31,7 +40,10 @@ type Reader struct { } func NewReader(u *url.URL) (reader *Reader, e error) { - ctx := context.Background() + return NewReaderWithContext(u, context.Background()) +} + +func NewReaderWithContext(u *url.URL, ctx context.Context) (reader *Reader, e error) { reader = &Reader{ uri: u } switch u.Scheme { case "http", "https": @@ -39,10 +51,10 @@ func NewReader(u *url.URL) (reader *Reader, e error) { case "file": fallthrough default: - reader.handle, e = NewFile(u) + reader.handle, e = NewFileReader(u) reader.exists = func() bool { return FileExists(u) } } - reader.SetStream(reader.handle.Reader()) + reader.SetStream(reader.handle.(HandlerReader).Reader()) return } @@ -62,18 +74,23 @@ type Writer struct { } func NewWriter(u *url.URL) (writer *Writer, e error) { - ctx := context.Background() + return NewWriterWithContext(u, context.Background()) +} + +func NewWriterWithContext(u *url.URL, ctx context.Context) (writer *Writer, e error) { writer = &Writer{ uri: u } + switch u.Scheme { case "http", "https": writer.handle, e = NewHTTP(u, ctx) case "file": fallthrough default: - writer.handle, e = NewFile(u) + writer.handle, e = NewFileWriter(u) writer.exists = func() bool { return FileExists(u) } } - writer.SetStream(writer.handle.Writer()) + + writer.SetStream(writer.handle.(HandlerWriter).Writer()) return writer, e } @@ -117,10 +134,18 @@ func (r *Reader) ContentType() string { return r.handle.ContentType() } +func (r *Reader) SetGzip(value bool) { + r.handle.SetGzip(value) +} + func (r *Reader) Gzip() bool { return r.handle.Gzip() } +func (r *Reader) Stat() (info fs.FileInfo, err error) { + return r.handle.Stat() +} + func (r *Reader) Signature() string { return r.handle.Signature() } @@ -129,6 +154,17 @@ func (r *Reader) SetStream(s io.ReadCloser) { r.stream = s } +func (r *Reader) AddHeader(name string, value string) { + r.handle.(*HTTP).GetRequest().Header.Add(name, value) +} + +func (r *Reader) Status() string { + return r.handle.(*HTTP).GetResponse().Status +} + +func (r *Reader) StatusCode() int { + return r.handle.(*HTTP).GetResponse().StatusCode +} func (w *Writer) Exists() bool { return w.exists() } @@ -136,6 +172,14 @@ func (w *Writer) Write(b []byte) (int, error) { return w.stream.Write(b) } +func (w *Writer) ReadFrom(r io.Reader) (n int64, e error) { + if v, ok := w.stream.(io.ReaderFrom); ok { + return v.ReadFrom(r) + } else { + panic("io.ReaderFrom interface not supported by writer") + } +} + func (w *Writer) Close() error { return w.stream.Close() } @@ -144,6 +188,10 @@ func (w *Writer) ContentType() string { return w.handle.ContentType() } +func (w *Writer) SetGzip(value bool) { + w.handle.SetGzip(value) +} + func (w *Writer) Gzip() bool { return w.handle.Gzip() } @@ -155,3 +203,15 @@ func (w *Writer) Signature() string { func (w *Writer) SetStream(s io.WriteCloser) { w.stream = s } + +func (w *Writer) AddHeader(name string, value string) { + w.handle.(*HTTP).PostRequest().Header.Add(name, value) +} + +func (w *Writer) Status() string { + return w.handle.(*HTTP).PostResponse().Status +} + +func (w *Writer) StatusCode() int { + return w.handle.(*HTTP).PostResponse().StatusCode +} diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index 6aa7477..d24a477 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -62,7 +62,7 @@ func TestTransportReaderContentType(t *testing.T) { assert.True(t, reader.Exists()) assert.NotNil(t, reader) - assert.Equal(t, reader.ContentType(), "yaml") + assert.Equal(t, "jx.yaml", reader.ContentType()) } func TestTransportReaderDir(t *testing.T) {