本文共 1474 字,大约阅读时间需要 4 分钟。
#include#include #include using namespace std;/// /// LeNet实现类/// class LeNet :public torch::nn::Module { public: // 构造器 LeNet(int num_classes,int num_linear); // 前向传播 torch::Tensor forward(torch::Tensor x);private: // 具体实现放到构造器实现中 torch::nn::Conv2d conv1{ nullptr}; torch::nn::Conv2d conv2{ nullptr}; torch::nn::Linear fc1{ nullptr }; torch::nn::Linear fc2{ nullptr}; torch::nn::Linear fc3{ nullptr};};LeNet::LeNet(int num_classes, int num_linear){ conv1 = register_module("conv1", torch::nn::Conv2d(torch::nn::Conv2dOptions(3, 6, 5))); conv2 = register_module("conv2", torch::nn::Conv2d(torch::nn::Conv2dOptions(6, 16, 5))); fc1 = register_module("fc1", torch::nn::Linear(torch::nn::LinearOptions(num_linear, 128))); fc2 = register_module("fc2", torch::nn::Linear(torch::nn::LinearOptions(128, 32))); fc3 = register_module("fc3", torch::nn::Linear(torch::nn::LinearOptions(32, num_classes)));}torch::Tensor LeNet::forward(torch::Tensor x){ auto out = torch::relu(conv1->forward(x)); out = torch::max_pool2d(out, 2); out = torch::relu(conv2(out)); out = torch::max_pool2d(out, 2); out = out.view({ 1, -1 }); out = torch::relu(fc1(out)); out = torch::relu(fc2(out)); out = fc3(out); return out;}int main(){ //step0:定义使用cuda auto device = torch::Device(torch::kCUDA, 0); // step1:生成测试数据 auto input = torch::ones({ 1,3,224,224}); cout << input.sizes() <

转载地址:http://xwwfk.baihongyu.com/