diff --git a/internal/generator/output/general/model_writer.go b/internal/generator/output/general/model_writer.go index d347b67..5c3f746 100644 --- a/internal/generator/output/general/model_writer.go +++ b/internal/generator/output/general/model_writer.go @@ -33,7 +33,7 @@ type ModelWriter struct { basePath string continueGeneration bool - numberColumnsToDiscard int + columnsToDiscard map[string]struct{} partitionColumnsIndexes []int orderedColumnNames []string @@ -61,12 +61,12 @@ func newModelWriter( orderedColumnNames = append(orderedColumnNames, column.Name) } - numberColumnsToDiscard := 0 + columnsToDiscard := make(map[string]struct{}) partitionOrderedColumnNames := make([]string, 0, len(model.PartitionColumns)) for _, column := range model.PartitionColumns { if !column.WriteToOutput { - numberColumnsToDiscard++ + columnsToDiscard[column.Name] = struct{}{} } partitionOrderedColumnNames = append(partitionOrderedColumnNames, column.Name) @@ -97,7 +97,7 @@ func newModelWriter( config: config, basePath: basePath, continueGeneration: continueGeneration, - numberColumnsToDiscard: numberColumnsToDiscard, + columnsToDiscard: columnsToDiscard, partitionColumnsIndexes: partitionColumnsIndexes, orderedColumnNames: orderedColumnNames, checkpointTicker: ticker, @@ -192,7 +192,7 @@ func (w *ModelWriter) WriteRows(ctx context.Context, rows []*models.DataRow) err // discard not writeable columns sendRow := &models.DataRow{ - Values: row.Values[:len(row.Values)-w.numberColumnsToDiscard], + Values: row.Values[:len(row.Values)-len(w.columnsToDiscard)], } if err := dataWriter.WriteRow(sendRow); err != nil { @@ -237,19 +237,30 @@ func (w *ModelWriter) newWriter(ctx context.Context, outPath string) (writer.Wri var dataWriter writer.Writer switch w.config.Type { + case "devnull": + dataWriter = devnull.NewWriter( + w.model, + w.config.DevNullParams, + ) case "csv": dataWriter = csv.NewWriter( ctx, w.model, w.config.CSVParams, + w.columnsToDiscard, outPath, w.continueGeneration, w.writtenRowsChan, ) - case "devnull": - dataWriter = devnull.NewWriter( + case "parquet": + dataWriter = parquet.NewWriter( w.model, - w.config.DevNullParams, + w.config.ParquetParams, + w.columnsToDiscard, + parquet.NewFileSystem(), + outPath, + w.continueGeneration, + w.writtenRowsChan, ) case "http": dataWriter = http.NewWriter( @@ -265,15 +276,6 @@ func (w *ModelWriter) newWriter(ctx context.Context, outPath string) (writer.Wri w.config.TCSParams, w.writtenRowsChan, ) - case "parquet": - dataWriter = parquet.NewWriter( - w.model, - w.config.ParquetParams, - parquet.NewFileSystem(), - outPath, - w.continueGeneration, - w.writtenRowsChan, - ) default: return nil, errors.Errorf("unknown output type: %q", w.config.Type) } diff --git a/internal/generator/output/general/writer/csv/csv.go b/internal/generator/output/general/writer/csv/csv.go index 1f5e66b..066e300 100644 --- a/internal/generator/output/general/writer/csv/csv.go +++ b/internal/generator/output/general/writer/csv/csv.go @@ -33,6 +33,7 @@ type Writer struct { ctx context.Context //nolint:containedctx model *models.Model + columnsToDiscard map[string]struct{} config *models.CSVConfig outputPath string continueGeneration bool @@ -58,6 +59,7 @@ func NewWriter( ctx context.Context, model *models.Model, config *models.CSVConfig, + columnsToDiscard map[string]struct{}, outputPath string, continueGeneration bool, writtenRowsChan chan<- uint64, @@ -65,6 +67,7 @@ func NewWriter( return &Writer{ ctx: ctx, model: model, + columnsToDiscard: columnsToDiscard, config: config, outputPath: outputPath, continueGeneration: continueGeneration, @@ -363,12 +366,7 @@ func (w *Writer) replaceFile(fileName string) error { w.fileDescriptor = file if !w.config.WithoutHeaders && (!w.continueGeneration || !fileExists) { - header := make([]string, len(w.model.Columns)) - for i, column := range w.model.Columns { - header[i] = column.Name - } - - err = w.csvWriter.Write(header) + err = w.csvWriter.Write(w.getHeaders()) if err != nil { return errors.New(err.Error()) } @@ -377,6 +375,20 @@ func (w *Writer) replaceFile(fileName string) error { return nil } +func (w *Writer) getHeaders() []string { + headers := make([]string, 0, len(w.model.Columns)-len(w.columnsToDiscard)) + + for _, column := range w.model.Columns { + if _, exists := w.columnsToDiscard[column.Name]; exists { + continue + } + + headers = append(headers, column.Name) + } + + return headers +} + // WriteRow function sends row to internal queue. func (w *Writer) WriteRow(row *models.DataRow) error { select { diff --git a/internal/generator/output/general/writer/csv/csv_test.go b/internal/generator/output/general/writer/csv/csv_test.go index 8d75741..d866399 100644 --- a/internal/generator/output/general/writer/csv/csv_test.go +++ b/internal/generator/output/general/writer/csv/csv_test.go @@ -130,7 +130,15 @@ func TestWriteRow(t *testing.T) { csvConfig.WithoutHeaders = tc.withoutHeaders - csvWriter := NewWriter(context.Background(), tc.model, csvConfig, "./", false, nil) + csvWriter := NewWriter( + context.Background(), + tc.model, + csvConfig, + getColumnsToDiscard(tc.model.PartitionColumns), + "./", + false, + nil, + ) err := csvWriter.Init() require.NoError(t, err) @@ -307,7 +315,15 @@ func TestWriteToCorrectFiles(t *testing.T) { } write := func(from, to int, continueGeneration bool) { - writer := NewWriter(context.Background(), model, config, dir, continueGeneration, nil) + writer := NewWriter( + context.Background(), + model, + config, + getColumnsToDiscard(model.PartitionColumns), + dir, + continueGeneration, + nil, + ) require.NoError(t, writer.Init()) for i := from; i < to; i++ { @@ -370,3 +386,15 @@ func getFileNumber(rows, rowsPerFile int) int { return fileNumber } + +func getColumnsToDiscard(partitionColumns []*models.PartitionColumn) map[string]struct{} { + columnsToDiscard := make(map[string]struct{}) + + for _, column := range partitionColumns { + if !column.WriteToOutput { + columnsToDiscard[column.Name] = struct{}{} + } + } + + return columnsToDiscard +} diff --git a/internal/generator/output/general/writer/parquet/parquet.go b/internal/generator/output/general/writer/parquet/parquet.go index 9d91399..da0f362 100644 --- a/internal/generator/output/general/writer/parquet/parquet.go +++ b/internal/generator/output/general/writer/parquet/parquet.go @@ -58,6 +58,7 @@ var _ writer.Writer = (*Writer)(nil) // Writer type is implementation of Writer to parquet file. type Writer struct { model *models.Model + columnsToDiscard map[string]struct{} config *models.ParquetConfig outputPath string continueGeneration bool @@ -91,6 +92,7 @@ type FileSystem interface { func NewWriter( model *models.Model, config *models.ParquetConfig, + columnsToDiscard map[string]struct{}, fs FileSystem, outputPath string, continueGeneration bool, @@ -98,6 +100,7 @@ func NewWriter( ) *Writer { return &Writer{ model: model, + columnsToDiscard: columnsToDiscard, config: config, outputPath: outputPath, continueGeneration: continueGeneration, @@ -122,14 +125,9 @@ func (w *Writer) generateModelSchema() (*arrow.Schema, []parquet.WriterProperty, arrowFields := make([]arrow.Field, 0, len(w.model.Columns)) - partitionColumnsByName := map[string]*models.PartitionColumn{} - for _, column := range w.model.PartitionColumns { - partitionColumnsByName[column.Name] = column - } - for _, column := range w.model.Columns { - colSettings, ok := partitionColumnsByName[column.Name] - if ok && !colSettings.WriteToOutput { // filter partition columns in schema + // filter partition columns in schema + if _, exists := w.columnsToDiscard[column.Name]; exists { continue } diff --git a/internal/generator/output/general/writer/parquet/parquet_test.go b/internal/generator/output/general/writer/parquet/parquet_test.go index e3545ba..311ccc1 100644 --- a/internal/generator/output/general/writer/parquet/parquet_test.go +++ b/internal/generator/output/general/writer/parquet/parquet_test.go @@ -460,10 +460,11 @@ func TestGetModelSchema(t *testing.T) { require.NotEqual(t, "", tc.model.Name) writer := &Writer{ - model: tc.model, - config: tc.cfg, - fs: fsMock, - outputPath: "./", + model: tc.model, + columnsToDiscard: getColumnsToDiscard(tc.model.PartitionColumns), + config: tc.cfg, + fs: fsMock, + outputPath: "./", } modelSchemaPointer, writerProperties, err := writer.generateModelSchema() @@ -616,7 +617,15 @@ func TestWriteRow(t *testing.T) { // WHEN fsMock := newFileSystemMock() - parquetWriter := NewWriter(tc.model, parquetConfig, fsMock, "./", false, nil) + parquetWriter := NewWriter( + tc.model, + parquetConfig, + getColumnsToDiscard(tc.model.PartitionColumns), + fsMock, + "./", + false, + nil, + ) err := parquetWriter.Init() require.NoError(t, err) @@ -825,7 +834,15 @@ func TestWriteToCorrectFiles(t *testing.T) { fsMock := newFileSystemMock() write := func(from, to int, continueGeneration bool) { - writer := NewWriter(model, config, fsMock, dir, continueGeneration, nil) + writer := NewWriter( + model, + config, + getColumnsToDiscard(model.PartitionColumns), + fsMock, + dir, + continueGeneration, + nil, + ) require.NoError(t, writer.Init()) for i := from; i < to; i++ { @@ -914,3 +931,15 @@ func getExpected(rows []*models.DataRow, rowsPerFile uint64, writersCount int) ( return expectedFiles, expectedData } + +func getColumnsToDiscard(partitionColumns []*models.PartitionColumn) map[string]struct{} { + columnsToDiscard := make(map[string]struct{}) + + for _, column := range partitionColumns { + if !column.WriteToOutput { + columnsToDiscard[column.Name] = struct{}{} + } + } + + return columnsToDiscard +}