I think I discovered the issue; if the function is not decorated with @task
, it will run in serial!
The Proof
WITH the decorator however, it will properly utilize the processes.
With Dask, I checked htop
because updating the logging formatter is a bit tedious (notice CPU at >100%).
With Ray, the output automatically shows the unique PIDs.
(begin_task_run pid=93880) 2013-01-01T12:00:00.000000000
(begin_task_run pid=93881) 2013-01-01T00:00:00.000000000
(begin_task_run pid=93878) 2013-01-01T18:00:00.000000000
(begin_task_run pid=93879) 2013-01-01T06:00:00.000000000
Setup
import dask
import xarray as xr
import matplotlib.pyplot as plt
from cartopy import crs as ccrs
from cartopy import feature as cfeature
import cartopy.crs as ccrs
import matplotlib.textpath
import matplotlib.patches
from matplotlib.font_manager import FontProperties
import numpy as np
from prefect import flow, task, get_run_logger
from prefect.task_runners import DaskTaskRunner, RayTaskRunner
@task
def plot(ds, time):
print(time)
plt.figure()
ax = plt.axes(projection=ccrs.Orthographic())
ax.add_feature(cfeature.LAKES.with_scale("10m"))
ax.add_feature(cfeature.OCEAN.with_scale("10m"))
ax.add_feature(cfeature.LAND.with_scale("10m"))
ax.add_feature(cfeature.STATES.with_scale("10m"))
ax.add_feature(cfeature.COASTLINE.with_scale("10m"))
ax.pcolormesh(ds["lon"], ds["lat"], ds["air"], transform=ccrs.PlateCarree())
# generate a matplotlib path representing the word "cartopy"
fp = FontProperties(family='Bitstream Vera Sans', weight='bold')
logo_path = matplotlib.textpath.TextPath((-175, -35), 'cartopy',
size=1, prop=fp)
# add a background image
im = ax.stock_img()
# clip the image according to the logo_path. mpl v1.2.0 does not support
# the transform API that cartopy makes use of, so we have to convert the
# projection into a transform manually
plate_carree_transform = ccrs.PlateCarree()._as_mpl_transform(ax)
im.set_clip_path(logo_path, transform=plate_carree_transform)
# add the path as a patch, drawing black outlines around the text
patch = matplotlib.patches.PathPatch(logo_path,
facecolor='none', edgecolor='black',
transform=ccrs.PlateCarree())
ax.add_patch(patch)
ax.set_global()
plt.savefig(str(time)[:16])
plt.close()
ds = xr.tutorial.open_dataset('air_temperature').isel(
time=slice(0, 100))
Benchmark
Running on a Jupyter notebook,
ConcurrentTaskRunner - 2m 32.7s
@flow()
def process():
for time in ds['time'].values:
plot(ds.sel(time=time), time)
process()
DaskTaskRunner - 1m 29.3s
@flow(task_runner=DaskTaskRunner(cluster_kwargs={"n_workers": 4, "processes": True}))
def process():
for time in ds['time'].values:
plot(ds.sel(time=time), time)
process()
RayTaskRunner - 1m 35.2s
@flow(task_runner=RayTaskRunner(init_kwargs={"num_cpus": 4}))
def process():
for time in ds['time'].values:
plot(ds.sel(time=time), time)
process()
Native Dask (dropped the @task decorator) - 1m 31.4s (!! around the same time as DaskTaskRunner!)
def process():
jobs = []
for time in ds['time'].values:
job = dask.delayed(plot)(ds.sel(time=time), time)
jobs.append(job)
dask.compute(jobs, scheduler="processes", num_workers=4)
process()
Serial mode - 3m 15.9s (I think this could be like an exciting blog post or announcement because just by wrapping @flow and @task, which defaults to the ConcurrentTaskRunner, one can speed up the run by 45 seconds!!)
def process():
for time in ds['time'].values:
plot(ds.sel(time=time), time)
process()
Conclusion:
Wrap your functions with @task
to get parallelization.