Skip to content

Commit

Permalink
cmd: spinner progress for transfer model data (ollama#6100)
Browse files Browse the repository at this point in the history
  • Loading branch information
joshyan1 authored Aug 12, 2024
1 parent 980dd15 commit f7e3b91
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 7 deletions.
45 changes: 42 additions & 3 deletions cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"runtime"
"slices"
"strings"
"sync/atomic"
"syscall"
"time"

Expand Down Expand Up @@ -78,6 +79,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
status := "transferring model data"
spinner := progress.NewSpinner(status)
p.Add(status, spinner)
defer p.Stop()

for i := range modelfile.Commands {
switch modelfile.Commands[i].Name {
Expand Down Expand Up @@ -112,7 +114,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
path = tempfile
}

digest, err := createBlob(cmd, client, path)
digest, err := createBlob(cmd, client, path, spinner)
if err != nil {
return err
}
Expand Down Expand Up @@ -263,13 +265,20 @@ func tempZipFiles(path string) (string, error) {
return tempfile.Name(), nil
}

func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
func createBlob(cmd *cobra.Command, client *api.Client, path string, spinner *progress.Spinner) (string, error) {
bin, err := os.Open(path)
if err != nil {
return "", err
}
defer bin.Close()

// Get file info to retrieve the size
fileInfo, err := bin.Stat()
if err != nil {
return "", err
}
fileSize := fileInfo.Size()

hash := sha256.New()
if _, err := io.Copy(hash, bin); err != nil {
return "", err
Expand All @@ -279,13 +288,43 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
return "", err
}

var pw progressWriter
status := "transferring model data 0%"
spinner.SetMessage(status)

done := make(chan struct{})
defer close(done)

go func() {
ticker := time.NewTicker(60 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
spinner.SetMessage(fmt.Sprintf("transferring model data %d%%", int(100*pw.n.Load()/fileSize)))
case <-done:
spinner.SetMessage("transferring model data 100%")
return
}
}
}()

digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
return "", err
}
return digest, nil
}

type progressWriter struct {
n atomic.Int64
}

func (w *progressWriter) Write(p []byte) (n int, err error) {
w.n.Add(int64(len(p)))
return len(p), nil
}

func RunHandler(cmd *cobra.Command, args []string) error {
interactive := true

Expand Down
14 changes: 10 additions & 4 deletions progress/spinner.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ package progress
import (
"fmt"
"strings"
"sync/atomic"
"time"
)

type Spinner struct {
message string
message atomic.Value
messageWidth int

parts []string
Expand All @@ -21,20 +22,25 @@ type Spinner struct {

func NewSpinner(message string) *Spinner {
s := &Spinner{
message: message,
parts: []string{
"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏",
},
started: time.Now(),
}
s.SetMessage(message)
go s.start()
return s
}

func (s *Spinner) SetMessage(message string) {
s.message.Store(message)
}

func (s *Spinner) String() string {
var sb strings.Builder
if len(s.message) > 0 {
message := strings.TrimSpace(s.message)

if message, ok := s.message.Load().(string); ok && len(message) > 0 {
message := strings.TrimSpace(message)
if s.messageWidth > 0 && len(message) > s.messageWidth {
message = message[:s.messageWidth]
}
Expand Down

0 comments on commit f7e3b91

Please sign in to comment.