## 存在的问题

`split`
。我们分别对多个输出后续进行不同的运算，在反向传播时，它们的梯度应该不同。

## 重写Tensor#backward

```def backward(self, grad: "Tensor" = None, retain_grad=False, create_graph=False) -> None:
'''
实现Tensor的反向传播
Args:
create_graph: 是否整个计算梯度的过程也需要保留到计算图中，即double_backprop
Returns:
'''
if not Config.backprop:
return
if self.shape == ():
else:
raise RuntimeError("grad must be specified for non scalar")
else:
funcs = []  # 候选函数堆
seen_set = set()
if f not in seen_set:
# heapq是小顶堆，为了实现大顶堆的效果，需要加一个负号
heapq.heappush(funcs, (-f.generation, len(seen_set), f))
while funcs:
_, _, f = heapq.heappop(funcs)
# 获取输出对应的梯度，解决多个输出梯度不一致的问题
gys = [output().grad.data for output in f.outputs]  # output 是 weakref
with using_config('backprop', create_graph):
with OpWrapper(f.__class__.__name__, gys, backward=True):
gxs = f.backward(*gys)
if not isinstance(gxs, tuple):
gxs = (gxs,)
for x, gx in zip(f.inputs, gxs):
if x.requires_grad and gx is not None:
assert x.shape == gx.shape, f"grad shape must match tensor shape in {
f!r}, {
gx.shape!r} != {
x.shape!r}"
gx = Tensor(gx, device=self.device, dtype=self.dtype)
else:
if x.creator is not None:
for y in f.outputs:

，产生该输出的运算(function)。

## 重写Function

```class Function:
def __init__(self) -> None:
# 保存需要在backward()中使用的Tensor或其他对象(如Shape)
self.saved_tensors = []
def save_for_backward(self, *x: Any) -> None:
self.saved_tensors.extend(x)
def forward(self, *args: Any, **kwargs: Any) -> NdArray:
'''前向传播，进行真正运算的地方'''
raise NotImplementedError("You must implement the forward function for custom Function.")
def backward(self, grad: NdArray) -> Any:
'''实现反向传播，计算梯度'''
raise NotImplementedError("You must implement the backward method for your custom Function "
"to use it with backward mode AD.")
def __call__(self, *xs: "Tensor", **kwargs) -> "Tensor":
# [t.data for t in xs]遍历Tensor中的data(NdArray)值，参与实际计算的都是NumPy的数组。
ys = self.forward(*[t.data for t in xs], **kwargs)
if not isinstance(ys, tuple):
ys = (ys,)
if Config.backprop:
self.generation = max([x.generation for x in xs])
for output in outputs:  # 设定每个输出是由此函数得到的
output.set_creator(self)
self.inputs = xs  # 记录输入
self.outputs = [weakref.ref(output) for output in outputs]  # 通过弱引用记录输出
# 返回多个则通过元组
return tuple(outputs) if len(outputs) > 1 else outputs[0]```

## 重写Function子类

`Split`

```class Split(Function):
'''Stack的逆操作'''
def forward(ctx, inputs: NdArray, axis: int) -> NdArray:
xp = get_array_module(inputs)
xs = xp.split(inputs, inputs.shape[axis], axis)
ys = [xp.squeeze(y, axis) for y in xs]  # 去掉维度axis
ctx.save_for_backward(len(ys), axis)
return tuple(ys)
def backward(ctx, grad: NdArray) -> NdArray:
size, axis = ctx.saved_tensors

```class Split(Function):
'''Stack的逆操作'''
def forward(self, inputs: NdArray, axis: int) -> NdArray:
xp = get_array_module(inputs)
xs = xp.split(inputs, inputs.shape[axis], axis)
ys = [xp.squeeze(y, axis) for y in xs]  # 去掉维度axis
self.save_for_backward(xp, axis, ys[0].shape, inputs.dtype)
return tuple(ys)
def backward(self, *grad: List[NdArray]) -> NdArray:
xp, axis, shape, dtype = self.saved_tensors
grads = [Tensor(xp.zeros(shape, dtype)) if g is None else Tensor(g) for g in grad]

`ctx`

```def test_split():
x = np.arange(6).reshape((2, 3)).astype(np.float32)
# x = array([[0., 1., 2.],
#           [3., 4., 5.]], dtype=float32)
my = F.split(mx)
ty = torch.split(tx, 1)
# 这里返回的是元组
assert isinstance(my, tuple)
assert np.allclose(my[0].data, ty[0].data)
(my[0]).sum().backward()
(ty[0]).sum().backward()

```test_split.py::test_split PASSED                                         [100%]
[[1. 1. 1.]
[0. 0. 0.]]
tensor([[1., 1., 1.],
[0., 0., 0.]])```

)，因此反向传播后，也应该只有该元素会产生梯度，如上输出。