百度360必应搜狗淘宝本站头条
当前位置:网站首页 > IT技术 > 正文

大白话讲nnvm

wptr33 2025-01-21 21:57 14 浏览

之前工作经验中,在某大厂,开发过机器学习框架,在和业务同学的合作下,取得还可以的成绩,但是一直觉得缺少了什么,最近在刷ai-system相关的公开课,才明白计算图的重要性,以往觉得不能理解的东西,现在突然都l理解了,工程能力可能真的要开始成为必备能力了,平时工作闲时加上周末时间,研究了nnvm这块的代码,和大家分享,有纰漏的地方,欢迎大家指出,谢谢。


NNVM是啥

还记得一到两年前突然,陈天奇突然推了个tinyflow,号称2000行的TensorFlow吗?(nnvm就是那个时候出来的,通常会带的title是“深度学习编译器”,由于本身背景原因,我其实不太理解深度学习编译器是个什么概念,因此经常看到相关的文章,尤其是pr文章时,经常云里雾里,这里呼吁下做pr的媒体同学,能在介绍知识时,不要高大上,介绍实质)。抛开其他一切的东西不谈,针对tinyflow,我认为(个人意见)主要就包括两个部分:一个是所有计算逻辑的计算(Op,如cnn、lstm、fc等等)这部分肯定不会包括在2000行里,是引入的torch lua的实现代码;另外一个就是nnvm,它的工作就是构建一个graph,这个graph就是现在深度学习相关框架都不能不提的计算图;计算图,顾名思义:就是将一个模型用有向无环图来表示,

那么,既然,模型可以用有向无环图表示,那么他的计算过程是否能够也由图的操作表示呢?当然是可以的,NNVM完成的就是这部分的意图。

图如何表示计算

计算包括前向计算和后向传播,前向计算的,把graph按post order DFS遍历一次, 就可以进行一次前向计算,很简单,比较复杂的就是后向传播,这里就要提下auto diff(这里要向弄清楚,可以看下auto diff,auto diff反向构造一个graiden的graph,去反向计算gradient, 比bp更有优势):

如上图,会根据model的graph,去构造gradient graph,这样反向的梯度计算同样可以表示为一个graph,可以通过post order DFS,就可以计算gradient了,具体实现是通过构造一个名为gradient的pass function来实现。

计算图的抽象(graph.h/cc, node.h/cc, op.h/cc)

NNVM在构造graph的代码,就在graph.h/cc里, 抽象出node, 而node代表一个operation,如matmul、add等等,op在nnvm表示所有和graph本身计算逻辑的一个数据结构,是计算图得以完成forward、gradient计算的的基础。op定义了基本的attr,其实我觉得这里理解为接口会更合适,如op_attr_types.h中声明的各种函数,相当于整个graph中某个node根据你的op空出来很多槽,当选择某个node时,会填入对应的实现逻辑:

如conv2d的op注册,会将conv2d的计算和梯度计算、以及各种接口函数的实现注册到conv2d中,conv2d和其他的op不一样,(比如infer shape, infer type,这里记起来纪念前刚开始接触mxnet的时候为mxnet写inception-resnetv2的model时,当时惊讶于有infer shape的接口,而tf是没有的,现在看这些算是联想起来了,infershape其实也是类似于auto diff的逻辑,每一个op会注册一个infer shape的逻辑,当对graph作一个infer shape的操作时,其实是对整个graph 作一个poster dfs,然后每一个node 去infer shape):

这里FGradient是通过新建一个conv2d_grad的op,conv2_grad的注册,注意TIsBackward为true:

稍微简单的一个op matmul的注册代码如下:

