Skip to content
This repository has been archived by the owner on Apr 13, 2023. It is now read-only.

Fix React.createContext in SSR #2304

Merged
merged 12 commits into from
Sep 27, 2018
76 changes: 48 additions & 28 deletions src/getDataFromTree.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ export interface Context {

interface PromiseTreeArgument {
rootElement: React.ReactNode;
rootContext?: Context;
rootContext: Context;
rootNewContext: Map<any, any>;
}
interface FetchComponent extends React.Component<any> {
fetchData(): Promise<void>;
Expand All @@ -16,6 +17,7 @@ interface PromiseTreeResult {
promise: Promise<any>;
context: Context;
instance: FetchComponent;
newContext: Map<any, any>;
}

interface PreactElement<P> {
Expand Down Expand Up @@ -49,12 +51,14 @@ export function walkTree(
visitor: (
element: React.ReactNode,
instance: React.Component<any> | null,
newContextMap: Map<any, any>,
context: Context,
childContext?: Context,
) => boolean | void,
newContext: Map<any, any>,
) {
if (Array.isArray(element)) {
element.forEach(item => walkTree(item, context, visitor));
element.forEach(item => walkTree(item, context, visitor, newContext));
return;
}

Expand Down Expand Up @@ -113,14 +117,14 @@ export function walkTree(
childContext = Object.assign({}, context, instance.getChildContext());
}

if (visitor(element, instance, context, childContext) === false) {
if (visitor(element, instance, newContext, context, childContext) === false) {
return;
}

child = instance.render();
} else {
// Just a stateless functional
if (visitor(element, null, context) === false) {
if (visitor(element, null, newContext, context) === false) {
return;
}

Expand All @@ -129,51 +133,55 @@ export function walkTree(

if (child) {
if (Array.isArray(child)) {
child.forEach(item => walkTree(item, childContext, visitor));
child.forEach(item => walkTree(item, childContext, visitor, newContext));
} else {
walkTree(child, childContext, visitor);
walkTree(child, childContext, visitor, newContext);
}
}
} else if ((element.type as any)._context || (element.type as any).Consumer) {
// A React context provider or consumer
if (visitor(element, null, context) === false) {
if (visitor(element, null, newContext, context) === false) {
return;
}

let child;
if ((element.type as any)._context) {
if (!!(element.type as any)._context) {
// A provider - sets the context value before rendering children
((element.type as any)._context as any)._currentValue = element.props.value;
// this needs to clone the map because this value should only apply to children of the provider
newContext = new Map(newContext.entries());
newContext.set(element.type, element.props.value);
child = element.props.children;
} else {
// A consumer
child = element.props.children((element.type as any)._currentValue);
child = element.props.children(
newContext.get((element.type as any).Provider) || (element.type as any)._currentValue,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to be more specific with the check here:

let currentValue = (element.type as any)._currentValue;
if (childContext.subContexts.has((element.type as any).Provider)) {
  currentValue = childContext.subContexts.get((element.type as any).Provider);
}
child = element.props.children(currentValue);

In theory, someone could set the value of context to a falsy value, and that would be the value that should be respected. From the React docs:

The defaultValue argument is only used by a Consumer when it does not have a matching Provider above it in the tree. This can be helpful for testing components in isolation without wrapping them. Note: passing undefined as a Provider value does not cause Consumers to use defaultValue.

);
}

if (child) {
if (Array.isArray(child)) {
child.forEach(item => walkTree(item, context, visitor));
child.forEach(item => walkTree(item, context, visitor, newContext));
} else {
walkTree(child, context, visitor);
walkTree(child, context, visitor, newContext);
}
}
} else {
// A basic string or dom element, just get children
if (visitor(element, null, context) === false) {
if (visitor(element, null, newContext, context) === false) {
return;
}

if (element.props && element.props.children) {
React.Children.forEach(element.props.children, (child: any) => {
if (child) {
walkTree(child, context, visitor);
walkTree(child, context, visitor, newContext);
}
});
}
}
} else if (typeof element === 'string' || typeof element === 'number') {
// Just visit these, they are leaves so we don't keep traversing.
visitor(element, null, context);
visitor(element, null, newContext, context);
}
// TODO: Portals?
}
Expand All @@ -188,37 +196,49 @@ function isPromise<T>(promise: Object): promise is Promise<T> {

function getPromisesFromTree({
rootElement,
rootContext = {},
rootContext,
rootNewContext,
}: PromiseTreeArgument): PromiseTreeResult[] {
const promises: PromiseTreeResult[] = [];

walkTree(rootElement, rootContext, (_, instance, context, childContext) => {
if (instance && hasFetchDataFunction(instance)) {
const promise = instance.fetchData();
if (isPromise<Object>(promise)) {
promises.push({ promise, context: childContext || context, instance });
return false;
walkTree(
rootElement,
rootContext,
(_, instance, newContext, context, childContext) => {
if (instance && hasFetchDataFunction(instance)) {
const promise = instance.fetchData();
if (isPromise<Object>(promise)) {
promises.push({
promise,
context: childContext || context,
instance,
newContext,
});
return false;
}
}
}
});
},
rootNewContext,
);

return promises;
}

function getDataAndErrorsFromTree(
rootElement: React.ReactNode,
rootContext: any = {},
rootContext: Object,
storeError: Function,
rootNewContext: Map<any, any> = new Map(),
): Promise<any> {
const promises = getPromisesFromTree({ rootElement, rootContext });
const promises = getPromisesFromTree({ rootElement, rootContext, rootNewContext });

if (!promises.length) {
return Promise.resolve();
}

const mappedPromises = promises.map(({ promise, context, instance }) => {
const mappedPromises = promises.map(({ promise, context, instance, newContext }) => {
return promise
.then(_ => getDataAndErrorsFromTree(instance.render(), context, storeError))
.then(_ => getDataAndErrorsFromTree(instance.render(), context, storeError, newContext))
.catch(e => storeError(e));
});

Expand Down
Loading