Skip to content

Commit

Permalink
Merge pull request #14 from graugans/feature/swupdater
Browse files Browse the repository at this point in the history
Implementation of the SWUpdate feature
  • Loading branch information
graugans authored Jun 14, 2024
2 parents 9347fc1 + d45a2f2 commit 2547bb1
Show file tree
Hide file tree
Showing 4 changed files with 302 additions and 0 deletions.
102 changes: 102 additions & 0 deletions cmd/ovp8xx/cmd/swupdate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
Copyright © 2023 Christian Ege <[email protected]>
*/
package cmd

import (
"fmt"
"path/filepath"
"sync"
"time"

"github.com/graugans/go-ovp8xx/v2/pkg/swupdater"
"github.com/spf13/cobra"
)

func swupdateCommand(cmd *cobra.Command, args []string) error {
var err error
host, err := rootCmd.PersistentFlags().GetString("ip")
if err != nil {
return fmt.Errorf("cannot get host: %w", err)
}

port, err := cmd.Flags().GetUint16("port")
if err != nil {
return fmt.Errorf("cannot get port: %w", err)
}

// Check if filename is provided as a positional argument
if len(args) < 1 {
return fmt.Errorf("no filename provided")
}
filename := args[0]

timeout, err := cmd.Flags().GetDuration("timeout")
if err != nil {
return fmt.Errorf("cannot get timeout: %w", err)
}

connectionTimeout, err := cmd.Flags().GetDuration("online")
if err != nil {
return fmt.Errorf("cannot get timeout: %w", err)
}

fmt.Printf("Updating firmware on %s:%d with file %s (%v)\n",
host,
port,
filepath.Base(filename),
timeout,
)

// notifications is a channel used to receive SWUpdaterNotification events.
// It has a buffer size of 10 to allow for asynchronous processing.
notifications := make(chan swupdater.SWUpdaterNotification, 10)

var wg sync.WaitGroup
wg.Add(1)

// Print the messages as they come
go func() {
for n := range notifications {
if value, ok := n["swupdater"]; ok {
fmt.Println(value)
}
if value, ok := n["text"]; ok && n["type"] == "message" {
fmt.Println(value)
}

}
wg.Done() // Decrease counter when goroutine completes
}()

// Create a new SWUpdater instance with the specified host, port, and notifications.
swu := swupdater.NewSWUpdater(host, port, notifications)
if err = swu.Update(filename,
connectionTimeout,
timeout,
); err != nil {
return fmt.Errorf("software update failed: %w", err)
}

wg.Wait() // Wait for all goroutines to finish
return nil
}

// swupdateCmd represents the swupdate command
var swupdateCmd = &cobra.Command{
Use: "swupdate [filename]",
Short: "Update the firmware on the device",
Long: `The swupdate command is used to update the firmware on the device.
It takes a filename as a positional argument, which is the path to the firmware file to be uploaded.
The command establishes a connection to the device, uploads the firmware file, and waits for the update process to complete.`,
RunE: swupdateCommand,
}

func init() {
rootCmd.AddCommand(swupdateCmd)
swupdateCmd.Flags().Uint16("port", 8080, "Port number for SWUpdate")
swupdateCmd.Flags().Duration("timeout", 5*time.Minute, "The timeout for the upload")
swupdateCmd.Flags().Duration("online", 2*time.Minute, "The time to wait for the device to become available")
}
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@ go 1.21