NNVM_REGISTER_OP(matmul)
.describe(R"doc(Matrix multiplication of two arrays.

``dot``'s behavior depends on the input array dimensions:

- 1-D arrays: inner product of vectors
- 2-D arrays: matrix multiplication
- N-D arrays: a sum product over the last axis of the first input and the first
  axis of the second input

  For example, given 3-D ``x`` with shape `(n,m,k)` and ``y`` with shape `(k,r,s)`, the
  result array will have shape `(n,m,r,s)`. It is computed by::

    dot(x,y) = sum(x[i,j,:]*y[:,a,b])

)doc" NNVM_ADD_FILELINE)
.set_support_level(1)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr_parser(ParamParser<MatMulParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<MatMulParam>)
.add_arguments(MatMulParam::__FIELDS__())
.add_argument("lhs", "NDArray-or-Symbol", "The first input")
.add_argument("rhs", "NDArray-or-Symbol", "The second input")
.set_attr<FInferShape>("FInferShape", DotShape)
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", DotCorrectLayout)
.set_attr<FTVMCompute>(
  "FTVMCompute", [](const NodeAttrs& attrs,
                    const Array<Tensor>& inputs,
                    const Array<Tensor>& out_info) {
    const MatMulParam& param = nnvm::get<MatMulParam>(attrs.parsed);
    return Array<Tensor>{
      topi::matmul(inputs[0], inputs[1], param.transpose_a, param.transpose_b)
    };
  })
.set_attr<FGradient>(
  "FGradient", [](const NodePtr& n,
                  const std::vector<NodeEntry>& ograds) {
    // z = x dot y
    // xshape (n,m,k), yshape (k,r,s)
    const MatMulParam& param = nnvm::get<MatMulParam>(n->attrs.parsed);
    bool Ta = param.transpose_a;
    bool Tb = param.transpose_b;
    // Ta = false, Tb = false
    // grad_x = grad_z dot y.T
    // grad_y = x.T dot grad_z
    if (!Ta && !Tb) {
      return std::vector<NodeEntry>{
        MakeNode("matmul", n->attrs.name + "_grad_0",
                 {ograds[0], n->inputs[1]},
                 {{"transpose_a", "false"},
                  {"transpose_b", "true"}}),
        MakeNode("matmul", n->attrs.name + "_grad_1",
                 {n->inputs[0], ograds[0]},
                 {{"transpose_a", "true"},
                  {"transpose_b", "false"}})
      };
    } else if (Ta && !Tb) {
      // Ta = true, Tb = false
      // grad_x = y dot grad_z.T
      // grad_y = x dot grad_z
      return std::vector<NodeEntry>{
        MakeNode("matmul", n->attrs.name + "_grad_0",
                 {n->inputs[1], ograds[0]},
                 {{"transpose_a", "false"},
                  {"transpose_b", "true"}}),
        MakeNode("matmul", n->attrs.name + "_grad_1",
                 {n->inputs[0], ograds[0]},
                 {{"transpose_a", "false"},
                  {"transpose_b", "false"}})
      };
    } else if (!Ta && Tb) {
      // Ta = false, Tb = true
      // grad_x = grad_z dot y
      // grad_y = grad_z.T dot x
      return std::vector<NodeEntry>{
        MakeNode("matmul", n->attrs.name + "_grad_0",
                 {ograds[0], n->inputs[1]},
                 {{"transpose_a", "false"},
                  {"transpose_b", "false"}}),
        MakeNode("matmul", n->attrs.name + "_grad_1",
                 {ograds[0], n->inputs[0]},
                 {{"transpose_a", "true"},
                  {"transpose_b", "false"}})
      };
    } else {
      // Ta = true, Tb = true
      // grad_x = y.T dot grad_z.T
      // grad_y = grad_z.T dot x.T
      return std::vector<NodeEntry>{
        MakeNode("matmul", n->attrs.name + "_grad_0",
                 {n->inputs[1], ograds[0]},
                 {{"transpose_a", "true"},
                  {"transpose_b", "true"}}),
        MakeNode("matmul", n->attrs.name + "_grad_1",
                 {ograds[0], n->inputs[0]},
                 {{"transpose_a", "true"},
                  {"transpose_b", "true"}})
      };
    }
});

由于nnvm后面是集成到tvm里面,可能部分代码并不是nnvm里面的实现,op 比如conv2d数学的计算,大家可以去tvm的代码去看,应该就是更高一层封装用来去注册到nnvm的graph中。

计算图的操作(pass.h/cc, pass_functions.h)

