Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add Add Elements method and update for workspace provider #78

Merged
merged 2 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 57 additions & 23 deletions datasets.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ type Dataset struct {
}

type datasetRequest struct {
Input string `json:"input"`
Workspace string `json:"workspace"`
DatasetToolRepo string `json:"datasetToolRepo"`
Input string `json:"input"`
WorkspaceID string `json:"workspaceID"`
DatasetToolRepo string `json:"datasetToolRepo"`
Env []string `json:"env"`
}

type createDatasetArgs struct {
Expand All @@ -47,6 +48,11 @@ type addDatasetElementArgs struct {
ElementContent string `json:"elementContent"`
}

type addDatasetElementsArgs struct {
DatasetID string `json:"datasetID"`
Elements []DatasetElement `json:"elements"`
}

type listDatasetElementArgs struct {
DatasetID string `json:"datasetID"`
}
Expand All @@ -56,15 +62,16 @@ type getDatasetElementArgs struct {
Element string `json:"element"`
}

func (g *GPTScript) ListDatasets(ctx context.Context, workspace string) ([]DatasetMeta, error) {
if workspace == "" {
workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR")
func (g *GPTScript) ListDatasets(ctx context.Context, workspaceID string) ([]DatasetMeta, error) {
if workspaceID == "" {
workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID")
}

out, err := g.runBasicCommand(ctx, "datasets", datasetRequest{
Input: "{}",
Workspace: workspace,
WorkspaceID: workspaceID,
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
Env: g.globalOpts.Env,
})
if err != nil {
return nil, err
Expand All @@ -77,9 +84,9 @@ func (g *GPTScript) ListDatasets(ctx context.Context, workspace string) ([]Datas
return datasets, nil
}

func (g *GPTScript) CreateDataset(ctx context.Context, workspace, name, description string) (Dataset, error) {
if workspace == "" {
workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR")
func (g *GPTScript) CreateDataset(ctx context.Context, workspaceID, name, description string) (Dataset, error) {
if workspaceID == "" {
workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID")
}

args := createDatasetArgs{
Expand All @@ -93,8 +100,9 @@ func (g *GPTScript) CreateDataset(ctx context.Context, workspace, name, descript

out, err := g.runBasicCommand(ctx, "datasets/create", datasetRequest{
Input: string(argsJSON),
Workspace: workspace,
WorkspaceID: workspaceID,
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
Env: g.globalOpts.Env,
})
if err != nil {
return Dataset{}, err
Expand All @@ -107,9 +115,9 @@ func (g *GPTScript) CreateDataset(ctx context.Context, workspace, name, descript
return dataset, nil
}

func (g *GPTScript) AddDatasetElement(ctx context.Context, workspace, datasetID, elementName, elementDescription, elementContent string) (DatasetElementMeta, error) {
if workspace == "" {
workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR")
func (g *GPTScript) AddDatasetElement(ctx context.Context, workspaceID, datasetID, elementName, elementDescription, elementContent string) (DatasetElementMeta, error) {
if workspaceID == "" {
workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID")
}

args := addDatasetElementArgs{
Expand All @@ -125,8 +133,9 @@ func (g *GPTScript) AddDatasetElement(ctx context.Context, workspace, datasetID,

out, err := g.runBasicCommand(ctx, "datasets/add-element", datasetRequest{
Input: string(argsJSON),
Workspace: workspace,
WorkspaceID: workspaceID,
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
Env: g.globalOpts.Env,
})
if err != nil {
return DatasetElementMeta{}, err
Expand All @@ -139,9 +148,32 @@ func (g *GPTScript) AddDatasetElement(ctx context.Context, workspace, datasetID,
return element, nil
}

func (g *GPTScript) ListDatasetElements(ctx context.Context, workspace, datasetID string) ([]DatasetElementMeta, error) {
if workspace == "" {
workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR")
func (g *GPTScript) AddDatasetElements(ctx context.Context, workspaceID, datasetID string, elements []DatasetElement) error {
if workspaceID == "" {
workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID")
}

args := addDatasetElementsArgs{
DatasetID: datasetID,
Elements: elements,
}
argsJSON, err := json.Marshal(args)
if err != nil {
return fmt.Errorf("failed to marshal element args: %w", err)
}

_, err = g.runBasicCommand(ctx, "datasets/add-elements", datasetRequest{
Input: string(argsJSON),
WorkspaceID: workspaceID,
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
Env: g.globalOpts.Env,
})
return err
}

func (g *GPTScript) ListDatasetElements(ctx context.Context, workspaceID, datasetID string) ([]DatasetElementMeta, error) {
if workspaceID == "" {
workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID")
}

args := listDatasetElementArgs{
Expand All @@ -154,8 +186,9 @@ func (g *GPTScript) ListDatasetElements(ctx context.Context, workspace, datasetI

out, err := g.runBasicCommand(ctx, "datasets/list-elements", datasetRequest{
Input: string(argsJSON),
Workspace: workspace,
WorkspaceID: workspaceID,
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
Env: g.globalOpts.Env,
})
if err != nil {
return nil, err
Expand All @@ -168,9 +201,9 @@ func (g *GPTScript) ListDatasetElements(ctx context.Context, workspace, datasetI
return elements, nil
}

func (g *GPTScript) GetDatasetElement(ctx context.Context, workspace, datasetID, elementName string) (DatasetElement, error) {
if workspace == "" {
workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR")
func (g *GPTScript) GetDatasetElement(ctx context.Context, workspaceID, datasetID, elementName string) (DatasetElement, error) {
if workspaceID == "" {
workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID")
}

args := getDatasetElementArgs{
Expand All @@ -184,8 +217,9 @@ func (g *GPTScript) GetDatasetElement(ctx context.Context, workspace, datasetID,

out, err := g.runBasicCommand(ctx, "datasets/get-element", datasetRequest{
Input: string(argsJSON),
Workspace: workspace,
WorkspaceID: workspaceID,
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
Env: g.globalOpts.Env,
})
if err != nil {
return DatasetElement{}, err
Expand Down
48 changes: 36 additions & 12 deletions datasets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,72 @@ package gptscript

import (
"context"
"os"
"testing"

"github.com/stretchr/testify/require"
)

func TestDatasets(t *testing.T) {
workspace, err := os.MkdirTemp("", "go-gptscript-test")
workspaceID, err := g.CreateWorkspace(context.Background(), "directory")
require.NoError(t, err)

defer func() {
_ = os.RemoveAll(workspace)
_ = g.DeleteWorkspace(context.Background(), DeleteWorkspaceOptions{WorkspaceID: workspaceID})
}()

// Create a dataset
dataset, err := g.CreateDataset(context.Background(), workspace, "test-dataset", "This is a test dataset")
dataset, err := g.CreateDataset(context.Background(), workspaceID, "test-dataset", "This is a test dataset")
require.NoError(t, err)
require.Equal(t, "test-dataset", dataset.Name)
require.Equal(t, "This is a test dataset", dataset.Description)
require.Equal(t, 0, len(dataset.Elements))

// Add an element
elementMeta, err := g.AddDatasetElement(context.Background(), workspace, dataset.ID, "test-element", "This is a test element", "This is the content")
elementMeta, err := g.AddDatasetElement(context.Background(), workspaceID, dataset.ID, "test-element", "This is a test element", "This is the content")
require.NoError(t, err)
require.Equal(t, "test-element", elementMeta.Name)
require.Equal(t, "This is a test element", elementMeta.Description)

// Get the element
element, err := g.GetDatasetElement(context.Background(), workspace, dataset.ID, "test-element")
// Add two more
err = g.AddDatasetElements(context.Background(), workspaceID, dataset.ID, []DatasetElement{
{
DatasetElementMeta: DatasetElementMeta{
Name: "test-element-2",
Description: "This is a test element 2",
},
Contents: "This is the content 2",
},
{
DatasetElementMeta: DatasetElementMeta{
Name: "test-element-3",
Description: "This is a test element 3",
},
Contents: "This is the content 3",
},
})
require.NoError(t, err)

// Get the first element
element, err := g.GetDatasetElement(context.Background(), workspaceID, dataset.ID, "test-element")
require.NoError(t, err)
require.Equal(t, "test-element", element.Name)
require.Equal(t, "This is a test element", element.Description)
require.Equal(t, "This is the content", element.Contents)

// Get the third element
element, err = g.GetDatasetElement(context.Background(), workspaceID, dataset.ID, "test-element-3")
require.NoError(t, err)
require.Equal(t, "test-element-3", element.Name)
require.Equal(t, "This is a test element 3", element.Description)
require.Equal(t, "This is the content 3", element.Contents)

// List elements in the dataset
elements, err := g.ListDatasetElements(context.Background(), workspace, dataset.ID)
elements, err := g.ListDatasetElements(context.Background(), workspaceID, dataset.ID)
require.NoError(t, err)
require.Equal(t, 1, len(elements))
require.Equal(t, "test-element", elements[0].Name)
require.Equal(t, "This is a test element", elements[0].Description)
require.Equal(t, 3, len(elements))

// List datasets
datasets, err := g.ListDatasets(context.Background(), workspace)
datasets, err := g.ListDatasets(context.Background(), workspaceID)
require.NoError(t, err)
require.Equal(t, 1, len(datasets))
require.Equal(t, "test-dataset", datasets[0].Name)
Expand Down