Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: chart generation badcases #80

Merged
merged 1 commit into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions packages/vmind/__tests__/performance/performanceTest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import {

const TEST_GPT = false;
const TEST_SKYLARK = true;
const ShowThoughts = false;
const EnableDataQuery = true;

const demoDataList: { [key: string]: any } = {
pie: mockUserInput2,
Expand Down Expand Up @@ -61,7 +63,7 @@ const modelResultMap = {
[Model.SKYLARK2]: { totalCount: 0, successCount: 0, totalTime: 0 }
};

const testPerformance = (model: Model, vmind: any) => {
const testPerformance = (model: Model, vmind: VMind) => {
dataList.some((dataName, index) => {
if (index >= START_INDEX) {
it(dataName, async done => {
Expand All @@ -70,9 +72,11 @@ const testPerformance = (model: Model, vmind: any) => {
const { fieldInfo, dataset } = vmind.parseCSVData(csv);
//const { fieldInfo, dataset } = await vmind.parseCSVDataWithLLM(csv, describe);
const startTime = new Date().getTime();
const { spec, time, chartSource } = await vmind.generateChart(input, fieldInfo, dataset);
const { spec, time, chartSource, chartType } = await vmind.generateChart(input, fieldInfo, dataset, {
enableDataQuery: EnableDataQuery
});
const endTime = new Date().getTime();
log('generated chart type: ' + spec.type);
log('generated chart type: ' + chartType);
if (chartSource !== 'chartAdvisor') {
const costTime = endTime - startTime;
log('time cost: ' + costTime / 1000 + 's');
Expand All @@ -97,7 +101,7 @@ if (gptKey && gptURL && TEST_GPT) {
url: gptURL,
model: Model.GPT3_5,
cache: true,
showThoughts: false,
showThoughts: ShowThoughts,
headers: {
'api-key': gptKey
}
Expand Down
3 changes: 2 additions & 1 deletion packages/vmind/jest.performance.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ module.exports = {
testTimeout: 60000,
moduleNameMapper: {
axios: 'axios/dist/node/axios.cjs',
'd3-hierarchy': 'd3-hierarchy/dist/d3-hierarchy.min.js'
'd3-hierarchy': 'd3-hierarchy/dist/d3-hierarchy.min.js',
'^src/(.*)$': '<rootDir>/src/$1',
},
verbose: true,
// 在测试之前设置环境变量
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import type { Transformer } from 'src/base/tools/transformer';
import type { ChartAdvisorContext, ChartAdvisorOutput } from './types';
import type { Cell } from '../../types';
import { isValidDataset } from 'src/common/dataProcess';

import { ChartType as VMindChartType } from 'src/common/typings';
/**
* call @visactor/chart-advisor to get the list of advised charts
* sorted by scores of each chart type
Expand Down Expand Up @@ -91,7 +91,7 @@ const getTop1AdvisedChart: Transformer<getAdvisedListOutput, ChartAdvisorOutput>
// call rule-based method to get recommended chart type and fieldMap(cell)
if (advisedList.length === 0) {
return {
chartType: 'BAR CHART',
chartType: VMindChartType.BarChart.toUpperCase() as VMindChartType,
cell: {},
dataset: undefined,
chartSource,
Expand All @@ -100,7 +100,7 @@ const getTop1AdvisedChart: Transformer<getAdvisedListOutput, ChartAdvisorOutput>
}
const result = advisedList[0];
return {
chartType: result.chartType,
chartType: result.chartType as VMindChartType,
cell: getCell(result.cell),
dataset: result.dataset,
chartSource,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import type { Transformer } from 'src/base/tools/transformer';
import type { GenerateFieldMapContext, GenerateFieldMapOutput } from '../../types';
import { isArray, isString } from 'lodash';
import { matchFieldWithoutPunctuation } from './utils';
import { ChartType } from 'src/common/typings';
import { DataType, ROLE } from 'src/common/typings';
import { calculateTokenUsage, foldDatasetByYField } from 'src/common/utils/utils';
import { FOLD_NAME, FOLD_VALUE } from '@visactor/chart-advisor';
Expand Down Expand Up @@ -62,8 +63,8 @@ const patchColorField: Transformer<PatchContext, Partial<GenerateFieldMapOutput>
cellNew.color = undefined;
if (['BAR CHART', 'LINE CHART', 'DUAL AXIS CHART'].includes(chartTypeNew)) {
cellNew.y = [cellNew.y, color].flat();
if (chartTypeNew === 'DUAL AXIS CHART' && cellNew.y.length > 2) {
chartTypeNew = 'BAR CHART';
if (chartTypeNew === ChartType.DualAxisChart.toUpperCase() && cellNew.y.length > 2) {
chartTypeNew = ChartType.BarChart.toUpperCase() as ChartType;
}
}
}
Expand All @@ -77,7 +78,7 @@ const patchColorField: Transformer<PatchContext, Partial<GenerateFieldMapOutput>
const patchRadarChart: Transformer<PatchContext, Partial<GenerateFieldMapOutput>> = (context: PatchContext) => {
const { chartType, cell } = context;

if (chartType === 'RADAR CHART') {
if (chartType === ChartType.RadarChart.toUpperCase()) {
const cellNew = {
x: cell.angle,
y: cell.value,
Expand All @@ -94,7 +95,7 @@ const patchRadarChart: Transformer<PatchContext, Partial<GenerateFieldMapOutput>
const patchBoxPlot: Transformer<PatchContext, Partial<GenerateFieldMapOutput>> = (context: PatchContext) => {
const { chartType, cell } = context;

if (chartType === 'BOX PLOT') {
if (chartType === ChartType.BoxPlot.toUpperCase()) {
const { x, min, q1, median, q3, max } = cell as any;
const cellNew = {
x,
Expand All @@ -107,12 +108,16 @@ const patchBoxPlot: Transformer<PatchContext, Partial<GenerateFieldMapOutput>> =
return context;
};

const patchBarChart: Transformer<PatchContext, Partial<GenerateFieldMapOutput>> = (context: PatchContext) => {
const patchFoldField: Transformer<PatchContext, Partial<GenerateFieldMapOutput>> = (context: PatchContext) => {
const { chartType, cell, fieldInfo, dataset } = context;
const chartTypeNew = chartType;
const cellNew = { ...cell };
let datasetNew = dataset;
if (chartTypeNew === 'BAR CHART' || chartTypeNew === 'LINE CHART') {
if (
chartTypeNew === ChartType.BarChart.toUpperCase() ||
chartTypeNew === ChartType.LineChart.toUpperCase() ||
chartTypeNew === ChartType.RadarChart.toUpperCase()
) {
if (isValidDataset(datasetNew) && isArray(cellNew.y) && cellNew.y.length > 1) {
datasetNew = foldDatasetByYField(datasetNew, cellNew.y, fieldInfo);
cellNew.y = FOLD_VALUE.toString();
Expand All @@ -131,7 +136,7 @@ const patchDualAxisChart: Transformer<PatchContext, Partial<GenerateFieldMapOutp
const cellNew: any = { ...cell };
//Dual-axis drawing yLeft and yRight

if (chartType === 'DUAL AXIS CHART') {
if (chartType === ChartType.DualAxisChart.toUpperCase()) {
cellNew.y = [
...(isArray(cellNew.y) ? cellNew.y : []),
cellNew.leftAxis,
Expand All @@ -151,7 +156,7 @@ const patchDynamicBarChart: Transformer<PatchContext, Partial<GenerateFieldMapOu
};
let chartTypeNew = chartType;

if (chartType === 'DYNAMIC BAR CHART') {
if (chartType === ChartType.DynamicBarChart.toUpperCase()) {
if (!cellNew.time || cellNew.time === '' || cellNew.time.length === 0) {
const flattenedXField = Array.isArray(cellNew.x) ? cellNew.x : [cellNew.x];
const usedFields = Object.values(cellNew).filter(f => !Array.isArray(f));
Expand All @@ -172,7 +177,7 @@ const patchDynamicBarChart: Transformer<PatchContext, Partial<GenerateFieldMapOu
cellNew.time = stringField.fieldName;
} else {
//no available field, set chart type to bar chart
chartTypeNew = 'BAR CHART';
chartTypeNew = ChartType.BarChart.toUpperCase() as ChartType;
}
}
}
Expand Down Expand Up @@ -217,7 +222,7 @@ export const patchPipelines: Transformer<PatchContext, Partial<GenerateFieldMapO
patchColorField,
patchRadarChart,
patchBoxPlot,
patchBarChart,
patchFoldField,
patchDualAxisChart,
patchDynamicBarChart,
patchArrayField
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { isArray, isNil } from 'lodash';

import type { Transformer } from 'src/base/tools/transformer';
import { foldDatasetByYField, getFieldByDataType, getFieldByRole, getRemainedFields } from 'src/common/utils/utils';
import type { ChartType } from 'src/common/typings';
import { ChartType } from 'src/common/typings';
import { DataType, ROLE } from 'src/common/typings';
import type { GenerateChartAndFieldMapContext, GenerateChartAndFieldMapOutput } from '../../types';
import { isValidDataset } from 'src/common/dataProcess';
Expand Down Expand Up @@ -99,9 +99,10 @@ export const patchYField: Transformer<
}

if (
chartTypeNew === ('BAR CHART' as ChartType) ||
chartTypeNew === ('LINE CHART' as ChartType) ||
chartTypeNew === ('DUAL AXIS CHART' as ChartType)
chartTypeNew === ChartType.BarChart.toUpperCase() ||
chartTypeNew === ChartType.LineChart.toUpperCase() ||
chartTypeNew === ChartType.DualAxisChart.toUpperCase() ||
chartTypeNew === ChartType.RadarChart.toUpperCase()
) {
//use fold to visualize more than 2 y fields
if (isValidDataset(datasetNew)) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import { generateRandomString } from './utils';

export const alasqlKeywordList = [
'ABSOLUTE',
'ACTION',
Expand Down Expand Up @@ -181,17 +179,3 @@ export const alasqlKeywordList = [
'WITH',
'WORK'
];

export const operatorList = [
['+', `_${generateRandomString(3)}_PLUS_${generateRandomString(3)}_`],
['-', `_${generateRandomString(3)}_DASH_${generateRandomString(3)}_`],
['*', `_${generateRandomString(3)}_ASTERISK_${generateRandomString(3)}_`],
['/', `_${generateRandomString(3)}_SLASH_${generateRandomString(3)}_`]
];

export const operators = operatorList.map(op => op[0]);

export const RESERVE_REPLACE_MAP = new Map<string, string>([
...operatorList,
...(alasqlKeywordList.map(keyword => [keyword, generateRandomString(10)]) as any)
]);
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ import { DataType, ROLE } from '../../../../common/typings';
import dayjs from 'dayjs';
import { uniqArray } from '@visactor/vutils';
import alasql from 'alasql';
import { RESERVE_REPLACE_MAP, operators } from './constants';

import { replaceAll } from 'src/common/utils/utils';
import { alasqlKeywordList } from './constants';

export const readTopNLine = (csvFile: string, n: number) => {
// get top n lines of a csv file
Expand Down Expand Up @@ -134,6 +135,20 @@ export function generateRandomString(len: number) {
return result;
}

const operatorList = [
['+', `_${generateRandomString(3)}_PLUS_${generateRandomString(3)}_`],
['-', `_${generateRandomString(3)}_DASH_${generateRandomString(3)}_`],
['*', `_${generateRandomString(3)}_ASTERISK_${generateRandomString(3)}_`],
['/', `_${generateRandomString(3)}_SLASH_${generateRandomString(3)}_`]
];

const operators = operatorList.map(op => op[0]);

const RESERVE_REPLACE_MAP = new Map<string, string>([
...operatorList,
...(alasqlKeywordList.map(keyword => [keyword, generateRandomString(10)]) as any)
]);

export const swapMap = (map: Map<string, string>) => {
//swap the map
const swappedMap = new Map();
Expand All @@ -150,7 +165,7 @@ export const swapMap = (map: Map<string, string>) => {
* @param str
* @returns
*/
export const replaceNonASCIICharacters = (str: string) => {
const replaceNonASCIICharacters = (str: string) => {
const nonAsciiCharMap = new Map();

const newStr = str.replace(/([^\x00-\x7F]+)/g, m => {
Expand Down
6 changes: 3 additions & 3 deletions packages/vmind/src/core/VMind.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ class VMind {
private _applicationMap: VMindApplicationMap;

constructor(options?: ILLMOptions) {
this._options = { ...(options ?? {}), showThoughts: options.showThoughts ?? true }; //apply default settings
this._model = options.model ?? Model.GPT3_5;
this._options = { ...(options ?? {}), showThoughts: options?.showThoughts ?? true }; //apply default settings
this._model = options?.model ?? Model.GPT3_5;
this.registerApplications();
}

Expand Down Expand Up @@ -133,7 +133,7 @@ class VMind {
let finalFieldInfo = fieldInfo;

let queryDatasetUsage;
const { enableDataQuery, colorPalette, animationDuration, chartTypeList } = options;
const { enableDataQuery, colorPalette, animationDuration, chartTypeList } = options ?? {};
try {
if (!isNil(dataset) && (isNil(enableDataQuery) || enableDataQuery) && modelType !== ModelType.CHART_ADVISOR) {
//run data aggregation first
Expand Down
Loading