Comments (10)
from deepwave.
Thank you so much. I would like to help with testing 3D boundary conditions and maybe improving that piece of this code.
Please help me get started with the compiled propagator.
from deepwave.
I have written some code to call the compiled propagator directly so that we can access the wavefields at arbitrary time steps:
import torch
import numpy as np
import scipy.signal
import deepwave
import deepwave.base.propagator
from deepwave.scalar import scalar
class SteppingPropagator(deepwave.base.propagator.Propagator):
"""PyTorch Module for scalar wave propagator.
See deepwave.base.propagator.Propagator for description.
"""
def __init__(self, model, dx,
source_amplitudes, source_locations, receiver_locations, dt,
pml_width=None, survey_pad=None, vpmax=None):
if list(model.keys()) != ["vp"]:
raise RuntimeError(
"Model must only contain vp, but contains {}".format(
list(model.keys())
)
)
super(SteppingPropagator, self).__init__(
SteppingPropagatorFunction,
model,
dx,
fd_width=4, # also in Pml
pml_width=pml_width,
survey_pad=survey_pad,
)
self.model.extra_info["vpmax"] = vpmax
if model["vp"].min() <= 0.0:
raise RuntimeError(
"vp must be > 0, but min is {}".format(model["vp"].min())
)
(source_amplitudes,
source_locations,
receiver_locations,
dt,
model,
property_names,
vp) = self.forward(source_amplitudes, source_locations, receiver_locations, dt)
if property_names != ["vp"]:
raise RuntimeError(
"Model must only contain vp, but contains {}".format(
property_names
)
)
if vp.min() <= 0.0:
raise RuntimeError(
"vp must be > 0, but min is {}".format(vp.min())
)
device = model.device
dtype = model.dtype
num_steps, num_shots, num_sources_per_shot = source_amplitudes.shape
num_receivers_per_shot = receiver_locations.shape[1]
if model.extra_info["vpmax"] is None:
max_vel = vp.max().item()
else:
max_vel = model.extra_info["vpmax"]
timestep = scalar.Timestep(dt, model.dx, max_vel)
model.add_properties(
{
"vp2dt2": vp ** 2 * timestep.inner_dt ** 2,
"scaling": 2 / vp ** 3,
}
)
source_model_locations = model.get_locations(source_locations)
receiver_model_locations = model.get_locations(receiver_locations)
scalar_wrapper = scalar._select_propagator(model.ndim, vp.dtype, vp.is_cuda)
wavefield_save_strategy = scalar._set_wavefield_save_strategy(
False, dt, timestep.inner_dt, scalar_wrapper
)
fd1, fd2 = scalar._set_finite_diff_coeffs(model.ndim, model.dx, device, dtype)
wavefield, saved_wavefields = scalar._allocate_wavefields(
wavefield_save_strategy,
scalar_wrapper,
model,
num_steps,
num_shots,
)
receiver_amplitudes = torch.zeros(
num_steps,
num_shots,
num_receivers_per_shot,
device=device,
dtype=dtype,
)
inner_dt = torch.tensor([timestep.inner_dt]).to(dtype)
pml = scalar.Pml(model, num_shots, max_vel)
source_amplitudes_resampled = scipy.signal.resample(
source_amplitudes.detach().cpu().numpy(),
num_steps * timestep.step_ratio,
)
source_amplitudes_resampled = (
torch.tensor(source_amplitudes_resampled)
.to(dtype)
.to(source_amplitudes.device)
)
source_amplitudes_resampled.requires_grad = (
source_amplitudes.requires_grad
)
self.scalar_wrapper = scalar_wrapper
self.wavefield = wavefield
self.pml = pml
self.receiver_amplitudes = receiver_amplitudes
self.saved_wavefields = saved_wavefields
self.model = model
self.fd1 = fd1
self.fd2 = fd2
self.source_amplitudes_resampled = source_amplitudes_resampled
self.source_model_locations = source_model_locations
self.receiver_model_locations = receiver_model_locations
self.inner_dt = inner_dt
self.timestep = timestep
self.num_shots = num_shots
self.num_sources_per_shot = num_sources_per_shot
self.num_receivers_per_shot = num_receivers_per_shot
self.wavefield_save_strategy = wavefield_save_strategy
self.dtype = dtype
self.total_num_steps = num_steps
self.current_step = 0
def step(self, num_steps):
assert self.current_step + num_steps <= self.total_num_steps
source_amplitudes_resampled_steps = \
self.source_amplitudes_resampled[self.current_step*self.timestep.step_ratio:
(self.current_step+num_steps)*self.timestep.step_ratio]
# Call compiled C code to do forward modeling
self.scalar_wrapper.forward(
self.wavefield.to(self.dtype).contiguous(),
self.pml.aux.to(self.dtype).contiguous(),
self.receiver_amplitudes.to(self.dtype).contiguous(),
self.saved_wavefields.to(self.dtype).contiguous(),
self.pml.sigma.to(self.dtype).contiguous(),
self.model.properties["vp2dt2"].to(self.dtype).contiguous(),
self.fd1.to(self.dtype).contiguous(),
self.fd2.to(self.dtype).contiguous(),
source_amplitudes_resampled_steps.to(self.dtype).contiguous(),
self.source_model_locations.long().contiguous(),
self.receiver_model_locations.long().contiguous(),
self.model.shape.contiguous(),
self.pml.pml_width.long().contiguous(),
self.inner_dt,
num_steps,
self.timestep.step_ratio,
self.num_shots,
self.num_sources_per_shot,
self.num_receivers_per_shot,
self.wavefield_save_strategy,
)
self.current_step += num_steps
if num_steps * self.timestep.step_ratio % 3 != 0:
# Swap the wavefield arrays so that they are in the correct order
wf_idxs = [0, 1, 2]
for stepidx in range(num_steps * self.timestep_step_ratio):
wf_idxs = [wf_idxs[2], wf_idxs[0], wf_idxs[1]]
self.wavefield[0], self.wavefield[1], self.wavefield[2] = \
(self.wavefield[wf_idxs[0]],
self.wavefield[wf_idxs[1]],
self.wavefield[wf_idxs[2]])
if num_steps * self.timestep.step_ratio % 2 != 0:
# Swap the aux arrays so that they are in the correct order
ndim = self.model.ndim
if ndim == 1:
aux_size = 1
elif ndim == 2:
aux_size = 2
else:
aux_size = 4
assert len(self.pml.aux) == 2 * aux_size
self.pml.aux[:aux_size], self.pml.aux[aux_size:] = \
self.pml.aux[aux_size:], self.pml.aux[:aux_size]
return self.wavefield[1]
class SteppingPropagatorFunction(torch.autograd.Function):
"""Forward modeling and backpropagation functions. Not called by users."""
@staticmethod
def forward(
ctx,
source_amplitudes,
source_locations,
receiver_locations,
dt,
model,
property_names,
vp,
):
return (
source_amplitudes,
source_locations,
receiver_locations,
dt,
model,
property_names,
vp,
)
It is a bit hacky - it runs the setup for a regular propagator and then extracts the variables that are passed to the forward method of the propagator. It then uses these to run all of the code in the usual forward propagator up to the point where the compiled propagator gets called, and saves the arguments to this so that they can be used when you actually want to run forward time steps of the propagator. The benefit of doing all of this setup is that the actual stepping part is then quite easy - we just get the right bits of the source wavelet for the desired steps, run the compiled propagator, and then swap some memory around if necessary to make sure it is in the right place.
Here is an example of how to use it:
import matplotlib.pyplot as plt
dx = 5.0 # 5m in each dimension
dt = 0.004 # 4ms
nz = 200
ny = 400
nt = int(5 / dt) # 1s
peak_freq = 4
peak_source_time = 1/peak_freq
# constant 1500m/s model
model = torch.ones(nz, ny) * 1500
# one source and receiver at the same location
x_s = torch.Tensor([[[0, 20 * dx]]])
x_r = x_s.clone()
source_amplitudes = deepwave.wavelets.ricker(peak_freq, nt, dt,
peak_source_time).reshape(-1, 1, 1)
prop = SteppingPropagator({'vp': model}, dx, source_amplitudes, x_s, x_r, dt)
wavefield1 = prop.step(100).detach().numpy().copy()
wavefield2 = prop.step(100).detach().numpy().copy()
wavefield3 = prop.step(100).detach().numpy().copy()
_, ax = plt.subplots(1,3,sharex=True,sharey=True)
ax[0].imshow(wavefield1[0,:,:,0], aspect='auto')
ax[1].imshow(wavefield2[0,:,:,0], aspect='auto')
ax[2].imshow(wavefield3[0,:,:,0], aspect='auto')
plt.show()
The CPU implementation of propagation in 3D is here. If I remember correctly, I used the same PML as PySIT.
from deepwave.
Thank you so much for getting me started. I will update you on my progess
from deepwave.
The code fails with the following error
"TypeError: SteppingPropagatorFunctionBackward.forward: expected Tensor or tuple of Tensor (got float) for return value 3"
from deepwave.
from deepwave.
from deepwave.
from deepwave.
From the message that you got, it sounds like your version of PyTorch is complaining about some of the return values from the forward function in SteppingPropagatorFunction not being Tensors. Perhaps you could try this version of the code instead:
import torch
import numpy as np
import scipy.signal
import deepwave
import deepwave.base.propagator
from deepwave.scalar import scalar
from deepwave.base.propagator import _check_locations_with_model
class SteppingPropagator(deepwave.base.propagator.Propagator):
"""PyTorch Module for scalar wave propagator.
See deepwave.base.propagator.Propagator for description.
"""
def __init__(self, model, dx,
source_amplitudes, source_locations, receiver_locations, dt,
pml_width=None, survey_pad=None, vpmax=None):
if list(model.keys()) != ["vp"]:
raise RuntimeError(
"Model must only contain vp, but contains {}".format(
list(model.keys())
)
)
super(SteppingPropagator, self).__init__(
SteppingPropagatorFunction,
model,
dx,
fd_width=4, # also in Pml
pml_width=pml_width,
survey_pad=survey_pad,
)
self.model.extra_info["vpmax"] = vpmax
if model["vp"].min() <= 0.0:
raise RuntimeError(
"vp must be > 0, but min is {}".format(model["vp"].min())
)
# Check dt
if not isinstance(dt, float):
raise RuntimeError('dt must be a float, but has type {}'
.format(type(dt)))
if dt <= 0.0:
raise RuntimeError('dt must be > 0, but is {}'.format(dt))
# Check same device as model
if not (self.model.device == source_amplitudes.device ==
source_locations.device == receiver_locations.device):
raise RuntimeError('model, source amplitudes, source_locations, '
'and receiver_locations must all have the same '
'device, but got {} {} {} {}'
.format(self.model.device,
source_amplitudes.device,
source_locations.device,
receiver_locations.device))
# Check shapes
if source_amplitudes.dim() != 3:
raise RuntimeError('source_amplitude must have shape '
'[nt, num_shots, num_sources_per_shot]')
if source_locations.dim() != 3:
raise RuntimeError('source_locations must have shape '
'[num_shots, num_sources_per_shot, num_dims]')
if receiver_locations.dim() != 3:
raise RuntimeError('receiver_locations must have shape '
'[num_shots, num_receivers_per_shot, num_dims]')
if not (source_amplitudes.shape[1] == source_locations.shape[0] ==
receiver_locations.shape[0]):
raise RuntimeError('Shape mismatch, expected '
'source_amplitudes.shape[1] '
'== source_locations.shape[0] '
'== receiver_locations.shape[0], but got '
'{} {} {}'.format(source_amplitudes.shape[1],
source_locations.shape[0],
receiver_locations.shape[0]))
if not (source_amplitudes.shape[2] == source_locations.shape[1]):
raise RuntimeError('Shape mismatch, expected '
'source_amplitudes.shape[2] '
'== source_locations.shape[1], but got '
'{} {}'.format(source_amplitudes.shape[2],
source_locations.shape[1]))
if not (self.model.ndim == source_locations.shape[2] ==
receiver_locations.shape[2]):
raise RuntimeError('Shape mismatch, expected '
'model num dims == source_locations.shape[2] '
'== receiver_locations.shape[2], but got '
'{} {} {}'.format(self.model.ndim,
source_locations.shape[2],
receiver_locations.shape[2]))
# Check src/rec locations within model
_check_locations_with_model(self.model, source_locations, 'source')
_check_locations_with_model(self.model, receiver_locations, 'receiver')
# Extract a region of the model around the sources/receivers
model = self.extract(self.model, source_locations, receiver_locations)
# Apply padding for the spatial finite difference and for the PML
model = self.pad(model)
property_names = list(model.properties.keys())
vp = model.properties["vp"]
if property_names != ["vp"]:
raise RuntimeError(
"Model must only contain vp, but contains {}".format(
property_names
)
)
if vp.min() <= 0.0:
raise RuntimeError(
"vp must be > 0, but min is {}".format(vp.min())
)
device = model.device
dtype = model.dtype
num_steps, num_shots, num_sources_per_shot = source_amplitudes.shape
num_receivers_per_shot = receiver_locations.shape[1]
if model.extra_info["vpmax"] is None:
max_vel = vp.max().item()
else:
max_vel = model.extra_info["vpmax"]
timestep = scalar.Timestep(dt, model.dx, max_vel)
model.add_properties(
{
"vp2dt2": vp ** 2 * timestep.inner_dt ** 2,
"scaling": 2 / vp ** 3,
}
)
source_model_locations = model.get_locations(source_locations)
receiver_model_locations = model.get_locations(receiver_locations)
scalar_wrapper = scalar._select_propagator(model.ndim, vp.dtype, vp.is_cuda)
wavefield_save_strategy = scalar._set_wavefield_save_strategy(
False, dt, timestep.inner_dt, scalar_wrapper
)
fd1, fd2 = scalar._set_finite_diff_coeffs(model.ndim, model.dx, device, dtype)
wavefield, saved_wavefields = scalar._allocate_wavefields(
wavefield_save_strategy,
scalar_wrapper,
model,
num_steps,
num_shots,
)
receiver_amplitudes = torch.zeros(
num_steps,
num_shots,
num_receivers_per_shot,
device=device,
dtype=dtype,
)
inner_dt = torch.tensor([timestep.inner_dt]).to(dtype)
pml = scalar.Pml(model, num_shots, max_vel)
source_amplitudes_resampled = scipy.signal.resample(
source_amplitudes.detach().cpu().numpy(),
num_steps * timestep.step_ratio,
)
source_amplitudes_resampled = (
torch.tensor(source_amplitudes_resampled)
.to(dtype)
.to(source_amplitudes.device)
)
source_amplitudes_resampled.requires_grad = (
source_amplitudes.requires_grad
)
self.dtype = dtype
self.scalar_wrapper = scalar_wrapper
self.wavefield = wavefield.to(self.dtype).contiguous()
self.pml = pml
self.pml.aux = self.pml.aux.to(self.dtype).contiguous()
self.pml.sigma = self.pml.sigma.to(self.dtype).contiguous()
self.pml.pml_width = self.pml.pml_width.long().contiguous()
self.receiver_amplitudes = receiver_amplitudes.to(self.dtype).contiguous()
self.saved_wavefields = saved_wavefields.to(self.dtype).contiguous()
self.model = model
self.model.properties["vp2dt2"] = self.model.properties["vp2dt2"].to(self.dtype).contiguous()
self.fd1 = fd1.to(self.dtype).contiguous()
self.fd2 = fd2.to(self.dtype).contiguous()
self.source_amplitudes_resampled = source_amplitudes_resampled
self.source_model_locations = source_model_locations.long().contiguous()
self.receiver_model_locations = receiver_model_locations.long().contiguous()
self.inner_dt = inner_dt
self.timestep = timestep
self.num_shots = num_shots
self.num_sources_per_shot = num_sources_per_shot
self.num_receivers_per_shot = num_receivers_per_shot
self.wavefield_save_strategy = wavefield_save_strategy
self.total_num_steps = num_steps
self.current_step = 0
def step(self, num_steps):
assert self.current_step + num_steps <= self.total_num_steps
source_amplitudes_resampled_steps = \
self.source_amplitudes_resampled[self.current_step*self.timestep.step_ratio:
(self.current_step+num_steps)*self.timestep.step_ratio]
# Call compiled C code to do forward modeling
self.scalar_wrapper.forward(
self.wavefield,
self.pml.aux,
self.receiver_amplitudes,
self.saved_wavefields,
self.pml.sigma,
self.model.properties["vp2dt2"],
self.fd1,
self.fd2,
source_amplitudes_resampled_steps.to(self.dtype).contiguous(),
self.source_model_locations,
self.receiver_model_locations,
self.model.shape.contiguous(),
self.pml.pml_width,
self.inner_dt,
num_steps,
self.timestep.step_ratio,
self.num_shots,
self.num_sources_per_shot,
self.num_receivers_per_shot,
self.wavefield_save_strategy,
)
self.current_step += num_steps
if num_steps * self.timestep.step_ratio % 3 != 0:
# Swap the wavefield arrays so that they are in the correct order
wf_idxs = [0, 1, 2]
for stepidx in range(num_steps * self.timestep_step_ratio):
wf_idxs = [wf_idxs[2], wf_idxs[0], wf_idxs[1]]
self.wavefield[0], self.wavefield[1], self.wavefield[2] = \
(self.wavefield[wf_idxs[0]],
self.wavefield[wf_idxs[1]],
self.wavefield[wf_idxs[2]])
if num_steps * self.timestep.step_ratio % 2 != 0:
# Swap the aux arrays so that they are in the correct order
ndim = self.model.ndim
if ndim == 1:
aux_size = 1
elif ndim == 2:
aux_size = 2
else:
aux_size = 4
assert len(self.pml.aux) == 2 * aux_size
self.pml.aux[:aux_size], self.pml.aux[aux_size:] = \
self.pml.aux[aux_size:], self.pml.aux[:aux_size]
return self.wavefield[1]
class SteppingPropagatorFunction(torch.autograd.Function):
"""Forward modeling and backpropagation functions. Not called by users."""
@staticmethod
def forward(
ctx,
source_amplitudes,
source_locations,
receiver_locations,
dt,
model,
property_names,
vp,
):
return vp
The example code to run the propagator should be the same.
from deepwave.
This code works. Thank you so much. I will keep you posted.
from deepwave.
Related Issues (20)
- Error in executing deepwave in MAC HOT 17
- How to calculate RTM using deepwave HOT 11
- Try the first-order acoustic equation propagation HOT 2
- scalar_born memory issue HOT 4
- 3D forward modelling HOT 5
- Incorrect output from DistributedDataParallel HOT 6
- It seams the scalar function cannot generate the ground roll when setting the free surface HOT 4
- Calculated Hessian for the elastic example. It gives zero values HOT 2
- I was unable to complete compilation HOT 5
- Apply deepwave to ultrasound HOT 13
- Generate the waveform data HOT 3
- How can I get the file called scalar2d_gpu_iso_4_float and scalar2d_gpu_iso_4_float.cp38-win_amd64 HOT 3
- How to write a propagator by scalar with the newest version HOT 3
- looked at the source code HOT 8
- how to simulate a source that is not point source, but has an arbitrarily spatial distribution? HOT 6
- Elastic FWI parameterization (Impedance) HOT 2
- Distributed (multi-GPU) execution HOT 12
- How to generate reverse time migration HOT 6
- Elastic wave gradient calculation HOT 7
- clarify openMP (and torch and CUDA and...) version requirements HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from deepwave.