nnvm构造了一个所谓的pass(我也不知道怎么翻译), 主要就是负责整个graph的操作,每一个操作会定义一个pass_function,举个例子SaveJSON, 将graph保存为json的格式的,会定义SaveJson实现:

Graph SaveJSON(Graph src) {
  std::shared_ptr<Symbol> src_symbol = std::make_shared<Symbol>();
  src_symbol->outputs = src.outputs;
  JSONGraph jgraph;
  Symbol2JSONGraph(src_symbol, &jgraph);
  jgraph.attrs = src.attrs;
  std::ostringstream os;
  dmlc::JSONWriter writer(&os);
  jgraph.Save(&writer);
  Graph ret;
  ret.attrs["json"] = std::make_shared<any>(os.str());
  return ret;
}

void Symbol2JSONGraph(std::shared_ptr<Symbol> src, JSONGraph *jgraph) {
  std::unordered_map<Node*, uint32_t> node2index;
  jgraph->node_row_ptr.push_back(0);
  DFSVisit(src->outputs, [&node2index, jgraph](const NodePtr& n) {
    uint32_t nid = static_cast<uint32_t>(jgraph->nodes.size());
    node2index[n.get()] = nid;
    if (n->is_variable()) {
      jgraph->arg_nodes.push_back(nid);
    }
    JSONNode jnode;
    jnode.node = n;
    jnode.inputs.reserve(n->inputs.size());
    for (const NodeEntry& e : n->inputs) {
      jnode.inputs.emplace_back(node2index.at(e.node.get()), e.index, e.version);
    }
    for (const NodePtr& c : n->control_deps) {
      jnode.control_deps.push_back(node2index.at(c.get()));
    }
    jgraph->node_row_ptr.push_back(jgraph->node_row_ptr.back() + n->num_outputs());
    jgraph->nodes.emplace_back(std::move(jnode));
  });
  for (const NodeEntry& e : src->outputs) {
    jgraph->heads.emplace_back(node2index.at(e.node.get()), e.index, e.version);
  }
  // recursively construct subgraphs
  for (JSONNode &jnode : jgraph->nodes) {
    // construct jnode's subgraphs
    const std::vector<std::shared_ptr<Symbol>> &subgraphs = jnode.node->attrs.subgraphs;
    std::vector<JSONGraph> &jsubgraphs = jnode.subgraphs;
    jsubgraphs.resize(subgraphs.size());
    for (uint32_t i = 0; i < subgraphs.size(); ++i) {
      Symbol2JSONGraph(subgraphs[i], &jsubgraphs[i]);
    }
  }
}

然后注册pass函数:

NNVM_REGISTER_PASS(SaveJSON)
.describe("Return a new empty Graph. Save graph to ret.attrs[\"json\"]")
.set_body(SaveJSON)
.set_change_graph(true)
.provide_graph_attr("json");

如何去使用:

inline std::string SaveJSON(Graph graph) {
  Graph ret = ApplyPass(std::move(graph), "SaveJSON");
  return ret.GetAttr<std::string>("json");
}

这样就可以对整个graph作一个SaveJson的操作,如何实现一个PassFunctionReg在pass.h中,感兴趣的可以仔细阅读下。

SaveJson是一个简单的pass function,gradient、InferShape、InferType也是类似的逻辑去理解,再gradient的代码,和SaveJson类似,都会去遍历graph,gradient会重构一个gradient computation graph来返回,如下:

