Introduction to Zustand Middleware#
This article is a supplement to the previous article Zustand Core Source Code Analysis, introducing the middleware system of zustand
using the persist
plugin as an example.
The core implementation of zustand
is very concise, and its features are relatively few. If you want more features, you need to implement them yourself or use existing libraries by introducing middleware.
The Essence of Zustand Middleware#
As introduced in the official documentation middleware, the middleware of zustand
is actually a higher-order function. Its parameters are the same as those of the create
function, both being functions of the createInitState
type. However, the difference is that its return value is still a function of the createInitState
type, essentially wrapping the createInitState
to inject specific logic and achieve a rewrite of createInitState
.
import { create } from "zustand";
const createInitState = (set) => ({
bees: false,
setBees: (input) => set((state) => void (state.bees = input)),
});
// no middleware
const useStore = create(createInitState);
// with middleware
const useStoreWithMiddleware = create(
middleware1(middleware2(createInitState))
);
Detailed Explanation of Persist Middleware Source Code#
For a detailed introduction to the persist middleware, see persist.
The parameters are createInitState
and options
, where options.name
is required, and the others are optional. The default values for options
are as follows:
defaultOptions = {
storage: createJSONStorage<S>(() => localStorage),
// Determines which data in the state to save
partialize: (state: S) => state,
version: 0,
// Determines how to merge existing data and data stored in storage
merge: (persistedState: unknown, currentState: S) => ({
...currentState,
...(persistedState as object),
}),
}
At the beginning, the createJSONStorage
function is called to generate storage. From my understanding, this function acts as a glue layer, allowing getItem
to support async storage
. The core part is:
const str = (storage as StateStorage).getItem(name) ?? null;
// support async storage
if (str instanceof Promise) {
return str.then(parse);
}
return parse(str);
From the createJSONStorage
function, it can be seen that the persist middleware was initially based on a wrapper for localStorage
, later expanded to require data storage to use strings. There is an issue: setItem
does not support async storage
and does not perform an await
operation, meaning it does not guarantee completion before the next operation, which is inconsistent with getItem
. However, from another perspective, it is not necessary to ensure setItem
completes, as it is generally not followed immediately by getItem
. Using await
would incur some performance loss, but I still think it is better to use await
.
The source code for createJSONStorage
is as follows:
export interface StateStorage {
getItem: (name: string) => string | null | Promise<string | null>;
setItem: (name: string, value: string) => void | Promise<void>;
removeItem: (name: string) => void | Promise<void>;
}
export function createJSONStorage<S>(
getStorage: () => StateStorage,
options?: JsonStorageOptions
): PersistStorage<S> | undefined {
let storage: StateStorage | undefined;
try {
storage = getStorage();
} catch (e) {
// prevent error if the storage is not defined (e.g. when server side rendering a page)
return;
}
const persistStorage: PersistStorage<S> = {
getItem: (name) => {
const parse = (str: string | null) => {
if (str === null) {
return null;
}
return JSON.parse(str, options?.reviver) as StorageValue<S>;
};
const str = (storage as StateStorage).getItem(name) ?? null;
// support async storage
if (str instanceof Promise) {
return str.then(parse);
}
return parse(str);
},
setItem: (name, newValue) =>
(storage as StateStorage).setItem(
name,
JSON.stringify(newValue, options?.replacer)
),
removeItem: (name) => (storage as StateStorage).removeItem(name),
};
return persistStorage;
}
Below is the core source code (the content of the hydrate
function is omitted and will be introduced in the next step). It can be seen that an api.persist
is attached to the original zustand
API to expose the persist API, and the final return value is also the return value of createInitState
(i.e., config) for the initial state.
You can see options.skipHydration
and the hydrate
function. The purpose of this function is to merge the saved state with the existing state, similar to the hydration in SSR
, where events are mounted onto the corresponding DOM.
const newImpl = (config, baseOptions) => (set, get, api) => {
// Here, S is actually the return type of createInitState mentioned above
type S = ReturnType<typeof config>;
let options = {
storage: createJSONStorage<S>(() => localStorage),
partialize: (state: S) => state,
version: 0,
merge: (persistedState: unknown, currentState: S) => ({
...currentState,
...(persistedState as object),
}),
// above are default configs
...baseOptions,
};
let hasHydrated = false;
const hydrationListeners = new Set<PersistListener<S>>();
const finishHydrationListeners = new Set<PersistListener<S>>();
let storage = options.storage;
// If there is no storage, do not save, directly return config(..args) which is initState
if (!storage) {
return config(
(...args) => {
console.warn(
`[zustand persist middleware] Unable to update item '${options.name}', the given storage is currently unavailable.`
);
set(...args);
},
get,
api
);
}
// set partialized item
const setItem = (): void | Promise<void> => {
// Only save the state that has been processed by partialize
const state = options.partialize({ ...get() });
return (storage as PersistStorage<S>).setItem(options.name, {
state,
version: options.version,
});
};
const savedSetState = api.setState;
// Replace the new setState, and after each update, store it in storage
api.setState = (state, replace) => {
savedSetState(state, replace);
void setItem();
};
const configResult = config(
/**
*
This step is to ensure that the setState injected in createState function is equivalent to api.setState
Because it is injected this way, updating api.setState will not affect the injected set function
const api = { setState, getState, subscribe, destroy }
state = createState(setState, getState, api)
*/
(...args) => {
// Here, set === savedSetState
set(...args);
void setItem();
},
get,
api
);
// A workaround to solve the issue of not storing rehydrated state in sync storage
// The set(state) value would be later overridden with initial state by create()
// To avoid this, we merge the state from localStorage into the initial state.
let stateFromStorage: S | undefined;
// Rehydrate initial state with existing stored state
const hydrate = () => {
...
};
(api as StoreApi<S> & StorePersist<S, S>).persist = {
setOptions: (newOptions) => {
options = {
...options,
...newOptions,
};
if (newOptions.storage) {
storage = newOptions.storage;
}
},
clearStorage: () => {
storage?.removeItem(options.name);
},
getOptions: () => options,
rehydrate: () => hydrate() as Promise<void>,
hasHydrated: () => hasHydrated,
onHydrate: (cb) => {
hydrationListeners.add(cb);
return () => {
hydrationListeners.delete(cb);
};
},
onFinishHydration: (cb) => {
finishHydrationListeners.add(cb);
return () => {
finishHydrationListeners.delete(cb);
};
},
};
if (!options.skipHydration) {
hydrate();
}
return stateFromStorage || configResult;
};
Here, I will introduce the hydrate
function part. The main process is to read the initial state from storage and use the toThenable
function to convert non-async storage
values into a promisify
type, unifying the function call format. First, check whether the retrieved value needs to be migrate
, then perform merge
, and finally call the exposed functions related to hydrate
.
const hydrate = () => {
if (!storage) return;
// On the first invocation of 'hydrate', state will not yet be defined (this is
// true for both the 'asynchronous' and 'synchronous' case). Pass 'configResult'
// as a backup to 'get()' so listeners and 'onRehydrateStorage' are called with
// the latest available state.
hasHydrated = false;
// When skipHydration is not set, calling hydrate when initState has not yet been generated,
// the result of get() is undefined, so we need to use the previously generated configResult
hydrationListeners.forEach((cb) => cb(get() ?? configResult));
const postRehydrationCallback =
options.onRehydrateStorage?.(get() ?? configResult) || undefined;
// bind is used to avoid `TypeError: Illegal invocation` error
return toThenable(storage.getItem.bind(storage))(options.name)
.then((deserializedStorageValue) => {
// This step is to implement migration of old data based on version
if (deserializedStorageValue) {
if (
typeof deserializedStorageValue.version === "number" &&
deserializedStorageValue.version !== options.version
) {
if (options.migrate) {
return options.migrate(
deserializedStorageValue.state,
deserializedStorageValue.version
);
}
console.error(
`State loaded from storage couldn't be migrated since no migrate function was provided`
);
} else {
return deserializedStorageValue.state;
}
}
})
.then((migratedState) => {
// This step performs the merge
stateFromStorage = options.merge(
migratedState as S,
get() ?? configResult
);
set(stateFromStorage as S, true);
return setItem();
})
.then(() => {
// TODO: In the asynchronous case, it's possible that the state has changed
// since it was set in the prior callback. As such, it would be better to
// pass 'get()' to the 'postRehydrationCallback' to ensure the most up-to-date
// state is used. However, this could be a breaking change, so this isn't being
// done now.
postRehydrationCallback?.(stateFromStorage, undefined);
// It's possible that 'postRehydrationCallback' updated the state. To ensure
// that isn't overwritten when returning 'stateFromStorage' below
// (synchronous-case only), update 'stateFromStorage' to point to the latest
// state. In the asynchronous case, 'stateFromStorage' isn't used after this
// callback, so there's no harm in updating it to match the latest state.
stateFromStorage = get();
hasHydrated = true;
finishHydrationListeners.forEach((cb) => cb(stateFromStorage as S));
})
.catch((e: Error) => {
postRehydrationCallback?.(undefined, e);
});
};
Source code for the toThenable
function:
const toThenable =
<Result, Input>(
fn: (input: Input) => Result | Promise<Result> | Thenable<Result>
) =>
(input: Input): Thenable<Result> => {
try {
const result = fn(input);
if (result instanceof Promise) {
return result as Thenable<Result>;
}
return {
then(onFulfilled) {
return toThenable(onFulfilled)(result as Result);
},
catch(_onRejected) {
return this as Thenable<any>;
},
};
} catch (e: any) {
return {
then(_onFulfilled) {
return this as Thenable<any>;
},
catch(onRejected) {
return toThenable(onRejected)(e);
},
};
}
};