大白话讲nnvm
wptr33 2025-01-21 21:57 14 浏览
之前工作经验中,在某大厂,开发过机器学习框架,在和业务同学的合作下,取得还可以的成绩,但是一直觉得缺少了什么,最近在刷ai-system相关的公开课,才明白计算图的重要性,以往觉得不能理解的东西,现在突然都l理解了,工程能力可能真的要开始成为必备能力了,平时工作闲时加上周末时间,研究了nnvm这块的代码,和大家分享,有纰漏的地方,欢迎大家指出,谢谢。
NNVM是啥
还记得一到两年前突然,陈天奇突然推了个tinyflow,号称2000行的TensorFlow吗?(nnvm就是那个时候出来的,通常会带的title是“深度学习编译器”,由于本身背景原因,我其实不太理解深度学习编译器是个什么概念,因此经常看到相关的文章,尤其是pr文章时,经常云里雾里,这里呼吁下做pr的媒体同学,能在介绍知识时,不要高大上,介绍实质)。抛开其他一切的东西不谈,针对tinyflow,我认为(个人意见)主要就包括两个部分:一个是所有计算逻辑的计算(Op,如cnn、lstm、fc等等)这部分肯定不会包括在2000行里,是引入的torch lua的实现代码;另外一个就是nnvm,它的工作就是构建一个graph,这个graph就是现在深度学习相关框架都不能不提的计算图;计算图,顾名思义:就是将一个模型用有向无环图来表示,
那么,既然,模型可以用有向无环图表示,那么他的计算过程是否能够也由图的操作表示呢?当然是可以的,NNVM完成的就是这部分的意图。
图如何表示计算
计算包括前向计算和后向传播,前向计算的,把graph按post order DFS遍历一次, 就可以进行一次前向计算,很简单,比较复杂的就是后向传播,这里就要提下auto diff(这里要向弄清楚,可以看下auto diff,auto diff反向构造一个graiden的graph,去反向计算gradient, 比bp更有优势):
如上图,会根据model的graph,去构造gradient graph,这样反向的梯度计算同样可以表示为一个graph,可以通过post order DFS,就可以计算gradient了,具体实现是通过构造一个名为gradient的pass function来实现。
计算图的抽象(graph.h/cc, node.h/cc, op.h/cc)
NNVM在构造graph的代码,就在graph.h/cc里, 抽象出node, 而node代表一个operation,如matmul、add等等,op在nnvm表示所有和graph本身计算逻辑的一个数据结构,是计算图得以完成forward、gradient计算的的基础。op定义了基本的attr,其实我觉得这里理解为接口会更合适,如op_attr_types.h中声明的各种函数,相当于整个graph中某个node根据你的op空出来很多槽,当选择某个node时,会填入对应的实现逻辑:
如conv2d的op注册,会将conv2d的计算和梯度计算、以及各种接口函数的实现注册到conv2d中,conv2d和其他的op不一样,(比如infer shape, infer type,这里记起来纪念前刚开始接触mxnet的时候为mxnet写inception-resnetv2的model时,当时惊讶于有infer shape的接口,而tf是没有的,现在看这些算是联想起来了,infershape其实也是类似于auto diff的逻辑,每一个op会注册一个infer shape的逻辑,当对graph作一个infer shape的操作时,其实是对整个graph 作一个poster dfs,然后每一个node 去infer shape):
这里FGradient是通过新建一个conv2d_grad的op,conv2_grad的注册,注意TIsBackward为true:
稍微简单的一个op matmul的注册代码如下:
NNVM_REGISTER_OP(matmul)
.describe(R"doc(Matrix multiplication of two arrays.
``dot``'s behavior depends on the input array dimensions:
- 1-D arrays: inner product of vectors
- 2-D arrays: matrix multiplication
- N-D arrays: a sum product over the last axis of the first input and the first
axis of the second input
For example, given 3-D ``x`` with shape `(n,m,k)` and ``y`` with shape `(k,r,s)`, the
result array will have shape `(n,m,r,s)`. It is computed by::
dot(x,y) = sum(x[i,j,:]*y[:,a,b])
)doc" NNVM_ADD_FILELINE)
.set_support_level(1)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr_parser(ParamParser<MatMulParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<MatMulParam>)
.add_arguments(MatMulParam::__FIELDS__())
.add_argument("lhs", "NDArray-or-Symbol", "The first input")
.add_argument("rhs", "NDArray-or-Symbol", "The second input")
.set_attr<FInferShape>("FInferShape", DotShape)
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", DotCorrectLayout)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const MatMulParam& param = nnvm::get<MatMulParam>(attrs.parsed);
return Array<Tensor>{
topi::matmul(inputs[0], inputs[1], param.transpose_a, param.transpose_b)
};
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
// z = x dot y
// xshape (n,m,k), yshape (k,r,s)
const MatMulParam& param = nnvm::get<MatMulParam>(n->attrs.parsed);
bool Ta = param.transpose_a;
bool Tb = param.transpose_b;
// Ta = false, Tb = false
// grad_x = grad_z dot y.T
// grad_y = x.T dot grad_z
if (!Ta && !Tb) {
return std::vector<NodeEntry>{
MakeNode("matmul", n->attrs.name + "_grad_0",
{ograds[0], n->inputs[1]},
{{"transpose_a", "false"},
{"transpose_b", "true"}}),
MakeNode("matmul", n->attrs.name + "_grad_1",
{n->inputs[0], ograds[0]},
{{"transpose_a", "true"},
{"transpose_b", "false"}})
};
} else if (Ta && !Tb) {
// Ta = true, Tb = false
// grad_x = y dot grad_z.T
// grad_y = x dot grad_z
return std::vector<NodeEntry>{
MakeNode("matmul", n->attrs.name + "_grad_0",
{n->inputs[1], ograds[0]},
{{"transpose_a", "false"},
{"transpose_b", "true"}}),
MakeNode("matmul", n->attrs.name + "_grad_1",
{n->inputs[0], ograds[0]},
{{"transpose_a", "false"},
{"transpose_b", "false"}})
};
} else if (!Ta && Tb) {
// Ta = false, Tb = true
// grad_x = grad_z dot y
// grad_y = grad_z.T dot x
return std::vector<NodeEntry>{
MakeNode("matmul", n->attrs.name + "_grad_0",
{ograds[0], n->inputs[1]},
{{"transpose_a", "false"},
{"transpose_b", "false"}}),
MakeNode("matmul", n->attrs.name + "_grad_1",
{ograds[0], n->inputs[0]},
{{"transpose_a", "true"},
{"transpose_b", "false"}})
};
} else {
// Ta = true, Tb = true
// grad_x = y.T dot grad_z.T
// grad_y = grad_z.T dot x.T
return std::vector<NodeEntry>{
MakeNode("matmul", n->attrs.name + "_grad_0",
{n->inputs[1], ograds[0]},
{{"transpose_a", "true"},
{"transpose_b", "true"}}),
MakeNode("matmul", n->attrs.name + "_grad_1",
{ograds[0], n->inputs[0]},
{{"transpose_a", "true"},
{"transpose_b", "true"}})
};
}
});
由于nnvm后面是集成到tvm里面,可能部分代码并不是nnvm里面的实现,op 比如conv2d数学的计算,大家可以去tvm的代码去看,应该就是更高一层封装用来去注册到nnvm的graph中。
计算图的操作(pass.h/cc, pass_functions.h)
nnvm构造了一个所谓的pass(我也不知道怎么翻译), 主要就是负责整个graph的操作,每一个操作会定义一个pass_function,举个例子SaveJSON, 将graph保存为json的格式的,会定义SaveJson实现:
Graph SaveJSON(Graph src) {
std::shared_ptr<Symbol> src_symbol = std::make_shared<Symbol>();
src_symbol->outputs = src.outputs;
JSONGraph jgraph;
Symbol2JSONGraph(src_symbol, &jgraph);
jgraph.attrs = src.attrs;
std::ostringstream os;
dmlc::JSONWriter writer(&os);
jgraph.Save(&writer);
Graph ret;
ret.attrs["json"] = std::make_shared<any>(os.str());
return ret;
}
void Symbol2JSONGraph(std::shared_ptr<Symbol> src, JSONGraph *jgraph) {
std::unordered_map<Node*, uint32_t> node2index;
jgraph->node_row_ptr.push_back(0);
DFSVisit(src->outputs, [&node2index, jgraph](const NodePtr& n) {
uint32_t nid = static_cast<uint32_t>(jgraph->nodes.size());
node2index[n.get()] = nid;
if (n->is_variable()) {
jgraph->arg_nodes.push_back(nid);
}
JSONNode jnode;
jnode.node = n;
jnode.inputs.reserve(n->inputs.size());
for (const NodeEntry& e : n->inputs) {
jnode.inputs.emplace_back(node2index.at(e.node.get()), e.index, e.version);
}
for (const NodePtr& c : n->control_deps) {
jnode.control_deps.push_back(node2index.at(c.get()));
}
jgraph->node_row_ptr.push_back(jgraph->node_row_ptr.back() + n->num_outputs());
jgraph->nodes.emplace_back(std::move(jnode));
});
for (const NodeEntry& e : src->outputs) {
jgraph->heads.emplace_back(node2index.at(e.node.get()), e.index, e.version);
}
// recursively construct subgraphs
for (JSONNode &jnode : jgraph->nodes) {
// construct jnode's subgraphs
const std::vector<std::shared_ptr<Symbol>> &subgraphs = jnode.node->attrs.subgraphs;
std::vector<JSONGraph> &jsubgraphs = jnode.subgraphs;
jsubgraphs.resize(subgraphs.size());
for (uint32_t i = 0; i < subgraphs.size(); ++i) {
Symbol2JSONGraph(subgraphs[i], &jsubgraphs[i]);
}
}
}
然后注册pass函数:
NNVM_REGISTER_PASS(SaveJSON)
.describe("Return a new empty Graph. Save graph to ret.attrs[\"json\"]")
.set_body(SaveJSON)
.set_change_graph(true)
.provide_graph_attr("json");
如何去使用:
inline std::string SaveJSON(Graph graph) {
Graph ret = ApplyPass(std::move(graph), "SaveJSON");
return ret.GetAttr<std::string>("json");
}
这样就可以对整个graph作一个SaveJson的操作,如何实现一个PassFunctionReg在pass.h中,感兴趣的可以仔细阅读下。
SaveJson是一个简单的pass function,gradient、InferShape、InferType也是类似的逻辑去理解,再gradient的代码,和SaveJson类似,都会去遍历graph,gradient会重构一个gradient computation graph来返回,如下:
Graph Gradient(Graph src) {
using nnvm::FGradient;
using MirrorFun = std::function<int (const Node& node)>;
using AttrHintFun = std::function<NodeEntry (const NodeEntry& src, const NodeEntry &like)>;
CHECK_NE(src.attrs.count("grad_ys"), 0U)
<< "Gradient require grad_ys to be presented.";
CHECK_NE(src.attrs.count("grad_ys_out_grad"), 0U)
<< "Gradient require grad_ys_out_grad to be presented.";
CHECK_NE(src.attrs.count("grad_xs"), 0U)
<< "Gradient require grad_xs to be presented.";
const std::vector<NodeEntry>& ys =
src.GetAttr<std::vector<NodeEntry> >("grad_ys");
const std::vector<NodeEntry>& ys_out_grad =
src.GetAttr<std::vector<NodeEntry> >("grad_ys_out_grad");
const std::vector<NodeEntry>& xs =
src.GetAttr<std::vector<NodeEntry> >("grad_xs");
using AggFun = std::function<NodeEntry (std::vector<NodeEntry>&& inputs)>;
AggFun agg_fun = DefaultAggregateGradient;
if (src.attrs.count("grad_aggregate_fun") != 0) {
agg_fun = src.GetAttr<AggFun>("grad_aggregate_fun");
}
MirrorFun mirror_fun = nullptr;
if (src.attrs.count("grad_mirror_fun") != 0) {
mirror_fun = src.GetAttr<MirrorFun>("grad_mirror_fun");
}
AttrHintFun attr_hint_fun = nullptr;
if (src.attrs.count("attr_hint_fun") != 0) {
attr_hint_fun = src.GetAttr<AttrHintFun>("attr_hint_fun");
}
std::vector<const Op*> zero_ops;
if (src.attrs.count("zero_ops") != 0) {
zero_ops = src.GetAttr<std::vector<const Op*> >("zero_ops");
}
const Op* copy_op = (src.attrs.count("copy_op") != 0) ?
Op::Get(src.GetAttr<std::string>("copy_op")) :
nullptr;
// topo sort
std::vector<NodePtr> topo_order;
std::unordered_map<Node*, std::vector<GradEntry> > output_grads;
DFSVisit(ys, [&](const NodePtr& node) {
if (output_grads.count(node.get()) == 0) {
output_grads[node.get()].resize(node->num_outputs());
}
topo_order.push_back(node);
});
CHECK_EQ(ys.size(), ys_out_grad.size());
for (size_t i = 0; i < ys.size(); ++i) {
NodeEntry ograd = ys_out_grad[i];
output_grads[ys[i].node.get()][ys[i].index].grads = { ograd };
}
// Check that all xs are reachable from ys
for (size_t i = 0; i < xs.size(); ++i) {
CHECK(output_grads.find(xs[i].node.get()) != output_grads.end())
<< "Cannot differentiate with respect to the " << i+1 << "-th variable "
<< "because it is unreachable from the outputs.";
}
// construct mirror as memory reduction strategy if needed
std::unordered_map<Node*, NodePtr> mirror_map;
if (mirror_fun != nullptr) {
for (const NodePtr& node_ptr : topo_order) {
if (mirror_fun(*node_ptr)) {
NodePtr new_node = Node::Create();
*new_node = *node_ptr;
new_node->attrs.name += "_mirror";
for (auto& e : new_node->inputs) {
e.node = mirror_map.at(e.node.get());
}
for (auto& n : new_node->control_deps) {
n = mirror_map.at(n.get());
}
mirror_map[node_ptr.get()] = std::move(new_node);
} else {
mirror_map[node_ptr.get()] = node_ptr;
}
}
}
// traverse backward
static auto& grad_fun_map = Op::GetAttr<FGradient>("FGradient");
static auto& finfer_shape = Op::GetAttr<FInferShape>("FInferShape");
std::vector<NodeEntry> out_agg_grads;
for (auto rit = topo_order.rbegin(); rit != topo_order.rend(); ++rit) {
const NodePtr& ptr = *rit;
if (ptr->is_variable()) continue;
out_agg_grads.clear();
auto& out_grad_vec = output_grads.at(ptr.get());
for (uint32_t i = 0; i < out_grad_vec.size(); ++i) {
GradEntry& e = out_grad_vec[i];
e.sum = agg_fun(std::move(e.grads));
if (e.need_attr_hint && attr_hint_fun != nullptr) {
e.sum = attr_hint_fun(e.sum, NodeEntry{ptr, 0, i});
}
out_agg_grads.push_back(e.sum);
}
if ((*rit)->inputs.size() != 0) {
NodePtr fwd_node = (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()));
std::vector<NodeEntry> input_grads;
// Check for FGradient
if (grad_fun_map.contains(ptr->op())) {
input_grads = grad_fun_map[ptr->op()](fwd_node, out_agg_grads);
CHECK_EQ((*rit)->inputs.size(), input_grads.size())
<< "Gradient function not returning enough gradient";
} else if (CheckGradAllZero(out_agg_grads, zero_ops)) {
for (size_t i = 0; i < fwd_node->num_inputs(); ++i) {
std::ostringstream os;
if (1 == fwd_node->num_inputs()) {
os << fwd_node->attrs.name << "_backward";
} else {
os << fwd_node->attrs.name << "_in" << i << "_backward";
}
auto p = Node::Create();
p->attrs.op = zero_ops[0];
p->attrs.name = os.str();
p->inputs.push_back(fwd_node->inputs[i]);
p->control_deps.emplace_back(fwd_node);
if (p->op()->attr_parser != nullptr) {
p->op()->attr_parser(&(p->attrs));
}
input_grads.emplace_back(p, 0, 0);
}
} else {
LOG(FATAL) << "Operator " << fwd_node->op()->name << " is non-differentiable "
<< "because it didn't register FGradient attribute.";
}
for (const auto& nodeEntry : input_grads)
CHECK(nodeEntry.node);
auto git = input_grads.begin();
CHECK((*rit)->inputs.size() <= input_grads.size());
for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) {
auto& output_grad_entry = output_grads[it->node.get()][it->index];
// if any of the backward op can do shape inference, the hint is not necessary.
if (finfer_shape.contains(git->node->op())) {
output_grad_entry.need_attr_hint = false;
}
output_grad_entry.grads.emplace_back(std::move(*git));
}
}
}
// take out the xs' grads
Graph ret;
ret.outputs.resize(xs.size());
NodeEntryMap<std::pair<size_t, size_t> > unique_grads;
size_t counter = 0;
for (const NodeEntry& e : xs) {
GradEntry& entry = output_grads[e.node.get()][e.index];
// aggregate sum if there haven't been
if (entry.sum.node.get() == nullptr) {
entry.sum = agg_fun(std::move(entry.grads));
if (entry.need_attr_hint && attr_hint_fun != nullptr) {
entry.sum = attr_hint_fun(entry.sum, e);
}
}
if (copy_op != nullptr) {
auto kv = unique_grads.find(entry.sum);
if (kv == unique_grads.end()) {
unique_grads.emplace(std::move(entry.sum), std::make_pair(1, counter));
} else {
NodePtr copy_node = Node::Create();
std::ostringstream os;
os << entry.sum.node->attrs.name << "_" << kv->second.first << "_copy";
kv->second.first++;
copy_node->attrs.op = copy_op;
copy_node->attrs.name = os.str();
copy_node->inputs.emplace_back(entry.sum);
if (copy_node->attrs.op->attr_parser != nullptr) {
copy_node->attrs.op->attr_parser(&(copy_node->attrs));
}
unique_grads.emplace(NodeEntry{std::move(copy_node), 0, 0}, std::make_pair(1, counter));
}
} else {
ret.outputs[counter] = entry.sum;
}
++counter;
}
if (copy_op != nullptr) {
for (const auto& kv : unique_grads) {
ret.outputs[kv.second.second] = kv.first;
}
}
return ret;
}
// register pass
NNVM_REGISTER_PASS(Gradient)
.describe("Return a gradient graph of src.attrs[\"ys\"] wrt src.attrs[\"xs\"]")
.set_body(Gradient)
.set_change_graph(true)
.depend_graph_attr("grad_ys")
.depend_graph_attr("grad_xs")
.depend_graph_attr("grad_ys_out_grad");
后话
nnvm现在集合在TVM,要更加深入的理解,现在看来必须得去硬抠tvm的代码了,现在对nnvm算是理解了,但理解和做出来差了十万八千里,后续想自己写个类似nnvm的东西,而转念一想,chentianqi这些牛人开源出这些东西,就是想减少我们的工作成本,直接参与进去可能是另一种学习一种“做出来”的方法,anyway,计算图这块还是得深入下去,接下来会自己写一些测试的程序去更深入学习nnvm,大家如果有对ai system 感兴趣的小伙伴,可以私信我,我们看看能不能一起学习学习,这里面确实涉及太多东西,一个人,尤其我这种算法出身的小伙伴,真的比较难
也欢迎大家关注我的同名微信公众号 小石头的码疯窝(xiaoshitou_ml_tech),或者通过公众号加我的个人微信进行讨论
相关推荐
- 【推荐】一款开源免费、美观实用的后台管理系统模版
-
如果您对源码&技术感兴趣,请点赞+收藏+转发+关注,大家的支持是我分享最大的动力!!!项目介绍...
- Android架构组件-App架构指南,你还不收藏嘛
-
本指南适用于那些已经拥有开发Android应用基础知识的开发人员,现在想了解能够开发出更加健壮、优质的应用程序架构。首先需要说明的是:AndroidArchitectureComponents翻...
- 高德地图经纬度坐标批量拾取(高德地图批量查询经纬度)
-
使用方法在桌面上新建一个index.txt文件,把下面的代码复制进去保存,再把文件名改成index.html保存,双击运行打开即可...
- flutter系列之:UI layout简介(flutter ui设计)
-
简介对于一个前端框架来说,除了各个组件之外,最重要的就是将这些组件进行连接的布局了。布局的英文名叫做layout,就是用来描述如何将组件进行摆放的一个约束。...
- Android开发基础入门(一):UI与基础控件
-
Android基础入门前言:...
- iOS的布局体系-流式布局MyFlowLayout
-
iOS布局体系的概览在我的CSDN博客中的几篇文章分别介绍MyLayout布局体系中的视图从一个方向依次排列的线性布局(MyLinearLayout)、视图层叠且停靠于父布局视图某个位置的框架布局(M...
- TDesign企业级开源设计系统越发成熟稳定,支持 Vue3 / 小程序
-
TDesing发展越来越好了,出了好几套组件库,很成熟稳定了,新项目完全可以考虑使用。...
- WinForm实现窗体自适应缩放(winform窗口缩放)
-
众所周知,...
- winform项目——仿QQ即时通讯程序03:搭建登录界面
-
上两篇文章已经对CIM仿QQ即时通讯项目进行了需求分析和数据库设计。winform项目——仿QQ即时通讯程序01:原理及项目分析...
- App自动化测试|原生app元素定位方法
-
元素定位方法介绍及应用Appium方法定位原生app元素...
- 61.C# TableLayoutPanel控件(c# tabcontrol)
-
摘要TableLayoutPanel在网格中排列内容,提供类似于HTML元素的功能。TableLayoutPanel控件允许你将控件放在网格布局中,而无需精确指定每个控件的位置。其单元格...
- 12个python数据处理常用内置函数(python 的内置函数)
-
在python数据分析中,经常需要对字符串进行各种处理,例如拼接字符串、检索字符串等。下面我将对python中常用的内置字符串操作函数进行介绍。1.计算字符串的长度-len()函数str1='我爱py...
- 如何用Python程序将几十个PDF文件合并成一个PDF?其实只要这四步
-
假定你有一个很无聊的任务,需要将几十个PDF文件合并成一个PDF文件。每一个文件都有一个封面作为第一页,但你不希望合并后的文件中重复出现这些封面。即使有许多免费的程序可以合并PDF,很多也只是简单的将...
- Python入门知识点总结,Python三大数据类型、数据结构、控制流
-
Python基础的重要性不言而喻,是每一个入门Python学习者所必备的知识点,作为Python入门,这部分知识点显得很庞杂,内容分支很多,大部分同学在刚刚学习时一头雾水。...
- 一周热门
-
-
C# 13 和 .NET 9 全知道 :13 使用 ASP.NET Core 构建网站 (1)
-
因果推断Matching方式实现代码 因果推断模型
-
面试官:git pull是哪两个指令的组合?
-
git pull命令使用实例 git pull--rebase
-
git 执行pull错误如何撤销 git pull fail
-
git pull 和git fetch 命令分别有什么作用?二者有什么区别?
-
git fetch 和git pull 的异同 git中fetch和pull的区别
-
git pull 之后本地代码被覆盖 解决方案
-
还可以这样玩?Git基本原理及各种骚操作,涨知识了
-
git命令之pull git.pull
-
- 最近发表
- 标签列表
-
- git pull (33)
- git fetch (35)
- mysql insert (35)
- mysql distinct (37)
- concat_ws (36)
- java continue (36)
- jenkins官网 (37)
- mysql 子查询 (37)
- python元组 (33)
- mysql max (33)
- vba instr (33)
- mybatis 分页 (35)
- vba split (37)
- redis watch (34)
- python list sort (37)
- nvarchar2 (34)
- mysql not null (36)
- hmset (35)
- python telnet (35)
- python readlines() 方法 (36)
- munmap (35)
- docker network create (35)
- redis 集合 (37)
- python sftp (37)
- setpriority (34)