-
Notifications
You must be signed in to change notification settings - Fork 12
/
spatial_grid.py
131 lines (115 loc) · 5.23 KB
/
spatial_grid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
from eureka_ml_insights.configs.experiment_config import ExperimentConfig
from eureka_ml_insights.core import EvalReporting, Inference, PromptProcessing
from eureka_ml_insights.data_utils import (
HFDataReader,
MMDataLoader,
ColumnRename,
DataLoader,
DataReader,
ExtractAnswerGrid,
PrependStringTransform,
SequenceTransform,
)
from eureka_ml_insights.metrics import CaseInsensitiveMatch, CountAggregator
from ..config import (
AggregatorConfig,
DataSetConfig,
EvalReportingConfig,
InferenceConfig,
MetricConfig,
ModelConfig,
PipelineConfig,
PromptProcessingConfig,
)
"""This file contains example user defined configuration classes for the grid counting task.
In order to define a new configuration, a new class must be created that directly or indirectly
inherits from UserDefinedConfig and the user_init method should be implemented.
You can inherit from one of the existing user defined classes below and override the necessary
attributes to reduce the amount of code you need to write.
The user defined configuration classes are used to define your desired *pipeline* that can include
any number of *component*s. Find *component* options in the core module.
Pass the name of the class to the main.py script to run the pipeline.
"""
class SPATIAL_GRID_PIPELINE(ExperimentConfig):
"""This method is used to define an eval pipeline with inference and metric report components,
on the grid counting dataset."""
def configure_pipeline(self, model_config: ModelConfig, resume_from: str = None) -> PipelineConfig:
# Configure the data processing component.
self.data_processing_comp = PromptProcessingConfig(
component_type=PromptProcessing,
data_reader_config=DataSetConfig(
HFDataReader,
{
"path": "microsoft/VISION_LANGUAGE",
"split": "val",
"tasks": "spatial_grid",
},
),
output_dir=os.path.join(self.log_dir, "data_processing_output"),
)
# Configure the inference component
self.inference_comp = InferenceConfig(
component_type=Inference,
model_config=model_config,
data_loader_config=DataSetConfig(
MMDataLoader,
{
"path": os.path.join(self.data_processing_comp.output_dir, "transformed_data.jsonl"),
},
),
output_dir=os.path.join(self.log_dir, "inference_result"),
resume_from=resume_from,
)
# Configure the evaluation and reporting component.
self.evalreporting_comp = EvalReportingConfig(
component_type=EvalReporting,
data_reader_config=DataSetConfig(
DataReader,
{
"path": os.path.join(self.inference_comp.output_dir, "inference_result.jsonl"),
"format": ".jsonl",
"transform": SequenceTransform(
[
ColumnRename(name_mapping={"model_output": "model_output_raw"}),
ExtractAnswerGrid(
answer_column_name="model_output_raw",
extracted_answer_column_name="model_output",
question_type_column_name="question_type",
mode="animal",
),
],
),
},
),
metric_config=MetricConfig(CaseInsensitiveMatch),
aggregator_configs=[
AggregatorConfig(CountAggregator, {"column_names": ["CaseInsensitiveMatch_result"], "normalize": True}),
AggregatorConfig(
CountAggregator,
{
"column_names": ["CaseInsensitiveMatch_result"],
"group_by": "task",
"normalize": True,
},
),
],
output_dir=os.path.join(self.log_dir, "eval_report"),
)
# Configure the pipeline
return PipelineConfig([self.data_processing_comp, self.inference_comp, self.evalreporting_comp], self.log_dir)
class SPATIAL_GRID_TEXTONLY_PIPELINE(SPATIAL_GRID_PIPELINE):
"""This class extends SPATIAL_GRID_PIPELINE to use text only data."""
def configure_pipeline(self, model_config: ModelConfig, resume_from: str = None) -> PipelineConfig:
config = super().configure_pipeline(model_config, resume_from)
self.data_processing_comp.data_reader_config.init_args["tasks"] = (
"spatial_grid_text_only"
)
return config
class SPATIAL_GRID_REPORTING_PIPELINE(SPATIAL_GRID_PIPELINE):
"""This method is used to define an eval pipeline with only a metric report component,
on the grid counting dataset."""
def configure_pipeline(self, model_config: ModelConfig, resume_from: str = None) -> PipelineConfig:
super().configure_pipeline(model_config, resume_from)
self.evalreporting_comp.data_reader_config.init_args["path"] = resume_from
return PipelineConfig([self.evalreporting_comp], self.log_dir)