Graph Gradient(Graph src) {
  using nnvm::FGradient;
  using MirrorFun = std::function<int (const Node& node)>;
  using AttrHintFun = std::function<NodeEntry (const NodeEntry& src, const NodeEntry &like)>;

  CHECK_NE(src.attrs.count("grad_ys"), 0U)
      << "Gradient require grad_ys to be presented.";
  CHECK_NE(src.attrs.count("grad_ys_out_grad"), 0U)
      << "Gradient require grad_ys_out_grad to be presented.";
  CHECK_NE(src.attrs.count("grad_xs"), 0U)
      << "Gradient require grad_xs to be presented.";
  const std::vector<NodeEntry>& ys =
      src.GetAttr<std::vector<NodeEntry> >("grad_ys");
  const std::vector<NodeEntry>& ys_out_grad =
      src.GetAttr<std::vector<NodeEntry> >("grad_ys_out_grad");
  const std::vector<NodeEntry>& xs =
      src.GetAttr<std::vector<NodeEntry> >("grad_xs");
  using AggFun = std::function<NodeEntry (std::vector<NodeEntry>&& inputs)>;
  AggFun agg_fun = DefaultAggregateGradient;
  if (src.attrs.count("grad_aggregate_fun") != 0) {
    agg_fun = src.GetAttr<AggFun>("grad_aggregate_fun");
  }
  MirrorFun mirror_fun = nullptr;
  if (src.attrs.count("grad_mirror_fun") != 0) {
    mirror_fun = src.GetAttr<MirrorFun>("grad_mirror_fun");
  }
  AttrHintFun attr_hint_fun = nullptr;
  if (src.attrs.count("attr_hint_fun") != 0) {
    attr_hint_fun = src.GetAttr<AttrHintFun>("attr_hint_fun");
  }
  std::vector<const Op*> zero_ops;
  if (src.attrs.count("zero_ops") != 0) {
    zero_ops = src.GetAttr<std::vector<const Op*> >("zero_ops");
  }
  const Op* copy_op = (src.attrs.count("copy_op") != 0) ?
      Op::Get(src.GetAttr<std::string>("copy_op")) :
      nullptr;

  // topo sort
  std::vector<NodePtr> topo_order;
  std::unordered_map<Node*, std::vector<GradEntry> > output_grads;

  DFSVisit(ys, [&](const NodePtr& node) {
      if (output_grads.count(node.get()) == 0) {
        output_grads[node.get()].resize(node->num_outputs());
      }
      topo_order.push_back(node);
    });

  CHECK_EQ(ys.size(), ys_out_grad.size());
  for (size_t i = 0; i < ys.size(); ++i) {
    NodeEntry ograd = ys_out_grad[i];
    output_grads[ys[i].node.get()][ys[i].index].grads = { ograd };
  }

  // Check that all xs are reachable from ys
  for (size_t i = 0; i < xs.size(); ++i) {
    CHECK(output_grads.find(xs[i].node.get()) != output_grads.end())
        << "Cannot differentiate with respect to the " << i+1 << "-th variable "
        << "because it is unreachable from the outputs.";
  }

  // construct mirror as memory reduction strategy if needed
  std::unordered_map<Node*, NodePtr> mirror_map;
  if (mirror_fun != nullptr) {
    for (const NodePtr& node_ptr : topo_order) {
      if (mirror_fun(*node_ptr)) {
        NodePtr new_node = Node::Create();
        *new_node = *node_ptr;
        new_node->attrs.name += "_mirror";
        for (auto& e : new_node->inputs) {
          e.node = mirror_map.at(e.node.get());
        }
        for (auto& n : new_node->control_deps) {
          n = mirror_map.at(n.get());
        }
        mirror_map[node_ptr.get()] = std::move(new_node);
      } else {
        mirror_map[node_ptr.get()] = node_ptr;
      }
    }
  }

  // traverse backward
  static auto& grad_fun_map = Op::GetAttr<FGradient>("FGradient");
  static auto& finfer_shape = Op::GetAttr<FInferShape>("FInferShape");

  std::vector<NodeEntry> out_agg_grads;
  for (auto rit = topo_order.rbegin(); rit != topo_order.rend(); ++rit) {
    const NodePtr& ptr = *rit;
    if (ptr->is_variable()) continue;
    out_agg_grads.clear();
    auto& out_grad_vec = output_grads.at(ptr.get());
    for (uint32_t i = 0; i < out_grad_vec.size(); ++i) {
      GradEntry& e = out_grad_vec[i];
      e.sum = agg_fun(std::move(e.grads));
      if (e.need_attr_hint && attr_hint_fun != nullptr) {
        e.sum = attr_hint_fun(e.sum, NodeEntry{ptr, 0, i});
      }
      out_agg_grads.push_back(e.sum);
    }
    if ((*rit)->inputs.size() != 0) {
      NodePtr fwd_node = (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()));
      std::vector<NodeEntry> input_grads;
      // Check for FGradient
      if (grad_fun_map.contains(ptr->op())) {
        input_grads = grad_fun_map[ptr->op()](fwd_node, out_agg_grads);
        CHECK_EQ((*rit)->inputs.size(), input_grads.size())
            << "Gradient function not returning enough gradient";
      } else if (CheckGradAllZero(out_agg_grads, zero_ops)) {
        for (size_t i = 0; i < fwd_node->num_inputs(); ++i) {
          std::ostringstream os;
          if (1 == fwd_node->num_inputs()) {
            os << fwd_node->attrs.name << "_backward";
          } else {
            os << fwd_node->attrs.name << "_in" << i << "_backward";
          }
          auto p = Node::Create();
          p->attrs.op = zero_ops[0];
          p->attrs.name = os.str();
          p->inputs.push_back(fwd_node->inputs[i]);
          p->control_deps.emplace_back(fwd_node);
          if (p->op()->attr_parser != nullptr) {
            p->op()->attr_parser(&(p->attrs));
          }
          input_grads.emplace_back(p, 0, 0);
        }
      } else {
        LOG(FATAL) << "Operator " << fwd_node->op()->name << " is non-differentiable "
                   << "because it didn't register FGradient attribute.";
      }
      for (const auto& nodeEntry : input_grads)
        CHECK(nodeEntry.node);
      auto git = input_grads.begin();
      CHECK((*rit)->inputs.size() <= input_grads.size());
      for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) {
        auto& output_grad_entry = output_grads[it->node.get()][it->index];
        // if any of the backward op can do shape inference, the hint is not necessary.
        if (finfer_shape.contains(git->node->op())) {
          output_grad_entry.need_attr_hint = false;
        }
        output_grad_entry.grads.emplace_back(std::move(*git));
      }
    }
  }
  // take out the xs' grads
  Graph ret;
  ret.outputs.resize(xs.size());
  NodeEntryMap<std::pair<size_t, size_t> > unique_grads;
  size_t counter = 0;
  for (const NodeEntry& e : xs) {
    GradEntry& entry = output_grads[e.node.get()][e.index];
    // aggregate sum if there haven't been
    if (entry.sum.node.get() == nullptr) {
      entry.sum = agg_fun(std::move(entry.grads));
      if (entry.need_attr_hint && attr_hint_fun != nullptr) {
        entry.sum = attr_hint_fun(entry.sum, e);
      }
    }
    if (copy_op != nullptr) {
      auto kv = unique_grads.find(entry.sum);
      if (kv == unique_grads.end()) {
        unique_grads.emplace(std::move(entry.sum), std::make_pair(1, counter));
      } else {
        NodePtr copy_node = Node::Create();
        std::ostringstream os;
        os << entry.sum.node->attrs.name << "_" << kv->second.first << "_copy";
        kv->second.first++;
        copy_node->attrs.op = copy_op;
        copy_node->attrs.name = os.str();
        copy_node->inputs.emplace_back(entry.sum);
        if (copy_node->attrs.op->attr_parser != nullptr) {
            copy_node->attrs.op->attr_parser(&(copy_node->attrs));
        }
        unique_grads.emplace(NodeEntry{std::move(copy_node), 0, 0}, std::make_pair(1, counter));
      }
    } else {
        ret.outputs[counter] = entry.sum;
    }
    ++counter;
  }
  if (copy_op != nullptr) {
    for (const auto& kv : unique_grads) {
      ret.outputs[kv.second.second] = kv.first;
    }
  }
  return ret;
}