require (
alexejk.io/go-xmlrpc v0.4.0
github.com/gorilla/websocket v1.5.1
github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.8.4
github.com/technoweenie/multipartstreamer v1.0.1
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
golang.org/x/net v0.17.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
6 changes: 6 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ alexejk.io/go-xmlrpc v0.4.0/go.mod h1:M7f2OzqvZIWrN1LftR4uwW/bLpxrFoQYjWfm4gQB4J
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY=
github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
Expand All @@ -14,6 +16,10 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/technoweenie/multipartstreamer v1.0.1 h1:XRztA5MXiR1TIRHxH2uNxXxaIkKQDeX7m2XsSOlQEnM=
github.com/technoweenie/multipartstreamer v1.0.1/go.mod h1:jNVxdtShOxzAsukZwTSw6MDx5eUJoiEBsSvzDU9uzog=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
Expand Down
191 changes: 191 additions & 0 deletions pkg/swupdater/swupdater.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
package swupdater

import (
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"strings"
"time"

"github.com/gorilla/websocket"
"github.com/technoweenie/multipartstreamer"
)

// SWUpdater represents a software updater.
type SWUpdater struct {
hostName string // The hostname of the updater.
port uint16 // The port number of the updater.
urlUpload string // The URL for uploading software updates.
urlStatus string // The URL for checking the status of software updates.
notifications chan SWUpdaterNotification // A channel for receiving notifications.
ws *websocket.Conn
}

type SWUpdaterNotification map[string]string

// NewSWUpdater creates a new instance of SWUpdater with the specified host name and port.
func NewSWUpdater(hostName string, port uint16, notifications chan SWUpdaterNotification) *SWUpdater {
return &SWUpdater{
hostName: hostName,
port: port,
urlUpload: fmt.Sprintf("http://%s:%d/upload", hostName, port),
urlStatus: fmt.Sprintf("ws://%s:%d/ws", hostName, port),
notifications: notifications,
}
}

// Upload performs the upload of the specified file.
// The filename parameter specifies the name of the file to be uploaded.
// Returns an error if the upload fails.
func (s *SWUpdater) upload(filename string) error {
s.statusUpdate(fmt.Sprintf("Uploading software image to %s\n", s.urlUpload))
const fieldname string = "file"

file, err := os.Open(filename)
if err != nil {
return fmt.Errorf("cannot open file: %w", err)
}
defer file.Close()

fileInfo, err := file.Stat()
if err != nil {
return fmt.Errorf("cannot get file info: %w", err)
}

ms := multipartstreamer.New()
err = ms.WriteReader(fieldname, filename, fileInfo.Size(), file)
if err != nil {
return fmt.Errorf("cannot write reader: %w", err)
}

req, _ := http.NewRequest("POST", s.urlUpload, nil)
ms.SetupRequest(req)

resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("cannot send request: %w", err)
}
defer resp.Body.Close()
return err
}

// waitForFinished waits for the SWUpdater process to finish by listening to a WebSocket connection.
// It continuously reads messages from the WebSocket and checks for specific conditions to determine
// if the SWUpdater process has completed successfully or has failed.
//
// Parameters:
// - done: A channel used to signal the completion of the SWUpdater process. If the process finishes
// successfully, nil is sent to the channel. If the process fails, an error is sent to the channel.
//
// Returns:
//
// None
//
// Example usage:
//
// done := make(chan error)
// go s.waitForFinished(done)
// err := <-done
// if err != nil {
// // Handle error
// } else {
// // SWUpdater process completed successfully
// }
func (s *SWUpdater) waitForFinished(done chan error) {

for {
_, message, err := s.ws.ReadMessage()
if err != nil {
done <- fmt.Errorf("cannot read message from websocket: %w", err)
return
}

data := make(SWUpdaterNotification)
err = json.Unmarshal(message, &data)
if err != nil {
done <- fmt.Errorf("cannot unmarshal message: %w", err)
return
}
// Send notification to channel
if s.notifications != nil {
s.notifications <- data
}
if data["type"] != "message" {
continue
}
if strings.Contains(data["text"], "SWUPDATE successful") {
done <- nil
return
}
if strings.Contains(data["text"], "Installation failed") {
done <- errors.New("installation failed")
return
}
}
}

func (s *SWUpdater) connect() error {
var err error
s.ws, _, err = websocket.DefaultDialer.Dial(s.urlStatus, nil)
if err != nil {
return fmt.Errorf("unable to connect to the status socket: %w", err)
}
return err
}

func (s *SWUpdater) disconnect() {
s.ws.Close()
}

// statusUpdate updates the status of the SWUpdater.
// It sends a notification to the channel with the provided status.
func (s *SWUpdater) statusUpdate(status string) {
notification := make(SWUpdaterNotification)
notification["swupdater"] = status
// Send notification to channel
if s.notifications != nil {
s.notifications <- notification
}
}

// Update uploads a software image and waits for the update process to finish.
// It takes a filename string and a timeout duration as parameters.
// It returns an error if the upload fails, or if the operation times out.
func (s *SWUpdater) Update(filename string, connectionTimeout, timeout time.Duration) error {
done := make(chan error)
start := time.Now()
s.statusUpdate("Waiting for the Device to become ready...")
// Retry connection until successful or connectionTimeout occurs
for {
err := s.connect()
if err == nil {
s.statusUpdate("Device is ready now")
break
}
if time.Since(start) > connectionTimeout {
return fmt.Errorf("connection timeout: %w", err)
}
time.Sleep(3 * time.Second) // wait for a second before retrying
}
defer s.disconnect()

s.statusUpdate("Starting the Software Update process...")
go s.waitForFinished(done)
err := s.upload(filename)
if err != nil {
return fmt.Errorf("cannot upload software image: %w", err)
}

select {
case err := <-done:
close(s.notifications) // Close the channel to signal the end of notifications
if err != nil {
return fmt.Errorf("update failed: %w", err)
}
return nil
case <-time.After(timeout):
return errors.New("a timeout occurred while waiting for the update to finish")
}
}

0 comments on commit 2547bb1

Please sign in to comment.