@@ -120,6 +120,18 @@ def _hasher(self):
120
120
return sklearn .feature_extraction .text .FeatureHasher
121
121
122
122
123
+ def _n_samples (X ):
124
+ """Count the number of samples in dask.array.Array X."""
125
+ def chunk_n_samples (chunk , axis , keepdims ):
126
+ return np .array ([chunk .shape [0 ]], dtype = np .int64 )
127
+
128
+ return da .reduction (X ,
129
+ chunk = chunk_n_samples ,
130
+ aggregate = np .sum ,
131
+ concatenate = False ,
132
+ dtype = np .int64 )
133
+
134
+
123
135
def _document_frequency (X , dtype ):
124
136
"""Count the number of non-zero values for each feature in dask array X."""
125
137
def chunk_doc_freq (chunk , axis , keepdims ):
@@ -133,7 +145,7 @@ def chunk_doc_freq(chunk, axis, keepdims):
133
145
aggregate = np .sum ,
134
146
axis = 0 ,
135
147
concatenate = False ,
136
- dtype = dtype ). compute (). astype ( dtype )
148
+ dtype = dtype )
137
149
138
150
139
151
class CountVectorizer (sklearn .feature_extraction .text .CountVectorizer ):
@@ -203,17 +215,19 @@ class CountVectorizer(sklearn.feature_extraction.text.CountVectorizer):
203
215
['and', 'document', 'first', 'is', 'one', 'second', 'the', 'third', 'this']
204
216
"""
205
217
206
- def fit_transform (self , raw_documents , y = None ):
218
+ def get_params (self ):
207
219
# Note that in general 'self' could refer to an instance of either this
208
220
# class or a subclass of this class. Hence it is possible that
209
221
# self.get_params() could get unexpected parameters of an instance of a
210
222
# subclass. Such parameters need to be excluded here:
211
- subclass_instance_params = self .get_params ()
223
+ subclass_instance_params = super () .get_params ()
212
224
excluded_keys = getattr (self , '_non_CountVectorizer_params' , [])
213
- params = {key : subclass_instance_params [key ]
214
- for key in subclass_instance_params
215
- if key not in excluded_keys }
225
+ return {key : subclass_instance_params [key ]
226
+ for key in subclass_instance_params
227
+ if key not in excluded_keys }
216
228
229
+ def fit_transform (self , raw_documents , y = None ):
230
+ params = self .get_params ()
217
231
vocabulary = params .pop ("vocabulary" )
218
232
vocabulary_for_transform = vocabulary
219
233
@@ -227,12 +241,12 @@ def fit_transform(self, raw_documents, y=None):
227
241
# Case 2: learn vocabulary from the data.
228
242
vocabularies = raw_documents .map_partitions (_build_vocabulary , params )
229
243
vocabulary = vocabulary_for_transform = (
230
- _merge_vocabulary ( * vocabularies .to_delayed () ))
244
+ _merge_vocabulary (* vocabularies .to_delayed ()))
231
245
vocabulary_for_transform = vocabulary_for_transform .persist ()
232
246
vocabulary_ = vocabulary .compute ()
233
247
n_features = len (vocabulary_ )
234
248
235
- meta = scipy .sparse .eye ( 0 , format = "csr" , dtype = self .dtype )
249
+ meta = scipy .sparse .csr_matrix (( 0 , n_features ) , dtype = self .dtype )
236
250
if isinstance (raw_documents , dd .Series ):
237
251
result = raw_documents .map_partitions (
238
252
_count_vectorizer_transform , vocabulary_for_transform ,
@@ -241,23 +255,14 @@ def fit_transform(self, raw_documents, y=None):
241
255
result = raw_documents .map_partitions (
242
256
_count_vectorizer_transform , vocabulary_for_transform , params )
243
257
result = build_array (result , n_features , meta )
244
- result .compute_chunk_sizes ()
245
258
246
259
self .vocabulary_ = vocabulary_
247
260
self .fixed_vocabulary_ = fixed_vocabulary
248
261
249
262
return result
250
263
251
264
def transform (self , raw_documents ):
252
- # Note that in general 'self' could refer to an instance of either this
253
- # class or a subclass of this class. Hence it is possible that
254
- # self.get_params() could get unexpected parameters of an instance of a
255
- # subclass. Such parameters need to be excluded here:
256
- subclass_instance_params = self .get_params ()
257
- excluded_keys = getattr (self , '_non_CountVectorizer_params' , [])
258
- params = {key : subclass_instance_params [key ]
259
- for key in subclass_instance_params
260
- if key not in excluded_keys }
265
+ params = self .get_params ()
261
266
vocabulary = params .pop ("vocabulary" )
262
267
263
268
if vocabulary is None :
@@ -271,14 +276,13 @@ def transform(self, raw_documents):
271
276
except ValueError :
272
277
vocabulary_for_transform = dask .delayed (vocabulary )
273
278
else :
274
- (vocabulary_for_transform ,) = client .scatter (
275
- (vocabulary ,), broadcast = True
276
- )
279
+ (vocabulary_for_transform ,) = client .scatter ((vocabulary ,),
280
+ broadcast = True )
277
281
else :
278
282
vocabulary_for_transform = vocabulary
279
283
280
284
n_features = vocabulary_length (vocabulary_for_transform )
281
- meta = scipy .sparse .eye ( 0 , format = "csr" , dtype = self .dtype )
285
+ meta = scipy .sparse .csr_matrix (( 0 , n_features ) , dtype = self .dtype )
282
286
if isinstance (raw_documents , dd .Series ):
283
287
result = raw_documents .map_partitions (
284
288
_count_vectorizer_transform , vocabulary_for_transform ,
@@ -287,7 +291,6 @@ def transform(self, raw_documents):
287
291
transformed = raw_documents .map_partitions (
288
292
_count_vectorizer_transform , vocabulary_for_transform , params )
289
293
result = build_array (transformed , n_features , meta )
290
- result .compute_chunk_sizes ()
291
294
return result
292
295
293
296
class TfidfTransformer (sklearn .feature_extraction .text .TfidfTransformer ):
@@ -331,30 +334,23 @@ def fit(self, X, y=None):
331
334
X : sparse matrix of shape n_samples, n_features)
332
335
A matrix of term/token counts.
333
336
"""
334
- # X = check_array(X, accept_sparse=('csr', 'csc'))
335
- # if not sp.issparse(X):
336
- # X = sp.csr_matrix(X)
337
- dtype = X .dtype if X .dtype in FLOAT_DTYPES else np .float64
338
-
339
- if self .use_idf :
340
- n_samples , n_features = X .shape
337
+ def get_idf_diag (X , dtype ):
338
+ n_samples = _n_samples (X ) # X.shape[0] is not yet known
339
+ n_features = X .shape [1 ]
341
340
df = _document_frequency (X , dtype )
342
- # df = df.astype(dtype, **_astype_copy_false(df))
343
341
344
342
# perform idf smoothing if required
345
343
df += int (self .smooth_idf )
346
344
n_samples += int (self .smooth_idf )
347
345
348
346
# log+1 instead of log makes sure terms with zero idf don't get
349
347
# suppressed entirely.
350
- idf = np .log (n_samples / df ) + 1
351
- self ._idf_diag = scipy .sparse .diags (
352
- idf ,
353
- offsets = 0 ,
354
- shape = (n_features , n_features ),
355
- format = "csr" ,
356
- dtype = dtype ,
357
- )
348
+ return np .log (n_samples / df ) + 1
349
+
350
+ dtype = X .dtype if X .dtype in FLOAT_DTYPES else np .float64
351
+
352
+ if self .use_idf :
353
+ self ._idf_diag = get_idf_diag (X , dtype )
358
354
359
355
return self
360
356
@@ -404,8 +400,17 @@ def _dot_idf_diag(chunk):
404
400
# idf_ being a property, the automatic attributes detection
405
401
# does not work as usual and we need to specify the attribute
406
402
# name:
407
- check_is_fitted (self , attributes = ["idf_" ], msg = "idf vector is not fitted" )
408
-
403
+ check_is_fitted (self , attributes = ["idf_" ],
404
+ msg = "idf vector is not fitted" )
405
+ if dask .is_dask_collection (self ._idf_diag ):
406
+ _idf_diag = self ._idf_diag .compute ()
407
+ n_features = len (_idf_diag )
408
+ self ._idf_diag = scipy .sparse .diags (
409
+ _idf_diag ,
410
+ offsets = 0 ,
411
+ shape = (n_features , n_features ),
412
+ format = "csr" ,
413
+ dtype = _idf_diag .dtype )
409
414
X = X .map_blocks (_dot_idf_diag , dtype = np .float64 , meta = meta )
410
415
411
416
if self .norm :
@@ -619,8 +624,7 @@ def fit(self, raw_documents, y=None):
619
624
"""
620
625
self ._check_params ()
621
626
self ._warn_for_unused_params ()
622
- X = super ().fit_transform (raw_documents ,
623
- y = self ._non_CountVectorizer_params )
627
+ X = super ().fit_transform (raw_documents )
624
628
self ._tfidf .fit (X )
625
629
return self
626
630
0 commit comments