From 6a49d2bfe551abee5ca51c08bd8dbe035f04b02a Mon Sep 17 00:00:00 2001 From: Joe Schafer Date: Mon, 19 Apr 2021 15:06:05 -0700 Subject: [PATCH] Handle type resolution of composite type that has child array Introduce a placeholderType that we resolve on a second pass. Necessary because we resolve types sequentially by kind. For example, we resolve all composite types before resolving array types. This approach requires two passes for cases like when a composite type has an child type that's an array. --- internal/pg/query.sql | 1 + internal/pg/query.sql.go | 1 + internal/pg/type_fetcher.go | 77 ++++++++++++++++++++++++-------- internal/pg/type_fetcher_test.go | 34 ++++++++++++++ internal/pg/types.go | 31 +++++++++---- 5 files changed, 117 insertions(+), 27 deletions(-) diff --git a/internal/pg/query.sql b/internal/pg/query.sql index 11d34743..4bfb855a 100644 --- a/internal/pg/query.sql +++ b/internal/pg/query.sql @@ -112,6 +112,7 @@ WHERE typ.oid = ANY (pggen.arg('oids')::oid[]) -- types. -- name: FindDescendantOIDs :many WITH RECURSIVE oid_descs(oid) AS ( + -- Base case. SELECT oid FROM unnest(pggen.arg('oids')::oid[]) AS t(oid) UNION diff --git a/internal/pg/query.sql.go b/internal/pg/query.sql.go index e176e7b4..1684b343 100644 --- a/internal/pg/query.sql.go +++ b/internal/pg/query.sql.go @@ -420,6 +420,7 @@ func (q *DBQuerier) FindCompositeTypesScan(results pgx.BatchResults) ([]FindComp } const findDescendantOIDsSQL = `WITH RECURSIVE oid_descs(oid) AS ( + -- Base case. SELECT oid FROM unnest($1::oid[]) AS t(oid) UNION diff --git a/internal/pg/type_fetcher.go b/internal/pg/type_fetcher.go index 21035a9d..35c62db1 100644 --- a/internal/pg/type_fetcher.go +++ b/internal/pg/type_fetcher.go @@ -85,6 +85,11 @@ func (tf *TypeFetcher) FindTypesByOIDs(oids ...uint32) (map[pgtype.OID]Type, err delete(uncached, unk.ID) } + // Resolve all placeholder types now that we know all types. + if err := tf.resolvePlaceholderTypes(types); err != nil { + return nil, err + } + if len(uncached) > 0 { return nil, fmt.Errorf("had %d unclassified types: %v", len(uncached), uncached) } @@ -128,34 +133,24 @@ func (tf *TypeFetcher) findCompositeTypes(ctx context.Context, uncached map[pgty types := make([]CompositeType, 0, len(rows)) idx := -1 -outer: for len(types) < len(rows) { idx = (idx + 1) % len(rows) row := rows[idx] - // Check if we can resolve all columns for the composite type. - for i, colOID := range row.ColOIDs { - if _, isInCache := tf.cache.getOID(uint32(colOID)); !isInCache { - if _, isInComposite := allComposites[pgtype.OID(colOID)]; !isInComposite { - // We won't ever be able resolve this composite type. - return nil, fmt.Errorf("find type for composite column %s oid=%d", row.ColNames[i], row.ColOIDs[i]) - } - // We'll be able to resolve this after one of the for loop iteration - // adds another composite to the cache. - continue outer - } - } - colTypes := make([]Type, len(row.ColOIDs)) colNames := make([]string, len(row.ColOIDs)) // Build each column of the composite type. for i, colOID := range row.ColOIDs { - colType, ok := tf.cache.getOID(uint32(colOID)) - if !ok { - return nil, fmt.Errorf("find type for composite column %s oid=%d", row.ColNames[i], row.ColOIDs[i]) + if colType, ok := tf.cache.getOID(uint32(colOID)); ok { + colTypes[i] = colType + colNames[i] = row.ColNames[i] + } else { + // We might resolve this type in a future pass like findArrayTypes. At + // the end, we'll attempt to to replace the placeholder with the + // resolved type. + colTypes[i] = placeholderType{ID: pgtype.OID(colOID)} + colNames[i] = row.ColNames[i] } - colTypes[i] = colType - colNames[i] = row.ColNames[i] } typ := CompositeType{ ID: row.TableTypeOID, @@ -207,6 +202,50 @@ func (tf *TypeFetcher) findArrayTypes(ctx context.Context, uncached map[pgtype.O return types, nil } +// resolvePlaceholderTypes resolves all placeholder types or errors if we can't +// resolve a placeholderType using all known types. +func (tf *TypeFetcher) resolvePlaceholderTypes(knownTypes map[pgtype.OID]Type) error { + // resolveType walks down type, replacing placeholderType with a known type. + var resolveType func(typ Type) (Type, error) + resolveType = func(typ Type) (Type, error) { + switch typ := typ.(type) { + case CompositeType: + for i, colType := range typ.ColumnTypes { + newType, err := resolveType(colType) + if err != nil { + return nil, fmt.Errorf("composite child '%s.%s': %w", typ.Name, colType.String(), err) + } + typ.ColumnTypes[i] = newType + } + return typ, nil + case ArrayType: + newType, err := resolveType(typ.ElemType) + if err != nil { + return nil, fmt.Errorf("array %q elem: %w", typ.Name, err) + } + typ.ElemType = newType + return typ, nil + case placeholderType: + newType, ok := knownTypes[typ.ID] + if !ok { + return nil, fmt.Errorf("unresolved placeholder type oid=%d", typ.ID) + } + return newType, nil + default: + return typ, nil + } + } + + for oid, typ := range knownTypes { + newType, err := resolveType(typ) + if err != nil { + return fmt.Errorf("resolve placeholder type: %w", err) + } + knownTypes[oid] = newType + } + return nil +} + func oidKeys(os map[pgtype.OID]struct{}) []uint32 { oids := make([]uint32, 0, len(os)) for oid := range os { diff --git a/internal/pg/type_fetcher_test.go b/internal/pg/type_fetcher_test.go index 920dcbca..3818aa1a 100644 --- a/internal/pg/type_fetcher_test.go +++ b/internal/pg/type_fetcher_test.go @@ -13,6 +13,15 @@ import ( ) func TestNewTypeFetcher(t *testing.T) { + productImageType := CompositeType{ + Name: "product_image_type", + ColumnNames: []string{"pixel_width", "pixel_height"}, + ColumnTypes: []Type{Int4, Int4}, + } + productImageArrayType := ArrayType{ + Name: "_product_image_type", + ElemType: productImageType, + } tests := []struct { name string schema string @@ -117,6 +126,30 @@ func TestNewTypeFetcher(t *testing.T) { Text, }, }, + { + name: "composite type - depth 2 array", + fetchOID: "product_image_set_type", + wants: []Type{ + Int4, + CompositeType{ + Name: "product_image_set_type", + ColumnNames: []string{"name", "images"}, + ColumnTypes: []Type{Text, productImageArrayType}}, + productImageType, + productImageArrayType, + Text, + }, + schema: texts.Dedent(` + CREATE TYPE product_image_type AS ( + pixel_width int4, + pixel_height int4 + ); + CREATE TYPE product_image_set_type AS ( + name text, + images product_image_type[] + ); + `), + }, { name: "custom base type", schema: texts.Dedent(` @@ -203,6 +236,7 @@ func TestNewTypeFetcher(t *testing.T) { opts := cmp.Options{ cmpopts.IgnoreFields(EnumType{}, "ChildOIDs", "ID"), cmpopts.IgnoreFields(CompositeType{}, "ID"), + cmpopts.IgnoreFields(ArrayType{}, "ID"), } sortTypes(wantTypes) sortTypes(gotTypes) diff --git a/internal/pg/types.go b/internal/pg/types.go index deb07e6e..61af8160 100644 --- a/internal/pg/types.go +++ b/internal/pg/types.go @@ -3,6 +3,7 @@ package pg import ( "github.com/jackc/pgtype" "github.com/jschaf/pggen/internal/pg/pgoid" + "strconv" ) // Type is a Postgres type. @@ -12,16 +13,17 @@ type Type interface { Kind() TypeKind } -// TypeKinds is the pg_type.typtype column, describing the meta type of Type. +// TypeKind is the pg_type.typtype column, describing the meta type of Type. type TypeKind byte const ( - KindBaseType TypeKind = 'b' // includes array types - KindCompositeType TypeKind = 'c' - KindDomainType TypeKind = 'd' - KindEnumType TypeKind = 'e' - KindPseudoType TypeKind = 'p' - KindRangeType TypeKind = 'r' + KindBaseType TypeKind = 'b' // includes array types + KindCompositeType TypeKind = 'c' + KindDomainType TypeKind = 'd' + KindEnumType TypeKind = 'e' + KindPseudoType TypeKind = 'p' + KindRangeType TypeKind = 'r' + kindPlaceholderType TypeKind = '?' // pggen only, not part of postgres ) func (k TypeKind) String() string { @@ -51,7 +53,7 @@ type ( Name string // pg_type.typname: data type name } - // Void type is an empty type. A void type doesn't appear in output but it's + // VoidType is an empty type. A void type doesn't appear in output but it's // necessary to scan rows. VoidType struct{} @@ -111,6 +113,15 @@ type ( Name string // pg_type.typname: data type name PgKind TypeKind } + + // placeholderType is an internal, temporary type that we resolve in a second + // pass. Useful because we resolve types sequentially by kind. For example, we + // resolve all composite types before resolving array types. This approach + // requires two passes for cases like when a composite type has an child type + // that's an array. + placeholderType struct { + ID pgtype.OID // pg_type.oid: row identifier + } ) func (b BaseType) OID() pgtype.OID { return b.ID } @@ -140,3 +151,7 @@ func (e CompositeType) Kind() TypeKind { return KindCompositeType } func (e UnknownType) OID() pgtype.OID { return e.ID } func (e UnknownType) String() string { return e.Name } func (e UnknownType) Kind() TypeKind { return e.PgKind } + +func (p placeholderType) OID() pgtype.OID { return p.ID } +func (p placeholderType) String() string { return "placeholder-" + strconv.Itoa(int(p.ID)) } +func (p placeholderType) Kind() TypeKind { return kindPlaceholderType }