From 18e47795e598973195887893e7d77baddec53ebb Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Mon, 25 Jan 2021 19:00:43 +0100 Subject: [PATCH] device: allow pipelining UAPI requests The original spec ends with \n\n especially for this reason. Signed-off-by: Jason A. Donenfeld --- device/uapi.go | 66 +++++++++++++++++++++++++++----------------------- 1 file changed, 36 insertions(+), 30 deletions(-) diff --git a/device/uapi.go b/device/uapi.go index c1ddb38..31fbdc7 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -380,9 +380,6 @@ func (device *Device) IpcSet(uapiConf string) error { } func (device *Device) IpcHandle(socket net.Conn) { - - // create buffered read/writer - defer socket.Close() buffered := func(s io.ReadWriter) *bufio.ReadWriter { @@ -391,34 +388,43 @@ func (device *Device) IpcHandle(socket net.Conn) { return bufio.NewReadWriter(reader, writer) }(socket) - defer buffered.Flush() + for { + op, err := buffered.ReadString('\n') + if err != nil { + return + } - op, err := buffered.ReadString('\n') - if err != nil { - return - } + // handle operation + switch op { + case "set=1\n": + err = device.IpcSetOperation(buffered.Reader) + case "get=1\n": + nextByte, err := buffered.ReadByte() + if err != nil { + return + } + if nextByte != '\n' { + err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %c", nextByte, err) + break + } + err = device.IpcGetOperation(buffered.Writer) + default: + device.log.Error.Println("invalid UAPI operation:", op) + return + } - // handle operation - switch op { - case "set=1\n": - err = device.IpcSetOperation(buffered.Reader) - case "get=1\n": - err = device.IpcGetOperation(buffered.Writer) - default: - device.log.Error.Println("invalid UAPI operation:", op) - return - } - - // write status - var status *IPCError - if err != nil && !errors.As(err, &status) { - // shouldn't happen - status = ipcErrorf(ipc.IpcErrorUnknown, "other UAPI error: %w", err) - } - if status != nil { - device.log.Error.Println(status) - fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode()) - } else { - fmt.Fprintf(buffered, "errno=0\n\n") + // write status + var status *IPCError + if err != nil && !errors.As(err, &status) { + // shouldn't happen + status = ipcErrorf(ipc.IpcErrorUnknown, "other UAPI error: %w", err) + } + if status != nil { + device.log.Error.Println(status) + fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode()) + } else { + fmt.Fprintf(buffered, "errno=0\n\n") + } + buffered.Flush() } }