Skip to content

Commit 1fa55ca

Browse files
removed all unnecessary calls to compute()
1 parent 1ac8e38 commit 1fa55ca

File tree

2 files changed

+47
-45
lines changed

2 files changed

+47
-45
lines changed

.gitignore

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,5 +122,3 @@ docs/source/auto_examples/
122122
docs/source/examples/mydask.png
123123

124124
dask-worker-space
125-
/.project
126-
/.pydevproject

dask_ml/feature_extraction/text.py

Lines changed: 47 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,18 @@ def _hasher(self):
120120
return sklearn.feature_extraction.text.FeatureHasher
121121

122122

123+
def _n_samples(X):
124+
"""Count the number of samples in dask.array.Array X."""
125+
def chunk_n_samples(chunk, axis, keepdims):
126+
return np.array([chunk.shape[0]], dtype=np.int64)
127+
128+
return da.reduction(X,
129+
chunk=chunk_n_samples,
130+
aggregate=np.sum,
131+
concatenate=False,
132+
dtype=np.int64)
133+
134+
123135
def _document_frequency(X, dtype):
124136
"""Count the number of non-zero values for each feature in dask array X."""
125137
def chunk_doc_freq(chunk, axis, keepdims):
@@ -133,7 +145,7 @@ def chunk_doc_freq(chunk, axis, keepdims):
133145
aggregate=np.sum,
134146
axis=0,
135147
concatenate=False,
136-
dtype=dtype).compute().astype(dtype)
148+
dtype=dtype)
137149

138150

139151
class CountVectorizer(sklearn.feature_extraction.text.CountVectorizer):
@@ -203,17 +215,19 @@ class CountVectorizer(sklearn.feature_extraction.text.CountVectorizer):
203215
['and', 'document', 'first', 'is', 'one', 'second', 'the', 'third', 'this']
204216
"""
205217

206-
def fit_transform(self, raw_documents, y=None):
218+
def get_params(self):
207219
# Note that in general 'self' could refer to an instance of either this
208220
# class or a subclass of this class. Hence it is possible that
209221
# self.get_params() could get unexpected parameters of an instance of a
210222
# subclass. Such parameters need to be excluded here:
211-
subclass_instance_params = self.get_params()
223+
subclass_instance_params = super().get_params()
212224
excluded_keys = getattr(self, '_non_CountVectorizer_params', [])
213-
params = {key: subclass_instance_params[key]
214-
for key in subclass_instance_params
215-
if key not in excluded_keys}
225+
return {key: subclass_instance_params[key]
226+
for key in subclass_instance_params
227+
if key not in excluded_keys}
216228

229+
def fit_transform(self, raw_documents, y=None):
230+
params = self.get_params()
217231
vocabulary = params.pop("vocabulary")
218232
vocabulary_for_transform = vocabulary
219233

@@ -227,12 +241,12 @@ def fit_transform(self, raw_documents, y=None):
227241
# Case 2: learn vocabulary from the data.
228242
vocabularies = raw_documents.map_partitions(_build_vocabulary, params)
229243
vocabulary = vocabulary_for_transform = (
230-
_merge_vocabulary( *vocabularies.to_delayed() ))
244+
_merge_vocabulary(*vocabularies.to_delayed()))
231245
vocabulary_for_transform = vocabulary_for_transform.persist()
232246
vocabulary_ = vocabulary.compute()
233247
n_features = len(vocabulary_)
234248

235-
meta = scipy.sparse.eye(0, format="csr", dtype=self.dtype)
249+
meta = scipy.sparse.csr_matrix((0, n_features), dtype=self.dtype)
236250
if isinstance(raw_documents, dd.Series):
237251
result = raw_documents.map_partitions(
238252
_count_vectorizer_transform, vocabulary_for_transform,
@@ -241,23 +255,14 @@ def fit_transform(self, raw_documents, y=None):
241255
result = raw_documents.map_partitions(
242256
_count_vectorizer_transform, vocabulary_for_transform, params)
243257
result = build_array(result, n_features, meta)
244-
result.compute_chunk_sizes()
245258

246259
self.vocabulary_ = vocabulary_
247260
self.fixed_vocabulary_ = fixed_vocabulary
248261

249262
return result
250263

251264
def transform(self, raw_documents):
252-
# Note that in general 'self' could refer to an instance of either this
253-
# class or a subclass of this class. Hence it is possible that
254-
# self.get_params() could get unexpected parameters of an instance of a
255-
# subclass. Such parameters need to be excluded here:
256-
subclass_instance_params = self.get_params()
257-
excluded_keys = getattr(self, '_non_CountVectorizer_params', [])
258-
params = {key: subclass_instance_params[key]
259-
for key in subclass_instance_params
260-
if key not in excluded_keys}
265+
params = self.get_params()
261266
vocabulary = params.pop("vocabulary")
262267

263268
if vocabulary is None:
@@ -271,14 +276,13 @@ def transform(self, raw_documents):
271276
except ValueError:
272277
vocabulary_for_transform = dask.delayed(vocabulary)
273278
else:
274-
(vocabulary_for_transform,) = client.scatter(
275-
(vocabulary,), broadcast=True
276-
)
279+
(vocabulary_for_transform,) = client.scatter((vocabulary,),
280+
broadcast=True)
277281
else:
278282
vocabulary_for_transform = vocabulary
279283

280284
n_features = vocabulary_length(vocabulary_for_transform)
281-
meta = scipy.sparse.eye(0, format="csr", dtype=self.dtype)
285+
meta = scipy.sparse.csr_matrix((0, n_features), dtype=self.dtype)
282286
if isinstance(raw_documents, dd.Series):
283287
result = raw_documents.map_partitions(
284288
_count_vectorizer_transform, vocabulary_for_transform,
@@ -287,7 +291,6 @@ def transform(self, raw_documents):
287291
transformed = raw_documents.map_partitions(
288292
_count_vectorizer_transform, vocabulary_for_transform, params)
289293
result = build_array(transformed, n_features, meta)
290-
result.compute_chunk_sizes()
291294
return result
292295

293296
class TfidfTransformer(sklearn.feature_extraction.text.TfidfTransformer):
@@ -331,30 +334,23 @@ def fit(self, X, y=None):
331334
X : sparse matrix of shape n_samples, n_features)
332335
A matrix of term/token counts.
333336
"""
334-
# X = check_array(X, accept_sparse=('csr', 'csc'))
335-
# if not sp.issparse(X):
336-
# X = sp.csr_matrix(X)
337-
dtype = X.dtype if X.dtype in FLOAT_DTYPES else np.float64
338-
339-
if self.use_idf:
340-
n_samples, n_features = X.shape
337+
def get_idf_diag(X, dtype):
338+
n_samples = _n_samples(X) # X.shape[0] is not yet known
339+
n_features = X.shape[1]
341340
df = _document_frequency(X, dtype)
342-
# df = df.astype(dtype, **_astype_copy_false(df))
343341