// register pass
NNVM_REGISTER_PASS(Gradient)
.describe("Return a gradient graph of src.attrs[\"ys\"] wrt src.attrs[\"xs\"]")
.set_body(Gradient)
.set_change_graph(true)
.depend_graph_attr("grad_ys")
.depend_graph_attr("grad_xs")
.depend_graph_attr("grad_ys_out_grad");

后话

nnvm现在集合在TVM,要更加深入的理解,现在看来必须得去硬抠tvm的代码了,现在对nnvm算是理解了,但理解和做出来差了十万八千里,后续想自己写个类似nnvm的东西,而转念一想,chentianqi这些牛人开源出这些东西,就是想减少我们的工作成本,直接参与进去可能是另一种学习一种“做出来”的方法,anyway,计算图这块还是得深入下去,接下来会自己写一些测试的程序去更深入学习nnvm,大家如果有对ai system 感兴趣的小伙伴,可以私信我,我们看看能不能一起学习学习,这里面确实涉及太多东西,一个人,尤其我这种算法出身的小伙伴,真的比较难


也欢迎大家关注我的同名微信公众号 小石头的码疯窝(xiaoshitou_ml_tech),或者通过公众号加我的个人微信进行讨论

相关推荐

【推荐】一款开源免费、美观实用的后台管理系统模版

