GithubHelp home page GithubHelp logo

xwying / torchshow Goto Github PK

View Code? Open in Web Editor NEW
552.0 552.0 10.0 7.79 MB

Visualize PyTorch tensors with a single line of code.

License: MIT License

Python 100.00%
image-tensor notebook pytorch tensor visualization

torchshow's People

Contributors

genghisun avatar xwying avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

torchshow's Issues

Plot titles

Hi there, I am currently visualising feature maps in yolov7 for each layer and I was wondering if it was possible for each image to have a subtitle? for example "Feature map 1"
filter_maps

torchshow to save mask

Hi, recently i meet this bug when saving mask, but this bug is a little bit strange because it not always happen but for certain images‘ gt mask,

File "/home/zmz/Mask2Former/mask2former/maskformer_model.py", line 316, in forward
torchshow.save(masks[0]>0,"gt_mask.jpg")
File "/home/zmz/miniconda3/envs/py38/lib/python3.8/site-packages/torchshow/torchshow.py", line 19, in save
show(x, save=True, file_path=path, **kwargs)
File "/home/zmz/miniconda3/envs/py38/lib/python3.8/site-packages/torchshow/torchshow.py", line 97, in show
display_plt(plot_list, **kwargs)
File "/home/zmz/miniconda3/envs/py38/lib/python3.8/site-packages/torchshow/visualization.py", line 156, in display_plt
fig.savefig(file_path, bbox_inches = 'tight', pad_inches=0)
File "/home/zmz/miniconda3/envs/py38/lib/python3.8/site-packages/matplotlib/figure.py", line 3046, in savefig
self.canvas.print_figure(fname, **kwargs)
File "/home/zmz/miniconda3/envs/py38/lib/python3.8/site-packages/matplotlib/backend_bases.py", line 2295, in print_figure
self.figure.draw(renderer)
File "/home/zmz/miniconda3/envs/py38/lib/python3.8/site-packages/matplotlib/artist.py", line 73, in draw_wrapper
result = draw(artist, renderer, *args, **kwargs)
File "/home/zmz/miniconda3/envs/py38/lib/python3.8/site-packages/matplotlib/artist.py", line 50, in draw_wrapper
return draw(artist, renderer)
File "/home/zmz/miniconda3/envs/py38/lib/python3.8/site-packages/matplotlib/figure.py", line 2837, in draw
mimage._draw_list_compositing_images(
File "/home/zmz/miniconda3/envs/py38/lib/python3.8/site-packages/matplotlib/image.py", line 132, in _draw_list_compositing_images
a.draw(renderer)
File "/home/zmz/miniconda3/envs/py38/lib/python3.8/site-packages/matplotlib/artist.py", line 50, in draw_wrapper
return draw(artist, renderer)
File "/home/zmz/miniconda3/envs/py38/lib/python3.8/site-packages/matplotlib/axes/_base.py", line 3091, in draw
mimage._draw_list_compositing_images(
File "/home/zmz/miniconda3/envs/py38/lib/python3.8/site-packages/matplotlib/image.py", line 132, in _draw_list_compositing_images
a.draw(renderer)
File "/home/zmz/miniconda3/envs/py38/lib/python3.8/site-packages/matplotlib/artist.py", line 50, in draw_wrapper
return draw(artist, renderer)
File "/home/zmz/miniconda3/envs/py38/lib/python3.8/site-packages/matplotlib/image.py", line 646, in draw
im, l, b, trans = self.make_image(
File "/home/zmz/miniconda3/envs/py38/lib/python3.8/site-packages/matplotlib/image.py", line 956, in make_image
return self._make_image(self._A, bbox, transformed_bbox, clip,
File "/home/zmz/miniconda3/envs/py38/lib/python3.8/site-packages/matplotlib/image.py", line 555, in _make_image
output_alpha = _resample( # resample alpha channel
File "/home/zmz/miniconda3/envs/py38/lib/python3.8/site-packages/matplotlib/image.py", line 193, in _resample
_image.resample(data, out, transform,
ValueError: Unsupported dtype

Got unsupported ScalarType BFloat16

torchshow.save(images)
Traceback (most recent call last):
  File "torchshow/utils.py", line 60, in tensor_to_array
    return x.detach().clone().cpu().numpy()
TypeError: Got unsupported ScalarType BFloat16

BFloat16 is currently not supported. It would be preferable to optimize this if possible.

[Bug] TypeError: 'str' object is not callable, when unnormalize with customize mean and std

torchshow.set_image_mean([0., 0., 0.])
torchshow.set_image_std([1., 1., 1.])
torchshow.show(img)


---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)

File /somepath/python3.8/site-packages/torchshow/visualization.py:289, in vis_image(x, unnormalize, **kwargs)
    [287]     if user_std == None:
    [288]         user_std = [1.] * x.shape[-1] # Initialize std to 1 if not specified.
--> [289]     x = unnormalize(x, user_mean, user_std)
    [291] elif unnormalize=='auto':
    [292]     x = auto_unnormalize_image(x)

TypeError: 'str' object is not callable

It seems to be caused by the variable name and method name being the same 'unnormalize'.
And this case should be added to the test.

Discussion about default unnormalize method for pytorch tensor

Currently, it seems that for input pytorch float tensor between -1 and 1, a min-max normalization is taken:

def auto_unnormalize_image(x):
all_int = isinteger(np.unique(x)).all()
range_0_1 = within_0_1(x)
range_0_255 = within_0_255(x)
has_negative = (x.min() < 0)
if has_negative:
logger.debug('Detects input has negative values, auto rescaling input to 0-1.')
return rescale_0_1(x)
if range_0_255 and all_int and (not range_0_1): # if image is all integer and between 0 - 255. Normalize it to 0-1.
logger.debug('Detects input are all integers within range 0-255. Divided all values by 255.')
return x / 255.
if range_0_1:
logger.debug('Inputs already within 0-1, no unnormalization is performed.')
return x
logger.debug('Auto rescaling input to 0-1.')
return rescale_0_1(x)
def rescale_0_1(x):
"""
Rescaling tensor to 0-1 using min-max normalization
"""
return (x - x.min()) / (x.max() - x.min())

However, usually people use transforms.Normalize((0.5,), (0.5,)) to normalize image.
So I think when x is between -1 and 1, it should be x * 0.5 + 0.5 to keep the image color consistent.

I know it can be achieved by setting set_image_mean([0.5, 0.5, 0.5]) and set_image_std([0.5, 0.5, 0.5]), but this may be used for other special cases. For the most common cases, i think we should adopt the above method.

Any idea?

How to control the size of the output images?

I have a relatively large tensor I'd like to visualize, multiple iterations of a 2048x3292 tensor. When I use save on a list of 4 of these tensors, the image gets shrunk down to 609x87, which is way too small to see the details I need to see. Is there an option available to make the save function not compress the image in size? I tried ncols and nrows, but it still shrunk it.

>>> print(len(tensors), tensors[0].shape)
4 torch.Size([2048, 3292])
>>> ts.save([tensors], args.tensor_name + ".png", nrows=2048, ncols=3292)

--> result is a 609x87 image

Also, as an aside, is there a convenient way to map positive numbers to green and negative numbers to red (or any other color scheme), or should I handle remapping those values before passing to torchshow?

Thanks!

Can we add "color map" in torchshow?

I'm visualizing my feature map using torchshow. It's very useful to me. However, torchshow only support grap color map when visualizing tensor whose shape is (N, C, H, W), where C>=1. I need color map like cv2.COLORMAP_JET to show clear information about feature tensors, which will make things easier, thanks!

torchshow.save memory problem

Hi, i meet this warning when i use torchshow to save masks in 'for' loop. Is there anyway to solve this problem?

envs/py38/lib/python3.8/site-packages/torchshow/visualization.py:108: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (matplotlib.pyplot.figure) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam figure.max_open_warning).
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, squeeze=False)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.