Skip to content

Commit

Permalink
chore: update for dataset rewrite (#83)
Browse files Browse the repository at this point in the history
Signed-off-by: Grant Linville <[email protected]>
  • Loading branch information
g-linville authored Nov 6, 2024
1 parent df4a8c9 commit ba040ce
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 155 deletions.
169 changes: 42 additions & 127 deletions datasets.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@ package gptscript

import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"os"
)

type DatasetElementMeta struct {
Expand All @@ -15,7 +13,8 @@ type DatasetElementMeta struct {

type DatasetElement struct {
DatasetElementMeta `json:",inline"`
Contents []byte `json:"contents"`
Contents string `json:"contents"`
BinaryContents []byte `json:"binaryContents"`
}

type DatasetMeta struct {
Expand All @@ -24,34 +23,17 @@ type DatasetMeta struct {
Description string `json:"description"`
}

type Dataset struct {
DatasetMeta `json:",inline"`
BaseDir string `json:"baseDir,omitempty"`
Elements map[string]DatasetElementMeta `json:"elements"`
}

type datasetRequest struct {
Input string `json:"input"`
WorkspaceID string `json:"workspaceID"`
DatasetToolRepo string `json:"datasetToolRepo"`
Env []string `json:"env"`
}

type createDatasetArgs struct {
Name string `json:"datasetName"`
Description string `json:"datasetDescription"`
}

type addDatasetElementArgs struct {
DatasetID string `json:"datasetID"`
ElementName string `json:"elementName"`
ElementDescription string `json:"elementDescription"`
ElementContent string `json:"elementContent"`
Input string `json:"input"`
DatasetTool string `json:"datasetTool"`
Env []string `json:"env"`
}

type addDatasetElementsArgs struct {
DatasetID string `json:"datasetID"`
Elements []DatasetElement `json:"elements"`
DatasetID string `json:"datasetID"`
Name string `json:"name"`
Description string `json:"description"`
Elements []DatasetElement `json:"elements"`
}

type listDatasetElementArgs struct {
Expand All @@ -60,19 +42,14 @@ type listDatasetElementArgs struct {

type getDatasetElementArgs struct {
DatasetID string `json:"datasetID"`
Element string `json:"element"`
Element string `json:"name"`
}

func (g *GPTScript) ListDatasets(ctx context.Context, workspaceID string) ([]DatasetMeta, error) {
if workspaceID == "" {
workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID")
}

func (g *GPTScript) ListDatasets(ctx context.Context) ([]DatasetMeta, error) {
out, err := g.runBasicCommand(ctx, "datasets", datasetRequest{
Input: "{}",
WorkspaceID: workspaceID,
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
Env: g.globalOpts.Env,
Input: "{}",
DatasetTool: g.globalOpts.DatasetTool,
Env: g.globalOpts.Env,
})
if err != nil {
return nil, err
Expand All @@ -85,98 +62,42 @@ func (g *GPTScript) ListDatasets(ctx context.Context, workspaceID string) ([]Dat
return datasets, nil
}

func (g *GPTScript) CreateDataset(ctx context.Context, workspaceID, name, description string) (Dataset, error) {
if workspaceID == "" {
workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID")
}

args := createDatasetArgs{
Name: name,
Description: description,
}
argsJSON, err := json.Marshal(args)
if err != nil {
return Dataset{}, fmt.Errorf("failed to marshal dataset args: %w", err)
}

out, err := g.runBasicCommand(ctx, "datasets/create", datasetRequest{
Input: string(argsJSON),
WorkspaceID: workspaceID,
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
Env: g.globalOpts.Env,
})
if err != nil {
return Dataset{}, err
}

var dataset Dataset
if err = json.Unmarshal([]byte(out), &dataset); err != nil {
return Dataset{}, err
}
return dataset, nil
type DatasetOptions struct {
Name, Description string
}

func (g *GPTScript) AddDatasetElement(ctx context.Context, workspaceID, datasetID, elementName, elementDescription string, elementContent []byte) (DatasetElementMeta, error) {
if workspaceID == "" {
workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID")
}

args := addDatasetElementArgs{
DatasetID: datasetID,
ElementName: elementName,
ElementDescription: elementDescription,
ElementContent: base64.StdEncoding.EncodeToString(elementContent),
}
argsJSON, err := json.Marshal(args)
if err != nil {
return DatasetElementMeta{}, fmt.Errorf("failed to marshal element args: %w", err)
}

out, err := g.runBasicCommand(ctx, "datasets/add-element", datasetRequest{
Input: string(argsJSON),
WorkspaceID: workspaceID,
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
Env: g.globalOpts.Env,
})
if err != nil {
return DatasetElementMeta{}, err
}

var element DatasetElementMeta
if err = json.Unmarshal([]byte(out), &element); err != nil {
return DatasetElementMeta{}, err
}
return element, nil
func (g *GPTScript) CreateDatasetWithElements(ctx context.Context, elements []DatasetElement, options ...DatasetOptions) (string, error) {
return g.AddDatasetElements(ctx, "", elements, options...)
}

func (g *GPTScript) AddDatasetElements(ctx context.Context, workspaceID, datasetID string, elements []DatasetElement) error {
if workspaceID == "" {
workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID")
}

func (g *GPTScript) AddDatasetElements(ctx context.Context, datasetID string, elements []DatasetElement, options ...DatasetOptions) (string, error) {
args := addDatasetElementsArgs{
DatasetID: datasetID,
Elements: elements,
}

for _, opt := range options {
if opt.Name != "" {
args.Name = opt.Name
}
if opt.Description != "" {
args.Description = opt.Description
}
}

argsJSON, err := json.Marshal(args)
if err != nil {
return fmt.Errorf("failed to marshal element args: %w", err)
return "", fmt.Errorf("failed to marshal element args: %w", err)
}

_, err = g.runBasicCommand(ctx, "datasets/add-elements", datasetRequest{
Input: string(argsJSON),
WorkspaceID: workspaceID,
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
Env: g.globalOpts.Env,
return g.runBasicCommand(ctx, "datasets/add-elements", datasetRequest{
Input: string(argsJSON),
DatasetTool: g.globalOpts.DatasetTool,
Env: g.globalOpts.Env,
})
return err
}

func (g *GPTScript) ListDatasetElements(ctx context.Context, workspaceID, datasetID string) ([]DatasetElementMeta, error) {
if workspaceID == "" {
workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID")
}

func (g *GPTScript) ListDatasetElements(ctx context.Context, datasetID string) ([]DatasetElementMeta, error) {
args := listDatasetElementArgs{
DatasetID: datasetID,
}
Expand All @@ -186,10 +107,9 @@ func (g *GPTScript) ListDatasetElements(ctx context.Context, workspaceID, datase
}

out, err := g.runBasicCommand(ctx, "datasets/list-elements", datasetRequest{
Input: string(argsJSON),
WorkspaceID: workspaceID,
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
Env: g.globalOpts.Env,
Input: string(argsJSON),
DatasetTool: g.globalOpts.DatasetTool,
Env: g.globalOpts.Env,
})
if err != nil {
return nil, err
Expand All @@ -202,11 +122,7 @@ func (g *GPTScript) ListDatasetElements(ctx context.Context, workspaceID, datase
return elements, nil
}

func (g *GPTScript) GetDatasetElement(ctx context.Context, workspaceID, datasetID, elementName string) (DatasetElement, error) {
if workspaceID == "" {
workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID")
}

func (g *GPTScript) GetDatasetElement(ctx context.Context, datasetID, elementName string) (DatasetElement, error) {
args := getDatasetElementArgs{
DatasetID: datasetID,
Element: elementName,
Expand All @@ -217,10 +133,9 @@ func (g *GPTScript) GetDatasetElement(ctx context.Context, workspaceID, datasetI
}

out, err := g.runBasicCommand(ctx, "datasets/get-element", datasetRequest{
Input: string(argsJSON),
WorkspaceID: workspaceID,
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
Env: g.globalOpts.Env,
Input: string(argsJSON),
DatasetTool: g.globalOpts.DatasetTool,
Env: g.globalOpts.Env,
})
if err != nil {
return DatasetElement{}, err
Expand Down
74 changes: 48 additions & 26 deletions datasets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package gptscript

import (
"context"
"os"
"testing"

"github.com/stretchr/testify/require"
Expand All @@ -11,66 +12,87 @@ func TestDatasets(t *testing.T) {
workspaceID, err := g.CreateWorkspace(context.Background(), "directory")
require.NoError(t, err)

client, err := NewGPTScript(GlobalOptions{
OpenAIAPIKey: os.Getenv("OPENAI_API_KEY"),
Env: append(os.Environ(), "GPTSCRIPT_WORKSPACE_ID="+workspaceID),
})
require.NoError(t, err)

defer func() {
_ = g.DeleteWorkspace(context.Background(), workspaceID)
}()

// Create a dataset
dataset, err := g.CreateDataset(context.Background(), workspaceID, "test-dataset", "This is a test dataset")
require.NoError(t, err)
require.Equal(t, "test-dataset", dataset.Name)
require.Equal(t, "This is a test dataset", dataset.Description)
require.Equal(t, 0, len(dataset.Elements))

// Add an element
elementMeta, err := g.AddDatasetElement(context.Background(), workspaceID, dataset.ID, "test-element", "This is a test element", []byte("This is the content"))
datasetID, err := client.CreateDatasetWithElements(context.Background(), []DatasetElement{
{
DatasetElementMeta: DatasetElementMeta{
Name: "test-element-1",
Description: "This is a test element 1",
},
Contents: "This is the content 1",
},
}, DatasetOptions{
Name: "test-dataset",
Description: "this is a test dataset",
})
require.NoError(t, err)
require.Equal(t, "test-element", elementMeta.Name)
require.Equal(t, "This is a test element", elementMeta.Description)

// Add two more
err = g.AddDatasetElements(context.Background(), workspaceID, dataset.ID, []DatasetElement{
// Add three more elements
_, err = client.AddDatasetElements(context.Background(), datasetID, []DatasetElement{
{
DatasetElementMeta: DatasetElementMeta{
Name: "test-element-2",
Description: "This is a test element 2",
},
Contents: []byte("This is the content 2"),
Contents: "This is the content 2",
},
{
DatasetElementMeta: DatasetElementMeta{
Name: "test-element-3",
Description: "This is a test element 3",
},
Contents: []byte("This is the content 3"),
Contents: "This is the content 3",
},
{
DatasetElementMeta: DatasetElementMeta{
Name: "binary-element",
Description: "this element has binary contents",
},
BinaryContents: []byte("binary contents"),
},
})
require.NoError(t, err)

// Get the first element
element, err := g.GetDatasetElement(context.Background(), workspaceID, dataset.ID, "test-element")
element, err := client.GetDatasetElement(context.Background(), datasetID, "test-element-1")
require.NoError(t, err)
require.Equal(t, "test-element", element.Name)
require.Equal(t, "This is a test element", element.Description)
require.Equal(t, []byte("This is the content"), element.Contents)
require.Equal(t, "test-element-1", element.Name)
require.Equal(t, "This is a test element 1", element.Description)
require.Equal(t, "This is the content 1", element.Contents)

// Get the third element
element, err = g.GetDatasetElement(context.Background(), workspaceID, dataset.ID, "test-element-3")
element, err = client.GetDatasetElement(context.Background(), datasetID, "test-element-3")
require.NoError(t, err)
require.Equal(t, "test-element-3", element.Name)
require.Equal(t, "This is a test element 3", element.Description)
require.Equal(t, []byte("This is the content 3"), element.Contents)
require.Equal(t, "This is the content 3", element.Contents)

// Get the binary element
element, err = client.GetDatasetElement(context.Background(), datasetID, "binary-element")
require.NoError(t, err)
require.Equal(t, "binary-element", element.Name)
require.Equal(t, "this element has binary contents", element.Description)
require.Equal(t, []byte("binary contents"), element.BinaryContents)

// List elements in the dataset
elements, err := g.ListDatasetElements(context.Background(), workspaceID, dataset.ID)
elements, err := client.ListDatasetElements(context.Background(), datasetID)
require.NoError(t, err)
require.Equal(t, 3, len(elements))
require.Equal(t, 4, len(elements))

// List datasets
datasets, err := g.ListDatasets(context.Background(), workspaceID)
datasets, err := client.ListDatasets(context.Background())
require.NoError(t, err)
require.Equal(t, 1, len(datasets))
require.Equal(t, datasetID, datasets[0].ID)
require.Equal(t, "test-dataset", datasets[0].Name)
require.Equal(t, "This is a test dataset", datasets[0].Description)
require.Equal(t, dataset.ID, datasets[0].ID)
require.Equal(t, "this is a test dataset", datasets[0].Description)
}
4 changes: 2 additions & 2 deletions opts.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ type GlobalOptions struct {
DefaultModelProvider string `json:"DefaultModelProvider"`
CacheDir string `json:"CacheDir"`
Env []string `json:"env"`
DatasetToolRepo string `json:"DatasetToolRepo"`
DatasetTool string `json:"DatasetTool"`
WorkspaceTool string `json:"WorkspaceTool"`
}

Expand Down Expand Up @@ -46,7 +46,7 @@ func completeGlobalOptions(opts ...GlobalOptions) GlobalOptions {
result.OpenAIBaseURL = firstSet(opt.OpenAIBaseURL, result.OpenAIBaseURL)
result.DefaultModel = firstSet(opt.DefaultModel, result.DefaultModel)
result.DefaultModelProvider = firstSet(opt.DefaultModelProvider, result.DefaultModelProvider)
result.DatasetToolRepo = firstSet(opt.DatasetToolRepo, result.DatasetToolRepo)
result.DatasetTool = firstSet(opt.DatasetTool, result.DatasetTool)
result.WorkspaceTool = firstSet(opt.WorkspaceTool, result.WorkspaceTool)
result.Env = append(result.Env, opt.Env...)
}
Expand Down

0 comments on commit ba040ce

Please sign in to comment.