Beatmup
inference_task.h
Go to the documentation of this file.
1 /*
2  Beatmup image and signal processing library
3  Copyright (C) 2020, lnstadrum
4 
5  This program is free software: you can redistribute it and/or modify
6  it under the terms of the GNU General Public License as published by
7  the Free Software Foundation, either version 3 of the License, or
8  (at your option) any later version.
9 
10  This program is distributed in the hope that it will be useful,
11  but WITHOUT ANY WARRANTY; without even the implied warranty of
12  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  GNU General Public License for more details.
14 
15  You should have received a copy of the GNU General Public License
16  along with this program. If not, see <http://www.gnu.org/licenses/>.
17 */
18 
19 #pragma once
20 
21 #include "model.h"
22 #include "../gpu/gpu_task.h"
23 #include <map>
24 
25 namespace Beatmup {
26  namespace NNets {
27 
28  /**
29  Task running inference of a Model.
30  During the firs run of this task with a given model the shader programs are built and the memory is allocated.
31  The subsequent runs are much faster.
32  */
33  class InferenceTask : public GpuTask, private BitmapContentLock {
34  private:
35  std::map<std::pair<AbstractOperation*, int>, AbstractBitmap*> inputImages;
36 
37  void beforeProcessing(ThreadIndex threadCount, ProcessingTarget target, GraphicPipeline* gpu) override;
38  void afterProcessing(ThreadIndex threadCount, GraphicPipeline* gpu, bool aborted) override;
39  bool processOnGPU(GraphicPipeline& gpu, TaskThread& thread) override;
40  bool process(TaskThread& thread) override;
41  ThreadIndex getMaxThreads() const override { return MAX_THREAD_INDEX; }
42 
43  protected:
46 
47  public:
49 
50  /**
51  Connects an image to a specific operation input.
52  Ensures the image content is up-to-date in GPU memory by the time the inference is run.
53  \param[in] image The image
54  \param[in] operation The operation
55  \param[in] inputIndex The input index of the operation
56  */
57  void connect(AbstractBitmap& image, AbstractOperation& operation, int inputIndex = 0);
58  inline void connect(AbstractBitmap& image, const std::string& operation, int inputIndex = 0) {
59  connect(image, model.getOperation(operation), inputIndex);
60  }
61  };
62 
63  }
64 }
A very basic class for any image.
Makes sure the bitmap content is accessible within an image processing task.
Definition: content_lock.h:34
A key-value pair set storing pieces of arbitrary data (chunks) under string keys.
Definition: chunkfile.h:36
Template of a task using GPU.
Definition: gpu_task.h:27
Internal low-level GPU control API.
Definition: pipeline.h:33
Abstract neural net operation (layer).
Definition: operation.h:46
Task running inference of a Model.
void connect(AbstractBitmap &image, const std::string &operation, int inputIndex=0)
std::map< std::pair< AbstractOperation *, int >, AbstractBitmap * > inputImages
ThreadIndex getMaxThreads() const override
Gives the upper limint on the number of threads the task may be performed by.
void connect(AbstractBitmap &image, AbstractOperation &operation, int inputIndex=0)
Connects an image to a specific operation input.
void afterProcessing(ThreadIndex threadCount, GraphicPipeline *gpu, bool aborted) override
Instruction called after the task is executed.
bool process(TaskThread &thread) override
Executes the task on CPU within a given thread.
void beforeProcessing(ThreadIndex threadCount, ProcessingTarget target, GraphicPipeline *gpu) override
Instruction called before the task is executed.
InferenceTask(Model &model, ChunkCollection &data)
bool processOnGPU(GraphicPipeline &gpu, TaskThread &thread) override
Executes the task on GPU.
Neural net model.
Definition: model.h:92
OperationClass & getOperation(const std::string &operationName)
Retrieves an operation by its name.
Definition: model.h:303
Thread executing tasks.
Definition: parallelism.h:154
unsigned char ThreadIndex
number of threads / thread index
Definition: parallelism.h:68
static const ThreadIndex MAX_THREAD_INDEX
maximum possible thread index value
Definition: parallelism.h:71
ProcessingTarget
Definition: basic_types.h:55