diff --git a/src/vanilla/store.ts b/src/vanilla/store.ts index afbd4c7641..7b04acdbae 100644 --- a/src/vanilla/store.ts +++ b/src/vanilla/store.ts @@ -191,6 +191,27 @@ const addPendingContinuablePromiseToDependency = ( } } +const addDependency = ( + pending: Pending | undefined, + atom: Atom, + atomState: AtomState, + a: AnyAtom, + aState: AtomState, +) => { + if (import.meta.env?.MODE !== 'production' && a === atom) { + throw new Error('[Bug] atom cannot depend on itself') + } + atomState.d.set(a, aState.n) + const continuablePromise = getPendingContinuablePromise(atomState) + if (continuablePromise) { + addPendingContinuablePromiseToDependency(atom, continuablePromise, aState) + } + aState.m?.t.add(atom) + if (pending) { + addPendingDependent(pending, a, atom) + } +} + // // Pending // @@ -244,9 +265,22 @@ const flushPending = (pending: Pending) => { } } +type GetAtomState = ( + atom: Atom, + originAtomState?: AtomState, +) => AtomState + +// internal & unstable type +type StoreArgs = readonly [ + getAtomState: GetAtomState, + // possible other arguments in the future +] + // for debugging purpose only type DevStoreRev4 = { - dev4_get_internal_weak_map: () => WeakMap + dev4_get_internal_weak_map: () => { + get: (atom: AnyAtom) => AtomState | undefined + } dev4_get_mounted_atoms: () => Set dev4_restore_atoms: (values: Iterable) => void } @@ -258,6 +292,7 @@ type PrdStore = { ...args: Args ) => Result sub: (atom: AnyAtom, listener: () => void) => () => void + unstable_derive: (fn: (...args: StoreArgs) => StoreArgs) => Store } type Store = PrdStore | (PrdStore & DevStoreRev4) @@ -265,8 +300,7 @@ type Store = PrdStore | (PrdStore & DevStoreRev4) export type INTERNAL_DevStoreRev4 = DevStoreRev4 export type INTERNAL_PrdStore = PrdStore -export const createStore = (): Store => { - const atomStateMap = new WeakMap() +const buildStore = (getAtomState: StoreArgs[0]): Store => { // for debugging purpose only let debugMountedAtoms: Set @@ -274,15 +308,6 @@ export const createStore = (): Store => { debugMountedAtoms = new Set() } - const getAtomState = (atom: Atom) => { - let atomState = atomStateMap.get(atom) as AtomState | undefined - if (!atomState) { - atomState = { d: new Map(), p: new Set(), n: 0 } - atomStateMap.set(atom, atomState) - } - return atomState - } - const setAtomStateValueOrPromise = ( atom: AnyAtom, atomState: AtomState, @@ -307,11 +332,10 @@ export const createStore = (): Store => { ) if (continuablePromise.status === PENDING) { for (const a of atomState.d.keys()) { - const aState = getAtomState(a) addPendingContinuablePromiseToDependency( atom, continuablePromise, - aState, + getAtomState(a, atomState), ) } } @@ -332,34 +356,14 @@ export const createStore = (): Store => { ++atomState.n } } - const addDependency = ( - pending: Pending | undefined, - atom: Atom, - a: AnyAtom, - aState: AtomState, - ) => { - if (import.meta.env?.MODE !== 'production' && a === atom) { - throw new Error('[Bug] atom cannot depend on itself') - } - const atomState = getAtomState(atom) - atomState.d.set(a, aState.n) - const continuablePromise = getPendingContinuablePromise(atomState) - if (continuablePromise) { - addPendingContinuablePromiseToDependency(atom, continuablePromise, aState) - } - aState.m?.t.add(atom) - if (pending) { - addPendingDependent(pending, a, atom) - } - } const readAtomState = ( pending: Pending | undefined, atom: Atom, + atomState: AtomState, force?: (a: AnyAtom) => boolean, ): AtomState => { // See if we can skip recomputing this atom. - const atomState = getAtomState(atom) if (!force?.(atom) && isAtomStateInitialized(atomState)) { // If the atom is mounted, we can use the cache. // because it should have been updated by dependencies. @@ -373,7 +377,8 @@ export const createStore = (): Store => { ([a, n]) => // Recursively, read the atom state of the dependency, and // check if the atom epoch number is unchanged - readAtomState(pending, a, force).n === n, + readAtomState(pending, a, getAtomState(a, atomState), force).n === + n, ) ) { return atomState @@ -384,7 +389,7 @@ export const createStore = (): Store => { let isSync = true const getter: Getter = (a: Atom) => { if (isSelfAtom(atom, a)) { - const aState = getAtomState(a) + const aState = getAtomState(a, atomState) if (!isAtomStateInitialized(aState)) { if (hasInitialValue(a)) { setAtomStateValueOrPromise(a, aState, a.init) @@ -396,12 +401,17 @@ export const createStore = (): Store => { return returnAtomValue(aState) } // a !== atom - const aState = readAtomState(pending, a, force) + const aState = readAtomState( + pending, + a, + getAtomState(a, atomState), + force, + ) if (isSync) { - addDependency(pending, atom, a, aState) + addDependency(pending, atom, atomState, a, aState) } else { const pending = createPending() - addDependency(pending, atom, a, aState) + addDependency(pending, atom, atomState, a, aState) mountDependencies(pending, atom, atomState) flushPending(pending) } @@ -463,20 +473,34 @@ export const createStore = (): Store => { } const readAtom = (atom: Atom): Value => - returnAtomValue(readAtomState(undefined, atom)) + returnAtomValue(readAtomState(undefined, atom, getAtomState(atom))) - const recomputeDependents = (pending: Pending, atom: AnyAtom) => { - const getDependents = (a: AnyAtom, aState: AtomState): Set => { - const dependents = new Set(aState.m?.t) - for (const atomWithPendingContinuablePromise of aState.p) { - dependents.add(atomWithPendingContinuablePromise) - } - getPendingDependents(pending, a)?.forEach((dependent) => { - dependents.add(dependent) - }) - return dependents + const getDependents = ( + pending: Pending, + atom: Atom, + atomState: AtomState, + ): Map => { + const dependents = new Map() + for (const a of atomState.m?.t || []) { + dependents.set(a, getAtomState(a, atomState)) + } + for (const atomWithPendingContinuablePromise of atomState.p) { + dependents.set( + atomWithPendingContinuablePromise, + getAtomState(atomWithPendingContinuablePromise, atomState), + ) } + getPendingDependents(pending, atom)?.forEach((dependent) => { + dependents.set(dependent, getAtomState(dependent, atomState)) + }) + return dependents + } + const recomputeDependents = ( + pending: Pending, + atom: Atom, + atomState: AtomState, + ) => { // This is a topological sort via depth-first search, slightly modified from // what's described here for simplicity and performance reasons: // https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search @@ -489,16 +513,14 @@ export const createStore = (): Store => { epochNumber: number, ])[] = [] const markedAtoms = new Set() - const visit = (a: AnyAtom) => { + const visit = (a: AnyAtom, aState: AtomState) => { if (markedAtoms.has(a)) { return } markedAtoms.add(a) - const aState = getAtomState(a) - for (const d of getDependents(a, aState)) { - // we shouldn't use isSelfAtom here. + for (const [d, s] of getDependents(pending, a, aState)) { if (a !== d) { - visit(d) + visit(d, s) } } // The algorithm calls for pushing onto the front of the list. For @@ -508,7 +530,7 @@ export const createStore = (): Store => { } // Visit the root atom. This is the only atom in the dependency graph // without incoming edges, which is one reason we can simplify the algorithm - visit(atom) + visit(atom, atomState) // Step 2: use the topsorted atom list to recompute all affected atoms // Track what's changed, so that we can short circuit when possible const changedAtoms = new Set([atom]) @@ -523,7 +545,7 @@ export const createStore = (): Store => { } } if (hasChangedDeps) { - readAtomState(pending, a, isMarked) + readAtomState(pending, a, aState, isMarked) mountDependencies(pending, a, aState) if (prevEpochNumber !== aState.n) { addPendingAtom(pending, a, aState) @@ -537,21 +559,22 @@ export const createStore = (): Store => { const writeAtomState = ( pending: Pending, atom: WritableAtom, + atomState: AtomState, ...args: Args ): Result => { const getter: Getter = (a: Atom) => - returnAtomValue(readAtomState(pending, a)) + returnAtomValue(readAtomState(pending, a, getAtomState(a, atomState))) const setter: Setter = ( a: WritableAtom, ...args: As ) => { + const aState = getAtomState(a, atomState) let r: R | undefined if (isSelfAtom(atom, a)) { if (!hasInitialValue(a)) { // NOTE technically possible but restricted as it may cause bugs throw new Error('atom not writable') } - const aState = getAtomState(a) const hasPrevValue = 'v' in aState const prevValue = aState.v const v = args[0] as V @@ -559,10 +582,10 @@ export const createStore = (): Store => { mountDependencies(pending, a, aState) if (!hasPrevValue || !Object.is(prevValue, aState.v)) { addPendingAtom(pending, a, aState) - recomputeDependents(pending, a) + recomputeDependents(pending, a, aState) } } else { - r = writeAtomState(pending, a as AnyWritableAtom, ...args) as R + r = writeAtomState(pending, a, aState, ...args) as R } flushPending(pending) return r as R @@ -576,7 +599,7 @@ export const createStore = (): Store => { ...args: Args ): Result => { const pending = createPending() - const result = writeAtomState(pending, atom, ...args) + const result = writeAtomState(pending, atom, getAtomState(atom), ...args) flushPending(pending) return result } @@ -589,7 +612,7 @@ export const createStore = (): Store => { if (atomState.m && !getPendingContinuablePromise(atomState)) { for (const a of atomState.d.keys()) { if (!atomState.m.d.has(a)) { - const aMounted = mountAtom(pending, a) + const aMounted = mountAtom(pending, a, getAtomState(a, atomState)) aMounted.t.add(atom) atomState.m.d.add(a) } @@ -597,21 +620,24 @@ export const createStore = (): Store => { for (const a of atomState.m.d || []) { if (!atomState.d.has(a)) { atomState.m.d.delete(a) - const aMounted = unmountAtom(pending, a) + const aMounted = unmountAtom(pending, a, getAtomState(a, atomState)) aMounted?.t.delete(atom) } } } } - const mountAtom = (pending: Pending, atom: AnyAtom): Mounted => { - const atomState = getAtomState(atom) + const mountAtom = ( + pending: Pending, + atom: Atom, + atomState: AtomState, + ): Mounted => { if (!atomState.m) { // recompute atom state - readAtomState(pending, atom) + readAtomState(pending, atom, atomState) // mount dependencies first for (const a of atomState.d.keys()) { - const aMounted = mountAtom(pending, a) + const aMounted = mountAtom(pending, a, getAtomState(a, atomState)) aMounted.t.add(atom) } // mount self @@ -628,7 +654,7 @@ export const createStore = (): Store => { const { onMount } = atom addPendingFunction(pending, () => { const onUnmount = onMount((...args) => - writeAtomState(pending, atom, ...args), + writeAtomState(pending, atom, atomState, ...args), ) if (onUnmount) { mounted.u = onUnmount @@ -639,15 +665,17 @@ export const createStore = (): Store => { return atomState.m } - const unmountAtom = ( + const unmountAtom = ( pending: Pending, - atom: AnyAtom, + atom: Atom, + atomState: AtomState, ): Mounted | undefined => { - const atomState = getAtomState(atom) if ( atomState.m && !atomState.m.l.size && - !Array.from(atomState.m.t).some((a) => getAtomState(a).m?.d.has(atom)) + !Array.from(atomState.m.t).some((a) => + getAtomState(a, atomState).m?.d.has(atom), + ) ) { // unmount self const onUnmount = atomState.m.u @@ -660,7 +688,7 @@ export const createStore = (): Store => { } // unmount dependencies for (const a of atomState.d.keys()) { - const aMounted = unmountAtom(pending, a) + const aMounted = unmountAtom(pending, a, getAtomState(a, atomState)) aMounted?.t.delete(atom) } // abort pending promise @@ -676,28 +704,41 @@ export const createStore = (): Store => { const subscribeAtom = (atom: AnyAtom, listener: () => void) => { const pending = createPending() - const mounted = mountAtom(pending, atom) + const atomState = getAtomState(atom) + const mounted = mountAtom(pending, atom, atomState) flushPending(pending) const listeners = mounted.l listeners.add(listener) return () => { listeners.delete(listener) const pending = createPending() - unmountAtom(pending, atom) + unmountAtom(pending, atom, atomState) flushPending(pending) } } + const unstable_derive = (fn: (...args: StoreArgs) => StoreArgs) => + buildStore(...fn(getAtomState)) + const store: Store = { get: readAtom, set: writeAtom, sub: subscribeAtom, + unstable_derive, } - if (import.meta.env?.MODE !== 'production') { const devStore: DevStoreRev4 = { // store dev methods (these are tentative and subject to change without notice) - dev4_get_internal_weak_map: () => atomStateMap, + dev4_get_internal_weak_map: () => ({ + get: (atom) => { + const atomState = getAtomState(atom) + if (atomState.n === 0) { + // for backward compatibility + return undefined + } + return atomState + }, + }), dev4_get_mounted_atoms: () => debugMountedAtoms, dev4_restore_atoms: (values) => { const pending = createPending() @@ -710,7 +751,7 @@ export const createStore = (): Store => { mountDependencies(pending, atom, atomState) if (!hasPrevValue || !Object.is(prevValue, atomState.v)) { addPendingAtom(pending, atom, atomState) - recomputeDependents(pending, atom) + recomputeDependents(pending, atom, atomState) } } } @@ -722,6 +763,19 @@ export const createStore = (): Store => { return store } +export const createStore = (): Store => { + const atomStateMap = new WeakMap() + const getAtomState = (atom: Atom) => { + let atomState = atomStateMap.get(atom) as AtomState | undefined + if (!atomState) { + atomState = { d: new Map(), p: new Set(), n: 0 } + atomStateMap.set(atom, atomState) + } + return atomState + } + return buildStore(getAtomState) +} + let defaultStore: Store | undefined export const getDefaultStore = (): Store => { diff --git a/tests/vanilla/unstable_derive.test.tsx b/tests/vanilla/unstable_derive.test.tsx new file mode 100644 index 0000000000..e7bfa468e6 --- /dev/null +++ b/tests/vanilla/unstable_derive.test.tsx @@ -0,0 +1,304 @@ +import { describe, expect, it, vi } from 'vitest' +import { atom, createStore } from 'jotai/vanilla' +import type { Atom } from 'jotai/vanilla' + +describe('unstable_derive for scoping atoms', () => { + /** + * a + * S1[a]: a1 + */ + it('primitive atom', async () => { + const a = atom('a') + a.onMount = (setSelf) => setSelf((v) => v + ':mounted') + const scopedAtoms = new Set>([a]) + + const store = createStore() + const derivedStore = store.unstable_derive((getAtomState) => { + const scopedAtomStateMap = new WeakMap() + return [ + (atom, originAtomState) => { + if (scopedAtoms.has(atom)) { + let atomState = scopedAtomStateMap.get(atom) + if (!atomState) { + atomState = { d: new Map(), p: new Set(), n: 0 } + scopedAtomStateMap.set(atom, atomState) + } + return atomState + } + return getAtomState(atom, originAtomState) + }, + ] + }) + + expect(store.get(a)).toBe('a') + expect(derivedStore.get(a)).toBe('a') + + derivedStore.sub(a, vi.fn()) + await new Promise((resolve) => setTimeout(resolve)) + expect(store.get(a)).toBe('a') + expect(derivedStore.get(a)).toBe('a:mounted') + + derivedStore.set(a, (v) => v + ':updated') + await new Promise((resolve) => setTimeout(resolve)) + expect(store.get(a)).toBe('a') + expect(derivedStore.get(a)).toBe('a:mounted:updated') + }) + + /** + * a, b, c(a + b) + * S1[a]: a1, b0, c0(a1 + b0) + */ + it('derived atom (scoping primitive)', async () => { + const a = atom('a') + const b = atom('b') + const c = atom((get) => get(a) + get(b)) + const scopedAtoms = new Set>([a]) + + const store = createStore() + const derivedStore = store.unstable_derive((getAtomState) => { + const scopedAtomStateMap = new WeakMap() + return [ + (atom, originAtomState) => { + if (scopedAtoms.has(atom)) { + let atomState = scopedAtomStateMap.get(atom) + if (!atomState) { + atomState = { d: new Map(), p: new Set(), n: 0 } + scopedAtomStateMap.set(atom, atomState) + } + return atomState + } + return getAtomState(atom, originAtomState) + }, + ] + }) + + expect(store.get(c)).toBe('ab') + expect(derivedStore.get(c)).toBe('ab') + + derivedStore.set(a, 'a2') + await new Promise((resolve) => setTimeout(resolve)) + expect(store.get(c)).toBe('ab') + expect(derivedStore.get(c)).toBe('a2b') + }) + + /** + * a, b(a) + * S1[b]: a0, b1(a1) + */ + it('derived atom (scoping derived)', async () => { + const a = atom('a') + const b = atom( + (get) => get(a), + (_get, set, v: string) => { + set(a, v) + }, + ) + const scopedAtoms = new Set>([b]) + + const store = createStore() + const derivedStore = store.unstable_derive((getAtomState) => { + const scopedAtomStateMap = new WeakMap() + const scopedAtomStateSet = new WeakSet() + return [ + (atom, originAtomState) => { + if ( + scopedAtomStateSet.has(originAtomState as never) || + scopedAtoms.has(atom) + ) { + let atomState = scopedAtomStateMap.get(atom) + if (!atomState) { + atomState = { d: new Map(), p: new Set(), n: 0 } + scopedAtomStateMap.set(atom, atomState) + scopedAtomStateSet.add(atomState) + } + return atomState + } + return getAtomState(atom, originAtomState) + }, + ] + }) + + expect(store.get(a)).toBe('a') + expect(store.get(b)).toBe('a') + expect(derivedStore.get(a)).toBe('a') + expect(derivedStore.get(b)).toBe('a') + + store.set(a, 'a2') + await new Promise((resolve) => setTimeout(resolve)) + expect(store.get(a)).toBe('a2') + expect(store.get(b)).toBe('a2') + expect(derivedStore.get(a)).toBe('a2') + expect(derivedStore.get(b)).toBe('a') + + store.set(b, 'a3') + await new Promise((resolve) => setTimeout(resolve)) + expect(store.get(a)).toBe('a3') + expect(store.get(b)).toBe('a3') + expect(derivedStore.get(a)).toBe('a3') + expect(derivedStore.get(b)).toBe('a') + + derivedStore.set(a, 'a4') + await new Promise((resolve) => setTimeout(resolve)) + expect(store.get(a)).toBe('a4') + expect(store.get(b)).toBe('a4') + expect(derivedStore.get(a)).toBe('a4') + expect(derivedStore.get(b)).toBe('a') + + derivedStore.set(b, 'a5') + await new Promise((resolve) => setTimeout(resolve)) + expect(store.get(a)).toBe('a4') + expect(store.get(b)).toBe('a4') + expect(derivedStore.get(a)).toBe('a4') + expect(derivedStore.get(b)).toBe('a5') + }) + + /** + * a, b, c(a), d(c), e(d + b) + * S1[d]: a0, b0, c0(a0), d1(c1(a1)), e0(d1(c1(a1)) + b0) + */ + it('derived atom (scoping derived chain)', async () => { + const a = atom('a') + const b = atom('b') + const c = atom( + (get) => get(a), + (_get, set, v: string) => set(a, v), + ) + const d = atom( + (get) => get(c), + (_get, set, v: string) => set(c, v), + ) + const e = atom( + (get) => get(d) + get(b), + (_get, set, av: string, bv: string) => { + set(d, av) + set(b, bv) + }, + ) + const scopedAtoms = new Set>([d]) + + function makeStores() { + const baseStore = createStore() + const deriStore = baseStore.unstable_derive((getAtomState) => { + const scopedAtomStateMap = new WeakMap() + const scopedAtomStateSet = new WeakSet() + return [ + (atom, originAtomState) => { + if ( + scopedAtomStateSet.has(originAtomState as never) || + scopedAtoms.has(atom) + ) { + let atomState = scopedAtomStateMap.get(atom) + if (!atomState) { + atomState = { d: new Map(), p: new Set(), n: 0 } + scopedAtomStateMap.set(atom, atomState) + scopedAtomStateSet.add(atomState) + } + return atomState + } + return getAtomState(atom, originAtomState) + }, + ] + }) + expect(getAtoms(baseStore)).toEqual(['a', 'b', 'a', 'a', 'ab']) + expect(getAtoms(deriStore)).toEqual(['a', 'b', 'a', 'a', 'ab']) + return { baseStore, deriStore } + } + type Store = ReturnType + function getAtoms(store: Store) { + return [ + store.get(a), + store.get(b), + store.get(c), + store.get(d), + store.get(e), + ] + } + + /** + * base[d]: a0, b0, c0(a0), d0(c0(a0)), e0(d0(c0(a0)) + b0) + * deri[d]: a0, b0, c0(a0), d1(c1(a1)), e0(d1(c1(a1)) + b0) + */ + { + // UPDATE a0 + // NOCHGE b0 and a1 + const { baseStore, deriStore } = makeStores() + baseStore.set(a, '*') + expect(getAtoms(baseStore)).toEqual(['*', 'b', '*', '*', '*b']) + expect(getAtoms(deriStore)).toEqual(['*', 'b', '*', 'a', 'ab']) + } + { + // UPDATE b0 + // NOCHGE a0 and a1 + const { baseStore, deriStore } = makeStores() + baseStore.set(b, '*') + expect(getAtoms(baseStore)).toEqual(['a', '*', 'a', 'a', 'a*']) + expect(getAtoms(deriStore)).toEqual(['a', '*', 'a', 'a', 'a*']) + } + { + // UPDATE c0, c0 -> a0 + // NOCHGE b0 and a1 + const { baseStore, deriStore } = makeStores() + baseStore.set(c, '*') + expect(getAtoms(baseStore)).toEqual(['*', 'b', '*', '*', '*b']) + expect(getAtoms(deriStore)).toEqual(['*', 'b', '*', 'a', 'ab']) + } + { + // UPDATE d0, d0 -> c0 -> a0 + // NOCHGE b0 and a1 + const { baseStore, deriStore } = makeStores() + baseStore.set(d, '*') + expect(getAtoms(baseStore)).toEqual(['*', 'b', '*', '*', '*b']) + expect(getAtoms(deriStore)).toEqual(['*', 'b', '*', 'a', 'ab']) + } + { + // UPDATE e0, e0 -> d0 -> c0 -> a0 + // └--------------> b0 + // NOCHGE a1 + const { baseStore, deriStore } = makeStores() + baseStore.set(e, '*', '*') + expect(getAtoms(baseStore)).toEqual(['*', '*', '*', '*', '**']) + expect(getAtoms(deriStore)).toEqual(['*', '*', '*', 'a', 'a*']) + } + { + // UPDATE a0 + // NOCHGE b0 and a1 + const { baseStore, deriStore } = makeStores() + deriStore.set(a, '*') + expect(getAtoms(baseStore)).toEqual(['*', 'b', '*', '*', '*b']) + expect(getAtoms(deriStore)).toEqual(['*', 'b', '*', 'a', 'ab']) + } + { + // UPDATE b0 + // NOCHGE a0 and a1 + const { baseStore, deriStore } = makeStores() + deriStore.set(b, '*') + expect(getAtoms(baseStore)).toEqual(['a', '*', 'a', 'a', 'a*']) + expect(getAtoms(deriStore)).toEqual(['a', '*', 'a', 'a', 'a*']) + } + { + // UPDATE c0, c0 -> a0 + // NOCHGE b0 and a1 + const { baseStore, deriStore } = makeStores() + deriStore.set(c, '*') + expect(getAtoms(baseStore)).toEqual(['*', 'b', '*', '*', '*b']) + expect(getAtoms(deriStore)).toEqual(['*', 'b', '*', 'a', 'ab']) + } + { + // UPDATE d1, d1 -> c1 -> a1 + // NOCHGE b0 and a0 + const { baseStore, deriStore } = makeStores() + deriStore.set(d, '*') + expect(getAtoms(baseStore)).toEqual(['a', 'b', 'a', 'a', 'ab']) + expect(getAtoms(deriStore)).toEqual(['a', 'b', 'a', '*', '*b']) + } + { + // UPDATE e0, e0 -> d1 -> c1 -> a1 + // └--------------> b0 + // NOCHGE a0 + const { baseStore, deriStore } = makeStores() + deriStore.set(e, '*', '*') + expect(getAtoms(baseStore)).toEqual(['a', '*', 'a', 'a', 'a*']) + expect(getAtoms(deriStore)).toEqual(['a', '*', 'a', '*', '**']) + } + }) +})