如果您对源码&技术感兴趣,请点赞+收藏+转发+关注,大家的支持是我分享最大的动力!!!项目介绍...

Android架构组件-App架构指南,你还不收藏嘛

本指南适用于那些已经拥有开发Android应用基础知识的开发人员,现在想了解能够开发出更加健壮、优质的应用程序架构。首先需要说明的是:AndroidArchitectureComponents翻...

高德地图经纬度坐标批量拾取(高德地图批量查询经纬度)

使用方法在桌面上新建一个index.txt文件,把下面的代码复制进去保存,再把文件名改成index.html保存,双击运行打开即可...

flutter系列之:UI layout简介(flutter ui设计)

简介对于一个前端框架来说,除了各个组件之外,最重要的就是将这些组件进行连接的布局了。布局的英文名叫做layout,就是用来描述如何将组件进行摆放的一个约束。...

Android开发基础入门(一):UI与基础控件

Android基础入门前言:...

iOS的布局体系-流式布局MyFlowLayout

iOS布局体系的概览在我的CSDN博客中的几篇文章分别介绍MyLayout布局体系中的视图从一个方向依次排列的线性布局(MyLinearLayout)、视图层叠且停靠于父布局视图某个位置的框架布局(M...

TDesign企业级开源设计系统越发成熟稳定,支持 Vue3 / 小程序

TDesing发展越来越好了,出了好几套组件库,很成熟稳定了,新项目完全可以考虑使用。...

WinForm实现窗体自适应缩放(winform窗口缩放)

众所周知,...

winform项目——仿QQ即时通讯程序03:搭建登录界面

上两篇文章已经对CIM仿QQ即时通讯项目进行了需求分析和数据库设计。winform项目——仿QQ即时通讯程序01:原理及项目分析...

App自动化测试|原生app元素定位方法

元素定位方法介绍及应用Appium方法定位原生app元素...

61.C# TableLayoutPanel控件(c# tabcontrol)

摘要TableLayoutPanel在网格中排列内容,提供类似于HTML元素的功能。TableLayoutPanel控件允许你将控件放在网格布局中,而无需精确指定每个控件的位置。其单元格...

想要深入学习Android性能优化?看完这篇直接让你一步到位

...

12个python数据处理常用内置函数(python 的内置函数)

在python数据分析中,经常需要对字符串进行各种处理,例如拼接字符串、检索字符串等。下面我将对python中常用的内置字符串操作函数进行介绍。1.计算字符串的长度-len()函数str1='我爱py...

如何用Python程序将几十个PDF文件合并成一个PDF?其实只要这四步

假定你有一个很无聊的任务,需要将几十个PDF文件合并成一个PDF文件。每一个文件都有一个封面作为第一页,但你不希望合并后的文件中重复出现这些封面。即使有许多免费的程序可以合并PDF,很多也只是简单的将...

Python入门知识点总结,Python三大数据类型、数据结构、控制流

Python基础的重要性不言而喻,是每一个入门Python学习者所必备的知识点,作为Python入门,这部分知识点显得很庞杂,内容分支很多,大部分同学在刚刚学习时一头雾水。...