37 at::Tensor particles_position_tensor =
39 {wheel.num_particles, 3}, torch::kFloat32);
41 at::Tensor particles_velocity_tensor =
43 {wheel.num_particles, 3}, torch::kFloat32);
45 at::Tensor wheel_positions_tensor =
47 {3}, torch::kFloat32);
49 at::Tensor wheel_oritentation_tensor =
51 {4}, torch::kFloat32);
53 at::Tensor wheel_linear_velocity_tensor =
55 {3}, torch::kFloat32);
57 at::Tensor wheel_angular_velocity_tensor =
59 {3}, torch::kFloat32);
61 std::vector<torch::jit::IValue> Tuple
62 {particles_position_tensor, particles_velocity_tensor, wheel_positions_tensor,
63 wheel_oritentation_tensor, wheel_linear_velocity_tensor, wheel_angular_velocity_tensor};
64 return torch::ivalue::Tuple::create(Tuple);
68 const at::Tensor &particle_forces,
69 const at::Tensor &wheel_forces ) {
71 const float* wheel_forces_data = wheel_forces.data_ptr<
float>();
78 const float* particle_forces_data = particle_forces.data_ptr<
float>();
79 int num_dimensions = 3;
80 int num_particles = particle_forces.sizes()[0];
82 for (
int i = 0; i < num_particles; i++) {
84 particle_forces_data[i*num_dimensions + 0]);
86 particle_forces_data[i*num_dimensions + 1]);
88 particle_forces_data[i*num_dimensions + 2]);
94 const at::Tensor &particle_forces,
95 const at::Tensor &wheel_forces) {
97 const float* wheel_forces_data = wheel_forces.data_ptr<
float>();
101 const float* particle_forces_data = particle_forces.data_ptr<
float>();
102 int num_dimensions = 3;
103 int num_particles = particle_forces.sizes()[0];
105 for (
int i = 0; i < num_particles; i++) {
107 particle_forces_data[i*num_dimensions + 0]);
109 particle_forces_data[i*num_dimensions + 1]);
111 particle_forces_data[i*num_dimensions + 2]);
128 at::Tensor particles_position_tensor =
130 {wheel.num_particles, 3}, torch::kFloat32);
132 at::Tensor particles_velocity_tensor =
134 {wheel.num_particles, 3}, torch::kFloat32);
136 at::Tensor wheel_positions_tensor =
138 {3}, torch::kFloat32);
140 at::Tensor wheel_oritentation_tensor =
142 {4}, torch::kFloat32);
144 at::Tensor wheel_linear_velocity_tensor =
146 {3}, torch::kFloat32);
148 at::Tensor wheel_angular_velocity_tensor =
150 {3}, torch::kFloat32);
152 std::vector<torch::jit::IValue> Tuple
153 {particles_position_tensor.cuda(), particles_velocity_tensor.cuda(), wheel_positions_tensor.cuda(),
154 wheel_oritentation_tensor.cuda(), wheel_linear_velocity_tensor.cuda(), wheel_angular_velocity_tensor.cuda(),
156 return torch::ivalue::Tuple::create(Tuple);
163 torch::jit::setTensorExprFuserEnabled(
false);
164 std::string filename_str(filename);
165 std::cout <<
"loading " << filename_str << std::endl;
167 Model->module = torch::jit::load(filename_str);
168 std::string cuda_str =
"cuda:" + std::to_string(device);
171 }
catch (
const c10::Error& e) {
172 std::cout <<
"Error loading model: " << e.msg() << std::endl;
174 std::cout <<
"loaded " << filename_str << std::endl;
183 std::vector<torch::jit::IValue> TorchInputs;
188 auto drv_inputs = torch::tensor(
190 TorchInputs.push_back(drv_inputs);
196 torch::jit::IValue Output;
198 Output =
Model->module.forward(TorchInputs);
199 }
catch (
const c10::Error& e) {
200 std::cout <<
"Error running model: " << e.msg() << std::endl;
203 std::vector<torch::jit::IValue> Tensors = Output.toTuple()->elements();
205 Tensors[0].toTensor().cpu(), Tensors[4].toTensor().cpu() );
207 Tensors[1].toTensor().cpu(), Tensors[5].toTensor().cpu() );
209 Tensors[2].toTensor().cpu(), Tensors[6].toTensor().cpu() );
211 Tensors[3].toTensor().cpu(), Tensors[7].toTensor().cpu() );
217 std::vector<torch::jit::IValue> TorchInputs;
222 auto drv_inputs = torch::tensor(
224 TorchInputs.push_back(drv_inputs);
230 torch::jit::IValue Output;
232 Output =
Model->module.forward(TorchInputs);
233 }
catch (
const c10::Error& e) {
234 std::cout <<
"Error running model: " << e.msg() << std::endl;
237 std::vector<torch::jit::IValue> Tensors = Output.toTuple()->elements();
239 Tensors[0].toTensor().cpu(), Tensors[4].toTensor().cpu());
241 Tensors[1].toTensor().cpu(), Tensors[5].toTensor().cpu());
243 Tensors[2].toTensor().cpu(), Tensors[6].toTensor().cpu());
245 Tensors[3].toTensor().cpu(), Tensors[7].toTensor().cpu());
249 c10::cuda::CUDACachingAllocator::emptyCache();
255 std::vector<torch::jit::IValue> TorchInputs;
260 auto drv_inputs = torch::tensor(
262 TorchInputs.push_back(drv_inputs.cuda());
268 torch::jit::IValue Output;
270 Output =
Model->module.forward(TorchInputs);
271 }
catch (
const c10::Error& e) {
272 std::cout <<
"Error running model: " << e.msg() << std::endl;
275 std::vector<torch::jit::IValue> Tensors = Output.toTuple()->elements();
277 Tensors[0].toTensor().cpu(), Tensors[4].toTensor().cpu() );
279 Tensors[1].toTensor().cpu(), Tensors[5].toTensor().cpu() );
281 Tensors[2].toTensor().cpu(), Tensors[6].toTensor().cpu() );
283 Tensors[3].toTensor().cpu(), Tensors[7].toTensor().cpu() );