344342
# perform idf smoothing if required
345343
df += int(self.smooth_idf)
346344
n_samples += int(self.smooth_idf)
347345

348346
# log+1 instead of log makes sure terms with zero idf don't get
349347
# suppressed entirely.
350-
idf = np.log(n_samples / df) + 1
351-
self._idf_diag = scipy.sparse.diags(
352-
idf,
353-
offsets=0,
354-
shape=(n_features, n_features),
355-
format="csr",
356-
dtype=dtype,
357-
)
348+
return np.log(n_samples / df) + 1
349+
350+
dtype = X.dtype if X.dtype in FLOAT_DTYPES else np.float64
351+
352+
if self.use_idf:
353+
self._idf_diag = get_idf_diag(X, dtype)
358354

359355
return self
360356

@@ -404,8 +400,17 @@ def _dot_idf_diag(chunk):
404400
# idf_ being a property, the automatic attributes detection
405401
# does not work as usual and we need to specify the attribute
406402
# name:
407-
check_is_fitted(self, attributes=["idf_"], msg="idf vector is not fitted")
408-
403+
check_is_fitted(self, attributes=["idf_"],
404+
msg="idf vector is not fitted")
405+
if dask.is_dask_collection(self._idf_diag):
406+
_idf_diag = self._idf_diag.compute()
407+
n_features = len(_idf_diag)
408+
self._idf_diag = scipy.sparse.diags(
409+
_idf_diag,
410+
offsets=0,
411+
shape=(n_features, n_features),
412+
format="csr",
413+
dtype=_idf_diag.dtype)
409414
X = X.map_blocks(_dot_idf_diag, dtype=np.float64, meta=meta)
410415

411416
if self.norm:
@@ -619,8 +624,7 @@ def fit(self, raw_documents, y=None):
619624
"""
620625
self._check_params()
621626
self._warn_for_unused_params()
622-
X = super().fit_transform(raw_documents,
623-
y=self._non_CountVectorizer_params)
627+
X = super().fit_transform(raw_documents)
624628
self._tfidf.fit(X)
625629
return self
626630

0 commit comments

Comments
 (0)